github.com/aquanetwork/aquachain@v1.7.8/p2p/discover/udp.go (about) 1 // Copyright 2015 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discover 18 19 import ( 20 "bytes" 21 "container/list" 22 "crypto/ecdsa" 23 "errors" 24 "fmt" 25 "net" 26 "time" 27 28 "gitlab.com/aquachain/aquachain/common/log" 29 "gitlab.com/aquachain/aquachain/crypto" 30 "gitlab.com/aquachain/aquachain/p2p/netutil" 31 "gitlab.com/aquachain/aquachain/rlp" 32 ) 33 34 const Version = 4 35 36 // Errors 37 var ( 38 errPacketTooSmall = errors.New("too small") 39 errBadHash = errors.New("bad hash") 40 errExpired = errors.New("expired") 41 errUnsolicitedReply = errors.New("unsolicited reply") 42 errUnknownNode = errors.New("unknown node") 43 errTimeout = errors.New("RPC timeout") 44 errClockWarp = errors.New("reply deadline too far in the future") 45 errClosed = errors.New("socket closed") 46 ) 47 48 // Timeouts 49 const ( 50 respTimeout = 500 * time.Millisecond 51 expiration = 20 * time.Second 52 53 ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP 54 ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning 55 driftThreshold = 10 * time.Second // Allowed clock drift before warning user 56 ) 57 58 // RPC packet types 59 const ( 60 pingPacket = iota + 134 // zero is 'reserved' 61 //pingPacket = iota + 1 // zero is 'reserved' 62 pongPacket 63 findnodePacket 64 neighborsPacket 65 ) 66 67 // RPC request structures 68 type ( 69 ping struct { 70 Version uint 71 From, To rpcEndpoint 72 Expiration uint64 73 // Ignore additional fields (for forward compatibility). 74 Rest []rlp.RawValue `rlp:"tail"` 75 } 76 77 // pong is the reply to ping. 78 pong struct { 79 // This field should mirror the UDP envelope address 80 // of the ping packet, which provides a way to discover the 81 // the external address (after NAT). 82 To rpcEndpoint 83 84 ReplyTok []byte // This contains the hash of the ping packet. 85 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 86 // Ignore additional fields (for forward compatibility). 87 Rest []rlp.RawValue `rlp:"tail"` 88 } 89 90 // findnode is a query for nodes close to the given target. 91 findnode struct { 92 Target NodeID // doesn't need to be an actual public key 93 Expiration uint64 94 // Ignore additional fields (for forward compatibility). 95 Rest []rlp.RawValue `rlp:"tail"` 96 } 97 98 // reply to findnode 99 neighbors struct { 100 Nodes []rpcNode 101 Expiration uint64 102 // Ignore additional fields (for forward compatibility). 103 Rest []rlp.RawValue `rlp:"tail"` 104 } 105 106 rpcNode struct { 107 IP net.IP // len 4 for IPv4 or 16 for IPv6 108 UDP uint16 // for discovery protocol 109 TCP uint16 // for RLPx protocol 110 ID NodeID 111 } 112 113 rpcEndpoint struct { 114 IP net.IP // len 4 for IPv4 or 16 for IPv6 115 UDP uint16 // for discovery protocol 116 TCP uint16 // for RLPx protocol 117 } 118 ) 119 120 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 121 ip := addr.IP.To4() 122 if ip == nil { 123 ip = addr.IP.To16() 124 } 125 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 126 } 127 128 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 129 if rn.UDP <= 1024 { 130 return nil, errors.New("low port") 131 } 132 // if 30000 <= rn.UDP && rn.UDP <= 39999 { 133 // return nil, errors.New("not aqua") 134 // } 135 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 136 return nil, err 137 } 138 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 139 return nil, errors.New("not contained in netrestrict whitelist") 140 } 141 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 142 err := n.validateComplete() 143 return n, err 144 } 145 146 func nodeToRPC(n *Node) rpcNode { 147 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 148 } 149 150 type packet interface { 151 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 152 name() string 153 } 154 155 type conn interface { 156 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 157 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 158 Close() error 159 LocalAddr() net.Addr 160 } 161 162 // udp implements the RPC protocol. 163 type udp struct { 164 conn conn 165 netrestrict *netutil.Netlist 166 priv *ecdsa.PrivateKey 167 ourEndpoint rpcEndpoint 168 169 addpending chan *pending 170 gotreply chan reply 171 172 closing chan struct{} 173 174 *Table 175 } 176 177 // pending represents a pending reply. 178 // 179 // some implementations of the protocol wish to send more than one 180 // reply packet to findnode. in general, any neighbors packet cannot 181 // be matched up with a specific findnode packet. 182 // 183 // our implementation handles this by storing a callback function for 184 // each pending reply. incoming packets from a node are dispatched 185 // to all the callback functions for that node. 186 type pending struct { 187 // these fields must match in the reply. 188 from NodeID 189 ptype byte 190 191 // time when the request must complete 192 deadline time.Time 193 194 // callback is called when a matching reply arrives. if it returns 195 // true, the callback is removed from the pending reply queue. 196 // if it returns false, the reply is considered incomplete and 197 // the callback will be invoked again for the next matching reply. 198 callback func(resp interface{}) (done bool) 199 200 // errc receives nil when the callback indicates completion or an 201 // error if no further reply is received within the timeout. 202 errc chan<- error 203 } 204 205 type reply struct { 206 from NodeID 207 ptype byte 208 data interface{} 209 // loop indicates whether there was 210 // a matching request by sending on this channel. 211 matched chan<- bool 212 } 213 214 // ReadPacket is sent to the unhandled channel when it could not be processed 215 type ReadPacket struct { 216 Data []byte 217 Addr *net.UDPAddr 218 } 219 220 // Config holds Table-related settings. 221 type Config struct { 222 // These settings are required and configure the UDP listener: 223 PrivateKey *ecdsa.PrivateKey 224 225 // These settings are optional: 226 AnnounceAddr *net.UDPAddr // local address announced in the DHT 227 NodeDBPath string // if set, the node database is stored at this filesystem location 228 NetRestrict *netutil.Netlist // network whitelist 229 Bootnodes []*Node // list of bootstrap nodes 230 Unhandled chan<- ReadPacket // unhandled packets are sent on this channel 231 } 232 233 // ListenUDP returns a new table that listens for UDP packets on laddr. 234 func ListenUDP(c conn, cfg Config) (*Table, error) { 235 tab, _, err := newUDP(c, cfg) 236 if err != nil { 237 return nil, err 238 } 239 log.Info("UDP listener up", "self", tab.self) 240 return tab, nil 241 } 242 243 func newUDP(c conn, cfg Config) (*Table, *udp, error) { 244 udp := &udp{ 245 conn: c, 246 priv: cfg.PrivateKey, 247 netrestrict: cfg.NetRestrict, 248 closing: make(chan struct{}), 249 gotreply: make(chan reply), 250 addpending: make(chan *pending), 251 } 252 realaddr := c.LocalAddr().(*net.UDPAddr) 253 if cfg.AnnounceAddr != nil { 254 realaddr = cfg.AnnounceAddr 255 } 256 // TODO: separate TCP port 257 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 258 tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) 259 if err != nil { 260 return nil, nil, err 261 } 262 udp.Table = tab 263 264 go udp.loop() 265 go udp.readLoop(cfg.Unhandled) 266 return udp.Table, udp, nil 267 } 268 269 func (t *udp) close() { 270 close(t.closing) 271 t.conn.Close() 272 // TODO: wait for the loops to end. 273 } 274 275 // ping sends a ping message to the given node and waits for a reply. 276 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 277 // if 30300 < toaddr.Port && toaddr.Port < 30309 { 278 // return fmt.Errorf("maybe not aqua") 279 // } 280 req := &ping{ 281 Version: Version, 282 From: t.ourEndpoint, 283 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 284 Expiration: uint64(time.Now().Add(expiration).Unix()), 285 } 286 packet, hash, err := encodePacket(t.priv, pingPacket, req) 287 if err != nil { 288 return err 289 } 290 errc := t.pending(toid, pongPacket, func(p interface{}) bool { 291 return bytes.Equal(p.(*pong).ReplyTok, hash) 292 }) 293 t.write(toaddr, req.name(), packet) 294 return <-errc 295 } 296 297 func (t *udp) waitping(from NodeID) error { 298 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 299 } 300 301 // findnode sends a findnode request to the given node and waits until 302 // the node has sent up to k neighbors. 303 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 304 nodes := make([]*Node, 0, bucketSize) 305 nreceived := 0 306 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 307 reply := r.(*neighbors) 308 for _, rn := range reply.Nodes { 309 nreceived++ 310 n, err := t.nodeFromRPC(toaddr, rn) 311 if err != nil { 312 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 313 continue 314 } 315 nodes = append(nodes, n) 316 } 317 return nreceived >= bucketSize 318 }) 319 t.send(toaddr, findnodePacket, &findnode{ 320 Target: target, 321 Expiration: uint64(time.Now().Add(expiration).Unix()), 322 }) 323 err := <-errc 324 return nodes, err 325 } 326 327 // pending adds a reply callback to the pending reply queue. 328 // see the documentation of type pending for a detailed explanation. 329 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 330 ch := make(chan error, 1) 331 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 332 select { 333 case t.addpending <- p: 334 // loop will handle it 335 case <-t.closing: 336 ch <- errClosed 337 } 338 return ch 339 } 340 341 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 342 matched := make(chan bool, 1) 343 select { 344 case t.gotreply <- reply{from, ptype, req, matched}: 345 // loop will handle it 346 return <-matched 347 case <-t.closing: 348 return false 349 } 350 } 351 352 // loop runs in its own goroutine. it keeps track of 353 // the refresh timer and the pending reply queue. 354 func (t *udp) loop() { 355 var ( 356 plist = list.New() 357 timeout = time.NewTimer(0) 358 nextTimeout *pending // head of plist when timeout was last reset 359 contTimeouts = 0 // number of continuous timeouts to do NTP checks 360 ntpWarnTime = time.Unix(0, 0) 361 ) 362 <-timeout.C // ignore first timeout 363 defer timeout.Stop() 364 365 resetTimeout := func() { 366 if plist.Front() == nil || nextTimeout == plist.Front().Value { 367 return 368 } 369 // Start the timer so it fires when the next pending reply has expired. 370 now := time.Now() 371 for el := plist.Front(); el != nil; el = el.Next() { 372 nextTimeout = el.Value.(*pending) 373 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 374 timeout.Reset(dist) 375 return 376 } 377 // Remove pending replies whose deadline is too far in the 378 // future. These can occur if the system clock jumped 379 // backwards after the deadline was assigned. 380 nextTimeout.errc <- errClockWarp 381 plist.Remove(el) 382 } 383 nextTimeout = nil 384 timeout.Stop() 385 } 386 387 for { 388 resetTimeout() 389 390 select { 391 case <-t.closing: 392 for el := plist.Front(); el != nil; el = el.Next() { 393 el.Value.(*pending).errc <- errClosed 394 } 395 return 396 397 case p := <-t.addpending: 398 p.deadline = time.Now().Add(respTimeout) 399 plist.PushBack(p) 400 401 case r := <-t.gotreply: 402 var matched bool 403 for el := plist.Front(); el != nil; el = el.Next() { 404 p := el.Value.(*pending) 405 if p.from == r.from && p.ptype == r.ptype { 406 matched = true 407 // Remove the matcher if its callback indicates 408 // that all replies have been received. This is 409 // required for packet types that expect multiple 410 // reply packets. 411 if p.callback(r.data) { 412 p.errc <- nil 413 plist.Remove(el) 414 } 415 // Reset the continuous timeout counter (time drift detection) 416 contTimeouts = 0 417 } 418 } 419 r.matched <- matched 420 421 case now := <-timeout.C: 422 nextTimeout = nil 423 424 // Notify and remove callbacks whose deadline is in the past. 425 for el := plist.Front(); el != nil; el = el.Next() { 426 p := el.Value.(*pending) 427 if now.After(p.deadline) || now.Equal(p.deadline) { 428 p.errc <- errTimeout 429 plist.Remove(el) 430 contTimeouts++ 431 } 432 } 433 // If we've accumulated too many timeouts, do an NTP time sync check 434 if contTimeouts > ntpFailureThreshold { 435 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 436 ntpWarnTime = time.Now() 437 go checkClockDrift() 438 } 439 contTimeouts = 0 440 } 441 } 442 } 443 } 444 445 const ( 446 macSize = 256 / 8 447 sigSize = 520 / 8 448 headSize = macSize + sigSize // space of packet frame data 449 ) 450 451 var ( 452 headSpace = make([]byte, headSize) 453 454 // Neighbors replies are sent across multiple packets to 455 // stay below the 1280 byte limit. We compute the maximum number 456 // of entries by stuffing a packet until it grows too large. 457 maxNeighbors int 458 ) 459 460 func init() { 461 p := neighbors{Expiration: ^uint64(0)} 462 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 463 for n := 0; ; n++ { 464 p.Nodes = append(p.Nodes, maxSizeNode) 465 size, _, err := rlp.EncodeToReader(p) 466 if err != nil { 467 // If this ever happens, it will be caught by the unit tests. 468 panic("cannot encode: " + err.Error()) 469 } 470 if headSize+size+1 >= 1280 { 471 maxNeighbors = n 472 break 473 } 474 } 475 } 476 477 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { 478 packet, hash, err := encodePacket(t.priv, ptype, req) 479 if err != nil { 480 return hash, err 481 } 482 return hash, t.write(toaddr, req.name(), packet) 483 } 484 485 func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { 486 _, err := t.conn.WriteToUDP(packet, toaddr) 487 log.Trace(">> "+what, "addr", toaddr, "err", err) 488 return err 489 } 490 491 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { 492 b := new(bytes.Buffer) 493 b.Write(headSpace) 494 b.WriteByte(ptype) 495 b.WriteString("aqua") 496 if err := rlp.Encode(b, req); err != nil { 497 log.Error("Can't encode discv4 packet", "err", err) 498 return nil, nil, err 499 } 500 packet = b.Bytes() 501 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 502 if err != nil { 503 log.Error("Can't sign discv4 packet", "err", err) 504 return nil, nil, err 505 } 506 copy(packet[macSize:], sig) 507 // add the hash to the front. Note: this doesn't protect the 508 // packet in any way. Our public key will be part of this hash in 509 // The future. 510 hash = crypto.Keccak256(packet[macSize:]) 511 copy(packet, hash) 512 return packet, hash, nil 513 } 514 515 // readLoop runs in its own goroutine. it handles incoming UDP packets. 516 func (t *udp) readLoop(unhandled chan<- ReadPacket) { 517 defer t.conn.Close() 518 if unhandled != nil { 519 defer close(unhandled) 520 } 521 // Discovery packets are defined to be no larger than 1280 bytes. 522 // Packets larger than this size will be cut at the end and treated 523 // as invalid because their hash won't match. 524 buf := make([]byte, 1280) 525 for { 526 nbytes, from, err := t.conn.ReadFromUDP(buf) 527 if netutil.IsTemporaryError(err) { 528 // Ignore temporary read errors. 529 log.Debug("Temporary UDP read error", "err", err) 530 continue 531 } else if err != nil { 532 // Shut down the loop for permament errors. 533 log.Debug("UDP read error", "err", err) 534 return 535 } 536 if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { 537 select { 538 case unhandled <- ReadPacket{buf[:nbytes], from}: 539 default: 540 } 541 } 542 } 543 } 544 545 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 546 packet, fromID, hash, err := decodePacket(buf) 547 if err != nil { 548 log.Debug("Bad discv4 packet", "addr", from, "err", err) 549 return err 550 } 551 err = packet.handle(t, from, fromID, hash) 552 log.Trace("<< "+packet.name(), "addr", from, "err", err) 553 return err 554 } 555 556 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 557 if len(buf) < headSize+1 { 558 return nil, NodeID{}, nil, errPacketTooSmall 559 } 560 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 561 shouldhash := crypto.Keccak256(buf[macSize:]) 562 if !bytes.Equal(hash, shouldhash) { 563 return nil, NodeID{}, nil, errBadHash 564 } 565 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 566 if err != nil { 567 return nil, NodeID{}, hash, err 568 } 569 var req packet 570 switch ptype := sigdata[0]; ptype { 571 case pingPacket: 572 req = new(ping) 573 case pongPacket: 574 req = new(pong) 575 case findnodePacket: 576 req = new(findnode) 577 case neighborsPacket: 578 req = new(neighbors) 579 default: 580 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 581 } 582 s := rlp.NewStream(bytes.NewReader(sigdata[1+len("aqua"):]), 0) 583 err = s.Decode(req) 584 return req, fromID, hash, err 585 } 586 587 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 588 if expired(req.Expiration) { 589 return errExpired 590 } 591 t.send(from, pongPacket, &pong{ 592 To: makeEndpoint(from, req.From.TCP), 593 ReplyTok: mac, 594 Expiration: uint64(time.Now().Add(expiration).Unix()), 595 }) 596 if !t.handleReply(fromID, pingPacket, req) { 597 // Note: we're ignoring the provided IP address right now 598 go t.bond(true, fromID, from, req.From.TCP) 599 } 600 return nil 601 } 602 603 func (req *ping) name() string { return "PING/v4" } 604 605 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 606 if expired(req.Expiration) { 607 return errExpired 608 } 609 if !t.handleReply(fromID, pongPacket, req) { 610 return errUnsolicitedReply 611 } 612 return nil 613 } 614 615 func (req *pong) name() string { return "PONG/v4" } 616 617 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 618 if expired(req.Expiration) { 619 return errExpired 620 } 621 if !t.db.hasBond(fromID) { 622 // No bond exists, we don't process the packet. This prevents 623 // an attack vector where the discovery protocol could be used 624 // to amplify traffic in a DDOS attack. A malicious actor 625 // would send a findnode request with the IP address and UDP 626 // port of the target as the source address. The recipient of 627 // the findnode packet would then send a neighbors packet 628 // (which is a much bigger packet than findnode) to the victim. 629 return errUnknownNode 630 } 631 target := crypto.Keccak256Hash(req.Target[:]) 632 t.mutex.Lock() 633 closest := t.closest(target, bucketSize).entries 634 t.mutex.Unlock() 635 636 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 637 var sent bool 638 // Send neighbors in chunks with at most maxNeighbors per packet 639 // to stay below the 1280 byte limit. 640 for _, n := range closest { 641 if netutil.CheckRelayIP(from.IP, n.IP) == nil { 642 p.Nodes = append(p.Nodes, nodeToRPC(n)) 643 } 644 if len(p.Nodes) == maxNeighbors { 645 t.send(from, neighborsPacket, &p) 646 p.Nodes = p.Nodes[:0] 647 sent = true 648 } 649 } 650 if len(p.Nodes) > 0 || !sent { 651 t.send(from, neighborsPacket, &p) 652 } 653 return nil 654 } 655 656 func (req *findnode) name() string { return "FINDNODE/v4" } 657 658 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 659 if expired(req.Expiration) { 660 return errExpired 661 } 662 if !t.handleReply(fromID, neighborsPacket, req) { 663 return errUnsolicitedReply 664 } 665 return nil 666 } 667 668 func (req *neighbors) name() string { return "NEIGHBORS/v4" } 669 670 func expired(ts uint64) bool { 671 return time.Unix(int64(ts), 0).Before(time.Now()) 672 }