github.com/jimmyx0x/go-ethereum@v1.10.28/p2p/discover/table_util_test.go (about)

     1  // Copyright 2018 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/ecdsa"
    22  	"encoding/hex"
    23  	"errors"
    24  	"fmt"
    25  	"math/rand"
    26  	"net"
    27  	"sort"
    28  	"sync"
    29  
    30  	"github.com/ethereum/go-ethereum/crypto"
    31  	"github.com/ethereum/go-ethereum/log"
    32  	"github.com/ethereum/go-ethereum/p2p/enode"
    33  	"github.com/ethereum/go-ethereum/p2p/enr"
    34  )
    35  
    36  var nullNode *enode.Node
    37  
    38  func init() {
    39  	var r enr.Record
    40  	r.Set(enr.IP{0, 0, 0, 0})
    41  	nullNode = enode.SignNull(&r, enode.ID{})
    42  }
    43  
    44  func newTestTable(t transport) (*Table, *enode.DB) {
    45  	db, _ := enode.OpenDB("")
    46  	tab, _ := newTable(t, db, nil, log.Root())
    47  	go tab.loop()
    48  	return tab, db
    49  }
    50  
    51  // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
    52  func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node {
    53  	var r enr.Record
    54  	r.Set(enr.IP(ip))
    55  	return wrapNode(enode.SignNull(&r, idAtDistance(base, ld)))
    56  }
    57  
    58  // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
    59  func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
    60  	results := make([]*enode.Node, n)
    61  	for i := range results {
    62  		results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i)))
    63  	}
    64  	return results
    65  }
    66  
    67  func nodesToRecords(nodes []*enode.Node) []*enr.Record {
    68  	records := make([]*enr.Record, len(nodes))
    69  	for i := range nodes {
    70  		records[i] = nodes[i].Record()
    71  	}
    72  	return records
    73  }
    74  
    75  // idAtDistance returns a random hash such that enode.LogDist(a, b) == n
    76  func idAtDistance(a enode.ID, n int) (b enode.ID) {
    77  	if n == 0 {
    78  		return a
    79  	}
    80  	// flip bit at position n, fill the rest with random bits
    81  	b = a
    82  	pos := len(a) - n/8 - 1
    83  	bit := byte(0x01) << (byte(n%8) - 1)
    84  	if bit == 0 {
    85  		pos++
    86  		bit = 0x80
    87  	}
    88  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
    89  	for i := pos + 1; i < len(a); i++ {
    90  		b[i] = byte(rand.Intn(255))
    91  	}
    92  	return b
    93  }
    94  
    95  func intIP(i int) net.IP {
    96  	return net.IP{byte(i), 0, 2, byte(i)}
    97  }
    98  
    99  // fillBucket inserts nodes into the given bucket until it is full.
   100  func fillBucket(tab *Table, n *node) (last *node) {
   101  	ld := enode.LogDist(tab.self().ID(), n.ID())
   102  	b := tab.bucket(n.ID())
   103  	for len(b.entries) < bucketSize {
   104  		b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld)))
   105  	}
   106  	return b.entries[bucketSize-1]
   107  }
   108  
   109  // fillTable adds nodes the table to the end of their corresponding bucket
   110  // if the bucket is not full. The caller must not hold tab.mutex.
   111  func fillTable(tab *Table, nodes []*node) {
   112  	for _, n := range nodes {
   113  		tab.addSeenNode(n)
   114  	}
   115  }
   116  
   117  type pingRecorder struct {
   118  	mu           sync.Mutex
   119  	dead, pinged map[enode.ID]bool
   120  	records      map[enode.ID]*enode.Node
   121  	n            *enode.Node
   122  }
   123  
   124  func newPingRecorder() *pingRecorder {
   125  	var r enr.Record
   126  	r.Set(enr.IP{0, 0, 0, 0})
   127  	n := enode.SignNull(&r, enode.ID{})
   128  
   129  	return &pingRecorder{
   130  		dead:    make(map[enode.ID]bool),
   131  		pinged:  make(map[enode.ID]bool),
   132  		records: make(map[enode.ID]*enode.Node),
   133  		n:       n,
   134  	}
   135  }
   136  
   137  // updateRecord updates a node record. Future calls to ping and
   138  // RequestENR will return this record.
   139  func (t *pingRecorder) updateRecord(n *enode.Node) {
   140  	t.mu.Lock()
   141  	defer t.mu.Unlock()
   142  	t.records[n.ID()] = n
   143  }
   144  
   145  // Stubs to satisfy the transport interface.
   146  func (t *pingRecorder) Self() *enode.Node           { return nullNode }
   147  func (t *pingRecorder) lookupSelf() []*enode.Node   { return nil }
   148  func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
   149  
   150  // ping simulates a ping request.
   151  func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
   152  	t.mu.Lock()
   153  	defer t.mu.Unlock()
   154  
   155  	t.pinged[n.ID()] = true
   156  	if t.dead[n.ID()] {
   157  		return 0, errTimeout
   158  	}
   159  	if t.records[n.ID()] != nil {
   160  		seq = t.records[n.ID()].Seq()
   161  	}
   162  	return seq, nil
   163  }
   164  
   165  // RequestENR simulates an ENR request.
   166  func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
   167  	t.mu.Lock()
   168  	defer t.mu.Unlock()
   169  
   170  	if t.dead[n.ID()] || t.records[n.ID()] == nil {
   171  		return nil, errTimeout
   172  	}
   173  	return t.records[n.ID()], nil
   174  }
   175  
   176  func hasDuplicates(slice []*node) bool {
   177  	seen := make(map[enode.ID]bool)
   178  	for i, e := range slice {
   179  		if e == nil {
   180  			panic(fmt.Sprintf("nil *Node at %d", i))
   181  		}
   182  		if seen[e.ID()] {
   183  			return true
   184  		}
   185  		seen[e.ID()] = true
   186  	}
   187  	return false
   188  }
   189  
   190  // checkNodesEqual checks whether the two given node lists contain the same nodes.
   191  func checkNodesEqual(got, want []*enode.Node) error {
   192  	if len(got) == len(want) {
   193  		for i := range got {
   194  			if !nodeEqual(got[i], want[i]) {
   195  				goto NotEqual
   196  			}
   197  		}
   198  	}
   199  	return nil
   200  
   201  NotEqual:
   202  	output := new(bytes.Buffer)
   203  	fmt.Fprintf(output, "got %d nodes:\n", len(got))
   204  	for _, n := range got {
   205  		fmt.Fprintf(output, "  %v %v\n", n.ID(), n)
   206  	}
   207  	fmt.Fprintf(output, "want %d:\n", len(want))
   208  	for _, n := range want {
   209  		fmt.Fprintf(output, "  %v %v\n", n.ID(), n)
   210  	}
   211  	return errors.New(output.String())
   212  }
   213  
   214  func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
   215  	return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
   216  }
   217  
   218  func sortByID(nodes []*enode.Node) {
   219  	sort.Slice(nodes, func(i, j int) bool {
   220  		return string(nodes[i].ID().Bytes()) < string(nodes[j].ID().Bytes())
   221  	})
   222  }
   223  
   224  func sortedByDistanceTo(distbase enode.ID, slice []*node) bool {
   225  	return sort.SliceIsSorted(slice, func(i, j int) bool {
   226  		return enode.DistCmp(distbase, slice[i].ID(), slice[j].ID()) < 0
   227  	})
   228  }
   229  
   230  // hexEncPrivkey decodes h as a private key.
   231  func hexEncPrivkey(h string) *ecdsa.PrivateKey {
   232  	b, err := hex.DecodeString(h)
   233  	if err != nil {
   234  		panic(err)
   235  	}
   236  	key, err := crypto.ToECDSA(b)
   237  	if err != nil {
   238  		panic(err)
   239  	}
   240  	return key
   241  }
   242  
   243  // hexEncPubkey decodes h as a public key.
   244  func hexEncPubkey(h string) (ret encPubkey) {
   245  	b, err := hex.DecodeString(h)
   246  	if err != nil {
   247  		panic(err)
   248  	}
   249  	if len(b) != len(ret) {
   250  		panic("invalid length")
   251  	}
   252  	copy(ret[:], b)
   253  	return ret
   254  }