github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/discover/udp.go (about) 1 package discover 2 3 import ( 4 "bytes" 5 "container/list" 6 "crypto/ecdsa" 7 "errors" 8 "fmt" 9 "net" 10 "time" 11 12 "github.com/neatio-net/neatio/chain/log" 13 "github.com/neatio-net/neatio/network/p2p/nat" 14 "github.com/neatio-net/neatio/network/p2p/netutil" 15 "github.com/neatio-net/neatio/utilities/crypto" 16 "github.com/neatio-net/neatio/utilities/rlp" 17 ) 18 19 const Version = 4 20 21 var ( 22 errPacketTooSmall = errors.New("too small") 23 errBadHash = errors.New("bad hash") 24 errExpired = errors.New("expired") 25 errUnsolicitedReply = errors.New("unsolicited reply") 26 errUnknownNode = errors.New("unknown node") 27 errTimeout = errors.New("RPC timeout") 28 errClockWarp = errors.New("reply deadline too far in the future") 29 errClosed = errors.New("socket closed") 30 ) 31 32 const ( 33 respTimeout = 500 * time.Millisecond 34 sendTimeout = 500 * time.Millisecond 35 expiration = 20 * time.Second 36 37 ntpFailureThreshold = 32 38 ntpWarningCooldown = 10 * time.Minute 39 driftThreshold = 10 * time.Second 40 ) 41 42 const ( 43 pingPacket = iota + 1 44 pongPacket 45 findnodePacket 46 neighborsPacket 47 ) 48 49 type ( 50 ping struct { 51 Version uint 52 From, To rpcEndpoint 53 Expiration uint64 54 55 Rest []rlp.RawValue `rlp:"tail"` 56 } 57 58 pong struct { 59 To rpcEndpoint 60 61 ReplyTok []byte 62 Expiration uint64 63 64 Rest []rlp.RawValue `rlp:"tail"` 65 } 66 67 findnode struct { 68 Target NodeID 69 Expiration uint64 70 71 Rest []rlp.RawValue `rlp:"tail"` 72 } 73 74 neighbors struct { 75 Nodes []rpcNode 76 Expiration uint64 77 78 Rest []rlp.RawValue `rlp:"tail"` 79 } 80 81 rpcNode struct { 82 IP net.IP 83 UDP uint16 84 TCP uint16 85 ID NodeID 86 } 87 88 rpcEndpoint struct { 89 IP net.IP 90 UDP uint16 91 TCP uint16 92 } 93 ) 94 95 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { 96 ip := addr.IP.To4() 97 if ip == nil { 98 ip = addr.IP.To16() 99 } 100 return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} 101 } 102 103 func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { 104 if rn.UDP <= 1024 { 105 return nil, errors.New("low port") 106 } 107 if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { 108 return nil, err 109 } 110 if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { 111 return nil, errors.New("not contained in netrestrict whitelist") 112 } 113 n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) 114 err := n.validateComplete() 115 return n, err 116 } 117 118 func nodeToRPC(n *Node) rpcNode { 119 return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP} 120 } 121 122 type packet interface { 123 handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error 124 name() string 125 } 126 127 type conn interface { 128 ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) 129 WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) 130 Close() error 131 LocalAddr() net.Addr 132 } 133 134 type udp struct { 135 conn conn 136 netrestrict *netutil.Netlist 137 priv *ecdsa.PrivateKey 138 ourEndpoint rpcEndpoint 139 140 addpending chan *pending 141 gotreply chan reply 142 143 closing chan struct{} 144 nat nat.Interface 145 146 *Table 147 } 148 149 type pending struct { 150 from NodeID 151 ptype byte 152 153 deadline time.Time 154 155 callback func(resp interface{}) (done bool) 156 157 errc chan<- error 158 } 159 160 type reply struct { 161 from NodeID 162 ptype byte 163 data interface{} 164 165 matched chan<- bool 166 } 167 168 type ReadPacket struct { 169 Data []byte 170 Addr *net.UDPAddr 171 } 172 173 type Config struct { 174 PrivateKey *ecdsa.PrivateKey 175 176 AnnounceAddr *net.UDPAddr 177 NodeDBPath string 178 NetRestrict *netutil.Netlist 179 Bootnodes []*Node 180 Unhandled chan<- ReadPacket 181 } 182 183 func ListenUDP(c conn, cfg Config) (*Table, error) { 184 tab, _, err := newUDP(c, cfg) 185 if err != nil { 186 return nil, err 187 } 188 189 return tab, nil 190 } 191 192 func newUDP(c conn, cfg Config) (*Table, *udp, error) { 193 udp := &udp{ 194 conn: c, 195 priv: cfg.PrivateKey, 196 netrestrict: cfg.NetRestrict, 197 closing: make(chan struct{}), 198 gotreply: make(chan reply), 199 addpending: make(chan *pending), 200 } 201 realaddr := c.LocalAddr().(*net.UDPAddr) 202 if cfg.AnnounceAddr != nil { 203 realaddr = cfg.AnnounceAddr 204 } 205 206 udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port)) 207 tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes) 208 if err != nil { 209 return nil, nil, err 210 } 211 udp.Table = tab 212 213 go udp.loop() 214 go udp.readLoop(cfg.Unhandled) 215 return udp.Table, udp, nil 216 } 217 218 func (t *udp) close() { 219 close(t.closing) 220 t.conn.Close() 221 222 } 223 224 func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error { 225 req := &ping{ 226 Version: Version, 227 From: t.ourEndpoint, 228 To: makeEndpoint(toaddr, 0), 229 Expiration: uint64(time.Now().Add(expiration).Unix()), 230 } 231 packet, hash, err := encodePacket(t.priv, pingPacket, req) 232 if err != nil { 233 return err 234 } 235 errc := t.pending(toid, pongPacket, func(p interface{}) bool { 236 return bytes.Equal(p.(*pong).ReplyTok, hash) 237 }) 238 t.write(toaddr, req.name(), packet) 239 return <-errc 240 } 241 242 func (t *udp) waitping(from NodeID) error { 243 return <-t.pending(from, pingPacket, func(interface{}) bool { return true }) 244 } 245 246 func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) { 247 nodes := make([]*Node, 0, bucketSize) 248 nreceived := 0 249 errc := t.pending(toid, neighborsPacket, func(r interface{}) bool { 250 reply := r.(*neighbors) 251 for _, rn := range reply.Nodes { 252 nreceived++ 253 n, err := t.nodeFromRPC(toaddr, rn) 254 if err != nil { 255 log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) 256 continue 257 } 258 nodes = append(nodes, n) 259 } 260 return nreceived >= bucketSize 261 }) 262 t.send(toaddr, findnodePacket, &findnode{ 263 Target: target, 264 Expiration: uint64(time.Now().Add(expiration).Unix()), 265 }) 266 err := <-errc 267 return nodes, err 268 } 269 270 func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error { 271 ch := make(chan error, 1) 272 p := &pending{from: id, ptype: ptype, callback: callback, errc: ch} 273 select { 274 case t.addpending <- p: 275 276 case <-t.closing: 277 ch <- errClosed 278 } 279 return ch 280 } 281 282 func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool { 283 matched := make(chan bool, 1) 284 select { 285 case t.gotreply <- reply{from, ptype, req, matched}: 286 287 return <-matched 288 case <-t.closing: 289 return false 290 } 291 } 292 293 func (t *udp) loop() { 294 var ( 295 plist = list.New() 296 timeout = time.NewTimer(0) 297 nextTimeout *pending 298 contTimeouts = 0 299 ntpWarnTime = time.Unix(0, 0) 300 ) 301 <-timeout.C 302 defer timeout.Stop() 303 304 resetTimeout := func() { 305 if plist.Front() == nil || nextTimeout == plist.Front().Value { 306 return 307 } 308 309 now := time.Now() 310 for el := plist.Front(); el != nil; el = el.Next() { 311 nextTimeout = el.Value.(*pending) 312 if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout { 313 timeout.Reset(dist) 314 return 315 } 316 317 nextTimeout.errc <- errClockWarp 318 plist.Remove(el) 319 } 320 nextTimeout = nil 321 timeout.Stop() 322 } 323 324 for { 325 resetTimeout() 326 327 select { 328 case <-t.closing: 329 for el := plist.Front(); el != nil; el = el.Next() { 330 el.Value.(*pending).errc <- errClosed 331 } 332 return 333 334 case p := <-t.addpending: 335 p.deadline = time.Now().Add(respTimeout) 336 plist.PushBack(p) 337 338 case r := <-t.gotreply: 339 var matched bool 340 for el := plist.Front(); el != nil; el = el.Next() { 341 p := el.Value.(*pending) 342 if p.from == r.from && p.ptype == r.ptype { 343 matched = true 344 345 if p.callback(r.data) { 346 p.errc <- nil 347 plist.Remove(el) 348 } 349 350 contTimeouts = 0 351 } 352 } 353 r.matched <- matched 354 355 case now := <-timeout.C: 356 nextTimeout = nil 357 358 for el := plist.Front(); el != nil; el = el.Next() { 359 p := el.Value.(*pending) 360 if now.After(p.deadline) || now.Equal(p.deadline) { 361 p.errc <- errTimeout 362 plist.Remove(el) 363 contTimeouts++ 364 } 365 } 366 367 if contTimeouts > ntpFailureThreshold { 368 if time.Since(ntpWarnTime) >= ntpWarningCooldown { 369 ntpWarnTime = time.Now() 370 go checkClockDrift() 371 } 372 contTimeouts = 0 373 } 374 } 375 } 376 } 377 378 const ( 379 macSize = 256 / 8 380 sigSize = 520 / 8 381 headSize = macSize + sigSize 382 ) 383 384 var ( 385 headSpace = make([]byte, headSize) 386 387 maxNeighbors int 388 ) 389 390 func init() { 391 p := neighbors{Expiration: ^uint64(0)} 392 maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)} 393 for n := 0; ; n++ { 394 p.Nodes = append(p.Nodes, maxSizeNode) 395 size, _, err := rlp.EncodeToReader(p) 396 if err != nil { 397 398 panic("cannot encode: " + err.Error()) 399 } 400 if headSize+size+1 >= 1280 { 401 maxNeighbors = n 402 break 403 } 404 } 405 } 406 407 func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) { 408 packet, hash, err := encodePacket(t.priv, ptype, req) 409 if err != nil { 410 return hash, err 411 } 412 return hash, t.write(toaddr, req.name(), packet) 413 } 414 415 func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error { 416 _, err := t.conn.WriteToUDP(packet, toaddr) 417 log.Trace(">> "+what, "addr", toaddr, "err", err) 418 return err 419 } 420 421 func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) { 422 b := new(bytes.Buffer) 423 b.Write(headSpace) 424 b.WriteByte(ptype) 425 if err := rlp.Encode(b, req); err != nil { 426 log.Error("Can't encode discv4 packet", "err", err) 427 return nil, nil, err 428 } 429 packet = b.Bytes() 430 sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv) 431 if err != nil { 432 log.Error("Can't sign discv4 packet", "err", err) 433 return nil, nil, err 434 } 435 copy(packet[macSize:], sig) 436 437 hash = crypto.Keccak256(packet[macSize:]) 438 copy(packet, hash) 439 return packet, hash, nil 440 } 441 442 func (t *udp) readLoop(unhandled chan<- ReadPacket) { 443 defer t.conn.Close() 444 if unhandled != nil { 445 defer close(unhandled) 446 } 447 448 buf := make([]byte, 1280) 449 for { 450 nbytes, from, err := t.conn.ReadFromUDP(buf) 451 if netutil.IsTemporaryError(err) { 452 453 log.Debug("Temporary UDP read error", "err", err) 454 continue 455 } else if err != nil { 456 457 log.Debug("UDP read error", "err", err) 458 return 459 } 460 if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil { 461 select { 462 case unhandled <- ReadPacket{buf[:nbytes], from}: 463 default: 464 } 465 } 466 } 467 } 468 469 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { 470 packet, fromID, hash, err := decodePacket(buf) 471 if err != nil { 472 log.Debug("Bad discv4 packet", "addr", from, "err", err) 473 return err 474 } 475 err = packet.handle(t, from, fromID, hash) 476 log.Trace("<< "+packet.name(), "addr", from, "err", err) 477 return err 478 } 479 480 func decodePacket(buf []byte) (packet, NodeID, []byte, error) { 481 if len(buf) < headSize+1 { 482 return nil, NodeID{}, nil, errPacketTooSmall 483 } 484 hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:] 485 shouldhash := crypto.Keccak256(buf[macSize:]) 486 if !bytes.Equal(hash, shouldhash) { 487 return nil, NodeID{}, nil, errBadHash 488 } 489 fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig) 490 if err != nil { 491 return nil, NodeID{}, hash, err 492 } 493 var req packet 494 switch ptype := sigdata[0]; ptype { 495 case pingPacket: 496 req = new(ping) 497 case pongPacket: 498 req = new(pong) 499 case findnodePacket: 500 req = new(findnode) 501 case neighborsPacket: 502 req = new(neighbors) 503 default: 504 return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype) 505 } 506 s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0) 507 err = s.Decode(req) 508 return req, fromID, hash, err 509 } 510 511 func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 512 if expired(req.Expiration) { 513 return errExpired 514 } 515 t.send(from, pongPacket, &pong{ 516 To: makeEndpoint(from, req.From.TCP), 517 ReplyTok: mac, 518 Expiration: uint64(time.Now().Add(expiration).Unix()), 519 }) 520 if !t.handleReply(fromID, pingPacket, req) { 521 522 go t.bond(true, fromID, from, req.From.TCP) 523 } 524 return nil 525 } 526 527 func (req *ping) name() string { return "PING/v4" } 528 529 func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 530 if expired(req.Expiration) { 531 return errExpired 532 } 533 if !t.handleReply(fromID, pongPacket, req) { 534 return errUnsolicitedReply 535 } 536 return nil 537 } 538 539 func (req *pong) name() string { return "PONG/v4" } 540 541 func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 542 if expired(req.Expiration) { 543 return errExpired 544 } 545 if !t.db.hasBond(fromID) { 546 547 return errUnknownNode 548 } 549 target := crypto.Keccak256Hash(req.Target[:]) 550 t.mutex.Lock() 551 closest := t.closest(target, bucketSize).entries 552 t.mutex.Unlock() 553 554 p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} 555 var sent bool 556 557 for _, n := range closest { 558 if netutil.CheckRelayIP(from.IP, n.IP) == nil { 559 p.Nodes = append(p.Nodes, nodeToRPC(n)) 560 } 561 if len(p.Nodes) == maxNeighbors { 562 t.send(from, neighborsPacket, &p) 563 p.Nodes = p.Nodes[:0] 564 sent = true 565 } 566 } 567 if len(p.Nodes) > 0 || !sent { 568 t.send(from, neighborsPacket, &p) 569 } 570 return nil 571 } 572 573 func (req *findnode) name() string { return "FINDNODE/v4" } 574 575 func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error { 576 if expired(req.Expiration) { 577 return errExpired 578 } 579 if !t.handleReply(fromID, neighborsPacket, req) { 580 return errUnsolicitedReply 581 } 582 return nil 583 } 584 585 func (req *neighbors) name() string { return "NEIGHBORS/v4" } 586 587 func expired(ts uint64) bool { 588 return time.Unix(int64(ts), 0).Before(time.Now()) 589 }