github.com/samgwo/go-ethereum@v1.8.2-0.20180302101319-49bcb5fbd55e/p2p/discover/udp.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discover 18 19 import ( 20 "bytes" 21 "container/list" 22 "crypto/ecdsa" 23 "errors" 24 "fmt" 25 "net" 26 "time" 27 28 "github.com/ethereum/go-ethereum/crypto" 29 "github.com/ethereum/go-ethereum/log" 30 "github.com/ethereum/go-ethereum/p2p/nat" 31 "github.com/ethereum/go-ethereum/p2p/netutil" 32 "github.com/ethereum/go-ethereum/rlp" 33 ) 34 35 const Version = 4 36 37 // Errors 38 var ( 39 errPacketTooSmall = errors.New("too small") 40 errBadHash = errors.New("bad hash") 41 errExpired = errors.New("expired") 42 errUnsolicitedReply = errors.New("unsolicited reply") 43 errUnknownNode = errors.New("unknown node") 44 errTimeout = errors.New("RPC timeout") 45 errClockWarp = errors.New("reply deadline too far in the future") 46 errClosed = errors.New("socket closed") 47 ) 48 49 // Timeouts 50 const ( 51 respTimeout = 500 * time.Millisecond 52 sendTimeout = 500 * time.Millisecond 53 expiration = 20 * time.Second 54 55 ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP 56 ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning 57 driftThreshold = 10 * time.Second // Allowed clock drift before warning user 58 ) 59 60 // RPC packet types 61 const ( 62 pingPacket = iota + 1 // zero is 'reserved' 63 pongPacket 64 findnodePacket 65 neighborsPacket 66 ) 67 68 // RPC request structures 69 type ( 70 ping struct { 71 Version uint 72 From, To rpcEndpoint 73 Expiration uint64 74 // Ignore additional fields (for forward compatibility). 75 Rest []rlp.RawValue `rlp:"tail"` 76 } 77 78 // pong is the reply to ping. 79 pong struct { 80 // This field should mirror the UDP envelope address 81 // of the ping packet, which provides a way to discover the 82 // the external address (after NAT). 83 To rpcEndpoint 84 85 ReplyTok []byte // This contains the hash of the ping packet. 86 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 87 // Ignore additional fields (for forward compatibility). 88 Rest []rlp.RawValue `rlp:"tail"` 89 } 90 91 // findnode is a query for nodes close to the given target. 92 findnode struct { 93 Target NodeID // doesn't need to be an actual public key 94 Expiration uint64 95 // Ignore additional fields (for forward compatibility). 96 Rest []rlp.RawValue `rlp:"tail"` 97 } 98 99 // reply to findnode 100 neighbors struct { 101 Nodes []rpcNode 102 Expiration uint64 103 // Ignore additional fields (for forward compatibility). 104 Rest []rlp.RawValue `rlp:"tail"` 105 } 106 107 rpcNode struct { 108 IP net.IP // len 4 for IPv4 or 16 for IPv6 109 UDP uint16 // for discovery protocol 110 TCP uint16 // for RLPx protocol 111 ID NodeID 112 } 113 114 rpcEndpoint struct { 115 IP net.IP // len 4 for IPv4 or 16 for IPv6 116 UDP uint16 // for discovery protocol 117 TCP uint16 // for RLPx protocol 118 } 119 ) 120 121 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 122 ip := addr.IP.To4() 123 if ip == nil { 124 ip = addr.IP.To16() 125 } 126 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 127 } 128 129 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 130 if rn.UDP <= 1024 { 131 return nil, errors.New("low port") 132 } 133 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 134 return nil, err 135 } 136 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 137 return nil, errors.New("not contained in netrestrict whitelist") 138 } 139 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 140 err := n.validateComplete() 141 return n, err 142 } 143 144 func nodeToRPC(n *Node) rpcNode { 145 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 146 } 147 148 type packet interface { 149 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 150 name() string 151 } 152 153 type conn interface { 154 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 155 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 156 Close() error 157 LocalAddr() net.Addr 158 } 159 160 // udp implements the RPC protocol. 161 type udp struct { 162 conn conn 163 netrestrict *netutil.Netlist 164 priv *ecdsa.PrivateKey 165 ourEndpoint rpcEndpoint 166 167 addpending chan *pending 168 gotreply chan reply 169 170 closing chan struct{} 171 nat nat.Interface 172 173 *Table 174 } 175 176 // pending represents a pending reply. 177 // 178 // some implementations of the protocol wish to send more than one 179 // reply packet to findnode. in general, any neighbors packet cannot 180 // be matched up with a specific findnode packet. 181 // 182 // our implementation handles this by storing a callback function for 183 // each pending reply. incoming packets from a node are dispatched 184 // to all the callback functions for that node. 185 type pending struct { 186 // these fields must match in the reply. 187 from NodeID 188 ptype byte 189 190 // time when the request must complete 191 deadline time.Time 192 193 // callback is called when a matching reply arrives. if it returns 194 // true, the callback is removed from the pending reply queue. 195 // if it returns false, the reply is considered incomplete and 196 // the callback will be invoked again for the next matching reply. 197 callback func(resp interface{}) (done bool) 198 199 // errc receives nil when the callback indicates completion or an 200 // error if no further reply is received within the timeout. 201 errc chan<- error 202 } 203 204 type reply struct { 205 from NodeID 206 ptype byte 207 data interface{} 208 // loop indicates whether there was 209 // a matching request by sending on this channel. 210 matched chan<- bool 211 } 212 213 // ReadPacket is sent to the unhandled channel when it could not be processed 214 type ReadPacket struct { 215 Data []byte 216 Addr *net.UDPAddr 217 } 218 219 // Config holds Table-related settings. 220 type Config struct { 221 // These settings are required and configure the UDP listener: 222 PrivateKey *ecdsa.PrivateKey 223 224 // These settings are optional: 225 AnnounceAddr *net.UDPAddr // local address announced in the DHT 226 NodeDBPath string // if set, the node database is stored at this filesystem location 227 NetRestrict *netutil.Netlist // network whitelist 228 Bootnodes []*Node // list of bootstrap nodes 229 Unhandled chan<- ReadPacket // unhandled packets are sent on this channel 230 } 231 232 // ListenUDP returns a new table that listens for UDP packets on laddr. 233 func ListenUDP(c conn, cfg Config) (*Table, error) { 234 tab, _, err := newUDP(c, cfg) 235 if err != nil { 236 return nil, err 237 } 238 log.Info("UDP listener up", "self", tab.self) 239 return tab, nil 240 } 241 242 func newUDP(c conn, cfg Config) (*Table, *udp, error) { 243 udp := &udp{ 244 conn: c, 245 priv: cfg.PrivateKey, 246 netrestrict: cfg.NetRestrict, 247 closing: make(chan struct{}), 248 gotreply: make(chan reply), 249 addpending: make(chan *pending), 250 } 251 realaddr := c.LocalAddr().(*net.UDPAddr) 252 if cfg.AnnounceAddr != nil { 253 realaddr = cfg.AnnounceAddr 254 } 255 // TODO: separate TCP port 256 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 257 tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) 258 if err != nil { 259 return nil, nil, err 260 } 261 udp.Table = tab 262 263 go udp.loop() 264 go udp.readLoop(cfg.Unhandled) 265 return udp.Table, udp, nil 266 } 267 268 func (t *udp) close() { 269 close(t.closing) 270 t.conn.Close() 271 // TODO: wait for the loops to end. 272 } 273 274 // ping sends a ping message to the given node and waits for a reply. 275 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 276 req := &ping{ 277 Version: Version, 278 From: t.ourEndpoint, 279 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 280 Expiration: uint64(time.Now().Add(expiration).Unix()), 281 } 282 packet, hash, err := encodePacket(t.priv, pingPacket, req) 283 if err != nil { 284 return err 285 } 286 errc := t.pending(toid, pongPacket, func(p interface{}) bool { 287 return bytes.Equal(p.(*pong).ReplyTok, hash) 288 }) 289 t.write(toaddr, req.name(), packet) 290 return <-errc 291 } 292 293 func (t *udp) waitping(from NodeID) error { 294 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 295 } 296 297 // findnode sends a findnode request to the given node and waits until 298 // the node has sent up to k neighbors. 299 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 300 nodes := make([]*Node, 0, bucketSize) 301 nreceived := 0 302 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 303 reply := r.(*neighbors) 304 for _, rn := range reply.Nodes { 305 nreceived++ 306 n, err := t.nodeFromRPC(toaddr, rn) 307 if err != nil { 308 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 309 continue 310 } 311 nodes = append(nodes, n) 312 } 313 return nreceived >= bucketSize 314 }) 315 t.send(toaddr, findnodePacket, &findnode{ 316 Target: target, 317 Expiration: uint64(time.Now().Add(expiration).Unix()), 318 }) 319 err := <-errc 320 return nodes, err 321 } 322 323 // pending adds a reply callback to the pending reply queue. 324 // see the documentation of type pending for a detailed explanation. 325 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 326 ch := make(chan error, 1) 327 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 328 select { 329 case t.addpending <- p: 330 // loop will handle it 331 case <-t.closing: 332 ch <- errClosed 333 } 334 return ch 335 } 336 337 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 338 matched := make(chan bool, 1) 339 select { 340 case t.gotreply <- reply{from, ptype, req, matched}: 341 // loop will handle it 342 return <-matched 343 case <-t.closing: 344 return false 345 } 346 } 347 348 // loop runs in its own goroutine. it keeps track of 349 // the refresh timer and the pending reply queue. 350 func (t *udp) loop() { 351 var ( 352 plist = list.New() 353 timeout = time.NewTimer(0) 354 nextTimeout *pending // head of plist when timeout was last reset 355 contTimeouts = 0 // number of continuous timeouts to do NTP checks 356 ntpWarnTime = time.Unix(0, 0) 357 ) 358 <-timeout.C // ignore first timeout 359 defer timeout.Stop() 360 361 resetTimeout := func() { 362 if plist.Front() == nil || nextTimeout == plist.Front().Value { 363 return 364 } 365 // Start the timer so it fires when the next pending reply has expired. 366 now := time.Now() 367 for el := plist.Front(); el != nil; el = el.Next() { 368 nextTimeout = el.Value.(*pending) 369 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 370 timeout.Reset(dist) 371 return 372 } 373 // Remove pending replies whose deadline is too far in the 374 // future. These can occur if the system clock jumped 375 // backwards after the deadline was assigned. 376 nextTimeout.errc <- errClockWarp 377 plist.Remove(el) 378 } 379 nextTimeout = nil 380 timeout.Stop() 381 } 382 383 for { 384 resetTimeout() 385 386 select { 387 case <-t.closing: 388 for el := plist.Front(); el != nil; el = el.Next() { 389 el.Value.(*pending).errc <- errClosed 390 } 391 return 392 393 case p := <-t.addpending: 394 p.deadline = time.Now().Add(respTimeout) 395 plist.PushBack(p) 396 397 case r := <-t.gotreply: 398 var matched bool 399 for el := plist.Front(); el != nil; el = el.Next() { 400 p := el.Value.(*pending) 401 if p.from == r.from && p.ptype == r.ptype { 402 matched = true 403 // Remove the matcher if its callback indicates 404 // that all replies have been received. This is 405 // required for packet types that expect multiple 406 // reply packets. 407 if p.callback(r.data) { 408 p.errc <- nil 409 plist.Remove(el) 410 } 411 // Reset the continuous timeout counter (time drift detection) 412 contTimeouts = 0 413 } 414 } 415 r.matched <- matched 416 417 case now := <-timeout.C: 418 nextTimeout = nil 419 420 // Notify and remove callbacks whose deadline is in the past. 421 for el := plist.Front(); el != nil; el = el.Next() { 422 p := el.Value.(*pending) 423 if now.After(p.deadline) || now.Equal(p.deadline) { 424 p.errc <- errTimeout 425 plist.Remove(el) 426 contTimeouts++ 427 } 428 } 429 // If we've accumulated too many timeouts, do an NTP time sync check 430 if contTimeouts > ntpFailureThreshold { 431 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 432 ntpWarnTime = time.Now() 433 go checkClockDrift() 434 } 435 contTimeouts = 0 436 } 437 } 438 } 439 } 440 441 const ( 442 macSize = 256 / 8 443 sigSize = 520 / 8 444 headSize = macSize + sigSize // space of packet frame data 445 ) 446 447 var ( 448 headSpace = make([]byte, headSize) 449 450 // Neighbors replies are sent across multiple packets to 451 // stay below the 1280 byte limit. We compute the maximum number 452 // of entries by stuffing a packet until it grows too large. 453 maxNeighbors int 454 ) 455 456 func init() { 457 p := neighbors{Expiration: ^uint64(0)} 458 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 459 for n := 0; ; n++ { 460 p.Nodes = append(p.Nodes, maxSizeNode) 461 size, _, err := rlp.EncodeToReader(p) 462 if err != nil { 463 // If this ever happens, it will be caught by the unit tests. 464 panic("cannot encode: " + err.Error()) 465 } 466 if headSize+size+1 >= 1280 { 467 maxNeighbors = n 468 break 469 } 470 } 471 } 472 473 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { 474 packet, hash, err := encodePacket(t.priv, ptype, req) 475 if err != nil { 476 return hash, err 477 } 478 return hash, t.write(toaddr, req.name(), packet) 479 } 480 481 func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { 482 _, err := t.conn.WriteToUDP(packet, toaddr) 483 log.Trace(">> "+what, "addr", toaddr, "err", err) 484 return err 485 } 486 487 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { 488 b := new(bytes.Buffer) 489 b.Write(headSpace) 490 b.WriteByte(ptype) 491 if err := rlp.Encode(b, req); err != nil { 492 log.Error("Can't encode discv4 packet", "err", err) 493 return nil, nil, err 494 } 495 packet = b.Bytes() 496 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 497 if err != nil { 498 log.Error("Can't sign discv4 packet", "err", err) 499 return nil, nil, err 500 } 501 copy(packet[macSize:], sig) 502 // add the hash to the front. Note: this doesn't protect the 503 // packet in any way. Our public key will be part of this hash in 504 // The future. 505 hash = crypto.Keccak256(packet[macSize:]) 506 copy(packet, hash) 507 return packet, hash, nil 508 } 509 510 // readLoop runs in its own goroutine. it handles incoming UDP packets. 511 func (t *udp) readLoop(unhandled chan<- ReadPacket) { 512 defer t.conn.Close() 513 if unhandled != nil { 514 defer close(unhandled) 515 } 516 // Discovery packets are defined to be no larger than 1280 bytes. 517 // Packets larger than this size will be cut at the end and treated 518 // as invalid because their hash won't match. 519 buf := make([]byte, 1280) 520 for { 521 nbytes, from, err := t.conn.ReadFromUDP(buf) 522 if netutil.IsTemporaryError(err) { 523 // Ignore temporary read errors. 524 log.Debug("Temporary UDP read error", "err", err) 525 continue 526 } else if err != nil { 527 // Shut down the loop for permament errors. 528 log.Debug("UDP read error", "err", err) 529 return 530 } 531 if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { 532 select { 533 case unhandled <- ReadPacket{buf[:nbytes], from}: 534 default: 535 } 536 } 537 } 538 } 539 540 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 541 packet, fromID, hash, err := decodePacket(buf) 542 if err != nil { 543 log.Debug("Bad discv4 packet", "addr", from, "err", err) 544 return err 545 } 546 err = packet.handle(t, from, fromID, hash) 547 log.Trace("<< "+packet.name(), "addr", from, "err", err) 548 return err 549 } 550 551 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 552 if len(buf) < headSize+1 { 553 return nil, NodeID{}, nil, errPacketTooSmall 554 } 555 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 556 shouldhash := crypto.Keccak256(buf[macSize:]) 557 if !bytes.Equal(hash, shouldhash) { 558 return nil, NodeID{}, nil, errBadHash 559 } 560 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 561 if err != nil { 562 return nil, NodeID{}, hash, err 563 } 564 var req packet 565 switch ptype := sigdata[0]; ptype { 566 case pingPacket: 567 req = new(ping) 568 case pongPacket: 569 req = new(pong) 570 case findnodePacket: 571 req = new(findnode) 572 case neighborsPacket: 573 req = new(neighbors) 574 default: 575 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 576 } 577 s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) 578 err = s.Decode(req) 579 return req, fromID, hash, err 580 } 581 582 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 583 if expired(req.Expiration) { 584 return errExpired 585 } 586 t.send(from, pongPacket, &pong{ 587 To: makeEndpoint(from, req.From.TCP), 588 ReplyTok: mac, 589 Expiration: uint64(time.Now().Add(expiration).Unix()), 590 }) 591 if !t.handleReply(fromID, pingPacket, req) { 592 // Note: we're ignoring the provided IP address right now 593 go t.bond(true, fromID, from, req.From.TCP) 594 } 595 return nil 596 } 597 598 func (req *ping) name() string { return "PING/v4" } 599 600 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 601 if expired(req.Expiration) { 602 return errExpired 603 } 604 if !t.handleReply(fromID, pongPacket, req) { 605 return errUnsolicitedReply 606 } 607 return nil 608 } 609 610 func (req *pong) name() string { return "PONG/v4" } 611 612 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 613 if expired(req.Expiration) { 614 return errExpired 615 } 616 if !t.db.hasBond(fromID) { 617 // No bond exists, we don't process the packet. This prevents 618 // an attack vector where the discovery protocol could be used 619 // to amplify traffic in a DDOS attack. A malicious actor 620 // would send a findnode request with the IP address and UDP 621 // port of the target as the source address. The recipient of 622 // the findnode packet would then send a neighbors packet 623 // (which is a much bigger packet than findnode) to the victim. 624 return errUnknownNode 625 } 626 target := crypto.Keccak256Hash(req.Target[:]) 627 t.mutex.Lock() 628 closest := t.closest(target, bucketSize).entries 629 t.mutex.Unlock() 630 631 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 632 var sent bool 633 // Send neighbors in chunks with at most maxNeighbors per packet 634 // to stay below the 1280 byte limit. 635 for _, n := range closest { 636 if netutil.CheckRelayIP(from.IP, n.IP) == nil { 637 p.Nodes = append(p.Nodes, nodeToRPC(n)) 638 } 639 if len(p.Nodes) == maxNeighbors { 640 t.send(from, neighborsPacket, &p) 641 p.Nodes = p.Nodes[:0] 642 sent = true 643 } 644 } 645 if len(p.Nodes) > 0 || !sent { 646 t.send(from, neighborsPacket, &p) 647 } 648 return nil 649 } 650 651 func (req *findnode) name() string { return "FINDNODE/v4" } 652 653 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 654 if expired(req.Expiration) { 655 return errExpired 656 } 657 if !t.handleReply(fromID, neighborsPacket, req) { 658 return errUnsolicitedReply 659 } 660 return nil 661 } 662 663 func (req *neighbors) name() string { return "NEIGHBORS/v4" } 664 665 func expired(ts uint64) bool { 666 return time.Unix(int64(ts), 0).Before(time.Now()) 667 }