github.com/aquanetwork/aquachain@v1.7.8/p2p/discv5/net.go (about) 1 // Copyright 2016 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package discv5 18 19 import ( 20 "bytes" 21 "crypto/ecdsa" 22 "errors" 23 "fmt" 24 "net" 25 "time" 26 27 "gitlab.com/aquachain/aquachain/common" 28 "gitlab.com/aquachain/aquachain/common/log" 29 "gitlab.com/aquachain/aquachain/common/mclock" 30 "gitlab.com/aquachain/aquachain/crypto" 31 "gitlab.com/aquachain/aquachain/crypto/sha3" 32 "gitlab.com/aquachain/aquachain/p2p/netutil" 33 "gitlab.com/aquachain/aquachain/rlp" 34 ) 35 36 var ( 37 errInvalidEvent = errors.New("invalid in current state") 38 errNoQuery = errors.New("no pending query") 39 ) 40 41 const ( 42 autoRefreshInterval = 1 * time.Hour 43 bucketRefreshInterval = 1 * time.Minute 44 seedCount = 30 45 seedMaxAge = 5 * 24 * time.Hour 46 lowPort = 1024 47 ) 48 49 const testTopic = "foo" 50 51 const ( 52 printTestImgLogs = false 53 ) 54 55 // Network manages the table and all protocol interaction. 56 type Network struct { 57 db *nodeDB // database of known nodes 58 conn transport 59 netrestrict *netutil.Netlist 60 61 closed chan struct{} // closed when loop is done 62 closeReq chan struct{} // 'request to close' 63 refreshReq chan []*Node // lookups ask for refresh on this channel 64 refreshResp chan (<-chan struct{}) // ...and get the channel to block on from this one 65 read chan ingressPacket // ingress packets arrive here 66 timeout chan timeoutEvent 67 queryReq chan *findnodeQuery // lookups submit findnode queries on this channel 68 tableOpReq chan func() 69 tableOpResp chan struct{} 70 topicRegisterReq chan topicRegisterReq 71 topicSearchReq chan topicSearchReq 72 73 // State of the main loop. 74 tab *Table 75 topictab *topicTable 76 ticketStore *ticketStore 77 nursery []*Node 78 nodes map[NodeID]*Node // tracks active nodes with state != known 79 timeoutTimers map[timeoutEvent]*time.Timer 80 } 81 82 // transport is implemented by the UDP transport. 83 // it is an interface so we can test without opening lots of UDP 84 // sockets and without generating a private key. 85 type transport interface { 86 sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte) 87 sendNeighbours(remote *Node, nodes []*Node) 88 sendFindnodeHash(remote *Node, target common.Hash) 89 sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte) 90 sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) 91 92 send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte) 93 94 localAddr() *net.UDPAddr 95 Close() 96 } 97 98 type findnodeQuery struct { 99 remote *Node 100 target common.Hash 101 reply chan<- []*Node 102 } 103 104 type topicRegisterReq struct { 105 add bool 106 topic Topic 107 } 108 109 type topicSearchReq struct { 110 topic Topic 111 found chan<- *Node 112 lookup chan<- bool 113 delay time.Duration 114 } 115 116 type topicSearchResult struct { 117 target lookupInfo 118 nodes []*Node 119 } 120 121 type timeoutEvent struct { 122 ev nodeEvent 123 node *Node 124 } 125 126 func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { 127 ourID := PubkeyID(&ourPubkey) 128 129 var db *nodeDB 130 if dbPath != "<no database>" { 131 var err error 132 if db, err = newNodeDB(dbPath, Version, ourID); err != nil { 133 return nil, err 134 } 135 } 136 137 tab := newTable(ourID, conn.localAddr()) 138 net := &Network{ 139 db: db, 140 conn: conn, 141 netrestrict: netrestrict, 142 tab: tab, 143 topictab: newTopicTable(db, tab.self), 144 ticketStore: newTicketStore(), 145 refreshReq: make(chan []*Node), 146 refreshResp: make(chan (<-chan struct{})), 147 closed: make(chan struct{}), 148 closeReq: make(chan struct{}), 149 read: make(chan ingressPacket, 100), 150 timeout: make(chan timeoutEvent), 151 timeoutTimers: make(map[timeoutEvent]*time.Timer), 152 tableOpReq: make(chan func()), 153 tableOpResp: make(chan struct{}), 154 queryReq: make(chan *findnodeQuery), 155 topicRegisterReq: make(chan topicRegisterReq), 156 topicSearchReq: make(chan topicSearchReq), 157 nodes: make(map[NodeID]*Node), 158 } 159 go net.loop() 160 return net, nil 161 } 162 163 // Close terminates the network listener and flushes the node database. 164 func (net *Network) Close() { 165 net.conn.Close() 166 select { 167 case <-net.closed: 168 case net.closeReq <- struct{}{}: 169 <-net.closed 170 } 171 } 172 173 // Self returns the local node. 174 // The returned node should not be modified by the caller. 175 func (net *Network) Self() *Node { 176 return net.tab.self 177 } 178 179 // ReadRandomNodes fills the given slice with random nodes from the 180 // table. It will not write the same node more than once. The nodes in 181 // the slice are copies and can be modified by the caller. 182 func (net *Network) ReadRandomNodes(buf []*Node) (n int) { 183 net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) }) 184 return n 185 } 186 187 // SetFallbackNodes sets the initial points of contact. These nodes 188 // are used to connect to the network if the table is empty and there 189 // are no known nodes in the database. 190 func (net *Network) SetFallbackNodes(nodes []*Node) error { 191 nursery := make([]*Node, 0, len(nodes)) 192 for _, n := range nodes { 193 if err := n.validateComplete(); err != nil { 194 return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err) 195 } 196 // Recompute cpy.sha because the node might not have been 197 // created by NewNode or ParseNode. 198 cpy := *n 199 cpy.sha = crypto.Keccak256Hash(n.ID[:]) 200 nursery = append(nursery, &cpy) 201 } 202 net.reqRefresh(nursery) 203 return nil 204 } 205 206 // Resolve searches for a specific node with the given ID. 207 // It returns nil if the node could not be found. 208 func (net *Network) Resolve(targetID NodeID) *Node { 209 result := net.lookup(crypto.Keccak256Hash(targetID[:]), true) 210 for _, n := range result { 211 if n.ID == targetID { 212 return n 213 } 214 } 215 return nil 216 } 217 218 // Lookup performs a network search for nodes close 219 // to the given target. It approaches the target by querying 220 // nodes that are closer to it on each iteration. 221 // The given target does not need to be an actual node 222 // identifier. 223 // 224 // The local node may be included in the result. 225 func (net *Network) Lookup(targetID NodeID) []*Node { 226 return net.lookup(crypto.Keccak256Hash(targetID[:]), false) 227 } 228 229 func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node { 230 var ( 231 asked = make(map[NodeID]bool) 232 seen = make(map[NodeID]bool) 233 reply = make(chan []*Node, alpha) 234 result = nodesByDistance{target: target} 235 pendingQueries = 0 236 ) 237 // Get initial answers from the local node. 238 result.push(net.tab.self, bucketSize) 239 for { 240 // Ask the α closest nodes that we haven't asked yet. 241 for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ { 242 n := result.entries[i] 243 if !asked[n.ID] { 244 asked[n.ID] = true 245 pendingQueries++ 246 net.reqQueryFindnode(n, target, reply) 247 } 248 } 249 if pendingQueries == 0 { 250 // We have asked all closest nodes, stop the search. 251 break 252 } 253 // Wait for the next reply. 254 select { 255 case nodes := <-reply: 256 for _, n := range nodes { 257 if n != nil && !seen[n.ID] { 258 seen[n.ID] = true 259 result.push(n, bucketSize) 260 if stopOnMatch && n.sha == target { 261 return result.entries 262 } 263 } 264 } 265 pendingQueries-- 266 case <-time.After(respTimeout): 267 // forget all pending requests, start new ones 268 pendingQueries = 0 269 reply = make(chan []*Node, alpha) 270 } 271 } 272 return result.entries 273 } 274 275 func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) { 276 select { 277 case net.topicRegisterReq <- topicRegisterReq{true, topic}: 278 case <-net.closed: 279 return 280 } 281 select { 282 case <-net.closed: 283 case <-stop: 284 select { 285 case net.topicRegisterReq <- topicRegisterReq{false, topic}: 286 case <-net.closed: 287 } 288 } 289 } 290 291 func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) { 292 for { 293 select { 294 case <-net.closed: 295 return 296 case delay, ok := <-setPeriod: 297 select { 298 case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}: 299 case <-net.closed: 300 return 301 } 302 if !ok { 303 return 304 } 305 } 306 } 307 } 308 309 func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} { 310 select { 311 case net.refreshReq <- nursery: 312 return <-net.refreshResp 313 case <-net.closed: 314 return net.closed 315 } 316 } 317 318 func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool { 319 q := &findnodeQuery{remote: n, target: target, reply: reply} 320 select { 321 case net.queryReq <- q: 322 return true 323 case <-net.closed: 324 return false 325 } 326 } 327 328 func (net *Network) reqReadPacket(pkt ingressPacket) { 329 select { 330 case net.read <- pkt: 331 case <-net.closed: 332 } 333 } 334 335 func (net *Network) reqTableOp(f func()) (called bool) { 336 select { 337 case net.tableOpReq <- f: 338 <-net.tableOpResp 339 return true 340 case <-net.closed: 341 return false 342 } 343 } 344 345 // TODO: external address handling. 346 347 type topicSearchInfo struct { 348 lookupChn chan<- bool 349 period time.Duration 350 } 351 352 const maxSearchCount = 5 353 354 func (net *Network) loop() { 355 var ( 356 refreshTimer = time.NewTicker(autoRefreshInterval) 357 bucketRefreshTimer = time.NewTimer(bucketRefreshInterval) 358 refreshDone chan struct{} // closed when the 'refresh' lookup has ended 359 ) 360 361 // Tracking the next ticket to register. 362 var ( 363 nextTicket *ticketRef 364 nextRegisterTimer *time.Timer 365 nextRegisterTime <-chan time.Time 366 ) 367 defer func() { 368 if nextRegisterTimer != nil { 369 nextRegisterTimer.Stop() 370 } 371 }() 372 resetNextTicket := func() { 373 ticket, timeout := net.ticketStore.nextFilteredTicket() 374 if nextTicket != ticket { 375 nextTicket = ticket 376 if nextRegisterTimer != nil { 377 nextRegisterTimer.Stop() 378 nextRegisterTime = nil 379 } 380 if ticket != nil { 381 nextRegisterTimer = time.NewTimer(timeout) 382 nextRegisterTime = nextRegisterTimer.C 383 } 384 } 385 } 386 387 // Tracking registration and search lookups. 388 var ( 389 topicRegisterLookupTarget lookupInfo 390 topicRegisterLookupDone chan []*Node 391 topicRegisterLookupTick = time.NewTimer(0) 392 searchReqWhenRefreshDone []topicSearchReq 393 searchInfo = make(map[Topic]topicSearchInfo) 394 activeSearchCount int 395 ) 396 topicSearchLookupDone := make(chan topicSearchResult, 100) 397 topicSearch := make(chan Topic, 100) 398 <-topicRegisterLookupTick.C 399 400 statsDump := time.NewTicker(10 * time.Second) 401 402 loop: 403 for { 404 resetNextTicket() 405 406 select { 407 case <-net.closeReq: 408 log.Trace("<-net.closeReq") 409 break loop 410 411 // Ingress packet handling. 412 case pkt := <-net.read: 413 //fmt.Println("read", pkt.ev) 414 log.Trace("<-net.read") 415 n := net.internNode(&pkt) 416 prestate := n.state 417 status := "ok" 418 if err := net.handle(n, pkt.ev, &pkt); err != nil { 419 status = err.Error() 420 } 421 log.Trace("", "msg", log.Lazy{Fn: func() string { 422 return fmt.Sprintf("<<< (%d) %v from %x@%v: %v -> %v (%v)", 423 net.tab.count, pkt.ev, pkt.remoteID[:8], pkt.remoteAddr, prestate, n.state, status) 424 }}) 425 // TODO: persist state if n.state goes >= known, delete if it goes <= known 426 427 // State transition timeouts. 428 case timeout := <-net.timeout: 429 log.Trace("<-net.timeout") 430 if net.timeoutTimers[timeout] == nil { 431 // Stale timer (was aborted). 432 continue 433 } 434 delete(net.timeoutTimers, timeout) 435 prestate := timeout.node.state 436 status := "ok" 437 if err := net.handle(timeout.node, timeout.ev, nil); err != nil { 438 status = err.Error() 439 } 440 log.Trace("", "msg", log.Lazy{Fn: func() string { 441 return fmt.Sprintf("--- (%d) %v for %x@%v: %v -> %v (%v)", 442 net.tab.count, timeout.ev, timeout.node.ID[:8], timeout.node.addr(), prestate, timeout.node.state, status) 443 }}) 444 445 // Querying. 446 case q := <-net.queryReq: 447 log.Trace("<-net.queryReq") 448 if !q.start(net) { 449 q.remote.deferQuery(q) 450 } 451 452 // Interacting with the table. 453 case f := <-net.tableOpReq: 454 log.Trace("<-net.tableOpReq") 455 f() 456 net.tableOpResp <- struct{}{} 457 458 // Topic registration stuff. 459 case req := <-net.topicRegisterReq: 460 log.Trace("<-net.topicRegisterReq") 461 if !req.add { 462 net.ticketStore.removeRegisterTopic(req.topic) 463 continue 464 } 465 net.ticketStore.addTopic(req.topic, true) 466 // If we're currently waiting idle (nothing to look up), give the ticket store a 467 // chance to start it sooner. This should speed up convergence of the radius 468 // determination for new topics. 469 // if topicRegisterLookupDone == nil { 470 if topicRegisterLookupTarget.target == (common.Hash{}) { 471 log.Trace("topicRegisterLookupTarget == null") 472 if topicRegisterLookupTick.Stop() { 473 <-topicRegisterLookupTick.C 474 } 475 target, delay := net.ticketStore.nextRegisterLookup() 476 topicRegisterLookupTarget = target 477 topicRegisterLookupTick.Reset(delay) 478 } 479 480 case nodes := <-topicRegisterLookupDone: 481 log.Trace("<-topicRegisterLookupDone") 482 net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte { 483 net.ping(n, n.addr()) 484 return n.pingEcho 485 }) 486 target, delay := net.ticketStore.nextRegisterLookup() 487 topicRegisterLookupTarget = target 488 topicRegisterLookupTick.Reset(delay) 489 topicRegisterLookupDone = nil 490 491 case <-topicRegisterLookupTick.C: 492 log.Trace("<-topicRegisterLookupTick") 493 if (topicRegisterLookupTarget.target == common.Hash{}) { 494 target, delay := net.ticketStore.nextRegisterLookup() 495 topicRegisterLookupTarget = target 496 topicRegisterLookupTick.Reset(delay) 497 topicRegisterLookupDone = nil 498 } else { 499 topicRegisterLookupDone = make(chan []*Node) 500 target := topicRegisterLookupTarget.target 501 go func() { topicRegisterLookupDone <- net.lookup(target, false) }() 502 } 503 504 case <-nextRegisterTime: 505 log.Trace("<-nextRegisterTime") 506 net.ticketStore.ticketRegistered(*nextTicket) 507 //fmt.Println("sendTopicRegister", nextTicket.t.node.addr().String(), nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) 508 net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong) 509 510 case req := <-net.topicSearchReq: 511 if refreshDone == nil { 512 log.Trace("<-net.topicSearchReq") 513 info, ok := searchInfo[req.topic] 514 if ok { 515 if req.delay == time.Duration(0) { 516 delete(searchInfo, req.topic) 517 net.ticketStore.removeSearchTopic(req.topic) 518 } else { 519 info.period = req.delay 520 searchInfo[req.topic] = info 521 } 522 continue 523 } 524 if req.delay != time.Duration(0) { 525 var info topicSearchInfo 526 info.period = req.delay 527 info.lookupChn = req.lookup 528 searchInfo[req.topic] = info 529 net.ticketStore.addSearchTopic(req.topic, req.found) 530 topicSearch <- req.topic 531 } 532 } else { 533 searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req) 534 } 535 536 case topic := <-topicSearch: 537 if activeSearchCount < maxSearchCount { 538 activeSearchCount++ 539 target := net.ticketStore.nextSearchLookup(topic) 540 go func() { 541 nodes := net.lookup(target.target, false) 542 topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes} 543 }() 544 } 545 period := searchInfo[topic].period 546 if period != time.Duration(0) { 547 go func() { 548 time.Sleep(period) 549 topicSearch <- topic 550 }() 551 } 552 553 case res := <-topicSearchLookupDone: 554 activeSearchCount-- 555 if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil { 556 lookupChn <- net.ticketStore.radius[res.target.topic].converged 557 } 558 net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte { 559 if n.state != nil && n.state.canQuery { 560 return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration 561 } else { 562 if n.state == unknown { 563 net.ping(n, n.addr()) 564 } 565 return nil 566 } 567 }) 568 569 case <-statsDump.C: 570 log.Trace("<-statsDump.C") 571 /*r, ok := net.ticketStore.radius[testTopic] 572 if !ok { 573 fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now()) 574 } else { 575 topics := len(net.ticketStore.tickets) 576 tickets := len(net.ticketStore.nodes) 577 rad := r.radius / (maxRadius/10000+1) 578 fmt.Printf("(%x) topics:%d radius:%d tickets:%d @ %v\n", net.tab.self.ID[:8], topics, rad, tickets, time.Now()) 579 }*/ 580 581 tm := mclock.Now() 582 for topic, r := range net.ticketStore.radius { 583 if printTestImgLogs { 584 rad := r.radius / (maxRadius/1000000 + 1) 585 minrad := r.minRadius / (maxRadius/1000000 + 1) 586 fmt.Printf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad) 587 fmt.Printf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad) 588 } 589 } 590 for topic, t := range net.topictab.topics { 591 wp := t.wcl.nextWaitPeriod(tm) 592 if printTestImgLogs { 593 fmt.Printf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000) 594 } 595 } 596 597 // Periodic / lookup-initiated bucket refresh. 598 case <-refreshTimer.C: 599 log.Trace("<-refreshTimer.C") 600 // TODO: ideally we would start the refresh timer after 601 // fallback nodes have been set for the first time. 602 if refreshDone == nil { 603 refreshDone = make(chan struct{}) 604 net.refresh(refreshDone) 605 } 606 case <-bucketRefreshTimer.C: 607 target := net.tab.chooseBucketRefreshTarget() 608 go func() { 609 net.lookup(target, false) 610 bucketRefreshTimer.Reset(bucketRefreshInterval) 611 }() 612 case newNursery := <-net.refreshReq: 613 log.Trace("<-net.refreshReq") 614 if newNursery != nil { 615 net.nursery = newNursery 616 } 617 if refreshDone == nil { 618 refreshDone = make(chan struct{}) 619 net.refresh(refreshDone) 620 } 621 net.refreshResp <- refreshDone 622 case <-refreshDone: 623 log.Trace("<-net.refreshDone", "table size", net.tab.count) 624 if net.tab.count != 0 { 625 refreshDone = nil 626 list := searchReqWhenRefreshDone 627 searchReqWhenRefreshDone = nil 628 go func() { 629 for _, req := range list { 630 net.topicSearchReq <- req 631 } 632 }() 633 } else { 634 refreshDone = make(chan struct{}) 635 net.refresh(refreshDone) 636 } 637 } 638 } 639 log.Trace("loop stopped") 640 641 log.Debug(fmt.Sprintf("shutting down")) 642 if net.conn != nil { 643 net.conn.Close() 644 } 645 if refreshDone != nil { 646 // TODO: wait for pending refresh. 647 //<-refreshResults 648 } 649 // Cancel all pending timeouts. 650 for _, timer := range net.timeoutTimers { 651 timer.Stop() 652 } 653 if net.db != nil { 654 net.db.close() 655 } 656 close(net.closed) 657 } 658 659 // Everything below runs on the Network.loop goroutine 660 // and can modify Node, Table and Network at any time without locking. 661 662 func (net *Network) refresh(done chan<- struct{}) { 663 var seeds []*Node 664 if net.db != nil { 665 seeds = net.db.querySeeds(seedCount, seedMaxAge) 666 } 667 if len(seeds) == 0 { 668 seeds = net.nursery 669 } 670 if len(seeds) == 0 { 671 log.Trace("no seed nodes found") 672 close(done) 673 return 674 } 675 for _, n := range seeds { 676 log.Debug("", "msg", log.Lazy{Fn: func() string { 677 var age string 678 if net.db != nil { 679 age = time.Since(net.db.lastPong(n.ID)).String() 680 } else { 681 age = "unknown" 682 } 683 return fmt.Sprintf("seed node (age %s): %v", age, n) 684 }}) 685 n = net.internNodeFromDB(n) 686 if n.state == unknown { 687 net.transition(n, verifyinit) 688 } 689 // Force-add the seed node so Lookup does something. 690 // It will be deleted again if verification fails. 691 net.tab.add(n) 692 } 693 // Start self lookup to fill up the buckets. 694 go func() { 695 net.Lookup(net.tab.self.ID) 696 close(done) 697 }() 698 } 699 700 // Node Interning. 701 702 func (net *Network) internNode(pkt *ingressPacket) *Node { 703 if n := net.nodes[pkt.remoteID]; n != nil { 704 n.IP = pkt.remoteAddr.IP 705 n.UDP = uint16(pkt.remoteAddr.Port) 706 n.TCP = uint16(pkt.remoteAddr.Port) 707 return n 708 } 709 n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port)) 710 n.state = unknown 711 net.nodes[pkt.remoteID] = n 712 return n 713 } 714 715 func (net *Network) internNodeFromDB(dbn *Node) *Node { 716 if n := net.nodes[dbn.ID]; n != nil { 717 return n 718 } 719 n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP) 720 n.state = unknown 721 net.nodes[n.ID] = n 722 return n 723 } 724 725 func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) { 726 if rn.ID == net.tab.self.ID { 727 return nil, errors.New("is self") 728 } 729 if rn.UDP <= lowPort { 730 return nil, errors.New("low port") 731 } 732 n = net.nodes[rn.ID] 733 if n == nil { 734 // We haven't seen this node before. 735 n, err = nodeFromRPC(sender, rn) 736 if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { 737 return n, errors.New("not contained in netrestrict whitelist") 738 } 739 if err == nil { 740 n.state = unknown 741 net.nodes[n.ID] = n 742 } 743 return n, err 744 } 745 if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP { 746 if n.state == known { 747 // reject address change if node is known by us 748 err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n) 749 } else { 750 // accept otherwise; this will be handled nicer with signed ENRs 751 n.IP = rn.IP 752 n.UDP = rn.UDP 753 n.TCP = rn.TCP 754 } 755 } 756 return n, err 757 } 758 759 // nodeNetGuts is embedded in Node and contains fields. 760 type nodeNetGuts struct { 761 // This is a cached copy of sha3(ID) which is used for node 762 // distance calculations. This is part of Node in order to make it 763 // possible to write tests that need a node at a certain distance. 764 // In those tests, the content of sha will not actually correspond 765 // with ID. 766 sha common.Hash 767 768 // State machine fields. Access to these fields 769 // is restricted to the Network.loop goroutine. 770 state *nodeState 771 pingEcho []byte // hash of last ping sent by us 772 pingTopics []Topic // topic set sent by us in last ping 773 deferredQueries []*findnodeQuery // queries that can't be sent yet 774 pendingNeighbours *findnodeQuery // current query, waiting for reply 775 queryTimeouts int 776 } 777 778 func (n *nodeNetGuts) deferQuery(q *findnodeQuery) { 779 n.deferredQueries = append(n.deferredQueries, q) 780 } 781 782 func (n *nodeNetGuts) startNextQuery(net *Network) { 783 if len(n.deferredQueries) == 0 { 784 return 785 } 786 nextq := n.deferredQueries[0] 787 if nextq.start(net) { 788 n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...) 789 } 790 } 791 792 func (q *findnodeQuery) start(net *Network) bool { 793 // Satisfy queries against the local node directly. 794 if q.remote == net.tab.self { 795 closest := net.tab.closest(crypto.Keccak256Hash(q.target[:]), bucketSize) 796 q.reply <- closest.entries 797 return true 798 } 799 if q.remote.state.canQuery && q.remote.pendingNeighbours == nil { 800 net.conn.sendFindnodeHash(q.remote, q.target) 801 net.timedEvent(respTimeout, q.remote, neighboursTimeout) 802 q.remote.pendingNeighbours = q 803 return true 804 } 805 // If the node is not known yet, it won't accept queries. 806 // Initiate the transition to known. 807 // The request will be sent later when the node reaches known state. 808 if q.remote.state == unknown { 809 net.transition(q.remote, verifyinit) 810 } 811 return false 812 } 813 814 // Node Events (the input to the state machine). 815 816 type nodeEvent uint 817 818 //go:generate stringer -type=nodeEvent 819 820 const ( 821 _ nodeEvent = iota 822 // Packet type events. 823 // These correspond to packet types in the UDP protocol. 824 pingPacket 825 pongPacket 826 findnodePacket 827 neighborsPacket 828 findnodeHashPacket 829 topicRegisterPacket 830 topicQueryPacket 831 topicNodesPacket 832 833 // Non-packet events. 834 // Event values in this category are allocated outside 835 // the packet type range (packet types are encoded as a single byte). 836 pongTimeout nodeEvent = iota + 256 837 pingTimeout 838 neighboursTimeout 839 ) 840 841 // Node State Machine. 842 843 type nodeState struct { 844 name string 845 handle func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error) 846 enter func(*Network, *Node) 847 canQuery bool 848 } 849 850 func (s *nodeState) String() string { 851 return s.name 852 } 853 854 var ( 855 unknown *nodeState 856 verifyinit *nodeState 857 verifywait *nodeState 858 remoteverifywait *nodeState 859 known *nodeState 860 contested *nodeState 861 unresponsive *nodeState 862 ) 863 864 func init() { 865 unknown = &nodeState{ 866 name: "unknown", 867 enter: func(net *Network, n *Node) { 868 net.tab.delete(n) 869 n.pingEcho = nil 870 // Abort active queries. 871 for _, q := range n.deferredQueries { 872 q.reply <- nil 873 } 874 n.deferredQueries = nil 875 if n.pendingNeighbours != nil { 876 n.pendingNeighbours.reply <- nil 877 n.pendingNeighbours = nil 878 } 879 n.queryTimeouts = 0 880 }, 881 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 882 switch ev { 883 case pingPacket: 884 net.handlePing(n, pkt) 885 net.ping(n, pkt.remoteAddr) 886 return verifywait, nil 887 default: 888 return unknown, errInvalidEvent 889 } 890 }, 891 } 892 893 verifyinit = &nodeState{ 894 name: "verifyinit", 895 enter: func(net *Network, n *Node) { 896 net.ping(n, n.addr()) 897 }, 898 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 899 switch ev { 900 case pingPacket: 901 net.handlePing(n, pkt) 902 return verifywait, nil 903 case pongPacket: 904 err := net.handleKnownPong(n, pkt) 905 return remoteverifywait, err 906 case pongTimeout: 907 return unknown, nil 908 default: 909 return verifyinit, errInvalidEvent 910 } 911 }, 912 } 913 914 verifywait = &nodeState{ 915 name: "verifywait", 916 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 917 switch ev { 918 case pingPacket: 919 net.handlePing(n, pkt) 920 return verifywait, nil 921 case pongPacket: 922 err := net.handleKnownPong(n, pkt) 923 return known, err 924 case pongTimeout: 925 return unknown, nil 926 default: 927 return verifywait, errInvalidEvent 928 } 929 }, 930 } 931 932 remoteverifywait = &nodeState{ 933 name: "remoteverifywait", 934 enter: func(net *Network, n *Node) { 935 net.timedEvent(respTimeout, n, pingTimeout) 936 }, 937 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 938 switch ev { 939 case pingPacket: 940 net.handlePing(n, pkt) 941 return remoteverifywait, nil 942 case pingTimeout: 943 return known, nil 944 default: 945 return remoteverifywait, errInvalidEvent 946 } 947 }, 948 } 949 950 known = &nodeState{ 951 name: "known", 952 canQuery: true, 953 enter: func(net *Network, n *Node) { 954 n.queryTimeouts = 0 955 n.startNextQuery(net) 956 // Insert into the table and start revalidation of the last node 957 // in the bucket if it is full. 958 last := net.tab.add(n) 959 if last != nil && last.state == known { 960 // TODO: do this asynchronously 961 net.transition(last, contested) 962 } 963 }, 964 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 965 switch ev { 966 case pingPacket: 967 net.handlePing(n, pkt) 968 return known, nil 969 case pongPacket: 970 err := net.handleKnownPong(n, pkt) 971 return known, err 972 default: 973 return net.handleQueryEvent(n, ev, pkt) 974 } 975 }, 976 } 977 978 contested = &nodeState{ 979 name: "contested", 980 canQuery: true, 981 enter: func(net *Network, n *Node) { 982 net.ping(n, n.addr()) 983 }, 984 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 985 switch ev { 986 case pongPacket: 987 // Node is still alive. 988 err := net.handleKnownPong(n, pkt) 989 return known, err 990 case pongTimeout: 991 net.tab.deleteReplace(n) 992 return unresponsive, nil 993 case pingPacket: 994 net.handlePing(n, pkt) 995 return contested, nil 996 default: 997 return net.handleQueryEvent(n, ev, pkt) 998 } 999 }, 1000 } 1001 1002 unresponsive = &nodeState{ 1003 name: "unresponsive", 1004 canQuery: true, 1005 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 1006 switch ev { 1007 case pingPacket: 1008 net.handlePing(n, pkt) 1009 return known, nil 1010 case pongPacket: 1011 err := net.handleKnownPong(n, pkt) 1012 return known, err 1013 default: 1014 return net.handleQueryEvent(n, ev, pkt) 1015 } 1016 }, 1017 } 1018 } 1019 1020 // handle processes packets sent by n and events related to n. 1021 func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error { 1022 //fmt.Println("handle", n.addr().String(), n.state, ev) 1023 if pkt != nil { 1024 if err := net.checkPacket(n, ev, pkt); err != nil { 1025 //fmt.Println("check err:", err) 1026 return err 1027 } 1028 // Start the background expiration goroutine after the first 1029 // successful communication. Subsequent calls have no effect if it 1030 // is already running. We do this here instead of somewhere else 1031 // so that the search for seed nodes also considers older nodes 1032 // that would otherwise be removed by the expirer. 1033 if net.db != nil { 1034 net.db.ensureExpirer() 1035 } 1036 } 1037 if n.state == nil { 1038 n.state = unknown //??? 1039 } 1040 next, err := n.state.handle(net, n, ev, pkt) 1041 net.transition(n, next) 1042 //fmt.Println("new state:", n.state) 1043 return err 1044 } 1045 1046 func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error { 1047 // Replay prevention checks. 1048 switch ev { 1049 case pingPacket, findnodeHashPacket, neighborsPacket: 1050 // TODO: check date is > last date seen 1051 // TODO: check ping version 1052 case pongPacket: 1053 if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) { 1054 // fmt.Println("pong reply token mismatch") 1055 return fmt.Errorf("pong reply token mismatch") 1056 } 1057 n.pingEcho = nil 1058 } 1059 // Address validation. 1060 // TODO: Ideally we would do the following: 1061 // - reject all packets with wrong address except ping. 1062 // - for ping with new address, transition to verifywait but keep the 1063 // previous node (with old address) around. if the new one reaches known, 1064 // swap it out. 1065 return nil 1066 } 1067 1068 func (net *Network) transition(n *Node, next *nodeState) { 1069 if n.state != next { 1070 n.state = next 1071 if next.enter != nil { 1072 next.enter(net, n) 1073 } 1074 } 1075 1076 // TODO: persist/unpersist node 1077 } 1078 1079 func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) { 1080 timeout := timeoutEvent{ev, n} 1081 net.timeoutTimers[timeout] = time.AfterFunc(d, func() { 1082 select { 1083 case net.timeout <- timeout: 1084 case <-net.closed: 1085 } 1086 }) 1087 } 1088 1089 func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) { 1090 timer := net.timeoutTimers[timeoutEvent{ev, n}] 1091 if timer != nil { 1092 timer.Stop() 1093 delete(net.timeoutTimers, timeoutEvent{ev, n}) 1094 } 1095 } 1096 1097 func (net *Network) ping(n *Node, addr *net.UDPAddr) { 1098 //fmt.Println("ping", n.addr().String(), n.ID.String(), n.sha.Hex()) 1099 if n.pingEcho != nil || n.ID == net.tab.self.ID { 1100 //fmt.Println(" not sent") 1101 return 1102 } 1103 log.Trace("Pinging remote node", "node", n.ID) 1104 n.pingTopics = net.ticketStore.regTopicSet() 1105 n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics) 1106 net.timedEvent(respTimeout, n, pongTimeout) 1107 } 1108 1109 func (net *Network) handlePing(n *Node, pkt *ingressPacket) { 1110 log.Trace("Handling remote ping", "node", n.ID) 1111 ping := pkt.data.(*ping) 1112 n.TCP = ping.From.TCP 1113 t := net.topictab.getTicket(n, ping.Topics) 1114 1115 pong := &pong{ 1116 To: makeEndpoint(n.addr(), n.TCP), // TODO: maybe use known TCP port from DB 1117 ReplyTok: pkt.hash, 1118 Expiration: uint64(time.Now().Add(expiration).Unix()), 1119 } 1120 ticketToPong(t, pong) 1121 net.conn.send(n, pongPacket, pong) 1122 } 1123 1124 func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error { 1125 log.Trace("Handling known pong", "node", n.ID) 1126 net.abortTimedEvent(n, pongTimeout) 1127 now := mclock.Now() 1128 ticket, err := pongToTicket(now, n.pingTopics, n, pkt) 1129 if err == nil { 1130 // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data) 1131 net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket) 1132 } else { 1133 log.Trace("Failed to convert pong to ticket", "err", err) 1134 } 1135 n.pingEcho = nil 1136 n.pingTopics = nil 1137 return err 1138 } 1139 1140 func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) { 1141 switch ev { 1142 case findnodePacket: 1143 target := crypto.Keccak256Hash(pkt.data.(*findnode).Target[:]) 1144 results := net.tab.closest(target, bucketSize).entries 1145 net.conn.sendNeighbours(n, results) 1146 return n.state, nil 1147 case neighborsPacket: 1148 err := net.handleNeighboursPacket(n, pkt) 1149 return n.state, err 1150 case neighboursTimeout: 1151 if n.pendingNeighbours != nil { 1152 n.pendingNeighbours.reply <- nil 1153 n.pendingNeighbours = nil 1154 } 1155 n.queryTimeouts++ 1156 if n.queryTimeouts > maxFindnodeFailures && n.state == known { 1157 return contested, errors.New("too many timeouts") 1158 } 1159 return n.state, nil 1160 1161 // v5 1162 1163 case findnodeHashPacket: 1164 results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries 1165 net.conn.sendNeighbours(n, results) 1166 return n.state, nil 1167 case topicRegisterPacket: 1168 //fmt.Println("got topicRegisterPacket") 1169 regdata := pkt.data.(*topicRegister) 1170 pong, err := net.checkTopicRegister(regdata) 1171 if err != nil { 1172 //fmt.Println(err) 1173 return n.state, fmt.Errorf("bad waiting ticket: %v", err) 1174 } 1175 net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods) 1176 return n.state, nil 1177 case topicQueryPacket: 1178 // TODO: handle expiration 1179 topic := pkt.data.(*topicQuery).Topic 1180 results := net.topictab.getEntries(topic) 1181 if _, ok := net.ticketStore.tickets[topic]; ok { 1182 results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too 1183 } 1184 if len(results) > 10 { 1185 results = results[:10] 1186 } 1187 var hash common.Hash 1188 copy(hash[:], pkt.hash) 1189 net.conn.sendTopicNodes(n, hash, results) 1190 return n.state, nil 1191 case topicNodesPacket: 1192 p := pkt.data.(*topicNodes) 1193 if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) { 1194 n.queryTimeouts++ 1195 if n.queryTimeouts > maxFindnodeFailures && n.state == known { 1196 return contested, errors.New("too many timeouts") 1197 } 1198 } 1199 return n.state, nil 1200 1201 default: 1202 return n.state, errInvalidEvent 1203 } 1204 } 1205 1206 func (net *Network) checkTopicRegister(data *topicRegister) (*pong, error) { 1207 var pongpkt ingressPacket 1208 if err := decodePacket(data.Pong, &pongpkt); err != nil { 1209 return nil, err 1210 } 1211 if pongpkt.ev != pongPacket { 1212 return nil, errors.New("is not pong packet") 1213 } 1214 if pongpkt.remoteID != net.tab.self.ID { 1215 return nil, errors.New("not signed by us") 1216 } 1217 // check that we previously authorised all topics 1218 // that the other side is trying to register. 1219 if rlpHash(data.Topics) != pongpkt.data.(*pong).TopicHash { 1220 return nil, errors.New("topic hash mismatch") 1221 } 1222 if data.Idx < 0 || int(data.Idx) >= len(data.Topics) { 1223 return nil, errors.New("topic index out of range") 1224 } 1225 return pongpkt.data.(*pong), nil 1226 } 1227 1228 func rlpHash(x interface{}) (h common.Hash) { 1229 hw := sha3.NewKeccak256() 1230 rlp.Encode(hw, x) 1231 hw.Sum(h[:0]) 1232 return h 1233 } 1234 1235 func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { 1236 if n.pendingNeighbours == nil { 1237 return errNoQuery 1238 } 1239 net.abortTimedEvent(n, neighboursTimeout) 1240 1241 req := pkt.data.(*neighbors) 1242 nodes := make([]*Node, len(req.Nodes)) 1243 for i, rn := range req.Nodes { 1244 nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn) 1245 if err != nil { 1246 log.Debug(fmt.Sprintf("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err)) 1247 continue 1248 } 1249 nodes[i] = nn 1250 // Start validation of query results immediately. 1251 // This fills the table quickly. 1252 // TODO: generates way too many packets, maybe do it via queue. 1253 if nn.state == unknown { 1254 net.transition(nn, verifyinit) 1255 } 1256 } 1257 // TODO: don't ignore second packet 1258 n.pendingNeighbours.reply <- nodes 1259 n.pendingNeighbours = nil 1260 // Now that this query is done, start the next one. 1261 n.startNextQuery(net) 1262 return nil 1263 }