github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/discover/udp.go (about) 1 package discover 2 3 import ( 4 "bytes" 5 "container/list" 6 "crypto/ecdsa" 7 "errors" 8 "fmt" 9 "net" 10 "time" 11 12 "github.com/quickchainproject/quickchain/crypto" 13 "github.com/quickchainproject/quickchain/log" 14 "github.com/quickchainproject/quickchain/p2p/nat" 15 "github.com/quickchainproject/quickchain/p2p/netutil" 16 "github.com/quickchainproject/quickchain/rlp" 17 ) 18 19 const Version = 4 20 21 // Errors 22 var ( 23 errPacketTooSmall = errors.New("too small") 24 errBadHash = errors.New("bad hash") 25 errExpired = errors.New("expired") 26 errUnsolicitedReply = errors.New("unsolicited reply") 27 errUnknownNode = errors.New("unknown node") 28 errTimeout = errors.New("RPC timeout") 29 errClockWarp = errors.New("reply deadline too far in the future") 30 errClosed = errors.New("socket closed") 31 ) 32 33 // Timeouts 34 const ( 35 respTimeout = 500 * time.Millisecond 36 expiration = 20 * time.Second 37 38 ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP 39 ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning 40 driftThreshold = 10 * time.Second // Allowed clock drift before warning user 41 ) 42 43 // RPC packet types 44 const ( 45 //pingPacket = iota + 1 // zero is 'reserved' 46 pingPacket = iota + 100 // insulate from quickchain network 47 pongPacket 48 findnodePacket 49 neighborsPacket 50 ) 51 52 // RPC request structures 53 type ( 54 ping struct { 55 Version uint 56 From, To rpcEndpoint 57 Expiration uint64 58 // Ignore additional fields (for forward compatibility). 59 Rest []rlp.RawValue `rlp:"tail"` 60 } 61 62 // pong is the reply to ping. 63 pong struct { 64 // This field should mirror the UDP envelope address 65 // of the ping packet, which provides a way to discover the 66 // the external address (after NAT). 67 To rpcEndpoint 68 69 ReplyTok []byte // This contains the hash of the ping packet. 70 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 71 // Ignore additional fields (for forward compatibility). 72 Rest []rlp.RawValue `rlp:"tail"` 73 } 74 75 // findnode is a query for nodes close to the given target. 76 findnode struct { 77 Target NodeID // doesn't need to be an actual public key 78 Expiration uint64 79 // Ignore additional fields (for forward compatibility). 80 Rest []rlp.RawValue `rlp:"tail"` 81 } 82 83 // reply to findnode 84 neighbors struct { 85 Nodes []rpcNode 86 Expiration uint64 87 // Ignore additional fields (for forward compatibility). 88 Rest []rlp.RawValue `rlp:"tail"` 89 } 90 91 rpcNode struct { 92 IP net.IP // len 4 for IPv4 or 16 for IPv6 93 UDP uint16 // for discovery protocol 94 TCP uint16 // for RLPx protocol 95 ID NodeID 96 } 97 98 rpcEndpoint struct { 99 IP net.IP // len 4 for IPv4 or 16 for IPv6 100 UDP uint16 // for discovery protocol 101 TCP uint16 // for RLPx protocol 102 } 103 ) 104 105 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 106 ip := addr.IP.To4() 107 if ip == nil { 108 ip = addr.IP.To16() 109 } 110 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 111 } 112 113 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 114 if rn.UDP <= 1024 { 115 return nil, errors.New("low port") 116 } 117 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 118 return nil, err 119 } 120 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 121 return nil, errors.New("not contained in netrestrict whitelist") 122 } 123 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 124 err := n.validateComplete() 125 return n, err 126 } 127 128 func nodeToRPC(n *Node) rpcNode { 129 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 130 } 131 132 type packet interface { 133 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 134 name() string 135 } 136 137 type conn interface { 138 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 139 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 140 Close() error 141 LocalAddr() net.Addr 142 } 143 144 // udp implements the RPC protocol. 145 type udp struct { 146 conn conn 147 netrestrict *netutil.Netlist 148 priv *ecdsa.PrivateKey 149 ourEndpoint rpcEndpoint 150 151 addpending chan *pending 152 gotreply chan reply 153 154 closing chan struct{} 155 nat nat.Interface 156 157 *Table 158 } 159 160 // pending represents a pending reply. 161 // 162 // some implementations of the protocol wish to send more than one 163 // reply packet to findnode. in general, any neighbors packet cannot 164 // be matched up with a specific findnode packet. 165 // 166 // our implementation handles this by storing a callback function for 167 // each pending reply. incoming packets from a node are dispatched 168 // to all the callback functions for that node. 169 type pending struct { 170 // these fields must match in the reply. 171 from NodeID 172 ptype byte 173 174 // time when the request must complete 175 deadline time.Time 176 177 // callback is called when a matching reply arrives. if it returns 178 // true, the callback is removed from the pending reply queue. 179 // if it returns false, the reply is considered incomplete and 180 // the callback will be invoked again for the next matching reply. 181 callback func(resp interface{}) (done bool) 182 183 // errc receives nil when the callback indicates completion or an 184 // error if no further reply is received within the timeout. 185 errc chan<- error 186 } 187 188 type reply struct { 189 from NodeID 190 ptype byte 191 data interface{} 192 // loop indicates whether there was 193 // a matching request by sending on this channel. 194 matched chan<- bool 195 } 196 197 // ReadPacket is sent to the unhandled channel when it could not be processed 198 type ReadPacket struct { 199 Data []byte 200 Addr *net.UDPAddr 201 } 202 203 // Config holds Table-related settings. 204 type Config struct { 205 // These settings are required and configure the UDP listener: 206 PrivateKey *ecdsa.PrivateKey 207 208 // These settings are optional: 209 AnnounceAddr *net.UDPAddr // local address announced in the DHT 210 NodeDBPath string // if set, the node database is stored at this filesystem location 211 NetRestrict *netutil.Netlist // network whitelist 212 Bootnodes []*Node // list of bootstrap nodes 213 Unhandled chan<- ReadPacket // unhandled packets are sent on this channel 214 } 215 216 // ListenUDP returns a new table that listens for UDP packets on laddr. 217 func ListenUDP(c conn, cfg Config) (*Table, error) { 218 tab, _, err := newUDP(c, cfg) 219 if err != nil { 220 return nil, err 221 } 222 log.Info("UDP listener up", "self", tab.self) 223 return tab, nil 224 } 225 226 func newUDP(c conn, cfg Config) (*Table, *udp, error) { 227 udp := &udp{ 228 conn: c, 229 priv: cfg.PrivateKey, 230 netrestrict: cfg.NetRestrict, 231 closing: make(chan struct{}), 232 gotreply: make(chan reply), 233 addpending: make(chan *pending), 234 } 235 realaddr := c.LocalAddr().(*net.UDPAddr) 236 if cfg.AnnounceAddr != nil { 237 realaddr = cfg.AnnounceAddr 238 } 239 // TODO: separate TCP port 240 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 241 tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) 242 if err != nil { 243 return nil, nil, err 244 } 245 udp.Table = tab 246 247 go udp.loop() 248 go udp.readLoop(cfg.Unhandled) 249 return udp.Table, udp, nil 250 } 251 252 func (t *udp) close() { 253 close(t.closing) 254 t.conn.Close() 255 // TODO: wait for the loops to end. 256 } 257 258 // ping sends a ping message to the given node and waits for a reply. 259 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 260 req := &ping{ 261 Version: Version, 262 From: t.ourEndpoint, 263 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 264 Expiration: uint64(time.Now().Add(expiration).Unix()), 265 } 266 packet, hash, err := encodePacket(t.priv, pingPacket, req) 267 if err != nil { 268 return err 269 } 270 errc := t.pending(toid, pongPacket, func(p interface{}) bool { 271 return bytes.Equal(p.(*pong).ReplyTok, hash) 272 }) 273 t.write(toaddr, req.name(), packet) 274 return <-errc 275 } 276 277 func (t *udp) waitping(from NodeID) error { 278 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 279 } 280 281 // findnode sends a findnode request to the given node and waits until 282 // the node has sent up to k neighbors. 283 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 284 nodes := make([]*Node, 0, bucketSize) 285 nreceived := 0 286 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 287 reply := r.(*neighbors) 288 for _, rn := range reply.Nodes { 289 nreceived++ 290 n, err := t.nodeFromRPC(toaddr, rn) 291 if err != nil { 292 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 293 continue 294 } 295 nodes = append(nodes, n) 296 } 297 return nreceived >= bucketSize 298 }) 299 t.send(toaddr, findnodePacket, &findnode{ 300 Target: target, 301 Expiration: uint64(time.Now().Add(expiration).Unix()), 302 }) 303 err := <-errc 304 return nodes, err 305 } 306 307 // pending adds a reply callback to the pending reply queue. 308 // see the documentation of type pending for a detailed explanation. 309 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 310 ch := make(chan error, 1) 311 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 312 select { 313 case t.addpending <- p: 314 // loop will handle it 315 case <-t.closing: 316 ch <- errClosed 317 } 318 return ch 319 } 320 321 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 322 matched := make(chan bool, 1) 323 select { 324 case t.gotreply <- reply{from, ptype, req, matched}: 325 // loop will handle it 326 return <-matched 327 case <-t.closing: 328 return false 329 } 330 } 331 332 // loop runs in its own goroutine. it keeps track of 333 // the refresh timer and the pending reply queue. 334 func (t *udp) loop() { 335 var ( 336 plist = list.New() 337 timeout = time.NewTimer(0) 338 nextTimeout *pending // head of plist when timeout was last reset 339 contTimeouts = 0 // number of continuous timeouts to do NTP checks 340 ntpWarnTime = time.Unix(0, 0) 341 ) 342 <-timeout.C // ignore first timeout 343 defer timeout.Stop() 344 345 resetTimeout := func() { 346 if plist.Front() == nil || nextTimeout == plist.Front().Value { 347 return 348 } 349 // Start the timer so it fires when the next pending reply has expired. 350 now := time.Now() 351 for el := plist.Front(); el != nil; el = el.Next() { 352 nextTimeout = el.Value.(*pending) 353 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 354 timeout.Reset(dist) 355 return 356 } 357 // Remove pending replies whose deadline is too far in the 358 // future. These can occur if the system clock jumped 359 // backwards after the deadline was assigned. 360 nextTimeout.errc <- errClockWarp 361 plist.Remove(el) 362 } 363 nextTimeout = nil 364 timeout.Stop() 365 } 366 367 for { 368 resetTimeout() 369 370 select { 371 case <-t.closing: 372 for el := plist.Front(); el != nil; el = el.Next() { 373 el.Value.(*pending).errc <- errClosed 374 } 375 return 376 377 case p := <-t.addpending: 378 p.deadline = time.Now().Add(respTimeout) 379 plist.PushBack(p) 380 381 case r := <-t.gotreply: 382 var matched bool 383 for el := plist.Front(); el != nil; el = el.Next() { 384 p := el.Value.(*pending) 385 if p.from == r.from && p.ptype == r.ptype { 386 matched = true 387 // Remove the matcher if its callback indicates 388 // that all replies have been received. This is 389 // required for packet types that expect multiple 390 // reply packets. 391 if p.callback(r.data) { 392 p.errc <- nil 393 plist.Remove(el) 394 } 395 // Reset the continuous timeout counter (time drift detection) 396 contTimeouts = 0 397 } 398 } 399 r.matched <- matched 400 401 case now := <-timeout.C: 402 nextTimeout = nil 403 404 // Notify and remove callbacks whose deadline is in the past. 405 for el := plist.Front(); el != nil; el = el.Next() { 406 p := el.Value.(*pending) 407 if now.After(p.deadline) || now.Equal(p.deadline) { 408 p.errc <- errTimeout 409 plist.Remove(el) 410 contTimeouts++ 411 } 412 } 413 // If we've accumulated too many timeouts, do an NTP time sync check 414 if contTimeouts > ntpFailureThreshold { 415 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 416 ntpWarnTime = time.Now() 417 go checkClockDrift() 418 } 419 contTimeouts = 0 420 } 421 } 422 } 423 } 424 425 const ( 426 macSize = 256 / 8 427 sigSize = 520 / 8 428 headSize = macSize + sigSize // space of packet frame data 429 ) 430 431 var ( 432 headSpace = make([]byte, headSize) 433 434 // Neighbors replies are sent across multiple packets to 435 // stay below the 1280 byte limit. We compute the maximum number 436 // of entries by stuffing a packet until it grows too large. 437 maxNeighbors int 438 ) 439 440 func init() { 441 p := neighbors{Expiration: ^uint64(0)} 442 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 443 for n := 0; ; n++ { 444 p.Nodes = append(p.Nodes, maxSizeNode) 445 size, _, err := rlp.EncodeToReader(p) 446 if err != nil { 447 // If this ever happens, it will be caught by the unit tests. 448 panic("cannot encode: " + err.Error()) 449 } 450 if headSize+size+1 >= 1280 { 451 maxNeighbors = n 452 break 453 } 454 } 455 } 456 457 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { 458 packet, hash, err := encodePacket(t.priv, ptype, req) 459 if err != nil { 460 return hash, err 461 } 462 return hash, t.write(toaddr, req.name(), packet) 463 } 464 465 func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { 466 _, err := t.conn.WriteToUDP(packet, toaddr) 467 log.Trace(">> "+what, "addr", toaddr, "err", err) 468 return err 469 } 470 471 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { 472 b := new(bytes.Buffer) 473 b.Write(headSpace) 474 b.WriteByte(ptype) 475 if err := rlp.Encode(b, req); err != nil { 476 log.Error("Can't encode discv4 packet", "err", err) 477 return nil, nil, err 478 } 479 packet = b.Bytes() 480 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 481 if err != nil { 482 log.Error("Can't sign discv4 packet", "err", err) 483 return nil, nil, err 484 } 485 copy(packet[macSize:], sig) 486 // add the hash to the front. Note: this doesn't protect the 487 // packet in any way. Our public key will be part of this hash in 488 // The future. 489 hash = crypto.Keccak256(packet[macSize:]) 490 copy(packet, hash) 491 return packet, hash, nil 492 } 493 494 // readLoop runs in its own goroutine. it handles incoming UDP packets. 495 func (t *udp) readLoop(unhandled chan<- ReadPacket) { 496 defer t.conn.Close() 497 if unhandled != nil { 498 defer close(unhandled) 499 } 500 // Discovery packets are defined to be no larger than 1280 bytes. 501 // Packets larger than this size will be cut at the end and treated 502 // as invalid because their hash won't match. 503 buf := make([]byte, 1280) 504 for { 505 nbytes, from, err := t.conn.ReadFromUDP(buf) 506 if netutil.IsTemporaryError(err) { 507 // Ignore temporary read errors. 508 log.Debug("Temporary UDP read error", "err", err) 509 continue 510 } else if err != nil { 511 // Shut down the loop for permament errors. 512 log.Debug("UDP read error", "err", err) 513 return 514 } 515 if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { 516 select { 517 case unhandled <- ReadPacket{buf[:nbytes], from}: 518 default: 519 } 520 } 521 } 522 } 523 524 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 525 packet, fromID, hash, err := decodePacket(buf) 526 if err != nil { 527 log.Debug("Bad discv4 packet", "addr", from, "err", err) 528 return err 529 } 530 err = packet.handle(t, from, fromID, hash) 531 log.Trace("<< "+packet.name(), "addr", from, "err", err) 532 return err 533 } 534 535 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 536 if len(buf) < headSize+1 { 537 return nil, NodeID{}, nil, errPacketTooSmall 538 } 539 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 540 shouldhash := crypto.Keccak256(buf[macSize:]) 541 if !bytes.Equal(hash, shouldhash) { 542 return nil, NodeID{}, nil, errBadHash 543 } 544 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 545 if err != nil { 546 return nil, NodeID{}, hash, err 547 } 548 var req packet 549 switch ptype := sigdata[0]; ptype { 550 case pingPacket: 551 req = new(ping) 552 case pongPacket: 553 req = new(pong) 554 case findnodePacket: 555 req = new(findnode) 556 case neighborsPacket: 557 req = new(neighbors) 558 default: 559 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 560 } 561 s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) 562 err = s.Decode(req) 563 return req, fromID, hash, err 564 } 565 566 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 567 if expired(req.Expiration) { 568 return errExpired 569 } 570 t.send(from, pongPacket, &pong{ 571 To: makeEndpoint(from, req.From.TCP), 572 ReplyTok: mac, 573 Expiration: uint64(time.Now().Add(expiration).Unix()), 574 }) 575 if !t.handleReply(fromID, pingPacket, req) { 576 // Note: we're ignoring the provided IP address right now 577 go t.bond(true, fromID, from, req.From.TCP) 578 } 579 return nil 580 } 581 582 func (req *ping) name() string { return "PING/v4" } 583 584 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 585 if expired(req.Expiration) { 586 return errExpired 587 } 588 if !t.handleReply(fromID, pongPacket, req) { 589 return errUnsolicitedReply 590 } 591 return nil 592 } 593 594 func (req *pong) name() string { return "PONG/v4" } 595 596 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 597 if expired(req.Expiration) { 598 return errExpired 599 } 600 if !t.db.hasBond(fromID) { 601 // No bond exists, we don't process the packet. This prevents 602 // an attack vector where the discovery protocol could be used 603 // to amplify traffic in a DDOS attack. A malicious actor 604 // would send a findnode request with the IP address and UDP 605 // port of the target as the source address. The recipient of 606 // the findnode packet would then send a neighbors packet 607 // (which is a much bigger packet than findnode) to the victim. 608 return errUnknownNode 609 } 610 target := crypto.Keccak256Hash(req.Target[:]) 611 t.mutex.Lock() 612 closest := t.closest(target, bucketSize).entries 613 t.mutex.Unlock() 614 615 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 616 var sent bool 617 // Send neighbors in chunks with at most maxNeighbors per packet 618 // to stay below the 1280 byte limit. 619 for _, n := range closest { 620 if netutil.CheckRelayIP(from.IP, n.IP) == nil { 621 p.Nodes = append(p.Nodes, nodeToRPC(n)) 622 } 623 if len(p.Nodes) == maxNeighbors { 624 t.send(from, neighborsPacket, &p) 625 p.Nodes = p.Nodes[:0] 626 sent = true 627 } 628 } 629 if len(p.Nodes) > 0 || !sent { 630 t.send(from, neighborsPacket, &p) 631 } 632 return nil 633 } 634 635 func (req *findnode) name() string { return "FINDNODE/v4" } 636 637 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 638 if expired(req.Expiration) { 639 return errExpired 640 } 641 if !t.handleReply(fromID, neighborsPacket, req) { 642 return errUnsolicitedReply 643 } 644 return nil 645 } 646 647 func (req *neighbors) name() string { return "NEIGHBORS/v4" } 648 649 func expired(ts uint64) bool { 650 return time.Unix(int64(ts), 0).Before(time.Now()) 651 }