github.com/halybang/go-ethereum@v1.0.5-0.20180325041310-3b262bc1367c/p2p/discover/udp.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discover 18 19 import ( 20 "bytes" 21 "container/list" 22 "crypto/ecdsa" 23 "errors" 24 "fmt" 25 "net" 26 "time" 27 28 "github.com/wanchain/go-wanchain/crypto" 29 "github.com/wanchain/go-wanchain/log" 30 "github.com/wanchain/go-wanchain/p2p/nat" 31 "github.com/wanchain/go-wanchain/p2p/netutil" 32 "github.com/wanchain/go-wanchain/rlp" 33 34 //"github.com/wanchain/go-wanchain/params" 35 ) 36 37 const Version = 5 38 39 // Errors 40 var ( 41 errPacketTooSmall = errors.New("too small") 42 errBadHash = errors.New("bad hash") 43 errExpired = errors.New("expired") 44 errUnsolicitedReply = errors.New("unsolicited reply") 45 errUnknownNode = errors.New("unknown node") 46 errTimeout = errors.New("RPC timeout") 47 errClockWarp = errors.New("reply deadline too far in the future") 48 errClosed = errors.New("socket closed") 49 ) 50 51 // Timeouts 52 const ( 53 respTimeout = 500 * time.Millisecond 54 sendTimeout = 500 * time.Millisecond 55 expiration = 20 * time.Second 56 57 ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP 58 ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning 59 driftThreshold = 10 * time.Second // Allowed clock drift before warning user 60 ) 61 62 // RPC packet types 63 const ( 64 pingPacket = iota + 10 // zero is 'reserved' 65 pongPacket 66 findnodePacket 67 neighborsPacket 68 ) 69 70 // RPC request structures 71 type ( 72 ping struct { 73 Version uint 74 From, To rpcEndpoint 75 Expiration uint64 76 // Ignore additional fields (for forward compatibility). 77 Rest []rlp.RawValue `rlp:"tail"` 78 } 79 80 // pong is the reply to ping. 81 pong struct { 82 // This field should mirror the UDP envelope address 83 // of the ping packet, which provides a way to discover the 84 // the external address (after NAT). 85 To rpcEndpoint 86 87 ReplyTok []byte // This contains the hash of the ping packet. 88 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 89 // Ignore additional fields (for forward compatibility). 90 Rest []rlp.RawValue `rlp:"tail"` 91 } 92 93 // findnode is a query for nodes close to the given target. 94 findnode struct { 95 Target NodeID // doesn't need to be an actual public key 96 Expiration uint64 97 // Ignore additional fields (for forward compatibility). 98 Rest []rlp.RawValue `rlp:"tail"` 99 } 100 101 // reply to findnode 102 neighbors struct { 103 Nodes []rpcNode 104 Expiration uint64 105 // Ignore additional fields (for forward compatibility). 106 Rest []rlp.RawValue `rlp:"tail"` 107 } 108 109 rpcNode struct { 110 IP net.IP // len 4 for IPv4 or 16 for IPv6 111 UDP uint16 // for discovery protocol 112 TCP uint16 // for RLPx protocol 113 ID NodeID 114 } 115 116 rpcEndpoint struct { 117 IP net.IP // len 4 for IPv4 or 16 for IPv6 118 UDP uint16 // for discovery protocol 119 TCP uint16 // for RLPx protocol 120 } 121 ) 122 123 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 124 ip := addr.IP.To4() 125 if ip == nil { 126 ip = addr.IP.To16() 127 } 128 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 129 } 130 131 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 132 if rn.UDP <= 1024 { 133 return nil, errors.New("low port") 134 } 135 136 //if rn.TCP != params.WanTcpPort { 137 // return nil, errors.New("wrong wanchain tcp port") 138 //} 139 140 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 141 return nil, err 142 } 143 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 144 return nil, errors.New("not contained in netrestrict whitelist") 145 } 146 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 147 err := n.validateComplete() 148 return n, err 149 } 150 151 func nodeToRPC(n *Node) rpcNode { 152 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 153 } 154 155 type packet interface { 156 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 157 name() string 158 } 159 160 type conn interface { 161 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 162 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 163 Close() error 164 LocalAddr() net.Addr 165 } 166 167 // udp implements the RPC protocol. 168 type udp struct { 169 conn conn 170 netrestrict *netutil.Netlist 171 priv *ecdsa.PrivateKey 172 ourEndpoint rpcEndpoint 173 174 addpending chan *pending 175 gotreply chan reply 176 177 closing chan struct{} 178 nat nat.Interface 179 180 *Table 181 } 182 183 // pending represents a pending reply. 184 // 185 // some implementations of the protocol wish to send more than one 186 // reply packet to findnode. in general, any neighbors packet cannot 187 // be matched up with a specific findnode packet. 188 // 189 // our implementation handles this by storing a callback function for 190 // each pending reply. incoming packets from a node are dispatched 191 // to all the callback functions for that node. 192 type pending struct { 193 // these fields must match in the reply. 194 from NodeID 195 ptype byte 196 197 // time when the request must complete 198 deadline time.Time 199 200 // callback is called when a matching reply arrives. if it returns 201 // true, the callback is removed from the pending reply queue. 202 // if it returns false, the reply is considered incomplete and 203 // the callback will be invoked again for the next matching reply. 204 callback func(resp interface{}) (done bool) 205 206 // errc receives nil when the callback indicates completion or an 207 // error if no further reply is received within the timeout. 208 errc chan<- error 209 } 210 211 type reply struct { 212 from NodeID 213 ptype byte 214 data interface{} 215 // loop indicates whether there was 216 // a matching request by sending on this channel. 217 matched chan<- bool 218 } 219 220 // ListenUDP returns a new table that listens for UDP packets on laddr. 221 func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { 222 addr, err := net.ResolveUDPAddr("udp", laddr) 223 if err != nil { 224 return nil, err 225 } 226 conn, err := net.ListenUDP("udp", addr) 227 if err != nil { 228 return nil, err 229 } 230 tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) 231 if err != nil { 232 return nil, err 233 } 234 log.Info("UDP listener up", "self", tab.self) 235 return tab, nil 236 } 237 238 func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { 239 240 //realaddr := c.LocalAddr().(*net.UDPAddr) 241 // 242 //if realaddr.Port != params.WanUdpPort { 243 // return nil, nil, errors.New("port is not wan port") 244 //} 245 246 udp := &udp{ 247 conn: c, 248 priv: priv, 249 netrestrict: netrestrict, 250 closing: make(chan struct{}), 251 gotreply: make(chan reply), 252 addpending: make(chan *pending), 253 } 254 255 realaddr := c.LocalAddr().(*net.UDPAddr) 256 257 if natm != nil { 258 if !realaddr.IP.IsLoopback() { 259 go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") 260 } 261 // TODO: react to external IP changes over time. 262 if ext, err := natm.ExternalIP(); err == nil { 263 realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} 264 } 265 } 266 // TODO: separate TCP port 267 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 268 tab, err := newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) 269 if err != nil { 270 return nil, nil, err 271 } 272 udp.Table = tab 273 274 go udp.loop() 275 go udp.readLoop() 276 return udp.Table, udp, nil 277 } 278 279 func (t *udp) close() { 280 close(t.closing) 281 t.conn.Close() 282 // TODO: wait for the loops to end. 283 } 284 285 // ping sends a ping message to the given node and waits for a reply. 286 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 287 // TODO: maybe check for ReplyTo field in callback to measure RTT 288 errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) 289 t.send(toaddr, pingPacket, &ping{ 290 Version: Version, 291 From: t.ourEndpoint, 292 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 293 Expiration: uint64(time.Now().Add(expiration).Unix()), 294 }) 295 return <-errc 296 } 297 298 299 func (t *udp) waitping(from NodeID) error { 300 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 301 } 302 303 // findnode sends a findnode request to the given node and waits until 304 // the node has sent up to k neighbors. 305 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 306 nodes := make([]*Node, 0, bucketSize) 307 nreceived := 0 308 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 309 reply := r.(*neighbors) 310 for _, rn := range reply.Nodes { 311 nreceived++ 312 n, err := t.nodeFromRPC(toaddr, rn) 313 if err != nil { 314 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 315 continue 316 } 317 nodes = append(nodes, n) 318 } 319 return nreceived >= bucketSize 320 }) 321 t.send(toaddr, findnodePacket, &findnode{ 322 Target: target, 323 Expiration: uint64(time.Now().Add(expiration).Unix()), 324 }) 325 err := <-errc 326 return nodes, err 327 } 328 329 // pending adds a reply callback to the pending reply queue. 330 // see the documentation of type pending for a detailed explanation. 331 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 332 ch := make(chan error, 1) 333 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 334 select { 335 case t.addpending <- p: 336 // loop will handle it 337 case <-t.closing: 338 ch <- errClosed 339 } 340 return ch 341 } 342 343 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 344 matched := make(chan bool, 1) 345 select { 346 case t.gotreply <- reply{from, ptype, req, matched}: 347 // loop will handle it 348 return <-matched 349 case <-t.closing: 350 return false 351 } 352 } 353 354 // loop runs in its own goroutine. it keeps track of 355 // the refresh timer and the pending reply queue. 356 func (t *udp) loop() { 357 var ( 358 plist = list.New() 359 timeout = time.NewTimer(0) 360 nextTimeout *pending // head of plist when timeout was last reset 361 contTimeouts = 0 // number of continuous timeouts to do NTP checks 362 ntpWarnTime = time.Unix(0, 0) 363 ) 364 <-timeout.C // ignore first timeout 365 defer timeout.Stop() 366 367 resetTimeout := func() { 368 if plist.Front() == nil || nextTimeout == plist.Front().Value { 369 return 370 } 371 // Start the timer so it fires when the next pending reply has expired. 372 now := time.Now() 373 for el := plist.Front(); el != nil; el = el.Next() { 374 nextTimeout = el.Value.(*pending) 375 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 376 timeout.Reset(dist) 377 return 378 } 379 // Remove pending replies whose deadline is too far in the 380 // future. These can occur if the system clock jumped 381 // backwards after the deadline was assigned. 382 nextTimeout.errc <- errClockWarp 383 plist.Remove(el) 384 } 385 nextTimeout = nil 386 timeout.Stop() 387 } 388 389 for { 390 resetTimeout() 391 392 select { 393 case <-t.closing: 394 for el := plist.Front(); el != nil; el = el.Next() { 395 el.Value.(*pending).errc <- errClosed 396 } 397 return 398 399 case p := <-t.addpending: 400 p.deadline = time.Now().Add(respTimeout) 401 plist.PushBack(p) 402 403 case r := <-t.gotreply: 404 var matched bool 405 for el := plist.Front(); el != nil; el = el.Next() { 406 p := el.Value.(*pending) 407 if p.from == r.from && p.ptype == r.ptype { 408 matched = true 409 // Remove the matcher if its callback indicates 410 // that all replies have been received. This is 411 // required for packet types that expect multiple 412 // reply packets. 413 if p.callback(r.data) { 414 p.errc <- nil 415 plist.Remove(el) 416 } 417 // Reset the continuous timeout counter (time drift detection) 418 contTimeouts = 0 419 } 420 } 421 r.matched <- matched 422 423 case now := <-timeout.C: 424 nextTimeout = nil 425 426 // Notify and remove callbacks whose deadline is in the past. 427 for el := plist.Front(); el != nil; el = el.Next() { 428 p := el.Value.(*pending) 429 if now.After(p.deadline) || now.Equal(p.deadline) { 430 p.errc <- errTimeout 431 plist.Remove(el) 432 contTimeouts++ 433 } 434 } 435 // If we've accumulated too many timeouts, do an NTP time sync check 436 if contTimeouts > ntpFailureThreshold { 437 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 438 ntpWarnTime = time.Now() 439 go checkClockDrift() 440 } 441 contTimeouts = 0 442 } 443 } 444 } 445 } 446 447 const ( 448 macSize = 256 / 8 449 sigSize = 520 / 8 450 headSize = macSize + sigSize // space of packet frame data 451 ) 452 453 var ( 454 headSpace = make([]byte, headSize) 455 456 // Neighbors replies are sent across multiple packets to 457 // stay below the 1280 byte limit. We compute the maximum number 458 // of entries by stuffing a packet until it grows too large. 459 maxNeighbors int 460 ) 461 462 func init() { 463 p := neighbors{Expiration: ^uint64(0)} 464 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0) - 1, TCP: ^uint16(0) - 1} 465 for n := 0; ; n++ { 466 p.Nodes = append(p.Nodes, maxSizeNode) 467 size, _, err := rlp.EncodeToReader(p) 468 if err != nil { 469 // If this ever happens, it will be caught by the unit tests. 470 panic("cannot encode: " + err.Error()) 471 } 472 if headSize+size+1 >= 1280 { 473 maxNeighbors = n 474 break 475 } 476 } 477 } 478 479 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) error { 480 packet, err := encodePacket(t.priv, ptype, req) 481 if err != nil { 482 return err 483 } 484 485 //if toaddr.Port != params.WanUdpPort { 486 // return errors.New("not specified port") 487 //} 488 _, err = t.conn.WriteToUDP(packet, toaddr) 489 log.Trace(">> "+req.name(), "addr", toaddr, "err", err) 490 return err 491 } 492 493 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { 494 b := new(bytes.Buffer) 495 b.Write(headSpace) 496 b.WriteByte(ptype) 497 if err := rlp.Encode(b, req); err != nil { 498 log.Error("Can't encode discv4 packet", "err", err) 499 return nil, err 500 } 501 packet := b.Bytes() 502 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 503 if err != nil { 504 log.Error("Can't sign discv4 packet", "err", err) 505 return nil, err 506 } 507 copy(packet[macSize:], sig) 508 // add the hash to the front. Note: this doesn't protect the 509 // packet in any way. Our public key will be part of this hash in 510 // The future. 511 copy(packet, crypto.Keccak256(packet[macSize:])) 512 return packet, nil 513 } 514 515 // readLoop runs in its own goroutine. it handles incoming UDP packets. 516 func (t *udp) readLoop() { 517 defer t.conn.Close() 518 // Discovery packets are defined to be no larger than 1280 bytes. 519 // Packets larger than this size will be cut at the end and treated 520 // as invalid because their hash won't match. 521 buf := make([]byte, 1280) 522 for { 523 nbytes, from, err := t.conn.ReadFromUDP(buf) 524 if netutil.IsTemporaryError(err) { 525 // Ignore temporary read errors. 526 log.Debug("Temporary UDP read error", "err", err) 527 continue 528 } else if err != nil { 529 // Shut down the loop for permament errors. 530 log.Debug("UDP read error", "err", err) 531 return 532 } 533 534 t.handlePacket(from, buf[:nbytes]) 535 } 536 } 537 538 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 539 packet, fromID, hash, err := decodePacket(buf) 540 if err != nil { 541 log.Debug("Bad discv4 packet", "addr", from, "err", err) 542 return err 543 } 544 545 //if from.Port != params.WanUdpPort { 546 // return errors.New("the port is not specified") 547 //} 548 549 err = packet.handle(t, from, fromID, hash) 550 log.Trace("<< "+packet.name(), "addr", from, "err", err) 551 return err 552 } 553 554 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 555 if len(buf) < headSize+1 { 556 return nil, NodeID{}, nil, errPacketTooSmall 557 } 558 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 559 shouldhash := crypto.Keccak256(buf[macSize:]) 560 if !bytes.Equal(hash, shouldhash) { 561 return nil, NodeID{}, nil, errBadHash 562 } 563 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 564 if err != nil { 565 return nil, NodeID{}, hash, err 566 } 567 var req packet 568 switch ptype := sigdata[0]; ptype { 569 case pingPacket: 570 req = new(ping) 571 case pongPacket: 572 req = new(pong) 573 case findnodePacket: 574 req = new(findnode) 575 case neighborsPacket: 576 req = new(neighbors) 577 default: 578 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 579 } 580 s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) 581 err = s.Decode(req) 582 return req, fromID, hash, err 583 } 584 585 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 586 if expired(req.Expiration) { 587 return errExpired 588 } 589 t.send(from, pongPacket, &pong{ 590 To: makeEndpoint(from, req.From.TCP), 591 ReplyTok: mac, 592 Expiration: uint64(time.Now().Add(expiration).Unix()), 593 }) 594 if !t.handleReply(fromID, pingPacket, req) { 595 // Note: we're ignoring the provided IP address right now 596 go t.bond(true, fromID, from, req.From.TCP) 597 } 598 return nil 599 } 600 601 func (req *ping) name() string { 602 return "PING/v4" 603 } 604 605 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 606 if expired(req.Expiration) { 607 return errExpired 608 } 609 if !t.handleReply(fromID, pongPacket, req) { 610 return errUnsolicitedReply 611 } 612 return nil 613 } 614 615 func (req *pong) name() string { 616 return "PONG/v4" 617 618 } 619 620 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 621 if expired(req.Expiration) { 622 return errExpired 623 } 624 if t.db.node(fromID) == nil { 625 // No bond exists, we don't process the packet. This prevents 626 // an attack vector where the discovery protocol could be used 627 // to amplify traffic in a DDOS attack. A malicious actor 628 // would send a findnode request with the IP address and UDP 629 // port of the target as the source address. The recipient of 630 // the findnode packet would then send a neighbors packet 631 // (which is a much bigger packet than findnode) to the victim. 632 return errUnknownNode 633 } 634 target := crypto.Keccak256Hash(req.Target[:]) 635 t.mutex.Lock() 636 closest := t.closest(target, bucketSize).entries 637 t.mutex.Unlock() 638 639 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 640 // Send neighbors in chunks with at most maxNeighbors per packet 641 // to stay below the 1280 byte limit. 642 for i, n := range closest { 643 if netutil.CheckRelayIP(from.IP, n.IP) != nil { 644 continue 645 } 646 647 p.Nodes = append(p.Nodes, nodeToRPC(n)) 648 if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { 649 t.send(from, neighborsPacket, &p) 650 p.Nodes = p.Nodes[:0] 651 } 652 } 653 return nil 654 } 655 656 func (req *findnode) name() string { 657 return "FINDNODE/v4" 658 } 659 660 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 661 if expired(req.Expiration) { 662 return errExpired 663 } 664 if !t.handleReply(fromID, neighborsPacket, req) { 665 return errUnsolicitedReply 666 } 667 return nil 668 } 669 670 func (req *neighbors) name() string { 671 return "NEIGHBORS/v4" 672 } 673 674 func expired(ts uint64) bool { 675 return time.Unix(int64(ts), 0).Before(time.Now()) 676 }