github.com/luckypickle/go-ethereum-vet@v1.14.2/p2p/discover/udp.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discover 18 19 import ( 20 "bytes" 21 "container/list" 22 "crypto/ecdsa" 23 "errors" 24 "fmt" 25 "net" 26 "time" 27 28 "github.com/luckypickle/go-ethereum-vet/crypto" 29 "github.com/luckypickle/go-ethereum-vet/log" 30 "github.com/luckypickle/go-ethereum-vet/p2p/nat" 31 "github.com/luckypickle/go-ethereum-vet/p2p/netutil" 32 "github.com/luckypickle/go-ethereum-vet/rlp" 33 ) 34 35 // Errors 36 var ( 37 errPacketTooSmall = errors.New("too small") 38 errBadHash = errors.New("bad hash") 39 errExpired = errors.New("expired") 40 errUnsolicitedReply = errors.New("unsolicited reply") 41 errUnknownNode = errors.New("unknown node") 42 errTimeout = errors.New("RPC timeout") 43 errClockWarp = errors.New("reply deadline too far in the future") 44 errClosed = errors.New("socket closed") 45 ) 46 47 // Timeouts 48 const ( 49 respTimeout = 500 * time.Millisecond 50 expiration = 20 * time.Second 51 52 ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP 53 ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning 54 driftThreshold = 10 * time.Second // Allowed clock drift before warning user 55 ) 56 57 // RPC packet types 58 const ( 59 pingPacket = iota + 1 // zero is 'reserved' 60 pongPacket 61 findnodePacket 62 neighborsPacket 63 ) 64 65 // RPC request structures 66 type ( 67 ping struct { 68 Version uint 69 From, To rpcEndpoint 70 Expiration uint64 71 // Ignore additional fields (for forward compatibility). 72 Rest []rlp.RawValue `rlp:"tail"` 73 } 74 75 // pong is the reply to ping. 76 pong struct { 77 // This field should mirror the UDP envelope address 78 // of the ping packet, which provides a way to discover the 79 // the external address (after NAT). 80 To rpcEndpoint 81 82 ReplyTok []byte // This contains the hash of the ping packet. 83 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 84 // Ignore additional fields (for forward compatibility). 85 Rest []rlp.RawValue `rlp:"tail"` 86 } 87 88 // findnode is a query for nodes close to the given target. 89 findnode struct { 90 Target NodeID // doesn't need to be an actual public key 91 Expiration uint64 92 // Ignore additional fields (for forward compatibility). 93 Rest []rlp.RawValue `rlp:"tail"` 94 } 95 96 // reply to findnode 97 neighbors struct { 98 Nodes []rpcNode 99 Expiration uint64 100 // Ignore additional fields (for forward compatibility). 101 Rest []rlp.RawValue `rlp:"tail"` 102 } 103 104 rpcNode struct { 105 IP net.IP // len 4 for IPv4 or 16 for IPv6 106 UDP uint16 // for discovery protocol 107 TCP uint16 // for RLPx protocol 108 ID NodeID 109 } 110 111 rpcEndpoint struct { 112 IP net.IP // len 4 for IPv4 or 16 for IPv6 113 UDP uint16 // for discovery protocol 114 TCP uint16 // for RLPx protocol 115 } 116 ) 117 118 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 119 ip := addr.IP.To4() 120 if ip == nil { 121 ip = addr.IP.To16() 122 } 123 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 124 } 125 126 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 127 if rn.UDP <= 1024 { 128 return nil, errors.New("low port") 129 } 130 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 131 return nil, err 132 } 133 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 134 return nil, errors.New("not contained in netrestrict whitelist") 135 } 136 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 137 err := n.validateComplete() 138 return n, err 139 } 140 141 func nodeToRPC(n *Node) rpcNode { 142 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 143 } 144 145 type packet interface { 146 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 147 name() string 148 } 149 150 type conn interface { 151 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 152 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 153 Close() error 154 LocalAddr() net.Addr 155 } 156 157 // udp implements the RPC protocol. 158 type udp struct { 159 conn conn 160 netrestrict *netutil.Netlist 161 priv *ecdsa.PrivateKey 162 ourEndpoint rpcEndpoint 163 164 addpending chan *pending 165 gotreply chan reply 166 167 closing chan struct{} 168 nat nat.Interface 169 170 *Table 171 } 172 173 // pending represents a pending reply. 174 // 175 // some implementations of the protocol wish to send more than one 176 // reply packet to findnode. in general, any neighbors packet cannot 177 // be matched up with a specific findnode packet. 178 // 179 // our implementation handles this by storing a callback function for 180 // each pending reply. incoming packets from a node are dispatched 181 // to all the callback functions for that node. 182 type pending struct { 183 // these fields must match in the reply. 184 from NodeID 185 ptype byte 186 187 // time when the request must complete 188 deadline time.Time 189 190 // callback is called when a matching reply arrives. if it returns 191 // true, the callback is removed from the pending reply queue. 192 // if it returns false, the reply is considered incomplete and 193 // the callback will be invoked again for the next matching reply. 194 callback func(resp interface{}) (done bool) 195 196 // errc receives nil when the callback indicates completion or an 197 // error if no further reply is received within the timeout. 198 errc chan<- error 199 } 200 201 type reply struct { 202 from NodeID 203 ptype byte 204 data interface{} 205 // loop indicates whether there was 206 // a matching request by sending on this channel. 207 matched chan<- bool 208 } 209 210 // ReadPacket is sent to the unhandled channel when it could not be processed 211 type ReadPacket struct { 212 Data []byte 213 Addr *net.UDPAddr 214 } 215 216 // Config holds Table-related settings. 217 type Config struct { 218 // These settings are required and configure the UDP listener: 219 PrivateKey *ecdsa.PrivateKey 220 221 // These settings are optional: 222 AnnounceAddr *net.UDPAddr // local address announced in the DHT 223 NodeDBPath string // if set, the node database is stored at this filesystem location 224 NetRestrict *netutil.Netlist // network whitelist 225 Bootnodes []*Node // list of bootstrap nodes 226 Unhandled chan<- ReadPacket // unhandled packets are sent on this channel 227 } 228 229 // ListenUDP returns a new table that listens for UDP packets on laddr. 230 func ListenUDP(c conn, cfg Config) (*Table, error) { 231 tab, _, err := newUDP(c, cfg) 232 if err != nil { 233 return nil, err 234 } 235 log.Info("UDP listener up", "self", tab.self) 236 return tab, nil 237 } 238 239 func newUDP(c conn, cfg Config) (*Table, *udp, error) { 240 udp := &udp{ 241 conn: c, 242 priv: cfg.PrivateKey, 243 netrestrict: cfg.NetRestrict, 244 closing: make(chan struct{}), 245 gotreply: make(chan reply), 246 addpending: make(chan *pending), 247 } 248 realaddr := c.LocalAddr().(*net.UDPAddr) 249 if cfg.AnnounceAddr != nil { 250 realaddr = cfg.AnnounceAddr 251 } 252 // TODO: separate TCP port 253 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 254 tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) 255 if err != nil { 256 return nil, nil, err 257 } 258 udp.Table = tab 259 260 go udp.loop() 261 go udp.readLoop(cfg.Unhandled) 262 return udp.Table, udp, nil 263 } 264 265 func (t *udp) close() { 266 close(t.closing) 267 t.conn.Close() 268 // TODO: wait for the loops to end. 269 } 270 271 // ping sends a ping message to the given node and waits for a reply. 272 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 273 return <-t.sendPing(toid, toaddr, nil) 274 } 275 276 // sendPing sends a ping message to the given node and invokes the callback 277 // when the reply arrives. 278 func (t *udp) sendPing(toid NodeID, toaddr *net.UDPAddr, callback func()) <-chan error { 279 req := &ping{ 280 Version: 4, 281 From: t.ourEndpoint, 282 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 283 Expiration: uint64(time.Now().Add(expiration).Unix()), 284 } 285 packet, hash, err := encodePacket(t.priv, pingPacket, req) 286 if err != nil { 287 errc := make(chan error, 1) 288 errc <- err 289 return errc 290 } 291 errc := t.pending(toid, pongPacket, func(p interface{}) bool { 292 ok := bytes.Equal(p.(*pong).ReplyTok, hash) 293 if ok && callback != nil { 294 callback() 295 } 296 return ok 297 }) 298 t.write(toaddr, req.name(), packet) 299 return errc 300 } 301 302 func (t *udp) waitping(from NodeID) error { 303 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 304 } 305 306 // findnode sends a findnode request to the given node and waits until 307 // the node has sent up to k neighbors. 308 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 309 // If we haven't seen a ping from the destination node for a while, it won't remember 310 // our endpoint proof and reject findnode. Solicit a ping first. 311 if time.Since(t.db.lastPingReceived(toid)) > nodeDBNodeExpiration { 312 t.ping(toid, toaddr) 313 t.waitping(toid) 314 } 315 316 nodes := make([]*Node, 0, bucketSize) 317 nreceived := 0 318 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 319 reply := r.(*neighbors) 320 for _, rn := range reply.Nodes { 321 nreceived++ 322 n, err := t.nodeFromRPC(toaddr, rn) 323 if err != nil { 324 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 325 continue 326 } 327 nodes = append(nodes, n) 328 } 329 return nreceived >= bucketSize 330 }) 331 t.send(toaddr, findnodePacket, &findnode{ 332 Target: target, 333 Expiration: uint64(time.Now().Add(expiration).Unix()), 334 }) 335 return nodes, <-errc 336 } 337 338 // pending adds a reply callback to the pending reply queue. 339 // see the documentation of type pending for a detailed explanation. 340 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 341 ch := make(chan error, 1) 342 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 343 select { 344 case t.addpending <- p: 345 // loop will handle it 346 case <-t.closing: 347 ch <- errClosed 348 } 349 return ch 350 } 351 352 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 353 matched := make(chan bool, 1) 354 select { 355 case t.gotreply <- reply{from, ptype, req, matched}: 356 // loop will handle it 357 return <-matched 358 case <-t.closing: 359 return false 360 } 361 } 362 363 // loop runs in its own goroutine. it keeps track of 364 // the refresh timer and the pending reply queue. 365 func (t *udp) loop() { 366 var ( 367 plist = list.New() 368 timeout = time.NewTimer(0) 369 nextTimeout *pending // head of plist when timeout was last reset 370 contTimeouts = 0 // number of continuous timeouts to do NTP checks 371 ntpWarnTime = time.Unix(0, 0) 372 ) 373 <-timeout.C // ignore first timeout 374 defer timeout.Stop() 375 376 resetTimeout := func() { 377 if plist.Front() == nil || nextTimeout == plist.Front().Value { 378 return 379 } 380 // Start the timer so it fires when the next pending reply has expired. 381 now := time.Now() 382 for el := plist.Front(); el != nil; el = el.Next() { 383 nextTimeout = el.Value.(*pending) 384 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 385 timeout.Reset(dist) 386 return 387 } 388 // Remove pending replies whose deadline is too far in the 389 // future. These can occur if the system clock jumped 390 // backwards after the deadline was assigned. 391 nextTimeout.errc <- errClockWarp 392 plist.Remove(el) 393 } 394 nextTimeout = nil 395 timeout.Stop() 396 } 397 398 for { 399 resetTimeout() 400 401 select { 402 case <-t.closing: 403 for el := plist.Front(); el != nil; el = el.Next() { 404 el.Value.(*pending).errc <- errClosed 405 } 406 return 407 408 case p := <-t.addpending: 409 p.deadline = time.Now().Add(respTimeout) 410 plist.PushBack(p) 411 412 case r := <-t.gotreply: 413 var matched bool 414 for el := plist.Front(); el != nil; el = el.Next() { 415 p := el.Value.(*pending) 416 if p.from == r.from && p.ptype == r.ptype { 417 matched = true 418 // Remove the matcher if its callback indicates 419 // that all replies have been received. This is 420 // required for packet types that expect multiple 421 // reply packets. 422 if p.callback(r.data) { 423 p.errc <- nil 424 plist.Remove(el) 425 } 426 // Reset the continuous timeout counter (time drift detection) 427 contTimeouts = 0 428 } 429 } 430 r.matched <- matched 431 432 case now := <-timeout.C: 433 nextTimeout = nil 434 435 // Notify and remove callbacks whose deadline is in the past. 436 for el := plist.Front(); el != nil; el = el.Next() { 437 p := el.Value.(*pending) 438 if now.After(p.deadline) || now.Equal(p.deadline) { 439 p.errc <- errTimeout 440 plist.Remove(el) 441 contTimeouts++ 442 } 443 } 444 // If we've accumulated too many timeouts, do an NTP time sync check 445 if contTimeouts > ntpFailureThreshold { 446 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 447 ntpWarnTime = time.Now() 448 go checkClockDrift() 449 } 450 contTimeouts = 0 451 } 452 } 453 } 454 } 455 456 const ( 457 macSize = 256 / 8 458 sigSize = 520 / 8 459 headSize = macSize + sigSize // space of packet frame data 460 ) 461 462 var ( 463 headSpace = make([]byte, headSize) 464 465 // Neighbors replies are sent across multiple packets to 466 // stay below the 1280 byte limit. We compute the maximum number 467 // of entries by stuffing a packet until it grows too large. 468 maxNeighbors int 469 ) 470 471 func init() { 472 p := neighbors{Expiration: ^uint64(0)} 473 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 474 for n := 0; ; n++ { 475 p.Nodes = append(p.Nodes, maxSizeNode) 476 size, _, err := rlp.EncodeToReader(p) 477 if err != nil { 478 // If this ever happens, it will be caught by the unit tests. 479 panic("cannot encode: " + err.Error()) 480 } 481 if headSize+size+1 >= 1280 { 482 maxNeighbors = n 483 break 484 } 485 } 486 } 487 488 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { 489 packet, hash, err := encodePacket(t.priv, ptype, req) 490 if err != nil { 491 return hash, err 492 } 493 return hash, t.write(toaddr, req.name(), packet) 494 } 495 496 func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { 497 _, err := t.conn.WriteToUDP(packet, toaddr) 498 log.Trace(">> "+what, "addr", toaddr, "err", err) 499 return err 500 } 501 502 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { 503 b := new(bytes.Buffer) 504 b.Write(headSpace) 505 b.WriteByte(ptype) 506 if err := rlp.Encode(b, req); err != nil { 507 log.Error("Can't encode discv4 packet", "err", err) 508 return nil, nil, err 509 } 510 packet = b.Bytes() 511 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 512 if err != nil { 513 log.Error("Can't sign discv4 packet", "err", err) 514 return nil, nil, err 515 } 516 copy(packet[macSize:], sig) 517 // add the hash to the front. Note: this doesn't protect the 518 // packet in any way. Our public key will be part of this hash in 519 // The future. 520 hash = crypto.Keccak256(packet[macSize:]) 521 copy(packet, hash) 522 return packet, hash, nil 523 } 524 525 // readLoop runs in its own goroutine. it handles incoming UDP packets. 526 func (t *udp) readLoop(unhandled chan<- ReadPacket) { 527 defer t.conn.Close() 528 if unhandled != nil { 529 defer close(unhandled) 530 } 531 // Discovery packets are defined to be no larger than 1280 bytes. 532 // Packets larger than this size will be cut at the end and treated 533 // as invalid because their hash won't match. 534 buf := make([]byte, 1280) 535 for { 536 nbytes, from, err := t.conn.ReadFromUDP(buf) 537 if netutil.IsTemporaryError(err) { 538 // Ignore temporary read errors. 539 log.Debug("Temporary UDP read error", "err", err) 540 continue 541 } else if err != nil { 542 // Shut down the loop for permament errors. 543 log.Debug("UDP read error", "err", err) 544 return 545 } 546 if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { 547 select { 548 case unhandled <- ReadPacket{buf[:nbytes], from}: 549 default: 550 } 551 } 552 } 553 } 554 555 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 556 packet, fromID, hash, err := decodePacket(buf) 557 if err != nil { 558 log.Debug("Bad discv4 packet", "addr", from, "err", err) 559 return err 560 } 561 err = packet.handle(t, from, fromID, hash) 562 log.Trace("<< "+packet.name(), "addr", from, "err", err) 563 return err 564 } 565 566 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 567 if len(buf) < headSize+1 { 568 return nil, NodeID{}, nil, errPacketTooSmall 569 } 570 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 571 shouldhash := crypto.Keccak256(buf[macSize:]) 572 if !bytes.Equal(hash, shouldhash) { 573 return nil, NodeID{}, nil, errBadHash 574 } 575 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 576 if err != nil { 577 return nil, NodeID{}, hash, err 578 } 579 var req packet 580 switch ptype := sigdata[0]; ptype { 581 case pingPacket: 582 req = new(ping) 583 case pongPacket: 584 req = new(pong) 585 case findnodePacket: 586 req = new(findnode) 587 case neighborsPacket: 588 req = new(neighbors) 589 default: 590 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 591 } 592 s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) 593 err = s.Decode(req) 594 return req, fromID, hash, err 595 } 596 597 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 598 if expired(req.Expiration) { 599 return errExpired 600 } 601 t.send(from, pongPacket, &pong{ 602 To: makeEndpoint(from, req.From.TCP), 603 ReplyTok: mac, 604 Expiration: uint64(time.Now().Add(expiration).Unix()), 605 }) 606 t.handleReply(fromID, pingPacket, req) 607 608 // Add the node to the table. Before doing so, ensure that we have a recent enough pong 609 // recorded in the database so their findnode requests will be accepted later. 610 n := NewNode(fromID, from.IP, uint16(from.Port), req.From.TCP) 611 if time.Since(t.db.lastPongReceived(fromID)) > nodeDBNodeExpiration { 612 t.sendPing(fromID, from, func() { t.addThroughPing(n) }) 613 } else { 614 t.addThroughPing(n) 615 } 616 t.db.updateLastPingReceived(fromID, time.Now()) 617 return nil 618 } 619 620 func (req *ping) name() string { return "PING/v4" } 621 622 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 623 if expired(req.Expiration) { 624 return errExpired 625 } 626 if !t.handleReply(fromID, pongPacket, req) { 627 return errUnsolicitedReply 628 } 629 t.db.updateLastPongReceived(fromID, time.Now()) 630 return nil 631 } 632 633 func (req *pong) name() string { return "PONG/v4" } 634 635 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 636 if expired(req.Expiration) { 637 return errExpired 638 } 639 if !t.db.hasBond(fromID) { 640 // No endpoint proof pong exists, we don't process the packet. This prevents an 641 // attack vector where the discovery protocol could be used to amplify traffic in a 642 // DDOS attack. A malicious actor would send a findnode request with the IP address 643 // and UDP port of the target as the source address. The recipient of the findnode 644 // packet would then send a neighbors packet (which is a much bigger packet than 645 // findnode) to the victim. 646 return errUnknownNode 647 } 648 target := crypto.Keccak256Hash(req.Target[:]) 649 t.mutex.Lock() 650 closest := t.closest(target, bucketSize).entries 651 t.mutex.Unlock() 652 653 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 654 var sent bool 655 // Send neighbors in chunks with at most maxNeighbors per packet 656 // to stay below the 1280 byte limit. 657 for _, n := range closest { 658 if netutil.CheckRelayIP(from.IP, n.IP) == nil { 659 p.Nodes = append(p.Nodes, nodeToRPC(n)) 660 } 661 if len(p.Nodes) == maxNeighbors { 662 t.send(from, neighborsPacket, &p) 663 p.Nodes = p.Nodes[:0] 664 sent = true 665 } 666 } 667 if len(p.Nodes) > 0 || !sent { 668 t.send(from, neighborsPacket, &p) 669 } 670 return nil 671 } 672 673 func (req *findnode) name() string { return "FINDNODE/v4" } 674 675 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 676 if expired(req.Expiration) { 677 return errExpired 678 } 679 if !t.handleReply(fromID, neighborsPacket, req) { 680 return errUnsolicitedReply 681 } 682 return nil 683 } 684 685 func (req *neighbors) name() string { return "NEIGHBORS/v4" } 686 687 func expired(ts uint64) bool { 688 return time.Unix(int64(ts), 0).Before(time.Now()) 689 }