github.com/jeffallen/go-ethereum@v1.1.4-0.20150910155051-571d3236c49c/p2p/discover/udp.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discover 18 19 import ( 20 "bytes" 21 "container/list" 22 "crypto/ecdsa" 23 "errors" 24 "fmt" 25 "net" 26 "time" 27 28 "github.com/ethereum/go-ethereum/crypto" 29 "github.com/ethereum/go-ethereum/logger" 30 "github.com/ethereum/go-ethereum/logger/glog" 31 "github.com/ethereum/go-ethereum/p2p/nat" 32 "github.com/ethereum/go-ethereum/rlp" 33 ) 34 35 const Version = 4 36 37 // Errors 38 var ( 39 errPacketTooSmall = errors.New("too small") 40 errBadHash = errors.New("bad hash") 41 errExpired = errors.New("expired") 42 errBadVersion = errors.New("version mismatch") 43 errUnsolicitedReply = errors.New("unsolicited reply") 44 errUnknownNode = errors.New("unknown node") 45 errTimeout = errors.New("RPC timeout") 46 errClockWarp = errors.New("reply deadline too far in the future") 47 errClosed = errors.New("socket closed") 48 ) 49 50 // Timeouts 51 const ( 52 respTimeout = 500 * time.Millisecond 53 sendTimeout = 500 * time.Millisecond 54 expiration = 20 * time.Second 55 56 refreshInterval = 1 * time.Hour 57 ) 58 59 // RPC packet types 60 const ( 61 pingPacket = iota + 1 // zero is 'reserved' 62 pongPacket 63 findnodePacket 64 neighborsPacket 65 ) 66 67 // RPC request structures 68 type ( 69 ping struct { 70 Version uint 71 From, To rpcEndpoint 72 Expiration uint64 73 } 74 75 // pong is the reply to ping. 76 pong struct { 77 // This field should mirror the UDP envelope address 78 // of the ping packet, which provides a way to discover the 79 // the external address (after NAT). 80 To rpcEndpoint 81 82 ReplyTok []byte // This contains the hash of the ping packet. 83 Expiration uint64 // Absolute timestamp at which the packet becomes invalid. 84 } 85 86 // findnode is a query for nodes close to the given target. 87 findnode struct { 88 Target NodeID // doesn't need to be an actual public key 89 Expiration uint64 90 } 91 92 // reply to findnode 93 neighbors struct { 94 Nodes []rpcNode 95 Expiration uint64 96 } 97 98 rpcNode struct { 99 IP net.IP // len 4 for IPv4 or 16 for IPv6 100 UDP uint16 // for discovery protocol 101 TCP uint16 // for RLPx protocol 102 ID NodeID 103 } 104 105 rpcEndpoint struct { 106 IP net.IP // len 4 for IPv4 or 16 for IPv6 107 UDP uint16 // for discovery protocol 108 TCP uint16 // for RLPx protocol 109 } 110 ) 111 112 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 113 ip := addr.IP.To4() 114 if ip == nil { 115 ip = addr.IP.To16() 116 } 117 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 118 } 119 120 func nodeFromRPC(rn rpcNode) (n *Node, valid bool) { 121 // TODO: don't accept localhost, LAN addresses from internet hosts 122 // TODO: check public key is on secp256k1 curve 123 if rn.IP.IsMulticast() || rn.IP.IsUnspecified() || rn.UDP == 0 { 124 return nil, false 125 } 126 return newNode(rn.ID, rn.IP, rn.UDP, rn.TCP), true 127 } 128 129 func nodeToRPC(n *Node) rpcNode { 130 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 131 } 132 133 type packet interface { 134 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 135 } 136 137 type conn interface { 138 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 139 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 140 Close() error 141 LocalAddr() net.Addr 142 } 143 144 // udp implements the RPC protocol. 145 type udp struct { 146 conn conn 147 priv *ecdsa.PrivateKey 148 ourEndpoint rpcEndpoint 149 150 addpending chan *pending 151 gotreply chan reply 152 153 closing chan struct{} 154 nat nat.Interface 155 156 *Table 157 } 158 159 // pending represents a pending reply. 160 // 161 // some implementations of the protocol wish to send more than one 162 // reply packet to findnode. in general, any neighbors packet cannot 163 // be matched up with a specific findnode packet. 164 // 165 // our implementation handles this by storing a callback function for 166 // each pending reply. incoming packets from a node are dispatched 167 // to all the callback functions for that node. 168 type pending struct { 169 // these fields must match in the reply. 170 from NodeID 171 ptype byte 172 173 // time when the request must complete 174 deadline time.Time 175 176 // callback is called when a matching reply arrives. if it returns 177 // true, the callback is removed from the pending reply queue. 178 // if it returns false, the reply is considered incomplete and 179 // the callback will be invoked again for the next matching reply. 180 callback func(resp interface{}) (done bool) 181 182 // errc receives nil when the callback indicates completion or an 183 // error if no further reply is received within the timeout. 184 errc chan<- error 185 } 186 187 type reply struct { 188 from NodeID 189 ptype byte 190 data interface{} 191 // loop indicates whether there was 192 // a matching request by sending on this channel. 193 matched chan<- bool 194 } 195 196 // ListenUDP returns a new table that listens for UDP packets on laddr. 197 func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { 198 addr, err := net.ResolveUDPAddr("udp", laddr) 199 if err != nil { 200 return nil, err 201 } 202 conn, err := net.ListenUDP("udp", addr) 203 if err != nil { 204 return nil, err 205 } 206 tab, _ := newUDP(priv, conn, natm, nodeDBPath) 207 glog.V(logger.Info).Infoln("Listening,", tab.self) 208 return tab, nil 209 } 210 211 func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp) { 212 udp := &udp{ 213 conn: c, 214 priv: priv, 215 closing: make(chan struct{}), 216 gotreply: make(chan reply), 217 addpending: make(chan *pending), 218 } 219 realaddr := c.LocalAddr().(*net.UDPAddr) 220 if natm != nil { 221 if !realaddr.IP.IsLoopback() { 222 go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery") 223 } 224 // TODO: react to external IP changes over time. 225 if ext, err := natm.ExternalIP(); err == nil { 226 realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port} 227 } 228 } 229 // TODO: separate TCP port 230 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 231 udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath) 232 go udp.loop() 233 go udp.readLoop() 234 return udp.Table, udp 235 } 236 237 func (t *udp) close() { 238 close(t.closing) 239 t.conn.Close() 240 // TODO: wait for the loops to end. 241 } 242 243 // ping sends a ping message to the given node and waits for a reply. 244 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 245 // TODO: maybe check for ReplyTo field in callback to measure RTT 246 errc := t.pending(toid, pongPacket, func(interface{}) bool { return true }) 247 t.send(toaddr, pingPacket, ping{ 248 Version: Version, 249 From: t.ourEndpoint, 250 To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB 251 Expiration: uint64(time.Now().Add(expiration).Unix()), 252 }) 253 return <-errc 254 } 255 256 func (t *udp) waitping(from NodeID) error { 257 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 258 } 259 260 // findnode sends a findnode request to the given node and waits until 261 // the node has sent up to k neighbors. 262 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 263 nodes := make([]*Node, 0, bucketSize) 264 nreceived := 0 265 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 266 reply := r.(*neighbors) 267 for _, rn := range reply.Nodes { 268 nreceived++ 269 if n, valid := nodeFromRPC(rn); valid { 270 nodes = append(nodes, n) 271 } 272 } 273 return nreceived >= bucketSize 274 }) 275 t.send(toaddr, findnodePacket, findnode{ 276 Target: target, 277 Expiration: uint64(time.Now().Add(expiration).Unix()), 278 }) 279 err := <-errc 280 return nodes, err 281 } 282 283 // pending adds a reply callback to the pending reply queue. 284 // see the documentation of type pending for a detailed explanation. 285 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 286 ch := make(chan error, 1) 287 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 288 select { 289 case t.addpending <- p: 290 // loop will handle it 291 case <-t.closing: 292 ch <- errClosed 293 } 294 return ch 295 } 296 297 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 298 matched := make(chan bool, 1) 299 select { 300 case t.gotreply <- reply{from, ptype, req, matched}: 301 // loop will handle it 302 return <-matched 303 case <-t.closing: 304 return false 305 } 306 } 307 308 // loop runs in its own goroutin. it keeps track of 309 // the refresh timer and the pending reply queue. 310 func (t *udp) loop() { 311 var ( 312 plist = list.New() 313 timeout = time.NewTimer(0) 314 nextTimeout *pending // head of plist when timeout was last reset 315 refresh = time.NewTicker(refreshInterval) 316 ) 317 <-timeout.C // ignore first timeout 318 defer refresh.Stop() 319 defer timeout.Stop() 320 321 resetTimeout := func() { 322 if plist.Front() == nil || nextTimeout == plist.Front().Value { 323 return 324 } 325 // Start the timer so it fires when the next pending reply has expired. 326 now := time.Now() 327 for el := plist.Front(); el != nil; el = el.Next() { 328 nextTimeout = el.Value.(*pending) 329 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 330 timeout.Reset(dist) 331 return 332 } 333 // Remove pending replies whose deadline is too far in the 334 // future. These can occur if the system clock jumped 335 // backwards after the deadline was assigned. 336 nextTimeout.errc <- errClockWarp 337 plist.Remove(el) 338 } 339 nextTimeout = nil 340 timeout.Stop() 341 } 342 343 for { 344 resetTimeout() 345 346 select { 347 case <-refresh.C: 348 go t.refresh() 349 350 case <-t.closing: 351 for el := plist.Front(); el != nil; el = el.Next() { 352 el.Value.(*pending).errc <- errClosed 353 } 354 return 355 356 case p := <-t.addpending: 357 p.deadline = time.Now().Add(respTimeout) 358 plist.PushBack(p) 359 360 case r := <-t.gotreply: 361 var matched bool 362 for el := plist.Front(); el != nil; el = el.Next() { 363 p := el.Value.(*pending) 364 if p.from == r.from && p.ptype == r.ptype { 365 matched = true 366 // Remove the matcher if its callback indicates 367 // that all replies have been received. This is 368 // required for packet types that expect multiple 369 // reply packets. 370 if p.callback(r.data) { 371 p.errc <- nil 372 plist.Remove(el) 373 } 374 } 375 } 376 r.matched <- matched 377 378 case now := <-timeout.C: 379 nextTimeout = nil 380 // Notify and remove callbacks whose deadline is in the past. 381 for el := plist.Front(); el != nil; el = el.Next() { 382 p := el.Value.(*pending) 383 if now.After(p.deadline) || now.Equal(p.deadline) { 384 p.errc <- errTimeout 385 plist.Remove(el) 386 } 387 } 388 } 389 } 390 } 391 392 const ( 393 macSize = 256 / 8 394 sigSize = 520 / 8 395 headSize = macSize + sigSize // space of packet frame data 396 ) 397 398 var ( 399 headSpace = make([]byte, headSize) 400 401 // Neighbors replies are sent across multiple packets to 402 // stay below the 1280 byte limit. We compute the maximum number 403 // of entries by stuffing a packet until it grows too large. 404 maxNeighbors int 405 ) 406 407 func init() { 408 p := neighbors{Expiration: ^uint64(0)} 409 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 410 for n := 0; ; n++ { 411 p.Nodes = append(p.Nodes, maxSizeNode) 412 size, _, err := rlp.EncodeToReader(p) 413 if err != nil { 414 // If this ever happens, it will be caught by the unit tests. 415 panic("cannot encode: " + err.Error()) 416 } 417 if headSize+size+1 >= 1280 { 418 maxNeighbors = n 419 break 420 } 421 } 422 } 423 424 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error { 425 packet, err := encodePacket(t.priv, ptype, req) 426 if err != nil { 427 return err 428 } 429 glog.V(logger.Detail).Infof(">>> %v %T\n", toaddr, req) 430 if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil { 431 glog.V(logger.Detail).Infoln("UDP send failed:", err) 432 } 433 return err 434 } 435 436 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) { 437 b := new(bytes.Buffer) 438 b.Write(headSpace) 439 b.WriteByte(ptype) 440 if err := rlp.Encode(b, req); err != nil { 441 glog.V(logger.Error).Infoln("error encoding packet:", err) 442 return nil, err 443 } 444 packet := b.Bytes() 445 sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv) 446 if err != nil { 447 glog.V(logger.Error).Infoln("could not sign packet:", err) 448 return nil, err 449 } 450 copy(packet[macSize:], sig) 451 // add the hash to the front. Note: this doesn't protect the 452 // packet in any way. Our public key will be part of this hash in 453 // The future. 454 copy(packet, crypto.Sha3(packet[macSize:])) 455 return packet, nil 456 } 457 458 type tempError interface { 459 Temporary() bool 460 } 461 462 // readLoop runs in its own goroutine. it handles incoming UDP packets. 463 func (t *udp) readLoop() { 464 defer t.conn.Close() 465 // Discovery packets are defined to be no larger than 1280 bytes. 466 // Packets larger than this size will be cut at the end and treated 467 // as invalid because their hash won't match. 468 buf := make([]byte, 1280) 469 for { 470 nbytes, from, err := t.conn.ReadFromUDP(buf) 471 if tempErr, ok := err.(tempError); ok && tempErr.Temporary() { 472 // Ignore temporary read errors. 473 glog.V(logger.Debug).Infof("Temporary read error: %v", err) 474 continue 475 } else if err != nil { 476 // Shut down the loop for permament errors. 477 glog.V(logger.Debug).Infof("Read error: %v", err) 478 return 479 } 480 t.handlePacket(from, buf[:nbytes]) 481 } 482 } 483 484 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 485 packet, fromID, hash, err := decodePacket(buf) 486 if err != nil { 487 glog.V(logger.Debug).Infof("Bad packet from %v: %v\n", from, err) 488 return err 489 } 490 status := "ok" 491 if err = packet.handle(t, from, fromID, hash); err != nil { 492 status = err.Error() 493 } 494 glog.V(logger.Detail).Infof("<<< %v %T: %s\n", from, packet, status) 495 return err 496 } 497 498 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 499 if len(buf) < headSize+1 { 500 return nil, NodeID{}, nil, errPacketTooSmall 501 } 502 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 503 shouldhash := crypto.Sha3(buf[macSize:]) 504 if !bytes.Equal(hash, shouldhash) { 505 return nil, NodeID{}, nil, errBadHash 506 } 507 fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig) 508 if err != nil { 509 return nil, NodeID{}, hash, err 510 } 511 var req packet 512 switch ptype := sigdata[0]; ptype { 513 case pingPacket: 514 req = new(ping) 515 case pongPacket: 516 req = new(pong) 517 case findnodePacket: 518 req = new(findnode) 519 case neighborsPacket: 520 req = new(neighbors) 521 default: 522 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 523 } 524 err = rlp.DecodeBytes(sigdata[1:], req) 525 return req, fromID, hash, err 526 } 527 528 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 529 if expired(req.Expiration) { 530 return errExpired 531 } 532 if req.Version != Version { 533 return errBadVersion 534 } 535 t.send(from, pongPacket, pong{ 536 To: makeEndpoint(from, req.From.TCP), 537 ReplyTok: mac, 538 Expiration: uint64(time.Now().Add(expiration).Unix()), 539 }) 540 if !t.handleReply(fromID, pingPacket, req) { 541 // Note: we're ignoring the provided IP address right now 542 go t.bond(true, fromID, from, req.From.TCP) 543 } 544 return nil 545 } 546 547 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 548 if expired(req.Expiration) { 549 return errExpired 550 } 551 if !t.handleReply(fromID, pongPacket, req) { 552 return errUnsolicitedReply 553 } 554 return nil 555 } 556 557 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 558 if expired(req.Expiration) { 559 return errExpired 560 } 561 if t.db.node(fromID) == nil { 562 // No bond exists, we don't process the packet. This prevents 563 // an attack vector where the discovery protocol could be used 564 // to amplify traffic in a DDOS attack. A malicious actor 565 // would send a findnode request with the IP address and UDP 566 // port of the target as the source address. The recipient of 567 // the findnode packet would then send a neighbors packet 568 // (which is a much bigger packet than findnode) to the victim. 569 return errUnknownNode 570 } 571 target := crypto.Sha3Hash(req.Target[:]) 572 t.mutex.Lock() 573 closest := t.closest(target, bucketSize).entries 574 t.mutex.Unlock() 575 576 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 577 // Send neighbors in chunks with at most maxNeighbors per packet 578 // to stay below the 1280 byte limit. 579 for i, n := range closest { 580 p.Nodes = append(p.Nodes, nodeToRPC(n)) 581 if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { 582 t.send(from, neighborsPacket, p) 583 p.Nodes = p.Nodes[:0] 584 } 585 } 586 return nil 587 } 588 589 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 590 if expired(req.Expiration) { 591 return errExpired 592 } 593 if !t.handleReply(fromID, neighborsPacket, req) { 594 return errUnsolicitedReply 595 } 596 return nil 597 } 598 599 func expired(ts uint64) bool { 600 return time.Unix(int64(ts), 0).Before(time.Now()) 601 }