github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/p2p/discover/dht/net.go (about) 1 package dht 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "errors" 7 "fmt" 8 "net" 9 "time" 10 11 log "github.com/sirupsen/logrus" 12 "github.com/tendermint/go-wire" 13 "golang.org/x/crypto/sha3" 14 15 "github.com/bytom/bytom/common" 16 "github.com/bytom/bytom/crypto/ed25519" 17 "github.com/bytom/bytom/p2p/netutil" 18 ) 19 20 var ( 21 errInvalidEvent = errors.New("invalid in current state") 22 errNoQuery = errors.New("no pending query") 23 errWrongAddress = errors.New("unknown sender address") 24 ) 25 26 const ( 27 autoRefreshInterval = 1 * time.Hour 28 bucketRefreshInterval = 1 * time.Minute 29 seedCount = 30 30 seedMaxAge = 5 * 24 * time.Hour 31 lowPort = 1024 32 ) 33 34 const ( 35 printTestImgLogs = false 36 ) 37 38 // Network manages the table and all protocol interaction. 39 type Network struct { 40 db *nodeDB // database of known nodes 41 conn transport 42 netrestrict *netutil.Netlist 43 44 closed chan struct{} // closed when loop is done 45 closeReq chan struct{} // 'request to close' 46 refreshReq chan []*Node // lookups ask for refresh on this channel 47 refreshResp chan (<-chan struct{}) // ...and get the channel to block on from this one 48 read chan ingressPacket // ingress packets arrive here 49 timeout chan timeoutEvent 50 queryReq chan *findnodeQuery // lookups submit findnode queries on this channel 51 tableOpReq chan func() 52 tableOpResp chan struct{} 53 topicRegisterReq chan topicRegisterReq 54 topicSearchReq chan topicSearchReq 55 56 // State of the main loop. 57 tab *Table 58 topictab *topicTable 59 ticketStore *ticketStore 60 nursery []*Node 61 nodes map[NodeID]*Node // tracks active nodes with state != known 62 timeoutTimers map[timeoutEvent]*time.Timer 63 64 // Revalidation queues. 65 // Nodes put on these queues will be pinged eventually. 66 slowRevalidateQueue []*Node 67 fastRevalidateQueue []*Node 68 69 // Buffers for state transition. 70 sendBuf []*ingressPacket 71 } 72 73 // transport is implemented by the UDP transport. 74 // it is an interface so we can test without opening lots of UDP 75 // sockets and without generating a private key. 76 type transport interface { 77 sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte) 78 sendNeighbours(remote *Node, nodes []*Node) 79 sendFindnodeHash(remote *Node, target common.Hash) 80 sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte) 81 sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) 82 83 send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte) 84 85 localAddr() *net.UDPAddr 86 Close() 87 } 88 89 type findnodeQuery struct { 90 remote *Node 91 target common.Hash 92 reply chan<- []*Node 93 nresults int // counter for received nodes 94 } 95 96 type topicRegisterReq struct { 97 add bool 98 topic Topic 99 } 100 101 type topicSearchReq struct { 102 topic Topic 103 found chan<- *Node 104 lookup chan<- bool 105 delay time.Duration 106 } 107 108 type topicSearchResult struct { 109 target lookupInfo 110 nodes []*Node 111 } 112 113 type timeoutEvent struct { 114 ev nodeEvent 115 node *Node 116 } 117 118 func newNetwork(conn transport, ourPubkey ed25519.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { 119 var ourID NodeID 120 copy(ourID[:], ourPubkey[:nodeIDBits]) 121 122 var db *nodeDB 123 if dbPath != "<no database>" { 124 var err error 125 if db, err = newNodeDB(dbPath, Version, ourID); err != nil { 126 return nil, err 127 } 128 } 129 130 tab := newTable(ourID, conn.localAddr()) 131 net := &Network{ 132 db: db, 133 conn: conn, 134 netrestrict: netrestrict, 135 tab: tab, 136 topictab: newTopicTable(db, tab.self), 137 ticketStore: newTicketStore(), 138 refreshReq: make(chan []*Node), 139 refreshResp: make(chan (<-chan struct{})), 140 closed: make(chan struct{}), 141 closeReq: make(chan struct{}), 142 read: make(chan ingressPacket, 100), 143 timeout: make(chan timeoutEvent), 144 timeoutTimers: make(map[timeoutEvent]*time.Timer), 145 tableOpReq: make(chan func()), 146 tableOpResp: make(chan struct{}), 147 queryReq: make(chan *findnodeQuery), 148 topicRegisterReq: make(chan topicRegisterReq), 149 topicSearchReq: make(chan topicSearchReq), 150 nodes: make(map[NodeID]*Node), 151 } 152 go net.loop() 153 return net, nil 154 } 155 156 // Close terminates the network listener and flushes the node database. 157 func (net *Network) Close() { 158 net.conn.Close() 159 select { 160 case <-net.closed: 161 case net.closeReq <- struct{}{}: 162 <-net.closed 163 } 164 } 165 166 // Self returns the local node. 167 // The returned node should not be modified by the caller. 168 func (net *Network) Self() *Node { 169 return net.tab.self 170 } 171 172 func (net *Network) selfIP() net.IP { 173 return net.tab.self.IP 174 } 175 176 // ReadRandomNodes fills the given slice with random nodes from the 177 // table. It will not write the same node more than once. The nodes in 178 // the slice are copies and can be modified by the caller. 179 func (net *Network) ReadRandomNodes(buf []*Node) (n int) { 180 net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) }) 181 return n 182 } 183 184 // SetFallbackNodes sets the initial points of contact. These nodes 185 // are used to connect to the network if the table is empty and there 186 // are no known nodes in the database. 187 func (net *Network) SetFallbackNodes(nodes []*Node) error { 188 nursery := make([]*Node, 0, len(nodes)) 189 for _, n := range nodes { 190 if err := n.validateComplete(); err != nil { 191 return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) 192 } 193 // Recompute cpy.sha because the node might not have been 194 // created by NewNode or ParseNode. 195 cpy := *n 196 cpy.sha = common.BytesToHash(n.ID[:]) 197 nursery = append(nursery, &cpy) 198 } 199 net.reqRefresh(nursery) 200 return nil 201 } 202 203 // Resolve searches for a specific node with the given ID. 204 // It returns nil if the node could not be found. 205 func (net *Network) Resolve(targetID NodeID) *Node { 206 result := net.lookup(common.BytesToHash(targetID[:]), true) 207 for _, n := range result { 208 if n.ID == targetID { 209 return n 210 } 211 } 212 return nil 213 } 214 215 // Lookup performs a network search for nodes close 216 // to the given target. It approaches the target by querying 217 // nodes that are closer to it on each iteration. 218 // The given target does not need to be an actual node 219 // identifier. 220 // 221 // The local node may be included in the result. 222 func (net *Network) Lookup(targetID NodeID) []*Node { 223 return net.lookup(common.BytesToHash(targetID[:]), false) 224 } 225 226 func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node { 227 var ( 228 asked = make(map[NodeID]bool) 229 seen = make(map[NodeID]bool) 230 reply = make(chan []*Node, alpha) 231 result = nodesByDistance{target: target} 232 pendingQueries = 0 233 ) 234 // Get initial answers from the local node. 235 result.push(net.tab.self, bucketSize) 236 for { 237 // Ask the α closest nodes that we haven't asked yet. 238 for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { 239 n := result.entries[i] 240 if !asked[n.ID] { 241 asked[n.ID] = true 242 pendingQueries++ 243 net.reqQueryFindnode(n, target, reply) 244 } 245 } 246 if pendingQueries == 0 { 247 // We have asked all closest nodes, stop the search. 248 break 249 } 250 // Wait for the next reply. 251 select { 252 case nodes := <-reply: 253 for _, n := range nodes { 254 if n != nil && !seen[n.ID] { 255 seen[n.ID] = true 256 result.push(n, bucketSize) 257 if stopOnMatch && n.sha == target { 258 return result.entries 259 } 260 } 261 } 262 pendingQueries-- 263 case <-time.After(respTimeout): 264 // forget all pending requests, start new ones 265 pendingQueries = 0 266 reply = make(chan []*Node, alpha) 267 } 268 } 269 return result.entries 270 } 271 272 func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) { 273 select { 274 case net.topicRegisterReq <- topicRegisterReq{true, topic}: 275 case <-net.closed: 276 return 277 } 278 select { 279 case <-net.closed: 280 case <-stop: 281 select { 282 case net.topicRegisterReq <- topicRegisterReq{false, topic}: 283 case <-net.closed: 284 } 285 } 286 } 287 288 func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) { 289 for { 290 select { 291 case <-net.closed: 292 return 293 case delay, ok := <-setPeriod: 294 select { 295 case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}: 296 case <-net.closed: 297 return 298 } 299 if !ok { 300 return 301 } 302 } 303 } 304 } 305 306 func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} { 307 select { 308 case net.refreshReq <- nursery: 309 return <-net.refreshResp 310 case <-net.closed: 311 return net.closed 312 } 313 } 314 315 func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool { 316 q := &findnodeQuery{remote: n, target: target, reply: reply} 317 select { 318 case net.queryReq <- q: 319 return true 320 case <-net.closed: 321 return false 322 } 323 } 324 325 func (net *Network) reqReadPacket(pkt ingressPacket) { 326 select { 327 case net.read <- pkt: 328 case <-net.closed: 329 } 330 } 331 332 func (net *Network) reqTableOp(f func()) (called bool) { 333 select { 334 case net.tableOpReq <- f: 335 <-net.tableOpResp 336 return true 337 case <-net.closed: 338 return false 339 } 340 } 341 342 // TODO: external address handling. 343 344 type topicSearchInfo struct { 345 lookupChn chan<- bool 346 period time.Duration 347 } 348 349 const maxSearchCount = 5 350 351 func (net *Network) loop() { 352 var ( 353 refreshTimer = time.NewTicker(autoRefreshInterval) 354 bucketRefreshTimer = time.NewTimer(bucketRefreshInterval) 355 refreshDone chan struct{} // closed when the 'refresh' lookup has ended 356 ) 357 358 // Tracking the next ticket to register. 359 var ( 360 nextTicket *ticketRef 361 nextRegisterTimer *time.Timer 362 nextRegisterTime <-chan time.Time 363 ) 364 defer func() { 365 if nextRegisterTimer != nil { 366 nextRegisterTimer.Stop() 367 } 368 refreshTimer.Stop() 369 bucketRefreshTimer.Stop() 370 }() 371 resetNextTicket := func() { 372 ticket, timeout := net.ticketStore.nextFilteredTicket() 373 if nextTicket != ticket { 374 nextTicket = ticket 375 if nextRegisterTimer != nil { 376 nextRegisterTimer.Stop() 377 nextRegisterTime = nil 378 } 379 if ticket != nil { 380 nextRegisterTimer = time.NewTimer(timeout) 381 nextRegisterTime = nextRegisterTimer.C 382 } 383 } 384 } 385 386 // Tracking registration and search lookups. 387 var ( 388 topicRegisterLookupTarget lookupInfo 389 topicRegisterLookupDone chan []*Node 390 topicRegisterLookupTick = time.NewTimer(0) 391 searchReqWhenRefreshDone []topicSearchReq 392 searchInfo = make(map[Topic]topicSearchInfo) 393 activeSearchCount int 394 ) 395 topicSearchLookupDone := make(chan topicSearchResult, 100) 396 topicSearch := make(chan Topic, 100) 397 <-topicRegisterLookupTick.C 398 399 statsDump := time.NewTicker(10 * time.Second) 400 defer statsDump.Stop() 401 402 loop: 403 for { 404 resetNextTicket() 405 406 select { 407 case <-net.closeReq: 408 log.WithFields(log.Fields{"module": logModule}).Debug("close request") 409 break loop 410 411 // Ingress packet handling. 412 case pkt := <-net.read: 413 log.WithFields(log.Fields{"module": logModule}).Debug("read from net") 414 n := net.internNode(&pkt) 415 prestate := n.state 416 status := "ok" 417 if err := net.handle(n, pkt.ev, &pkt); err != nil { 418 status = err.Error() 419 } 420 log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": pkt.ev, "remote id": hex.EncodeToString(pkt.remoteID[:8]), "remote addr": pkt.remoteAddr, "pre state": prestate, "node state": n.state, "status": status}).Debug("handle ingress msg") 421 422 // TODO: persist state if n.state goes >= known, delete if it goes <= known 423 424 // State transition timeouts. 425 case timeout := <-net.timeout: 426 log.WithFields(log.Fields{"module": logModule}).Debug("net timeout") 427 if net.timeoutTimers[timeout] == nil { 428 // Stale timer (was aborted). 429 continue 430 } 431 delete(net.timeoutTimers, timeout) 432 prestate := timeout.node.state 433 status := "ok" 434 if err := net.handle(timeout.node, timeout.ev, nil); err != nil { 435 status = err.Error() 436 } 437 log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": timeout.ev, "node id": hex.EncodeToString(timeout.node.ID[:8]), "node addr": timeout.node.addr(), "pre state": prestate, "node state": timeout.node.state, "status": status}).Debug("handle timeout") 438 439 // Querying. 440 case q := <-net.queryReq: 441 log.WithFields(log.Fields{"module": logModule}).Debug("net query request") 442 if !q.start(net) { 443 q.remote.deferQuery(q) 444 } 445 446 // Interacting with the table. 447 case f := <-net.tableOpReq: 448 log.WithFields(log.Fields{"module": logModule}).Debug("net table operate request") 449 f() 450 net.tableOpResp <- struct{}{} 451 452 // Topic registration stuff. 453 case req := <-net.topicRegisterReq: 454 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic register request") 455 if !req.add { 456 net.ticketStore.removeRegisterTopic(req.topic) 457 continue 458 } 459 net.ticketStore.addTopic(req.topic, true) 460 // If we're currently waiting idle (nothing to look up), give the ticket store a 461 // chance to start it sooner. This should speed up convergence of the radius 462 // determination for new topics. 463 // if topicRegisterLookupDone == nil { 464 if topicRegisterLookupTarget.target == (common.Hash{}) { 465 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("topic register lookup target null") 466 if topicRegisterLookupTick.Stop() { 467 <-topicRegisterLookupTick.C 468 } 469 target, delay := net.ticketStore.nextRegisterLookup() 470 topicRegisterLookupTarget = target 471 topicRegisterLookupTick.Reset(delay) 472 } 473 474 case nodes := <-topicRegisterLookupDone: 475 log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup done") 476 net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte { 477 net.ping(n, n.addr()) 478 return n.pingEcho 479 }) 480 target, delay := net.ticketStore.nextRegisterLookup() 481 topicRegisterLookupTarget = target 482 topicRegisterLookupTick.Reset(delay) 483 topicRegisterLookupDone = nil 484 485 case <-topicRegisterLookupTick.C: 486 log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup tick") 487 if (topicRegisterLookupTarget.target == common.Hash{}) { 488 target, delay := net.ticketStore.nextRegisterLookup() 489 topicRegisterLookupTarget = target 490 topicRegisterLookupTick.Reset(delay) 491 topicRegisterLookupDone = nil 492 } else { 493 topicRegisterLookupDone = make(chan []*Node) 494 target := topicRegisterLookupTarget.target 495 go func() { topicRegisterLookupDone <- net.lookup(target, false) }() 496 } 497 498 case <-nextRegisterTime: 499 log.WithFields(log.Fields{"module": logModule}).Debug("next register time") 500 net.ticketStore.ticketRegistered(*nextTicket) 501 net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) 502 503 case req := <-net.topicSearchReq: 504 if refreshDone == nil { 505 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic rearch req") 506 info, ok := searchInfo[req.topic] 507 if ok { 508 if req.delay == time.Duration(0) { 509 delete(searchInfo, req.topic) 510 net.ticketStore.removeSearchTopic(req.topic) 511 } else { 512 info.period = req.delay 513 searchInfo[req.topic] = info 514 } 515 continue 516 } 517 if req.delay != time.Duration(0) { 518 var info topicSearchInfo 519 info.period = req.delay 520 info.lookupChn = req.lookup 521 searchInfo[req.topic] = info 522 net.ticketStore.addSearchTopic(req.topic, req.found) 523 topicSearch <- req.topic 524 } 525 } else { 526 searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req) 527 } 528 529 case topic := <-topicSearch: 530 if activeSearchCount < maxSearchCount { 531 activeSearchCount++ 532 target := net.ticketStore.nextSearchLookup(topic) 533 go func() { 534 nodes := net.lookup(target.target, false) 535 topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes} 536 }() 537 } 538 period := searchInfo[topic].period 539 if period != time.Duration(0) { 540 go func() { 541 time.Sleep(period) 542 topicSearch <- topic 543 }() 544 } 545 546 case res := <-topicSearchLookupDone: 547 activeSearchCount-- 548 if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil { 549 lookupChn <- net.ticketStore.radius[res.target.topic].converged 550 } 551 net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte { 552 if n.state != nil && n.state.canQuery { 553 return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration 554 } else { 555 if n.state == unknown { 556 net.ping(n, n.addr()) 557 } 558 return nil 559 } 560 }) 561 562 case <-statsDump.C: 563 log.WithFields(log.Fields{"module": logModule}).Debug("stats dump clock") 564 /*r, ok := net.ticketStore.radius[testTopic] 565 if !ok { 566 fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now()) 567 } else { 568 topics := len(net.ticketStore.tickets) 569 tickets := len(net.ticketStore.nodes) 570 rad := r.radius / (maxRadius/10000+1) 571 fmt.Printf("(%x) topics:%d radius:%d tickets:%d @ %v\n", net.tab.self.ID[:8], topics, rad, tickets, time.Now()) 572 }*/ 573 574 tm := Now() 575 for topic, r := range net.ticketStore.radius { 576 if printTestImgLogs { 577 rad := r.radius / (maxRadius/1000000 + 1) 578 minrad := r.minRadius / (maxRadius/1000000 + 1) 579 log.WithFields(log.Fields{"module": logModule}).Debugf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad) 580 log.WithFields(log.Fields{"module": logModule}).Debugf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad) 581 } 582 } 583 for topic, t := range net.topictab.topics { 584 wp := t.wcl.nextWaitPeriod(tm) 585 if printTestImgLogs { 586 log.WithFields(log.Fields{"module": logModule}).Debugf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000) 587 } 588 } 589 590 // Periodic / lookup-initiated bucket refresh. 591 case <-refreshTimer.C: 592 log.WithFields(log.Fields{"module": logModule}).Debug("refresh timer clock") 593 // TODO: ideally we would start the refresh timer after 594 // fallback nodes have been set for the first time. 595 if refreshDone == nil { 596 refreshDone = make(chan struct{}) 597 net.refresh(refreshDone) 598 } 599 case <-bucketRefreshTimer.C: 600 target := net.tab.chooseBucketRefreshTarget() 601 go func() { 602 net.lookup(target, false) 603 bucketRefreshTimer.Reset(bucketRefreshInterval) 604 }() 605 case newNursery := <-net.refreshReq: 606 log.WithFields(log.Fields{"module": logModule}).Debug("net refresh request") 607 if newNursery != nil { 608 net.nursery = newNursery 609 } 610 if refreshDone == nil { 611 refreshDone = make(chan struct{}) 612 net.refresh(refreshDone) 613 } 614 net.refreshResp <- refreshDone 615 case <-refreshDone: 616 log.WithFields(log.Fields{"module": logModule, "table size": net.tab.count}).Debug("net refresh done") 617 if net.tab.count != 0 { 618 refreshDone = nil 619 list := searchReqWhenRefreshDone 620 searchReqWhenRefreshDone = nil 621 go func() { 622 for _, req := range list { 623 net.topicSearchReq <- req 624 } 625 }() 626 } else { 627 refreshDone = make(chan struct{}) 628 net.refresh(refreshDone) 629 } 630 } 631 } 632 log.WithFields(log.Fields{"module": logModule}).Debug("loop stopped,shutting down") 633 if net.conn != nil { 634 net.conn.Close() 635 } 636 if refreshDone != nil { 637 // TODO: wait for pending refresh. 638 //<-refreshResults 639 } 640 // Cancel all pending timeouts. 641 for _, timer := range net.timeoutTimers { 642 timer.Stop() 643 } 644 if net.db != nil { 645 net.db.close() 646 } 647 close(net.closed) 648 } 649 650 // Everything below runs on the Network.loop goroutine 651 // and can modify Node, Table and Network at any time without locking. 652 653 func (net *Network) refresh(done chan<- struct{}) { 654 var seeds []*Node 655 if net.db != nil { 656 seeds = net.db.querySeeds(seedCount, seedMaxAge) 657 } 658 if len(seeds) == 0 { 659 seeds = net.nursery 660 } 661 if len(seeds) == 0 { 662 log.WithFields(log.Fields{"module": logModule}).Debug("no seed nodes found") 663 time.AfterFunc(time.Second*10, func() { close(done) }) 664 return 665 } 666 for _, n := range seeds { 667 n = net.internNodeFromDB(n) 668 if n.state == unknown { 669 net.transition(n, verifyinit) 670 } 671 // Force-add the seed node so Lookup does something. 672 // It will be deleted again if verification fails. 673 net.tab.add(n) 674 } 675 // Start self lookup to fill up the buckets. 676 go func() { 677 net.Lookup(net.tab.self.ID) 678 close(done) 679 }() 680 } 681 682 // Node Interning. 683 684 func (net *Network) internNode(pkt *ingressPacket) *Node { 685 if n := net.nodes[pkt.remoteID]; n != nil { 686 n.IP = pkt.remoteAddr.IP 687 n.UDP = uint16(pkt.remoteAddr.Port) 688 n.TCP = uint16(pkt.remoteAddr.Port) 689 return n 690 } 691 n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port)) 692 n.state = unknown 693 net.nodes[pkt.remoteID] = n 694 return n 695 } 696 697 func (net *Network) internNodeFromDB(dbn *Node) *Node { 698 if n := net.nodes[dbn.ID]; n != nil { 699 return n 700 } 701 n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP) 702 n.state = unknown 703 net.nodes[n.ID] = n 704 return n 705 } 706 707 func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) { 708 if rn.ID == net.tab.self.ID { 709 return nil, errors.New("is self") 710 } 711 if rn.UDP <= lowPort { 712 return nil, errors.New("low port") 713 } 714 n = net.nodes[rn.ID] 715 if n == nil { 716 // We haven't seen this node before. 717 n, err = nodeFromRPC(sender, rn) 718 if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { 719 return n, errors.New("not contained in netrestrict whitelist") 720 } 721 if err == nil { 722 n.state = unknown 723 net.nodes[n.ID] = n 724 } 725 return n, err 726 } 727 if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP { 728 if n.state == known { 729 // reject address change if node is known by us 730 err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) 731 } else { 732 // accept otherwise; this will be handled nicer with signed ENRs 733 n.IP = rn.IP 734 n.UDP = rn.UDP 735 n.TCP = rn.TCP 736 } 737 } 738 return n, err 739 } 740 741 // nodeNetGuts is embedded in Node and contains fields. 742 type nodeNetGuts struct { 743 // This is a cached copy of sha3(ID) which is used for node 744 // distance calculations. This is part of Node in order to make it 745 // possible to write tests that need a node at a certain distance. 746 // In those tests, the content of sha will not actually correspond 747 // with ID. 748 sha common.Hash 749 750 // State machine fields. Access to these fields 751 // is restricted to the Network.loop goroutine. 752 state *nodeState 753 pingEcho []byte // hash of last ping sent by us 754 pingTopics []Topic // topic set sent by us in last ping 755 deferredQueries []*findnodeQuery // queries that can't be sent yet 756 pendingNeighbours *findnodeQuery // current query, waiting for reply 757 queryTimeouts int 758 } 759 760 func (n *nodeNetGuts) deferQuery(q *findnodeQuery) { 761 n.deferredQueries = append(n.deferredQueries, q) 762 } 763 764 func (n *nodeNetGuts) startNextQuery(net *Network) { 765 if len(n.deferredQueries) == 0 { 766 return 767 } 768 nextq := n.deferredQueries[0] 769 if nextq.start(net) { 770 n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...) 771 } 772 } 773 774 func (q *findnodeQuery) start(net *Network) bool { 775 // Satisfy queries against the local node directly. 776 if q.remote == net.tab.self { 777 log.WithFields(log.Fields{"module": logModule}).Debug("findnodeQuery self") 778 closest := net.tab.closest(common.BytesToHash(q.target[:]), bucketSize) 779 780 q.reply <- closest.entries 781 return true 782 } 783 if q.remote.state.canQuery && q.remote.pendingNeighbours == nil { 784 log.WithFields(log.Fields{"module": logModule, "remote peer": q.remote.ID, "targetID": q.target}).Debug("find node query") 785 net.conn.sendFindnodeHash(q.remote, q.target) 786 net.timedEvent(respTimeout, q.remote, neighboursTimeout) 787 q.remote.pendingNeighbours = q 788 return true 789 } 790 // If the node is not known yet, it won't accept queries. 791 // Initiate the transition to known. 792 // The request will be sent later when the node reaches known state. 793 if q.remote.state == unknown { 794 log.WithFields(log.Fields{"module": logModule, "id": q.remote.ID, "status": "unknown->verify init"}).Debug("find node query") 795 net.transition(q.remote, verifyinit) 796 } 797 return false 798 } 799 800 // Node Events (the input to the state machine). 801 802 type nodeEvent uint 803 804 //go:generate stringer -type=nodeEvent 805 806 const ( 807 invalidEvent nodeEvent = iota // zero is reserved 808 809 // Packet type events. 810 // These correspond to packet types in the UDP protocol. 811 pingPacket 812 pongPacket 813 findnodePacket 814 neighborsPacket 815 findnodeHashPacket 816 topicRegisterPacket 817 topicQueryPacket 818 topicNodesPacket 819 820 // Non-packet events. 821 // Event values in this category are allocated outside 822 // the packet type range (packet types are encoded as a single byte). 823 pongTimeout nodeEvent = iota + 256 824 pingTimeout 825 neighboursTimeout 826 ) 827 828 // Node State Machine. 829 830 type nodeState struct { 831 name string 832 handle func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error) 833 enter func(*Network, *Node) 834 canQuery bool 835 } 836 837 func (s *nodeState) String() string { 838 return s.name 839 } 840 841 var ( 842 unknown *nodeState 843 verifyinit *nodeState 844 verifywait *nodeState 845 remoteverifywait *nodeState 846 known *nodeState 847 contested *nodeState 848 unresponsive *nodeState 849 ) 850 851 func init() { 852 unknown = &nodeState{ 853 name: "unknown", 854 enter: func(net *Network, n *Node) { 855 net.tab.delete(n) 856 n.pingEcho = nil 857 // Abort active queries. 858 for _, q := range n.deferredQueries { 859 q.reply <- nil 860 } 861 n.deferredQueries = nil 862 if n.pendingNeighbours != nil { 863 n.pendingNeighbours.reply <- nil 864 n.pendingNeighbours = nil 865 } 866 n.queryTimeouts = 0 867 }, 868 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 869 switch ev { 870 case pingPacket: 871 net.handlePing(n, pkt) 872 net.ping(n, pkt.remoteAddr) 873 return verifywait, nil 874 default: 875 return unknown, errInvalidEvent 876 } 877 }, 878 } 879 880 verifyinit = &nodeState{ 881 name: "verifyinit", 882 enter: func(net *Network, n *Node) { 883 net.ping(n, n.addr()) 884 }, 885 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 886 switch ev { 887 case pingPacket: 888 net.handlePing(n, pkt) 889 return verifywait, nil 890 case pongPacket: 891 err := net.handleKnownPong(n, pkt) 892 return remoteverifywait, err 893 case pongTimeout: 894 return unknown, nil 895 default: 896 return verifyinit, errInvalidEvent 897 } 898 }, 899 } 900 901 verifywait = &nodeState{ 902 name: "verifywait", 903 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 904 switch ev { 905 case pingPacket: 906 net.handlePing(n, pkt) 907 return verifywait, nil 908 case pongPacket: 909 err := net.handleKnownPong(n, pkt) 910 return known, err 911 case pongTimeout: 912 return unknown, nil 913 default: 914 return verifywait, errInvalidEvent 915 } 916 }, 917 } 918 919 remoteverifywait = &nodeState{ 920 name: "remoteverifywait", 921 enter: func(net *Network, n *Node) { 922 net.timedEvent(respTimeout, n, pingTimeout) 923 }, 924 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 925 switch ev { 926 case pingPacket: 927 net.handlePing(n, pkt) 928 return remoteverifywait, nil 929 case pingTimeout: 930 return known, nil 931 default: 932 return remoteverifywait, errInvalidEvent 933 } 934 }, 935 } 936 937 known = &nodeState{ 938 name: "known", 939 canQuery: true, 940 enter: func(net *Network, n *Node) { 941 n.queryTimeouts = 0 942 n.startNextQuery(net) 943 // Insert into the table and start revalidation of the last node 944 // in the bucket if it is full. 945 last := net.tab.add(n) 946 if last != nil && last.state == known { 947 // TODO: do this asynchronously 948 net.transition(last, contested) 949 } 950 }, 951 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 952 if err := net.db.updateNode(n); err != nil { 953 return known, err 954 } 955 956 switch ev { 957 case pingPacket: 958 net.handlePing(n, pkt) 959 return known, nil 960 case pongPacket: 961 err := net.handleKnownPong(n, pkt) 962 return known, err 963 default: 964 return net.handleQueryEvent(n, ev, pkt) 965 } 966 }, 967 } 968 969 contested = &nodeState{ 970 name: "contested", 971 canQuery: true, 972 enter: func(net *Network, n *Node) { 973 n.pingEcho = nil 974 net.ping(n, n.addr()) 975 }, 976 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 977 switch ev { 978 case pongPacket: 979 // Node is still alive. 980 err := net.handleKnownPong(n, pkt) 981 return known, err 982 case pongTimeout: 983 net.tab.deleteReplace(n) 984 return unresponsive, nil 985 case pingPacket: 986 net.handlePing(n, pkt) 987 return contested, nil 988 default: 989 return net.handleQueryEvent(n, ev, pkt) 990 } 991 }, 992 } 993 994 unresponsive = &nodeState{ 995 name: "unresponsive", 996 canQuery: true, 997 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 998 net.db.deleteNode(n.ID) 999 1000 switch ev { 1001 case pingPacket: 1002 net.handlePing(n, pkt) 1003 return known, nil 1004 case pongPacket: 1005 err := net.handleKnownPong(n, pkt) 1006 return known, err 1007 default: 1008 return net.handleQueryEvent(n, ev, pkt) 1009 } 1010 }, 1011 } 1012 } 1013 1014 // handle processes packets sent by n and events related to n. 1015 func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { 1016 //fmt.Println("handle", n.addr().String(), n.state, ev) 1017 if pkt != nil { 1018 if err := net.checkPacket(n, ev, pkt); err != nil { 1019 //fmt.Println("check err:", err) 1020 return err 1021 } 1022 // Start the background expiration goroutine after the first 1023 // successful communication. Subsequent calls have no effect if it 1024 // is already running. We do this here instead of somewhere else 1025 // so that the search for seed nodes also considers older nodes 1026 // that would otherwise be removed by the expirer. 1027 if net.db != nil { 1028 net.db.ensureExpirer() 1029 } 1030 } 1031 if n.state == nil { 1032 n.state = unknown //??? 1033 } 1034 next, err := n.state.handle(net, n, ev, pkt) 1035 net.transition(n, next) 1036 //fmt.Println("new state:", n.state) 1037 return err 1038 } 1039 1040 func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error { 1041 // Replay prevention checks. 1042 switch ev { 1043 case pingPacket, findnodeHashPacket, neighborsPacket: 1044 // TODO: check date is > last date seen 1045 // TODO: check ping version 1046 case pongPacket: 1047 if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) { 1048 // fmt.Println("pong reply token mismatch") 1049 return fmt.Errorf("pong reply token mismatch") 1050 } 1051 n.pingEcho = nil 1052 } 1053 // Address validation. 1054 // TODO: Ideally we would do the following: 1055 // - reject all packets with wrong address except ping. 1056 // - for ping with new address, transition to verifywait but keep the 1057 // previous node (with old address) around. if the new one reaches known, 1058 // swap it out. 1059 return nil 1060 } 1061 1062 func (net *Network) transition(n *Node, next *nodeState) { 1063 if n.state != next { 1064 n.state = next 1065 if next.enter != nil { 1066 next.enter(net, n) 1067 } 1068 } 1069 1070 // TODO: persist/unpersist node 1071 } 1072 1073 func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) { 1074 timeout := timeoutEvent{ev, n} 1075 net.timeoutTimers[timeout] = time.AfterFunc(d, func() { 1076 select { 1077 case net.timeout <- timeout: 1078 case <-net.closed: 1079 } 1080 }) 1081 } 1082 1083 func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) { 1084 timer := net.timeoutTimers[timeoutEvent{ev, n}] 1085 if timer != nil { 1086 timer.Stop() 1087 delete(net.timeoutTimers, timeoutEvent{ev, n}) 1088 } 1089 } 1090 1091 func (net *Network) ping(n *Node, addr *net.UDPAddr) { 1092 //fmt.Println("ping", n.addr().String(), n.ID.String(), n.sha.Hex()) 1093 if n.pingEcho != nil || n.ID == net.tab.self.ID { 1094 //fmt.Println(" not sent") 1095 return 1096 } 1097 log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Pinging remote node") 1098 n.pingTopics = net.ticketStore.regTopicSet() 1099 n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics) 1100 net.timedEvent(respTimeout, n, pongTimeout) 1101 } 1102 1103 func (net *Network) handlePing(n *Node, pkt *ingressPacket) { 1104 log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling remote ping") 1105 ping := pkt.data.(*ping) 1106 n.TCP = ping.From.TCP 1107 t := net.topictab.getTicket(n, ping.Topics) 1108 1109 pong := &pong{ 1110 To: makeEndpoint(n.addr(), n.TCP), // TODO: maybe use known TCP port from DB 1111 ReplyTok: pkt.hash, 1112 Expiration: uint64(time.Now().Add(expiration).Unix()), 1113 } 1114 ticketToPong(t, pong) 1115 net.conn.send(n, pongPacket, pong) 1116 } 1117 1118 func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { 1119 log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling known pong") 1120 net.abortTimedEvent(n, pongTimeout) 1121 now := Now() 1122 ticket, err := pongToTicket(now, n.pingTopics, n, pkt) 1123 if err == nil { 1124 // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data) 1125 net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket) 1126 } else { 1127 log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Failed to convert pong to ticket") 1128 } 1129 n.pingEcho = nil 1130 n.pingTopics = nil 1131 net.db.updateLastPong(n.ID, time.Now()) 1132 return err 1133 } 1134 1135 func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 1136 switch ev { 1137 case findnodePacket: 1138 target := common.BytesToHash(pkt.data.(*findnode).Target[:]) 1139 results := net.tab.closest(target, bucketSize).entries 1140 net.conn.sendNeighbours(n, results) 1141 return n.state, nil 1142 case neighborsPacket: 1143 err := net.handleNeighboursPacket(n, pkt) 1144 return n.state, err 1145 case neighboursTimeout: 1146 if n.pendingNeighbours != nil { 1147 n.pendingNeighbours.reply <- nil 1148 n.pendingNeighbours = nil 1149 } 1150 n.queryTimeouts++ 1151 if n.queryTimeouts > maxFindnodeFailures && n.state == known { 1152 return contested, errors.New("too many timeouts") 1153 } 1154 return n.state, nil 1155 1156 // v5 1157 1158 case findnodeHashPacket: 1159 results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries 1160 net.conn.sendNeighbours(n, results) 1161 return n.state, nil 1162 case topicRegisterPacket: 1163 //fmt.Println("got topicRegisterPacket") 1164 regdata := pkt.data.(*topicRegister) 1165 pong, err := net.checkTopicRegister(regdata) 1166 if err != nil { 1167 //fmt.Println(err) 1168 return n.state, fmt.Errorf("bad waiting ticket: %v", err) 1169 } 1170 net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods) 1171 return n.state, nil 1172 case topicQueryPacket: 1173 // TODO: handle expiration 1174 topic := pkt.data.(*topicQuery).Topic 1175 results := net.topictab.getEntries(topic) 1176 if _, ok := net.ticketStore.tickets[topic]; ok { 1177 results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too 1178 } 1179 if len(results) > 10 { 1180 results = results[:10] 1181 } 1182 var hash common.Hash 1183 copy(hash[:], pkt.hash) 1184 net.conn.sendTopicNodes(n, hash, results) 1185 return n.state, nil 1186 case topicNodesPacket: 1187 p := pkt.data.(*topicNodes) 1188 if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) { 1189 n.queryTimeouts++ 1190 if n.queryTimeouts > maxFindnodeFailures && n.state == known { 1191 return contested, errors.New("too many timeouts") 1192 } 1193 } 1194 return n.state, nil 1195 1196 default: 1197 return n.state, errInvalidEvent 1198 } 1199 } 1200 1201 func (net *Network) checkTopicRegister(data *topicRegister) (*pong, error) { 1202 var pongpkt ingressPacket 1203 if err := decodePacket(data.Pong, &pongpkt); err != nil { 1204 return nil, err 1205 } 1206 if pongpkt.ev != pongPacket { 1207 return nil, errors.New("is not pong packet") 1208 } 1209 if pongpkt.remoteID != net.tab.self.ID { 1210 return nil, errors.New("not signed by us") 1211 } 1212 // check that we previously authorised all topics 1213 // that the other side is trying to register. 1214 hash, _, _ := wireHash(data.Topics) 1215 if hash != pongpkt.data.(*pong).TopicHash { 1216 return nil, errors.New("topic hash mismatch") 1217 } 1218 if int(data.Idx) < 0 || int(data.Idx) >= len(data.Topics) { 1219 return nil, errors.New("topic index out of range") 1220 } 1221 return pongpkt.data.(*pong), nil 1222 } 1223 1224 func wireHash(x interface{}) (h common.Hash, n int, err error) { 1225 hw := sha3.New256() 1226 wire.WriteBinary(x, hw, &n, &err) 1227 hw.Sum(h[:0]) 1228 return h, n, err 1229 } 1230 1231 func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { 1232 if n.pendingNeighbours == nil { 1233 return errNoQuery 1234 } 1235 net.abortTimedEvent(n, neighboursTimeout) 1236 1237 req := pkt.data.(*neighbors) 1238 nodes := make([]*Node, len(req.Nodes)) 1239 for i, rn := range req.Nodes { 1240 nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn) 1241 if err != nil { 1242 log.WithFields(log.Fields{"module": logModule, "ip": rn.IP, "id:": n.ID[:8], "addr:": pkt.remoteAddr, "error": err}).Debug("invalid neighbour") 1243 continue 1244 } 1245 nodes[i] = nn 1246 // Start validation of query results immediately. 1247 // This fills the table quickly. 1248 // TODO: generates way too many packets, maybe do it via queue. 1249 if nn.state == unknown { 1250 net.transition(nn, verifyinit) 1251 } 1252 } 1253 // TODO: don't ignore second packet 1254 n.pendingNeighbours.reply <- nodes 1255 n.pendingNeighbours = nil 1256 // Now that this query is done, start the next one. 1257 n.startNextQuery(net) 1258 return nil 1259 }