github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/discover/table.go (about)

     1  package discover
     2  
     3  import (
     4  	crand "crypto/rand"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	mrand "math/rand"
     9  	"net"
    10  	"sort"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/neatio-net/neatio/chain/log"
    15  	"github.com/neatio-net/neatio/network/p2p/netutil"
    16  	"github.com/neatio-net/neatio/utilities/common"
    17  	"github.com/neatio-net/neatio/utilities/crypto"
    18  )
    19  
    20  const (
    21  	alpha           = 3
    22  	bucketSize      = 16
    23  	maxReplacements = 10
    24  
    25  	hashBits          = len(common.Hash{}) * 8
    26  	nBuckets          = hashBits / 15
    27  	bucketMinDistance = hashBits - nBuckets
    28  
    29  	bucketIPLimit, bucketSubnet = 2, 24
    30  	tableIPLimit, tableSubnet   = 10, 24
    31  
    32  	maxBondingPingPongs = 16
    33  	maxFindnodeFailures = 5
    34  
    35  	refreshInterval    = 30 * time.Minute
    36  	revalidateInterval = 10 * time.Second
    37  	copyNodesInterval  = 30 * time.Second
    38  	seedMinTableTime   = 5 * time.Minute
    39  	seedCount          = 30
    40  	seedMaxAge         = 5 * 24 * time.Hour
    41  )
    42  
    43  type Table struct {
    44  	mutex   sync.Mutex
    45  	buckets [nBuckets]*bucket
    46  	nursery []*Node
    47  	rand    *mrand.Rand
    48  	ips     netutil.DistinctNetSet
    49  
    50  	db         *nodeDB
    51  	refreshReq chan chan struct{}
    52  	initDone   chan struct{}
    53  	closeReq   chan struct{}
    54  	closed     chan struct{}
    55  
    56  	bondmu    sync.Mutex
    57  	bonding   map[NodeID]*bondproc
    58  	bondslots chan struct{}
    59  
    60  	nodeAddedHook func(*Node)
    61  
    62  	net  transport
    63  	self *Node
    64  }
    65  
    66  type bondproc struct {
    67  	err  error
    68  	n    *Node
    69  	done chan struct{}
    70  }
    71  
    72  type transport interface {
    73  	ping(NodeID, *net.UDPAddr) error
    74  	waitping(NodeID) error
    75  	findnode(toid NodeID, addr *net.UDPAddr, target NodeID) ([]*Node, error)
    76  	close()
    77  }
    78  
    79  type bucket struct {
    80  	entries      []*Node
    81  	replacements []*Node
    82  	ips          netutil.DistinctNetSet
    83  }
    84  
    85  func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string, bootnodes []*Node) (*Table, error) {
    86  
    87  	db, err := newNodeDB(nodeDBPath, Version, ourID)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	tab := &Table{
    92  		net:        t,
    93  		db:         db,
    94  		self:       NewNode(ourID, ourAddr.IP, uint16(ourAddr.Port), uint16(ourAddr.Port)),
    95  		bonding:    make(map[NodeID]*bondproc),
    96  		bondslots:  make(chan struct{}, maxBondingPingPongs),
    97  		refreshReq: make(chan chan struct{}),
    98  		initDone:   make(chan struct{}),
    99  		closeReq:   make(chan struct{}),
   100  		closed:     make(chan struct{}),
   101  		rand:       mrand.New(mrand.NewSource(0)),
   102  		ips:        netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
   103  	}
   104  	if err := tab.setFallbackNodes(bootnodes); err != nil {
   105  		return nil, err
   106  	}
   107  	for i := 0; i < cap(tab.bondslots); i++ {
   108  		tab.bondslots <- struct{}{}
   109  	}
   110  	for i := range tab.buckets {
   111  		tab.buckets[i] = &bucket{
   112  			ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
   113  		}
   114  	}
   115  	tab.seedRand()
   116  	tab.loadSeedNodes(false)
   117  
   118  	tab.db.ensureExpirer()
   119  	go tab.loop()
   120  	return tab, nil
   121  }
   122  
   123  func (tab *Table) seedRand() {
   124  	var b [8]byte
   125  	crand.Read(b[:])
   126  
   127  	tab.mutex.Lock()
   128  	tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
   129  	tab.mutex.Unlock()
   130  }
   131  
   132  func (tab *Table) Self() *Node {
   133  	return tab.self
   134  }
   135  
   136  func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
   137  	if !tab.isInitDone() {
   138  		return 0
   139  	}
   140  	tab.mutex.Lock()
   141  	defer tab.mutex.Unlock()
   142  
   143  	var buckets [][]*Node
   144  	for _, b := range tab.buckets {
   145  		if len(b.entries) > 0 {
   146  			buckets = append(buckets, b.entries[:])
   147  		}
   148  	}
   149  	if len(buckets) == 0 {
   150  		return 0
   151  	}
   152  
   153  	for i := len(buckets) - 1; i > 0; i-- {
   154  		j := tab.rand.Intn(len(buckets))
   155  		buckets[i], buckets[j] = buckets[j], buckets[i]
   156  	}
   157  
   158  	var i, j int
   159  	for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) {
   160  		b := buckets[j]
   161  		buf[i] = &(*b[0])
   162  		buckets[j] = b[1:]
   163  		if len(b) == 1 {
   164  			buckets = append(buckets[:j], buckets[j+1:]...)
   165  		}
   166  		if len(buckets) == 0 {
   167  			break
   168  		}
   169  	}
   170  	return i + 1
   171  }
   172  
   173  func (tab *Table) Close() {
   174  	select {
   175  	case <-tab.closed:
   176  
   177  	case tab.closeReq <- struct{}{}:
   178  		<-tab.closed
   179  	}
   180  }
   181  
   182  func (tab *Table) setFallbackNodes(nodes []*Node) error {
   183  	for _, n := range nodes {
   184  		if err := n.validateComplete(); err != nil {
   185  			return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
   186  		}
   187  	}
   188  	tab.nursery = make([]*Node, 0, len(nodes))
   189  	for _, n := range nodes {
   190  		cpy := *n
   191  
   192  		cpy.sha = crypto.Keccak256Hash(n.ID[:])
   193  		tab.nursery = append(tab.nursery, &cpy)
   194  	}
   195  	return nil
   196  }
   197  
   198  func (tab *Table) isInitDone() bool {
   199  	select {
   200  	case <-tab.initDone:
   201  		return true
   202  	default:
   203  		return false
   204  	}
   205  }
   206  
   207  func (tab *Table) Resolve(targetID NodeID) *Node {
   208  
   209  	hash := crypto.Keccak256Hash(targetID[:])
   210  	tab.mutex.Lock()
   211  	cl := tab.closest(hash, 1)
   212  	tab.mutex.Unlock()
   213  	if len(cl.entries) > 0 && cl.entries[0].ID == targetID {
   214  		return cl.entries[0]
   215  	}
   216  
   217  	result := tab.Lookup(targetID)
   218  	for _, n := range result {
   219  		if n.ID == targetID {
   220  			return n
   221  		}
   222  	}
   223  	return nil
   224  }
   225  
   226  func (tab *Table) Lookup(targetID NodeID) []*Node {
   227  	return tab.lookup(targetID, true)
   228  }
   229  
   230  func (tab *Table) lookup(targetID NodeID, refreshIfEmpty bool) []*Node {
   231  	var (
   232  		target         = crypto.Keccak256Hash(targetID[:])
   233  		asked          = make(map[NodeID]bool)
   234  		seen           = make(map[NodeID]bool)
   235  		reply          = make(chan []*Node, alpha)
   236  		pendingQueries = 0
   237  		result         *nodesByDistance
   238  	)
   239  
   240  	asked[tab.self.ID] = true
   241  
   242  	for {
   243  		tab.mutex.Lock()
   244  
   245  		result = tab.closest(target, bucketSize)
   246  		tab.mutex.Unlock()
   247  		if len(result.entries) > 0 || !refreshIfEmpty {
   248  			break
   249  		}
   250  
   251  		<-tab.refresh()
   252  		refreshIfEmpty = false
   253  	}
   254  
   255  	for {
   256  
   257  		for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
   258  			n := result.entries[i]
   259  			if !asked[n.ID] {
   260  				asked[n.ID] = true
   261  				pendingQueries++
   262  				go func() {
   263  
   264  					r, err := tab.net.findnode(n.ID, n.addr(), targetID)
   265  					if err != nil {
   266  
   267  						fails := tab.db.findFails(n.ID) + 1
   268  						tab.db.updateFindFails(n.ID, fails)
   269  						log.Trace("Bumping findnode failure counter", "id", n.ID, "failcount", fails)
   270  
   271  						if fails >= maxFindnodeFailures {
   272  							log.Trace("Too many findnode failures, dropping", "id", n.ID, "failcount", fails)
   273  							tab.delete(n)
   274  						}
   275  					}
   276  					reply <- tab.bondall(r)
   277  				}()
   278  			}
   279  		}
   280  		if pendingQueries == 0 {
   281  
   282  			break
   283  		}
   284  
   285  		for _, n := range <-reply {
   286  			if n != nil && !seen[n.ID] {
   287  				seen[n.ID] = true
   288  				result.push(n, bucketSize)
   289  			}
   290  		}
   291  		pendingQueries--
   292  	}
   293  	return result.entries
   294  }
   295  
   296  func (tab *Table) refresh() <-chan struct{} {
   297  	done := make(chan struct{})
   298  	select {
   299  	case tab.refreshReq <- done:
   300  	case <-tab.closed:
   301  		close(done)
   302  	}
   303  	return done
   304  }
   305  
   306  func (tab *Table) loop() {
   307  	var (
   308  		revalidate     = time.NewTimer(tab.nextRevalidateTime())
   309  		refresh        = time.NewTicker(refreshInterval)
   310  		copyNodes      = time.NewTicker(copyNodesInterval)
   311  		revalidateDone = make(chan struct{})
   312  		refreshDone    = make(chan struct{})
   313  		waiting        = []chan struct{}{tab.initDone}
   314  	)
   315  	defer refresh.Stop()
   316  	defer revalidate.Stop()
   317  	defer copyNodes.Stop()
   318  
   319  	go tab.doRefresh(refreshDone)
   320  
   321  loop:
   322  	for {
   323  		select {
   324  		case <-refresh.C:
   325  			tab.seedRand()
   326  			if refreshDone == nil {
   327  				refreshDone = make(chan struct{})
   328  				go tab.doRefresh(refreshDone)
   329  			}
   330  		case req := <-tab.refreshReq:
   331  			waiting = append(waiting, req)
   332  			if refreshDone == nil {
   333  				refreshDone = make(chan struct{})
   334  				go tab.doRefresh(refreshDone)
   335  			}
   336  		case <-refreshDone:
   337  			for _, ch := range waiting {
   338  				close(ch)
   339  			}
   340  			waiting, refreshDone = nil, nil
   341  		case <-revalidate.C:
   342  			go tab.doRevalidate(revalidateDone)
   343  		case <-revalidateDone:
   344  			revalidate.Reset(tab.nextRevalidateTime())
   345  		case <-copyNodes.C:
   346  			go tab.copyBondedNodes()
   347  		case <-tab.closeReq:
   348  			break loop
   349  		}
   350  	}
   351  
   352  	if tab.net != nil {
   353  		tab.net.close()
   354  	}
   355  	if refreshDone != nil {
   356  		<-refreshDone
   357  	}
   358  	for _, ch := range waiting {
   359  		close(ch)
   360  	}
   361  	tab.db.close()
   362  	close(tab.closed)
   363  }
   364  
   365  func (tab *Table) doRefresh(done chan struct{}) {
   366  	defer close(done)
   367  
   368  	tab.loadSeedNodes(true)
   369  
   370  	tab.lookup(tab.self.ID, false)
   371  
   372  	for i := 0; i < 3; i++ {
   373  		var target NodeID
   374  		crand.Read(target[:])
   375  		tab.lookup(target, false)
   376  	}
   377  }
   378  
   379  func (tab *Table) loadSeedNodes(bond bool) {
   380  	seeds := tab.db.querySeeds(seedCount, seedMaxAge)
   381  	seeds = append(seeds, tab.nursery...)
   382  	if bond {
   383  		seeds = tab.bondall(seeds)
   384  	}
   385  	for i := range seeds {
   386  		seed := seeds[i]
   387  		age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.bondTime(seed.ID)) }}
   388  		log.Debug("Found seed node in database", "id", seed.ID, "addr", seed.addr(), "age", age)
   389  		tab.add(seed)
   390  	}
   391  }
   392  
   393  func (tab *Table) doRevalidate(done chan<- struct{}) {
   394  	defer func() { done <- struct{}{} }()
   395  
   396  	last, bi := tab.nodeToRevalidate()
   397  	if last == nil {
   398  
   399  		return
   400  	}
   401  
   402  	err := tab.ping(last.ID, last.addr())
   403  
   404  	tab.mutex.Lock()
   405  	defer tab.mutex.Unlock()
   406  	b := tab.buckets[bi]
   407  	if err == nil {
   408  
   409  		log.Debug("Revalidated node", "b", bi, "id", last.ID)
   410  		b.bump(last)
   411  		return
   412  	}
   413  
   414  	if r := tab.replace(b, last); r != nil {
   415  		log.Debug("Replaced dead node", "b", bi, "id", last.ID, "ip", last.IP, "r", r.ID, "rip", r.IP)
   416  	} else {
   417  		log.Debug("Removed dead node", "b", bi, "id", last.ID, "ip", last.IP)
   418  	}
   419  }
   420  
   421  func (tab *Table) nodeToRevalidate() (n *Node, bi int) {
   422  	tab.mutex.Lock()
   423  	defer tab.mutex.Unlock()
   424  
   425  	for _, bi = range tab.rand.Perm(len(tab.buckets)) {
   426  		b := tab.buckets[bi]
   427  		if len(b.entries) > 0 {
   428  			last := b.entries[len(b.entries)-1]
   429  			return last, bi
   430  		}
   431  	}
   432  	return nil, 0
   433  }
   434  
   435  func (tab *Table) nextRevalidateTime() time.Duration {
   436  	tab.mutex.Lock()
   437  	defer tab.mutex.Unlock()
   438  
   439  	return time.Duration(tab.rand.Int63n(int64(revalidateInterval)))
   440  }
   441  
   442  func (tab *Table) copyBondedNodes() {
   443  	tab.mutex.Lock()
   444  	defer tab.mutex.Unlock()
   445  
   446  	now := time.Now()
   447  	for _, b := range tab.buckets {
   448  		for _, n := range b.entries {
   449  			if now.Sub(n.addedAt) >= seedMinTableTime {
   450  				tab.db.updateNode(n)
   451  			}
   452  		}
   453  	}
   454  }
   455  
   456  func (tab *Table) closest(target common.Hash, nresults int) *nodesByDistance {
   457  
   458  	close := &nodesByDistance{target: target}
   459  	for _, b := range tab.buckets {
   460  		for _, n := range b.entries {
   461  			close.push(n, nresults)
   462  		}
   463  	}
   464  	return close
   465  }
   466  
   467  func (tab *Table) len() (n int) {
   468  	for _, b := range tab.buckets {
   469  		n += len(b.entries)
   470  	}
   471  	return n
   472  }
   473  
   474  func (tab *Table) bondall(nodes []*Node) (result []*Node) {
   475  	rc := make(chan *Node, len(nodes))
   476  	for i := range nodes {
   477  		go func(n *Node) {
   478  			nn, _ := tab.bond(false, n.ID, n.addr(), n.TCP)
   479  			rc <- nn
   480  		}(nodes[i])
   481  	}
   482  	for range nodes {
   483  		if n := <-rc; n != nil {
   484  			result = append(result, n)
   485  		}
   486  	}
   487  	return result
   488  }
   489  
   490  func (tab *Table) bond(pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) (*Node, error) {
   491  	if id == tab.self.ID {
   492  		return nil, errors.New("is self")
   493  	}
   494  	if pinged && !tab.isInitDone() {
   495  		return nil, errors.New("still initializing")
   496  	}
   497  
   498  	node, fails := tab.db.node(id), tab.db.findFails(id)
   499  	age := time.Since(tab.db.bondTime(id))
   500  	var result error
   501  	if fails > 0 || age > nodeDBNodeExpiration {
   502  		log.Trace("Starting bonding ping/pong", "id", id, "known", node != nil, "failcount", fails, "age", age)
   503  
   504  		tab.bondmu.Lock()
   505  		w := tab.bonding[id]
   506  		if w != nil {
   507  
   508  			tab.bondmu.Unlock()
   509  			<-w.done
   510  		} else {
   511  
   512  			w = &bondproc{done: make(chan struct{})}
   513  			tab.bonding[id] = w
   514  			tab.bondmu.Unlock()
   515  
   516  			tab.pingpong(w, pinged, id, addr, tcpPort)
   517  
   518  			tab.bondmu.Lock()
   519  			delete(tab.bonding, id)
   520  			tab.bondmu.Unlock()
   521  		}
   522  
   523  		result = w.err
   524  		if result == nil {
   525  			node = w.n
   526  		}
   527  	}
   528  
   529  	if node != nil {
   530  		tab.add(node)
   531  		tab.db.updateFindFails(id, 0)
   532  	}
   533  	return node, result
   534  }
   535  
   536  func (tab *Table) pingpong(w *bondproc, pinged bool, id NodeID, addr *net.UDPAddr, tcpPort uint16) {
   537  
   538  	<-tab.bondslots
   539  	defer func() { tab.bondslots <- struct{}{} }()
   540  
   541  	if w.err = tab.ping(id, addr); w.err != nil {
   542  		close(w.done)
   543  		return
   544  	}
   545  	if !pinged {
   546  
   547  		tab.net.waitping(id)
   548  	}
   549  
   550  	w.n = NewNode(id, addr.IP, uint16(addr.Port), tcpPort)
   551  	close(w.done)
   552  }
   553  
   554  func (tab *Table) ping(id NodeID, addr *net.UDPAddr) error {
   555  	tab.db.updateLastPing(id, time.Now())
   556  	if err := tab.net.ping(id, addr); err != nil {
   557  		return err
   558  	}
   559  	tab.db.updateBondTime(id, time.Now())
   560  	return nil
   561  }
   562  
   563  func (tab *Table) bucket(sha common.Hash) *bucket {
   564  	d := logdist(tab.self.sha, sha)
   565  	if d <= bucketMinDistance {
   566  		return tab.buckets[0]
   567  	}
   568  	return tab.buckets[d-bucketMinDistance-1]
   569  }
   570  
   571  func (tab *Table) add(new *Node) {
   572  	tab.mutex.Lock()
   573  	defer tab.mutex.Unlock()
   574  
   575  	b := tab.bucket(new.sha)
   576  	if !tab.bumpOrAdd(b, new) {
   577  
   578  		tab.addReplacement(b, new)
   579  	}
   580  }
   581  
   582  func (tab *Table) stuff(nodes []*Node) {
   583  	tab.mutex.Lock()
   584  	defer tab.mutex.Unlock()
   585  
   586  	for _, n := range nodes {
   587  		if n.ID == tab.self.ID {
   588  			continue
   589  		}
   590  		b := tab.bucket(n.sha)
   591  		if len(b.entries) < bucketSize {
   592  			tab.bumpOrAdd(b, n)
   593  		}
   594  	}
   595  }
   596  
   597  func (tab *Table) delete(node *Node) {
   598  	tab.mutex.Lock()
   599  	defer tab.mutex.Unlock()
   600  
   601  	tab.deleteInBucket(tab.bucket(node.sha), node)
   602  }
   603  
   604  func (tab *Table) addIP(b *bucket, ip net.IP) bool {
   605  	if netutil.IsLAN(ip) {
   606  		return true
   607  	}
   608  	if !tab.ips.Add(ip) {
   609  		log.Debug("IP exceeds table limit", "ip", ip)
   610  		return false
   611  	}
   612  	if !b.ips.Add(ip) {
   613  		log.Debug("IP exceeds bucket limit", "ip", ip)
   614  		tab.ips.Remove(ip)
   615  		return false
   616  	}
   617  	return true
   618  }
   619  
   620  func (tab *Table) removeIP(b *bucket, ip net.IP) {
   621  	if netutil.IsLAN(ip) {
   622  		return
   623  	}
   624  	tab.ips.Remove(ip)
   625  	b.ips.Remove(ip)
   626  }
   627  
   628  func (tab *Table) addReplacement(b *bucket, n *Node) {
   629  	for _, e := range b.replacements {
   630  		if e.ID == n.ID {
   631  			return
   632  		}
   633  	}
   634  	if !tab.addIP(b, n.IP) {
   635  		return
   636  	}
   637  	var removed *Node
   638  	b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
   639  	if removed != nil {
   640  		tab.removeIP(b, removed.IP)
   641  	}
   642  }
   643  
   644  func (tab *Table) replace(b *bucket, last *Node) *Node {
   645  	if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID != last.ID {
   646  
   647  		return nil
   648  	}
   649  
   650  	if len(b.replacements) == 0 {
   651  		tab.deleteInBucket(b, last)
   652  		return nil
   653  	}
   654  	r := b.replacements[tab.rand.Intn(len(b.replacements))]
   655  	b.replacements = deleteNode(b.replacements, r)
   656  	b.entries[len(b.entries)-1] = r
   657  	tab.removeIP(b, last.IP)
   658  	return r
   659  }
   660  
   661  func (b *bucket) bump(n *Node) bool {
   662  	for i := range b.entries {
   663  		if b.entries[i].ID == n.ID {
   664  
   665  			copy(b.entries[1:], b.entries[:i])
   666  			b.entries[0] = n
   667  			return true
   668  		}
   669  	}
   670  	return false
   671  }
   672  
   673  func (tab *Table) bumpOrAdd(b *bucket, n *Node) bool {
   674  	if b.bump(n) {
   675  		return true
   676  	}
   677  	if len(b.entries) >= bucketSize || !tab.addIP(b, n.IP) {
   678  		return false
   679  	}
   680  	b.entries, _ = pushNode(b.entries, n, bucketSize)
   681  	b.replacements = deleteNode(b.replacements, n)
   682  	n.addedAt = time.Now()
   683  	if tab.nodeAddedHook != nil {
   684  		tab.nodeAddedHook(n)
   685  	}
   686  	return true
   687  }
   688  
   689  func (tab *Table) deleteInBucket(b *bucket, n *Node) {
   690  	b.entries = deleteNode(b.entries, n)
   691  	tab.removeIP(b, n.IP)
   692  }
   693  
   694  func pushNode(list []*Node, n *Node, max int) ([]*Node, *Node) {
   695  	if len(list) < max {
   696  		list = append(list, nil)
   697  	}
   698  	removed := list[len(list)-1]
   699  	copy(list[1:], list)
   700  	list[0] = n
   701  	return list, removed
   702  }
   703  
   704  func deleteNode(list []*Node, n *Node) []*Node {
   705  	for i := range list {
   706  		if list[i].ID == n.ID {
   707  			return append(list[:i], list[i+1:]...)
   708  		}
   709  	}
   710  	return list
   711  }
   712  
   713  type nodesByDistance struct {
   714  	entries []*Node
   715  	target  common.Hash
   716  }
   717  
   718  func (h *nodesByDistance) push(n *Node, maxElems int) {
   719  	ix := sort.Search(len(h.entries), func(i int) bool {
   720  		return distcmp(h.target, h.entries[i].sha, n.sha) > 0
   721  	})
   722  	if len(h.entries) < maxElems {
   723  		h.entries = append(h.entries, n)
   724  	}
   725  	if ix == len(h.entries) {
   726  
   727  	} else {
   728  
   729  		copy(h.entries[ix+1:], h.entries[ix:])
   730  		h.entries[ix] = n
   731  	}
   732  }