github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/discv5/net.go (about)

     1  package discv5
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/neatlab/neatio/chain/log"
    12  	"github.com/neatlab/neatio/network/p2p/netutil"
    13  	"github.com/neatlab/neatio/utilities/common"
    14  	"github.com/neatlab/neatio/utilities/common/mclock"
    15  	"github.com/neatlab/neatio/utilities/crypto"
    16  	"github.com/neatlab/neatio/utilities/rlp"
    17  	"golang.org/x/crypto/sha3"
    18  )
    19  
    20  var (
    21  	errInvalidEvent = errors.New("invalid in current state")
    22  	errNoQuery      = errors.New("no pending query")
    23  	errWrongAddress = errors.New("unknown sender address")
    24  )
    25  
    26  const (
    27  	autoRefreshInterval   = 1 * time.Hour
    28  	bucketRefreshInterval = 1 * time.Minute
    29  	seedCount             = 30
    30  	seedMaxAge            = 5 * 24 * time.Hour
    31  	lowPort               = 1024
    32  )
    33  
    34  const testTopic = "foo"
    35  
    36  const (
    37  	printTestImgLogs = false
    38  )
    39  
    40  type Network struct {
    41  	db          *nodeDB
    42  	conn        transport
    43  	netrestrict *netutil.Netlist
    44  
    45  	closed           chan struct{}
    46  	closeReq         chan struct{}
    47  	refreshReq       chan []*Node
    48  	refreshResp      chan (<-chan struct{})
    49  	read             chan ingressPacket
    50  	timeout          chan timeoutEvent
    51  	queryReq         chan *findnodeQuery
    52  	tableOpReq       chan func()
    53  	tableOpResp      chan struct{}
    54  	topicRegisterReq chan topicRegisterReq
    55  	topicSearchReq   chan topicSearchReq
    56  
    57  	tab           *Table
    58  	topictab      *topicTable
    59  	ticketStore   *ticketStore
    60  	nursery       []*Node
    61  	nodes         map[NodeID]*Node
    62  	timeoutTimers map[timeoutEvent]*time.Timer
    63  
    64  	slowRevalidateQueue []*Node
    65  	fastRevalidateQueue []*Node
    66  
    67  	sendBuf []*ingressPacket
    68  }
    69  
    70  type transport interface {
    71  	sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte)
    72  	sendNeighbours(remote *Node, nodes []*Node)
    73  	sendFindnodeHash(remote *Node, target common.Hash)
    74  	sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte)
    75  	sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
    76  
    77  	send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte)
    78  
    79  	localAddr() *net.UDPAddr
    80  	Close()
    81  }
    82  
    83  type findnodeQuery struct {
    84  	remote   *Node
    85  	target   common.Hash
    86  	reply    chan<- []*Node
    87  	nresults int
    88  }
    89  
    90  type topicRegisterReq struct {
    91  	add   bool
    92  	topic Topic
    93  }
    94  
    95  type topicSearchReq struct {
    96  	topic  Topic
    97  	found  chan<- *Node
    98  	lookup chan<- bool
    99  	delay  time.Duration
   100  }
   101  
   102  type topicSearchResult struct {
   103  	target lookupInfo
   104  	nodes  []*Node
   105  }
   106  
   107  type timeoutEvent struct {
   108  	ev   nodeEvent
   109  	node *Node
   110  }
   111  
   112  func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
   113  	ourID := PubkeyID(&ourPubkey)
   114  
   115  	var db *nodeDB
   116  	if dbPath != "<no database>" {
   117  		var err error
   118  		if db, err = newNodeDB(dbPath, Version, ourID); err != nil {
   119  			return nil, err
   120  		}
   121  	}
   122  
   123  	tab := newTable(ourID, conn.localAddr())
   124  	net := &Network{
   125  		db:               db,
   126  		conn:             conn,
   127  		netrestrict:      netrestrict,
   128  		tab:              tab,
   129  		topictab:         newTopicTable(db, tab.self),
   130  		ticketStore:      newTicketStore(),
   131  		refreshReq:       make(chan []*Node),
   132  		refreshResp:      make(chan (<-chan struct{})),
   133  		closed:           make(chan struct{}),
   134  		closeReq:         make(chan struct{}),
   135  		read:             make(chan ingressPacket, 100),
   136  		timeout:          make(chan timeoutEvent),
   137  		timeoutTimers:    make(map[timeoutEvent]*time.Timer),
   138  		tableOpReq:       make(chan func()),
   139  		tableOpResp:      make(chan struct{}),
   140  		queryReq:         make(chan *findnodeQuery),
   141  		topicRegisterReq: make(chan topicRegisterReq),
   142  		topicSearchReq:   make(chan topicSearchReq),
   143  		nodes:            make(map[NodeID]*Node),
   144  	}
   145  	go net.loop()
   146  	return net, nil
   147  }
   148  
   149  func (net *Network) Close() {
   150  	net.conn.Close()
   151  	select {
   152  	case <-net.closed:
   153  	case net.closeReq <- struct{}{}:
   154  		<-net.closed
   155  	}
   156  }
   157  
   158  func (net *Network) Self() *Node {
   159  	return net.tab.self
   160  }
   161  
   162  func (net *Network) ReadRandomNodes(buf []*Node) (n int) {
   163  	net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) })
   164  	return n
   165  }
   166  
   167  func (net *Network) SetFallbackNodes(nodes []*Node) error {
   168  	nursery := make([]*Node, 0, len(nodes))
   169  	for _, n := range nodes {
   170  		if err := n.validateComplete(); err != nil {
   171  			return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
   172  		}
   173  
   174  		cpy := *n
   175  		cpy.sha = crypto.Keccak256Hash(n.ID[:])
   176  		nursery = append(nursery, &cpy)
   177  	}
   178  	net.reqRefresh(nursery)
   179  	return nil
   180  }
   181  
   182  func (net *Network) Resolve(targetID NodeID) *Node {
   183  	result := net.lookup(crypto.Keccak256Hash(targetID[:]), true)
   184  	for _, n := range result {
   185  		if n.ID == targetID {
   186  			return n
   187  		}
   188  	}
   189  	return nil
   190  }
   191  
   192  func (net *Network) Lookup(targetID NodeID) []*Node {
   193  	return net.lookup(crypto.Keccak256Hash(targetID[:]), false)
   194  }
   195  
   196  func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node {
   197  	var (
   198  		asked          = make(map[NodeID]bool)
   199  		seen           = make(map[NodeID]bool)
   200  		reply          = make(chan []*Node, alpha)
   201  		result         = nodesByDistance{target: target}
   202  		pendingQueries = 0
   203  	)
   204  
   205  	result.push(net.tab.self, bucketSize)
   206  	for {
   207  
   208  		for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
   209  			n := result.entries[i]
   210  			if !asked[n.ID] {
   211  				asked[n.ID] = true
   212  				pendingQueries++
   213  				net.reqQueryFindnode(n, target, reply)
   214  			}
   215  		}
   216  		if pendingQueries == 0 {
   217  
   218  			break
   219  		}
   220  
   221  		select {
   222  		case nodes := <-reply:
   223  			for _, n := range nodes {
   224  				if n != nil && !seen[n.ID] {
   225  					seen[n.ID] = true
   226  					result.push(n, bucketSize)
   227  					if stopOnMatch && n.sha == target {
   228  						return result.entries
   229  					}
   230  				}
   231  			}
   232  			pendingQueries--
   233  		case <-time.After(respTimeout):
   234  
   235  			pendingQueries = 0
   236  			reply = make(chan []*Node, alpha)
   237  		}
   238  	}
   239  	return result.entries
   240  }
   241  
   242  func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) {
   243  	select {
   244  	case net.topicRegisterReq <- topicRegisterReq{true, topic}:
   245  	case <-net.closed:
   246  		return
   247  	}
   248  	select {
   249  	case <-net.closed:
   250  	case <-stop:
   251  		select {
   252  		case net.topicRegisterReq <- topicRegisterReq{false, topic}:
   253  		case <-net.closed:
   254  		}
   255  	}
   256  }
   257  
   258  func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) {
   259  	for {
   260  		select {
   261  		case <-net.closed:
   262  			return
   263  		case delay, ok := <-setPeriod:
   264  			select {
   265  			case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}:
   266  			case <-net.closed:
   267  				return
   268  			}
   269  			if !ok {
   270  				return
   271  			}
   272  		}
   273  	}
   274  }
   275  
   276  func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} {
   277  	select {
   278  	case net.refreshReq <- nursery:
   279  		return <-net.refreshResp
   280  	case <-net.closed:
   281  		return net.closed
   282  	}
   283  }
   284  
   285  func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool {
   286  	q := &findnodeQuery{remote: n, target: target, reply: reply}
   287  	select {
   288  	case net.queryReq <- q:
   289  		return true
   290  	case <-net.closed:
   291  		return false
   292  	}
   293  }
   294  
   295  func (net *Network) reqReadPacket(pkt ingressPacket) {
   296  	select {
   297  	case net.read <- pkt:
   298  	case <-net.closed:
   299  	}
   300  }
   301  
   302  func (net *Network) reqTableOp(f func()) (called bool) {
   303  	select {
   304  	case net.tableOpReq <- f:
   305  		<-net.tableOpResp
   306  		return true
   307  	case <-net.closed:
   308  		return false
   309  	}
   310  }
   311  
   312  type topicSearchInfo struct {
   313  	lookupChn chan<- bool
   314  	period    time.Duration
   315  }
   316  
   317  const maxSearchCount = 5
   318  
   319  func (net *Network) loop() {
   320  	var (
   321  		refreshTimer       = time.NewTicker(autoRefreshInterval)
   322  		bucketRefreshTimer = time.NewTimer(bucketRefreshInterval)
   323  		refreshDone        chan struct{}
   324  	)
   325  
   326  	var (
   327  		nextTicket        *ticketRef
   328  		nextRegisterTimer *time.Timer
   329  		nextRegisterTime  <-chan time.Time
   330  	)
   331  	defer func() {
   332  		if nextRegisterTimer != nil {
   333  			nextRegisterTimer.Stop()
   334  		}
   335  	}()
   336  	resetNextTicket := func() {
   337  		ticket, timeout := net.ticketStore.nextFilteredTicket()
   338  		if nextTicket != ticket {
   339  			nextTicket = ticket
   340  			if nextRegisterTimer != nil {
   341  				nextRegisterTimer.Stop()
   342  				nextRegisterTime = nil
   343  			}
   344  			if ticket != nil {
   345  				nextRegisterTimer = time.NewTimer(timeout)
   346  				nextRegisterTime = nextRegisterTimer.C
   347  			}
   348  		}
   349  	}
   350  
   351  	var (
   352  		topicRegisterLookupTarget lookupInfo
   353  		topicRegisterLookupDone   chan []*Node
   354  		topicRegisterLookupTick   = time.NewTimer(0)
   355  		searchReqWhenRefreshDone  []topicSearchReq
   356  		searchInfo                = make(map[Topic]topicSearchInfo)
   357  		activeSearchCount         int
   358  	)
   359  	topicSearchLookupDone := make(chan topicSearchResult, 100)
   360  	topicSearch := make(chan Topic, 100)
   361  	<-topicRegisterLookupTick.C
   362  
   363  	statsDump := time.NewTicker(10 * time.Second)
   364  
   365  loop:
   366  	for {
   367  		resetNextTicket()
   368  
   369  		select {
   370  		case <-net.closeReq:
   371  			log.Trace("<-net.closeReq")
   372  			break loop
   373  
   374  		case pkt := <-net.read:
   375  
   376  			log.Trace("<-net.read")
   377  			n := net.internNode(&pkt)
   378  			prestate := n.state
   379  			status := "ok"
   380  			if err := net.handle(n, pkt.ev, &pkt); err != nil {
   381  				status = err.Error()
   382  			}
   383  			log.Trace("", "msg", log.Lazy{Fn: func() string {
   384  				return fmt.Sprintf("<<< (%d) %v from %x@%v: %v -> %v (%v)",
   385  					net.tab.count, pkt.ev, pkt.remoteID[:8], pkt.remoteAddr, prestate, n.state, status)
   386  			}})
   387  
   388  		case timeout := <-net.timeout:
   389  			log.Trace("<-net.timeout")
   390  			if net.timeoutTimers[timeout] == nil {
   391  
   392  				continue
   393  			}
   394  			delete(net.timeoutTimers, timeout)
   395  			prestate := timeout.node.state
   396  			status := "ok"
   397  			if err := net.handle(timeout.node, timeout.ev, nil); err != nil {
   398  				status = err.Error()
   399  			}
   400  			log.Trace("", "msg", log.Lazy{Fn: func() string {
   401  				return fmt.Sprintf("--- (%d) %v for %x@%v: %v -> %v (%v)",
   402  					net.tab.count, timeout.ev, timeout.node.ID[:8], timeout.node.addr(), prestate, timeout.node.state, status)
   403  			}})
   404  
   405  		case q := <-net.queryReq:
   406  			log.Trace("<-net.queryReq")
   407  			if !q.start(net) {
   408  				q.remote.deferQuery(q)
   409  			}
   410  
   411  		case f := <-net.tableOpReq:
   412  			log.Trace("<-net.tableOpReq")
   413  			f()
   414  			net.tableOpResp <- struct{}{}
   415  
   416  		case req := <-net.topicRegisterReq:
   417  			log.Trace("<-net.topicRegisterReq")
   418  			if !req.add {
   419  				net.ticketStore.removeRegisterTopic(req.topic)
   420  				continue
   421  			}
   422  			net.ticketStore.addTopic(req.topic, true)
   423  
   424  			if topicRegisterLookupTarget.target == (common.Hash{}) {
   425  				log.Trace("topicRegisterLookupTarget == null")
   426  				if topicRegisterLookupTick.Stop() {
   427  					<-topicRegisterLookupTick.C
   428  				}
   429  				target, delay := net.ticketStore.nextRegisterLookup()
   430  				topicRegisterLookupTarget = target
   431  				topicRegisterLookupTick.Reset(delay)
   432  			}
   433  
   434  		case nodes := <-topicRegisterLookupDone:
   435  			log.Trace("<-topicRegisterLookupDone")
   436  			net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte {
   437  				net.ping(n, n.addr())
   438  				return n.pingEcho
   439  			})
   440  			target, delay := net.ticketStore.nextRegisterLookup()
   441  			topicRegisterLookupTarget = target
   442  			topicRegisterLookupTick.Reset(delay)
   443  			topicRegisterLookupDone = nil
   444  
   445  		case <-topicRegisterLookupTick.C:
   446  			log.Trace("<-topicRegisterLookupTick")
   447  			if (topicRegisterLookupTarget.target == common.Hash{}) {
   448  				target, delay := net.ticketStore.nextRegisterLookup()
   449  				topicRegisterLookupTarget = target
   450  				topicRegisterLookupTick.Reset(delay)
   451  				topicRegisterLookupDone = nil
   452  			} else {
   453  				topicRegisterLookupDone = make(chan []*Node)
   454  				target := topicRegisterLookupTarget.target
   455  				go func() { topicRegisterLookupDone <- net.lookup(target, false) }()
   456  			}
   457  
   458  		case <-nextRegisterTime:
   459  			log.Trace("<-nextRegisterTime")
   460  			net.ticketStore.ticketRegistered(*nextTicket)
   461  
   462  			net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
   463  
   464  		case req := <-net.topicSearchReq:
   465  			if refreshDone == nil {
   466  				log.Trace("<-net.topicSearchReq")
   467  				info, ok := searchInfo[req.topic]
   468  				if ok {
   469  					if req.delay == time.Duration(0) {
   470  						delete(searchInfo, req.topic)
   471  						net.ticketStore.removeSearchTopic(req.topic)
   472  					} else {
   473  						info.period = req.delay
   474  						searchInfo[req.topic] = info
   475  					}
   476  					continue
   477  				}
   478  				if req.delay != time.Duration(0) {
   479  					var info topicSearchInfo
   480  					info.period = req.delay
   481  					info.lookupChn = req.lookup
   482  					searchInfo[req.topic] = info
   483  					net.ticketStore.addSearchTopic(req.topic, req.found)
   484  					topicSearch <- req.topic
   485  				}
   486  			} else {
   487  				searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req)
   488  			}
   489  
   490  		case topic := <-topicSearch:
   491  			if activeSearchCount < maxSearchCount {
   492  				activeSearchCount++
   493  				target := net.ticketStore.nextSearchLookup(topic)
   494  				go func() {
   495  					nodes := net.lookup(target.target, false)
   496  					topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes}
   497  				}()
   498  			}
   499  			period := searchInfo[topic].period
   500  			if period != time.Duration(0) {
   501  				go func() {
   502  					time.Sleep(period)
   503  					topicSearch <- topic
   504  				}()
   505  			}
   506  
   507  		case res := <-topicSearchLookupDone:
   508  			activeSearchCount--
   509  			if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil {
   510  				lookupChn <- net.ticketStore.radius[res.target.topic].converged
   511  			}
   512  			net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte {
   513  				if n.state != nil && n.state.canQuery {
   514  					return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic})
   515  				} else {
   516  					if n.state == unknown {
   517  						net.ping(n, n.addr())
   518  					}
   519  					return nil
   520  				}
   521  			})
   522  
   523  		case <-statsDump.C:
   524  			log.Trace("<-statsDump.C")
   525  
   526  			tm := mclock.Now()
   527  			for topic, r := range net.ticketStore.radius {
   528  				if printTestImgLogs {
   529  					rad := r.radius / (maxRadius/1000000 + 1)
   530  					minrad := r.minRadius / (maxRadius/1000000 + 1)
   531  					fmt.Printf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad)
   532  					fmt.Printf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad)
   533  				}
   534  			}
   535  			for topic, t := range net.topictab.topics {
   536  				wp := t.wcl.nextWaitPeriod(tm)
   537  				if printTestImgLogs {
   538  					fmt.Printf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000)
   539  				}
   540  			}
   541  
   542  		case <-refreshTimer.C:
   543  			log.Trace("<-refreshTimer.C")
   544  
   545  			if refreshDone == nil {
   546  				refreshDone = make(chan struct{})
   547  				net.refresh(refreshDone)
   548  			}
   549  		case <-bucketRefreshTimer.C:
   550  			target := net.tab.chooseBucketRefreshTarget()
   551  			go func() {
   552  				net.lookup(target, false)
   553  				bucketRefreshTimer.Reset(bucketRefreshInterval)
   554  			}()
   555  		case newNursery := <-net.refreshReq:
   556  			log.Trace("<-net.refreshReq")
   557  			if newNursery != nil {
   558  				net.nursery = newNursery
   559  			}
   560  			if refreshDone == nil {
   561  				refreshDone = make(chan struct{})
   562  				net.refresh(refreshDone)
   563  			}
   564  			net.refreshResp <- refreshDone
   565  		case <-refreshDone:
   566  			log.Trace("<-net.refreshDone", "table size", net.tab.count)
   567  			if net.tab.count != 0 {
   568  				refreshDone = nil
   569  				list := searchReqWhenRefreshDone
   570  				searchReqWhenRefreshDone = nil
   571  				go func() {
   572  					for _, req := range list {
   573  						net.topicSearchReq <- req
   574  					}
   575  				}()
   576  			} else {
   577  				refreshDone = make(chan struct{})
   578  				net.refresh(refreshDone)
   579  			}
   580  		}
   581  	}
   582  	log.Trace("loop stopped")
   583  
   584  	log.Debug(fmt.Sprintf("shutting down"))
   585  	if net.conn != nil {
   586  		net.conn.Close()
   587  	}
   588  	if refreshDone != nil {
   589  
   590  	}
   591  
   592  	for _, timer := range net.timeoutTimers {
   593  		timer.Stop()
   594  	}
   595  	if net.db != nil {
   596  		net.db.close()
   597  	}
   598  	close(net.closed)
   599  }
   600  
   601  func (net *Network) refresh(done chan<- struct{}) {
   602  	var seeds []*Node
   603  	if net.db != nil {
   604  		seeds = net.db.querySeeds(seedCount, seedMaxAge)
   605  	}
   606  	if len(seeds) == 0 {
   607  		seeds = net.nursery
   608  	}
   609  	if len(seeds) == 0 {
   610  		log.Trace("no seed nodes found")
   611  		close(done)
   612  		return
   613  	}
   614  	for _, n := range seeds {
   615  		log.Debug("", "msg", log.Lazy{Fn: func() string {
   616  			var age string
   617  			if net.db != nil {
   618  				age = time.Since(net.db.lastPong(n.ID)).String()
   619  			} else {
   620  				age = "unknown"
   621  			}
   622  			return fmt.Sprintf("seed node (age %s): %v", age, n)
   623  		}})
   624  		n = net.internNodeFromDB(n)
   625  		if n.state == unknown {
   626  			net.transition(n, verifyinit)
   627  		}
   628  
   629  		net.tab.add(n)
   630  	}
   631  
   632  	go func() {
   633  		net.Lookup(net.tab.self.ID)
   634  		close(done)
   635  	}()
   636  }
   637  
   638  func (net *Network) internNode(pkt *ingressPacket) *Node {
   639  	if n := net.nodes[pkt.remoteID]; n != nil {
   640  		n.IP = pkt.remoteAddr.IP
   641  		n.UDP = uint16(pkt.remoteAddr.Port)
   642  		n.TCP = uint16(pkt.remoteAddr.Port)
   643  		return n
   644  	}
   645  	n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port))
   646  	n.state = unknown
   647  	net.nodes[pkt.remoteID] = n
   648  	return n
   649  }
   650  
   651  func (net *Network) internNodeFromDB(dbn *Node) *Node {
   652  	if n := net.nodes[dbn.ID]; n != nil {
   653  		return n
   654  	}
   655  	n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP)
   656  	n.state = unknown
   657  	net.nodes[n.ID] = n
   658  	return n
   659  }
   660  
   661  func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
   662  	if rn.ID == net.tab.self.ID {
   663  		return nil, errors.New("is self")
   664  	}
   665  	if rn.UDP <= lowPort {
   666  		return nil, errors.New("low port")
   667  	}
   668  	n = net.nodes[rn.ID]
   669  	if n == nil {
   670  
   671  		n, err = nodeFromRPC(sender, rn)
   672  		if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
   673  			return n, errors.New("not contained in netrestrict whitelist")
   674  		}
   675  		if err == nil {
   676  			n.state = unknown
   677  			net.nodes[n.ID] = n
   678  		}
   679  		return n, err
   680  	}
   681  	if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP {
   682  		if n.state == known {
   683  
   684  			err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
   685  		} else {
   686  
   687  			n.IP = rn.IP
   688  			n.UDP = rn.UDP
   689  			n.TCP = rn.TCP
   690  		}
   691  	}
   692  	return n, err
   693  }
   694  
   695  type nodeNetGuts struct {
   696  	sha common.Hash
   697  
   698  	state             *nodeState
   699  	pingEcho          []byte
   700  	pingTopics        []Topic
   701  	deferredQueries   []*findnodeQuery
   702  	pendingNeighbours *findnodeQuery
   703  	queryTimeouts     int
   704  }
   705  
   706  func (n *nodeNetGuts) deferQuery(q *findnodeQuery) {
   707  	n.deferredQueries = append(n.deferredQueries, q)
   708  }
   709  
   710  func (n *nodeNetGuts) startNextQuery(net *Network) {
   711  	if len(n.deferredQueries) == 0 {
   712  		return
   713  	}
   714  	nextq := n.deferredQueries[0]
   715  	if nextq.start(net) {
   716  		n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...)
   717  	}
   718  }
   719  
   720  func (q *findnodeQuery) start(net *Network) bool {
   721  
   722  	if q.remote == net.tab.self {
   723  		closest := net.tab.closest(crypto.Keccak256Hash(q.target[:]), bucketSize)
   724  		q.reply <- closest.entries
   725  		return true
   726  	}
   727  	if q.remote.state.canQuery && q.remote.pendingNeighbours == nil {
   728  		net.conn.sendFindnodeHash(q.remote, q.target)
   729  		net.timedEvent(respTimeout, q.remote, neighboursTimeout)
   730  		q.remote.pendingNeighbours = q
   731  		return true
   732  	}
   733  
   734  	if q.remote.state == unknown {
   735  		net.transition(q.remote, verifyinit)
   736  	}
   737  	return false
   738  }
   739  
   740  type nodeEvent uint
   741  
   742  const (
   743  	invalidEvent nodeEvent = iota
   744  
   745  	pingPacket
   746  	pongPacket
   747  	findnodePacket
   748  	neighborsPacket
   749  	findnodeHashPacket
   750  	topicRegisterPacket
   751  	topicQueryPacket
   752  	topicNodesPacket
   753  
   754  	pongTimeout nodeEvent = iota + 256
   755  	pingTimeout
   756  	neighboursTimeout
   757  )
   758  
   759  type nodeState struct {
   760  	name     string
   761  	handle   func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error)
   762  	enter    func(*Network, *Node)
   763  	canQuery bool
   764  }
   765  
   766  func (s *nodeState) String() string {
   767  	return s.name
   768  }
   769  
   770  var (
   771  	unknown          *nodeState
   772  	verifyinit       *nodeState
   773  	verifywait       *nodeState
   774  	remoteverifywait *nodeState
   775  	known            *nodeState
   776  	contested        *nodeState
   777  	unresponsive     *nodeState
   778  )
   779  
   780  func init() {
   781  	unknown = &nodeState{
   782  		name: "unknown",
   783  		enter: func(net *Network, n *Node) {
   784  			net.tab.delete(n)
   785  			n.pingEcho = nil
   786  
   787  			for _, q := range n.deferredQueries {
   788  				q.reply <- nil
   789  			}
   790  			n.deferredQueries = nil
   791  			if n.pendingNeighbours != nil {
   792  				n.pendingNeighbours.reply <- nil
   793  				n.pendingNeighbours = nil
   794  			}
   795  			n.queryTimeouts = 0
   796  		},
   797  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   798  			switch ev {
   799  			case pingPacket:
   800  				net.handlePing(n, pkt)
   801  				net.ping(n, pkt.remoteAddr)
   802  				return verifywait, nil
   803  			default:
   804  				return unknown, errInvalidEvent
   805  			}
   806  		},
   807  	}
   808  
   809  	verifyinit = &nodeState{
   810  		name: "verifyinit",
   811  		enter: func(net *Network, n *Node) {
   812  			net.ping(n, n.addr())
   813  		},
   814  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   815  			switch ev {
   816  			case pingPacket:
   817  				net.handlePing(n, pkt)
   818  				return verifywait, nil
   819  			case pongPacket:
   820  				err := net.handleKnownPong(n, pkt)
   821  				return remoteverifywait, err
   822  			case pongTimeout:
   823  				return unknown, nil
   824  			default:
   825  				return verifyinit, errInvalidEvent
   826  			}
   827  		},
   828  	}
   829  
   830  	verifywait = &nodeState{
   831  		name: "verifywait",
   832  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   833  			switch ev {
   834  			case pingPacket:
   835  				net.handlePing(n, pkt)
   836  				return verifywait, nil
   837  			case pongPacket:
   838  				err := net.handleKnownPong(n, pkt)
   839  				return known, err
   840  			case pongTimeout:
   841  				return unknown, nil
   842  			default:
   843  				return verifywait, errInvalidEvent
   844  			}
   845  		},
   846  	}
   847  
   848  	remoteverifywait = &nodeState{
   849  		name: "remoteverifywait",
   850  		enter: func(net *Network, n *Node) {
   851  			net.timedEvent(respTimeout, n, pingTimeout)
   852  		},
   853  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   854  			switch ev {
   855  			case pingPacket:
   856  				net.handlePing(n, pkt)
   857  				return remoteverifywait, nil
   858  			case pingTimeout:
   859  				return known, nil
   860  			default:
   861  				return remoteverifywait, errInvalidEvent
   862  			}
   863  		},
   864  	}
   865  
   866  	known = &nodeState{
   867  		name:     "known",
   868  		canQuery: true,
   869  		enter: func(net *Network, n *Node) {
   870  			n.queryTimeouts = 0
   871  			n.startNextQuery(net)
   872  
   873  			last := net.tab.add(n)
   874  			if last != nil && last.state == known {
   875  
   876  				net.transition(last, contested)
   877  			}
   878  		},
   879  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   880  			switch ev {
   881  			case pingPacket:
   882  				net.handlePing(n, pkt)
   883  				return known, nil
   884  			case pongPacket:
   885  				err := net.handleKnownPong(n, pkt)
   886  				return known, err
   887  			default:
   888  				return net.handleQueryEvent(n, ev, pkt)
   889  			}
   890  		},
   891  	}
   892  
   893  	contested = &nodeState{
   894  		name:     "contested",
   895  		canQuery: true,
   896  		enter: func(net *Network, n *Node) {
   897  			net.ping(n, n.addr())
   898  		},
   899  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   900  			switch ev {
   901  			case pongPacket:
   902  
   903  				err := net.handleKnownPong(n, pkt)
   904  				return known, err
   905  			case pongTimeout:
   906  				net.tab.deleteReplace(n)
   907  				return unresponsive, nil
   908  			case pingPacket:
   909  				net.handlePing(n, pkt)
   910  				return contested, nil
   911  			default:
   912  				return net.handleQueryEvent(n, ev, pkt)
   913  			}
   914  		},
   915  	}
   916  
   917  	unresponsive = &nodeState{
   918  		name:     "unresponsive",
   919  		canQuery: true,
   920  		handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
   921  			switch ev {
   922  			case pingPacket:
   923  				net.handlePing(n, pkt)
   924  				return known, nil
   925  			case pongPacket:
   926  				err := net.handleKnownPong(n, pkt)
   927  				return known, err
   928  			default:
   929  				return net.handleQueryEvent(n, ev, pkt)
   930  			}
   931  		},
   932  	}
   933  }
   934  
   935  func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
   936  
   937  	if pkt != nil {
   938  		if err := net.checkPacket(n, ev, pkt); err != nil {
   939  
   940  			return err
   941  		}
   942  
   943  		if net.db != nil {
   944  			net.db.ensureExpirer()
   945  		}
   946  	}
   947  	if n.state == nil {
   948  		n.state = unknown
   949  	}
   950  	next, err := n.state.handle(net, n, ev, pkt)
   951  	net.transition(n, next)
   952  
   953  	return err
   954  }
   955  
   956  func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error {
   957  
   958  	switch ev {
   959  	case pingPacket, findnodeHashPacket, neighborsPacket:
   960  
   961  	case pongPacket:
   962  		if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) {
   963  
   964  			return fmt.Errorf("pong reply token mismatch")
   965  		}
   966  		n.pingEcho = nil
   967  	}
   968  
   969  	return nil
   970  }
   971  
   972  func (net *Network) transition(n *Node, next *nodeState) {
   973  	if n.state != next {
   974  		n.state = next
   975  		if next.enter != nil {
   976  			next.enter(net, n)
   977  		}
   978  	}
   979  
   980  }
   981  
   982  func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) {
   983  	timeout := timeoutEvent{ev, n}
   984  	net.timeoutTimers[timeout] = time.AfterFunc(d, func() {
   985  		select {
   986  		case net.timeout <- timeout:
   987  		case <-net.closed:
   988  		}
   989  	})
   990  }
   991  
   992  func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) {
   993  	timer := net.timeoutTimers[timeoutEvent{ev, n}]
   994  	if timer != nil {
   995  		timer.Stop()
   996  		delete(net.timeoutTimers, timeoutEvent{ev, n})
   997  	}
   998  }
   999  
  1000  func (net *Network) ping(n *Node, addr *net.UDPAddr) {
  1001  
  1002  	if n.pingEcho != nil || n.ID == net.tab.self.ID {
  1003  
  1004  		return
  1005  	}
  1006  	log.Trace("Pinging remote node", "node", n.ID)
  1007  	n.pingTopics = net.ticketStore.regTopicSet()
  1008  	n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics)
  1009  	net.timedEvent(respTimeout, n, pongTimeout)
  1010  }
  1011  
  1012  func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
  1013  	log.Trace("Handling remote ping", "node", n.ID)
  1014  	ping := pkt.data.(*ping)
  1015  	n.TCP = ping.From.TCP
  1016  	t := net.topictab.getTicket(n, ping.Topics)
  1017  
  1018  	pong := &pong{
  1019  		To:         makeEndpoint(n.addr(), n.TCP),
  1020  		ReplyTok:   pkt.hash,
  1021  		Expiration: uint64(time.Now().Add(expiration).Unix()),
  1022  	}
  1023  	ticketToPong(t, pong)
  1024  	net.conn.send(n, pongPacket, pong)
  1025  }
  1026  
  1027  func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
  1028  	log.Trace("Handling known pong", "node", n.ID)
  1029  	net.abortTimedEvent(n, pongTimeout)
  1030  	now := mclock.Now()
  1031  	ticket, err := pongToTicket(now, n.pingTopics, n, pkt)
  1032  	if err == nil {
  1033  
  1034  		net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket)
  1035  	} else {
  1036  		log.Trace("Failed to convert pong to ticket", "err", err)
  1037  	}
  1038  	n.pingEcho = nil
  1039  	n.pingTopics = nil
  1040  	return err
  1041  }
  1042  
  1043  func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
  1044  	switch ev {
  1045  	case findnodePacket:
  1046  		target := crypto.Keccak256Hash(pkt.data.(*findnode).Target[:])
  1047  		results := net.tab.closest(target, bucketSize).entries
  1048  		net.conn.sendNeighbours(n, results)
  1049  		return n.state, nil
  1050  	case neighborsPacket:
  1051  		err := net.handleNeighboursPacket(n, pkt)
  1052  		return n.state, err
  1053  	case neighboursTimeout:
  1054  		if n.pendingNeighbours != nil {
  1055  			n.pendingNeighbours.reply <- nil
  1056  			n.pendingNeighbours = nil
  1057  		}
  1058  		n.queryTimeouts++
  1059  		if n.queryTimeouts > maxFindnodeFailures && n.state == known {
  1060  			return contested, errors.New("too many timeouts")
  1061  		}
  1062  		return n.state, nil
  1063  
  1064  	case findnodeHashPacket:
  1065  		results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries
  1066  		net.conn.sendNeighbours(n, results)
  1067  		return n.state, nil
  1068  	case topicRegisterPacket:
  1069  
  1070  		regdata := pkt.data.(*topicRegister)
  1071  		pong, err := net.checkTopicRegister(regdata)
  1072  		if err != nil {
  1073  
  1074  			return n.state, fmt.Errorf("bad waiting ticket: %v", err)
  1075  		}
  1076  		net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods)
  1077  		return n.state, nil
  1078  	case topicQueryPacket:
  1079  
  1080  		topic := pkt.data.(*topicQuery).Topic
  1081  		results := net.topictab.getEntries(topic)
  1082  		if _, ok := net.ticketStore.tickets[topic]; ok {
  1083  			results = append(results, net.tab.self)
  1084  		}
  1085  		if len(results) > 10 {
  1086  			results = results[:10]
  1087  		}
  1088  		var hash common.Hash
  1089  		copy(hash[:], pkt.hash)
  1090  		net.conn.sendTopicNodes(n, hash, results)
  1091  		return n.state, nil
  1092  	case topicNodesPacket:
  1093  		p := pkt.data.(*topicNodes)
  1094  		if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) {
  1095  			n.queryTimeouts++
  1096  			if n.queryTimeouts > maxFindnodeFailures && n.state == known {
  1097  				return contested, errors.New("too many timeouts")
  1098  			}
  1099  		}
  1100  		return n.state, nil
  1101  
  1102  	default:
  1103  		return n.state, errInvalidEvent
  1104  	}
  1105  }
  1106  
  1107  func (net *Network) checkTopicRegister(data *topicRegister) (*pong, error) {
  1108  	var pongpkt ingressPacket
  1109  	if err := decodePacket(data.Pong, &pongpkt); err != nil {
  1110  		return nil, err
  1111  	}
  1112  	if pongpkt.ev != pongPacket {
  1113  		return nil, errors.New("is not pong packet")
  1114  	}
  1115  	if pongpkt.remoteID != net.tab.self.ID {
  1116  		return nil, errors.New("not signed by us")
  1117  	}
  1118  
  1119  	if rlpHash(data.Topics) != pongpkt.data.(*pong).TopicHash {
  1120  		return nil, errors.New("topic hash mismatch")
  1121  	}
  1122  	if data.Idx < 0 || int(data.Idx) >= len(data.Topics) {
  1123  		return nil, errors.New("topic index out of range")
  1124  	}
  1125  	return pongpkt.data.(*pong), nil
  1126  }
  1127  
  1128  func rlpHash(x interface{}) (h common.Hash) {
  1129  	hw := sha3.NewLegacyKeccak256()
  1130  	rlp.Encode(hw, x)
  1131  	hw.Sum(h[:0])
  1132  	return h
  1133  }
  1134  
  1135  func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
  1136  	if n.pendingNeighbours == nil {
  1137  		return errNoQuery
  1138  	}
  1139  	net.abortTimedEvent(n, neighboursTimeout)
  1140  
  1141  	req := pkt.data.(*neighbors)
  1142  	nodes := make([]*Node, len(req.Nodes))
  1143  	for i, rn := range req.Nodes {
  1144  		nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
  1145  		if err != nil {
  1146  			log.Debug(fmt.Sprintf("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err))
  1147  			continue
  1148  		}
  1149  		nodes[i] = nn
  1150  
  1151  		if nn.state == unknown {
  1152  			net.transition(nn, verifyinit)
  1153  		}
  1154  	}
  1155  
  1156  	n.pendingNeighbours.reply <- nodes
  1157  	n.pendingNeighbours = nil
  1158  
  1159  	n.startNextQuery(net)
  1160  	return nil
  1161  }