github.com/ethereum/go-ethereum@v1.16.1/p2p/discover/table_util_test.go (about)

     1  // Copyright 2018 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/ecdsa"
    22  	"encoding/hex"
    23  	"errors"
    24  	"fmt"
    25  	"math/rand"
    26  	"net"
    27  	"slices"
    28  	"sync"
    29  	"sync/atomic"
    30  	"time"
    31  
    32  	"github.com/ethereum/go-ethereum/crypto"
    33  	"github.com/ethereum/go-ethereum/p2p/discover/v4wire"
    34  	"github.com/ethereum/go-ethereum/p2p/enode"
    35  	"github.com/ethereum/go-ethereum/p2p/enr"
    36  )
    37  
    38  var nullNode *enode.Node
    39  
    40  func init() {
    41  	var r enr.Record
    42  	r.Set(enr.IP{0, 0, 0, 0})
    43  	nullNode = enode.SignNull(&r, enode.ID{})
    44  }
    45  
    46  func newTestTable(t transport, cfg Config) (*Table, *enode.DB) {
    47  	tab, db := newInactiveTestTable(t, cfg)
    48  	go tab.loop()
    49  	return tab, db
    50  }
    51  
    52  // newInactiveTestTable creates a Table without running the main loop.
    53  func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) {
    54  	db, _ := enode.OpenDB("")
    55  	tab, _ := newTable(t, db, cfg)
    56  	return tab, db
    57  }
    58  
    59  // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
    60  func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node {
    61  	var r enr.Record
    62  	r.Set(enr.IP(ip))
    63  	r.Set(enr.UDP(30303))
    64  	return enode.SignNull(&r, idAtDistance(base, ld))
    65  }
    66  
    67  // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
    68  func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
    69  	results := make([]*enode.Node, n)
    70  	for i := range results {
    71  		results[i] = nodeAtDistance(base, ld, intIP(i))
    72  	}
    73  	return results
    74  }
    75  
    76  func nodesToRecords(nodes []*enode.Node) []*enr.Record {
    77  	records := make([]*enr.Record, len(nodes))
    78  	for i := range nodes {
    79  		records[i] = nodes[i].Record()
    80  	}
    81  	return records
    82  }
    83  
    84  // idAtDistance returns a random hash such that enode.LogDist(a, b) == n
    85  func idAtDistance(a enode.ID, n int) (b enode.ID) {
    86  	if n == 0 {
    87  		return a
    88  	}
    89  	// flip bit at position n, fill the rest with random bits
    90  	b = a
    91  	pos := len(a) - n/8 - 1
    92  	bit := byte(0x01) << (byte(n%8) - 1)
    93  	if bit == 0 {
    94  		pos++
    95  		bit = 0x80
    96  	}
    97  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
    98  	for i := pos + 1; i < len(a); i++ {
    99  		b[i] = byte(rand.Intn(255))
   100  	}
   101  	return b
   102  }
   103  
   104  // intIP returns a LAN IP address based on i.
   105  func intIP(i int) net.IP {
   106  	return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
   107  }
   108  
   109  // fillBucket inserts nodes into the given bucket until it is full.
   110  func fillBucket(tab *Table, id enode.ID) (last *tableNode) {
   111  	ld := enode.LogDist(tab.self().ID(), id)
   112  	b := tab.bucket(id)
   113  	for len(b.entries) < bucketSize {
   114  		node := nodeAtDistance(tab.self().ID(), ld, intIP(ld))
   115  		if !tab.addFoundNode(node, false) {
   116  			panic("node not added")
   117  		}
   118  	}
   119  	return b.entries[bucketSize-1]
   120  }
   121  
   122  // fillTable adds nodes the table to the end of their corresponding bucket
   123  // if the bucket is not full. The caller must not hold tab.mutex.
   124  func fillTable(tab *Table, nodes []*enode.Node, setLive bool) {
   125  	for _, n := range nodes {
   126  		tab.addFoundNode(n, setLive)
   127  	}
   128  }
   129  
   130  type pingRecorder struct {
   131  	mu      sync.Mutex
   132  	cond    *sync.Cond
   133  	dead    map[enode.ID]bool
   134  	records map[enode.ID]*enode.Node
   135  	pinged  []*enode.Node
   136  	n       *enode.Node
   137  }
   138  
   139  func newPingRecorder() *pingRecorder {
   140  	var r enr.Record
   141  	r.Set(enr.IP{0, 0, 0, 0})
   142  	n := enode.SignNull(&r, enode.ID{})
   143  
   144  	t := &pingRecorder{
   145  		dead:    make(map[enode.ID]bool),
   146  		records: make(map[enode.ID]*enode.Node),
   147  		n:       n,
   148  	}
   149  	t.cond = sync.NewCond(&t.mu)
   150  	return t
   151  }
   152  
   153  // updateRecord updates a node record. Future calls to ping and
   154  // RequestENR will return this record.
   155  func (t *pingRecorder) updateRecord(n *enode.Node) {
   156  	t.mu.Lock()
   157  	defer t.mu.Unlock()
   158  	t.records[n.ID()] = n
   159  }
   160  
   161  // Stubs to satisfy the transport interface.
   162  func (t *pingRecorder) Self() *enode.Node           { return nullNode }
   163  func (t *pingRecorder) lookupSelf() []*enode.Node   { return nil }
   164  func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
   165  
   166  func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node {
   167  	t.mu.Lock()
   168  	defer t.mu.Unlock()
   169  
   170  	// Wake up the loop on timeout.
   171  	var timedout atomic.Bool
   172  	timer := time.AfterFunc(timeout, func() {
   173  		timedout.Store(true)
   174  		t.cond.Broadcast()
   175  	})
   176  	defer timer.Stop()
   177  
   178  	// Wait for a ping.
   179  	for {
   180  		if timedout.Load() {
   181  			return nil
   182  		}
   183  		if len(t.pinged) > 0 {
   184  			n := t.pinged[0]
   185  			t.pinged = append(t.pinged[:0], t.pinged[1:]...)
   186  			return n
   187  		}
   188  		t.cond.Wait()
   189  	}
   190  }
   191  
   192  // ping simulates a ping request.
   193  func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
   194  	t.mu.Lock()
   195  	defer t.mu.Unlock()
   196  
   197  	t.pinged = append(t.pinged, n)
   198  	t.cond.Broadcast()
   199  
   200  	if t.dead[n.ID()] {
   201  		return 0, errTimeout
   202  	}
   203  	if t.records[n.ID()] != nil {
   204  		seq = t.records[n.ID()].Seq()
   205  	}
   206  	return seq, nil
   207  }
   208  
   209  // RequestENR simulates an ENR request.
   210  func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
   211  	t.mu.Lock()
   212  	defer t.mu.Unlock()
   213  
   214  	if t.dead[n.ID()] || t.records[n.ID()] == nil {
   215  		return nil, errTimeout
   216  	}
   217  	return t.records[n.ID()], nil
   218  }
   219  
   220  func hasDuplicates(slice []*enode.Node) bool {
   221  	seen := make(map[enode.ID]bool, len(slice))
   222  	for i, e := range slice {
   223  		if e == nil {
   224  			panic(fmt.Sprintf("nil *Node at %d", i))
   225  		}
   226  		if seen[e.ID()] {
   227  			return true
   228  		}
   229  		seen[e.ID()] = true
   230  	}
   231  	return false
   232  }
   233  
   234  // checkNodesEqual checks whether the two given node lists contain the same nodes.
   235  func checkNodesEqual(got, want []*enode.Node) error {
   236  	if len(got) == len(want) {
   237  		for i := range got {
   238  			if !nodeEqual(got[i], want[i]) {
   239  				goto NotEqual
   240  			}
   241  		}
   242  	}
   243  	return nil
   244  
   245  NotEqual:
   246  	output := new(bytes.Buffer)
   247  	fmt.Fprintf(output, "got %d nodes:\n", len(got))
   248  	for _, n := range got {
   249  		fmt.Fprintf(output, "  %v %v\n", n.ID(), n)
   250  	}
   251  	fmt.Fprintf(output, "want %d:\n", len(want))
   252  	for _, n := range want {
   253  		fmt.Fprintf(output, "  %v %v\n", n.ID(), n)
   254  	}
   255  	return errors.New(output.String())
   256  }
   257  
   258  func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
   259  	return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
   260  }
   261  
   262  func sortByID[N nodeType](nodes []N) {
   263  	slices.SortFunc(nodes, func(a, b N) int {
   264  		return bytes.Compare(a.ID().Bytes(), b.ID().Bytes())
   265  	})
   266  }
   267  
   268  func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool {
   269  	return slices.IsSortedFunc(slice, func(a, b *enode.Node) int {
   270  		return enode.DistCmp(distbase, a.ID(), b.ID())
   271  	})
   272  }
   273  
   274  // hexEncPrivkey decodes h as a private key.
   275  func hexEncPrivkey(h string) *ecdsa.PrivateKey {
   276  	b, err := hex.DecodeString(h)
   277  	if err != nil {
   278  		panic(err)
   279  	}
   280  	key, err := crypto.ToECDSA(b)
   281  	if err != nil {
   282  		panic(err)
   283  	}
   284  	return key
   285  }
   286  
   287  // hexEncPubkey decodes h as a public key.
   288  func hexEncPubkey(h string) (ret v4wire.Pubkey) {
   289  	b, err := hex.DecodeString(h)
   290  	if err != nil {
   291  		panic(err)
   292  	}
   293  	if len(b) != len(ret) {
   294  		panic("invalid length")
   295  	}
   296  	copy(ret[:], b)
   297  	return ret
   298  }
   299  
   300  type nodeEventRecorder struct {
   301  	evc chan recordedNodeEvent
   302  }
   303  
   304  type recordedNodeEvent struct {
   305  	node  *tableNode
   306  	added bool
   307  }
   308  
   309  func newNodeEventRecorder(buffer int) *nodeEventRecorder {
   310  	return &nodeEventRecorder{
   311  		evc: make(chan recordedNodeEvent, buffer),
   312  	}
   313  }
   314  
   315  func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) {
   316  	select {
   317  	case set.evc <- recordedNodeEvent{n, true}:
   318  	default:
   319  		panic("no space in event buffer")
   320  	}
   321  }
   322  
   323  func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) {
   324  	select {
   325  	case set.evc <- recordedNodeEvent{n, false}:
   326  	default:
   327  		panic("no space in event buffer")
   328  	}
   329  }
   330  
   331  func (set *nodeEventRecorder) waitNodePresent(id enode.ID, timeout time.Duration) bool {
   332  	return set.waitNodeEvent(id, timeout, true)
   333  }
   334  
   335  func (set *nodeEventRecorder) waitNodeAbsent(id enode.ID, timeout time.Duration) bool {
   336  	return set.waitNodeEvent(id, timeout, false)
   337  }
   338  
   339  func (set *nodeEventRecorder) waitNodeEvent(id enode.ID, timeout time.Duration, added bool) bool {
   340  	timer := time.NewTimer(timeout)
   341  	defer timer.Stop()
   342  	for {
   343  		select {
   344  		case ev := <-set.evc:
   345  			if ev.node.ID() == id && ev.added == added {
   346  				return true
   347  			}
   348  		case <-timer.C:
   349  			return false
   350  		}
   351  	}
   352  }