github.com/jimmyx0x/go-ethereum@v1.10.28/p2p/discover/table_test.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"fmt"
    22  	"math/rand"
    23  
    24  	"net"
    25  	"reflect"
    26  	"testing"
    27  	"testing/quick"
    28  	"time"
    29  
    30  	"github.com/ethereum/go-ethereum/crypto"
    31  	"github.com/ethereum/go-ethereum/p2p/enode"
    32  	"github.com/ethereum/go-ethereum/p2p/enr"
    33  	"github.com/ethereum/go-ethereum/p2p/netutil"
    34  )
    35  
    36  func TestTable_pingReplace(t *testing.T) {
    37  	run := func(newNodeResponding, lastInBucketResponding bool) {
    38  		name := fmt.Sprintf("newNodeResponding=%t/lastInBucketResponding=%t", newNodeResponding, lastInBucketResponding)
    39  		t.Run(name, func(t *testing.T) {
    40  			t.Parallel()
    41  			testPingReplace(t, newNodeResponding, lastInBucketResponding)
    42  		})
    43  	}
    44  
    45  	run(true, true)
    46  	run(false, true)
    47  	run(true, false)
    48  	run(false, false)
    49  }
    50  
    51  func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
    52  	transport := newPingRecorder()
    53  	tab, db := newTestTable(transport)
    54  	defer db.Close()
    55  	defer tab.close()
    56  
    57  	<-tab.initDone
    58  
    59  	// Fill up the sender's bucket.
    60  	pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
    61  	pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99))
    62  	last := fillBucket(tab, pingSender)
    63  
    64  	// Add the sender as if it just pinged us. Revalidate should replace the last node in
    65  	// its bucket if it is unresponsive. Revalidate again to ensure that
    66  	transport.dead[last.ID()] = !lastInBucketIsResponding
    67  	transport.dead[pingSender.ID()] = !newNodeIsResponding
    68  	tab.addSeenNode(pingSender)
    69  	tab.doRevalidate(make(chan struct{}, 1))
    70  	tab.doRevalidate(make(chan struct{}, 1))
    71  
    72  	if !transport.pinged[last.ID()] {
    73  		// Oldest node in bucket is pinged to see whether it is still alive.
    74  		t.Error("table did not ping last node in bucket")
    75  	}
    76  
    77  	tab.mutex.Lock()
    78  	defer tab.mutex.Unlock()
    79  	wantSize := bucketSize
    80  	if !lastInBucketIsResponding && !newNodeIsResponding {
    81  		wantSize--
    82  	}
    83  	if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize {
    84  		t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize)
    85  	}
    86  	if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding {
    87  		t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding)
    88  	}
    89  	wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
    90  	if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry {
    91  		t.Errorf("new entry found: %t, want: %t", found, wantNewEntry)
    92  	}
    93  }
    94  
    95  func TestBucket_bumpNoDuplicates(t *testing.T) {
    96  	t.Parallel()
    97  	cfg := &quick.Config{
    98  		MaxCount: 1000,
    99  		Rand:     rand.New(rand.NewSource(time.Now().Unix())),
   100  		Values: func(args []reflect.Value, rand *rand.Rand) {
   101  			// generate a random list of nodes. this will be the content of the bucket.
   102  			n := rand.Intn(bucketSize-1) + 1
   103  			nodes := make([]*node, n)
   104  			for i := range nodes {
   105  				nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200))
   106  			}
   107  			args[0] = reflect.ValueOf(nodes)
   108  			// generate random bump positions.
   109  			bumps := make([]int, rand.Intn(100))
   110  			for i := range bumps {
   111  				bumps[i] = rand.Intn(len(nodes))
   112  			}
   113  			args[1] = reflect.ValueOf(bumps)
   114  		},
   115  	}
   116  
   117  	prop := func(nodes []*node, bumps []int) (ok bool) {
   118  		tab, db := newTestTable(newPingRecorder())
   119  		defer db.Close()
   120  		defer tab.close()
   121  
   122  		b := &bucket{entries: make([]*node, len(nodes))}
   123  		copy(b.entries, nodes)
   124  		for i, pos := range bumps {
   125  			tab.bumpInBucket(b, b.entries[pos])
   126  			if hasDuplicates(b.entries) {
   127  				t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps))
   128  				for _, n := range b.entries {
   129  					t.Logf("  %p", n)
   130  				}
   131  				return false
   132  			}
   133  		}
   134  		checkIPLimitInvariant(t, tab)
   135  		return true
   136  	}
   137  	if err := quick.Check(prop, cfg); err != nil {
   138  		t.Error(err)
   139  	}
   140  }
   141  
   142  // This checks that the table-wide IP limit is applied correctly.
   143  func TestTable_IPLimit(t *testing.T) {
   144  	transport := newPingRecorder()
   145  	tab, db := newTestTable(transport)
   146  	defer db.Close()
   147  	defer tab.close()
   148  
   149  	for i := 0; i < tableIPLimit+1; i++ {
   150  		n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
   151  		tab.addSeenNode(n)
   152  	}
   153  	if tab.len() > tableIPLimit {
   154  		t.Errorf("too many nodes in table")
   155  	}
   156  	checkIPLimitInvariant(t, tab)
   157  }
   158  
   159  // This checks that the per-bucket IP limit is applied correctly.
   160  func TestTable_BucketIPLimit(t *testing.T) {
   161  	transport := newPingRecorder()
   162  	tab, db := newTestTable(transport)
   163  	defer db.Close()
   164  	defer tab.close()
   165  
   166  	d := 3
   167  	for i := 0; i < bucketIPLimit+1; i++ {
   168  		n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
   169  		tab.addSeenNode(n)
   170  	}
   171  	if tab.len() > bucketIPLimit {
   172  		t.Errorf("too many nodes in table")
   173  	}
   174  	checkIPLimitInvariant(t, tab)
   175  }
   176  
   177  // checkIPLimitInvariant checks that ip limit sets contain an entry for every
   178  // node in the table and no extra entries.
   179  func checkIPLimitInvariant(t *testing.T, tab *Table) {
   180  	t.Helper()
   181  
   182  	tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}
   183  	for _, b := range tab.buckets {
   184  		for _, n := range b.entries {
   185  			tabset.Add(n.IP())
   186  		}
   187  	}
   188  	if tabset.String() != tab.ips.String() {
   189  		t.Errorf("table IP set is incorrect:\nhave: %v\nwant: %v", tab.ips, tabset)
   190  	}
   191  }
   192  
   193  func TestTable_findnodeByID(t *testing.T) {
   194  	t.Parallel()
   195  
   196  	test := func(test *closeTest) bool {
   197  		// for any node table, Target and N
   198  		transport := newPingRecorder()
   199  		tab, db := newTestTable(transport)
   200  		defer db.Close()
   201  		defer tab.close()
   202  		fillTable(tab, test.All)
   203  
   204  		// check that closest(Target, N) returns nodes
   205  		result := tab.findnodeByID(test.Target, test.N, false).entries
   206  		if hasDuplicates(result) {
   207  			t.Errorf("result contains duplicates")
   208  			return false
   209  		}
   210  		if !sortedByDistanceTo(test.Target, result) {
   211  			t.Errorf("result is not sorted by distance to target")
   212  			return false
   213  		}
   214  
   215  		// check that the number of results is min(N, tablen)
   216  		wantN := test.N
   217  		if tlen := tab.len(); tlen < test.N {
   218  			wantN = tlen
   219  		}
   220  		if len(result) != wantN {
   221  			t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
   222  			return false
   223  		} else if len(result) == 0 {
   224  			return true // no need to check distance
   225  		}
   226  
   227  		// check that the result nodes have minimum distance to target.
   228  		for _, b := range tab.buckets {
   229  			for _, n := range b.entries {
   230  				if contains(result, n.ID()) {
   231  					continue // don't run the check below for nodes in result
   232  				}
   233  				farthestResult := result[len(result)-1].ID()
   234  				if enode.DistCmp(test.Target, n.ID(), farthestResult) < 0 {
   235  					t.Errorf("table contains node that is closer to target but it's not in result")
   236  					t.Logf("  Target:          %v", test.Target)
   237  					t.Logf("  Farthest Result: %v", farthestResult)
   238  					t.Logf("  ID:              %v", n.ID())
   239  					return false
   240  				}
   241  			}
   242  		}
   243  		return true
   244  	}
   245  	if err := quick.Check(test, quickcfg()); err != nil {
   246  		t.Error(err)
   247  	}
   248  }
   249  
   250  func TestTable_ReadRandomNodesGetAll(t *testing.T) {
   251  	cfg := &quick.Config{
   252  		MaxCount: 200,
   253  		Rand:     rand.New(rand.NewSource(time.Now().Unix())),
   254  		Values: func(args []reflect.Value, rand *rand.Rand) {
   255  			args[0] = reflect.ValueOf(make([]*enode.Node, rand.Intn(1000)))
   256  		},
   257  	}
   258  	test := func(buf []*enode.Node) bool {
   259  		transport := newPingRecorder()
   260  		tab, db := newTestTable(transport)
   261  		defer db.Close()
   262  		defer tab.close()
   263  		<-tab.initDone
   264  
   265  		for i := 0; i < len(buf); i++ {
   266  			ld := cfg.Rand.Intn(len(tab.buckets))
   267  			fillTable(tab, []*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
   268  		}
   269  		gotN := tab.ReadRandomNodes(buf)
   270  		if gotN != tab.len() {
   271  			t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len())
   272  			return false
   273  		}
   274  		if hasDuplicates(wrapNodes(buf[:gotN])) {
   275  			t.Errorf("result contains duplicates")
   276  			return false
   277  		}
   278  		return true
   279  	}
   280  	if err := quick.Check(test, cfg); err != nil {
   281  		t.Error(err)
   282  	}
   283  }
   284  
   285  type closeTest struct {
   286  	Self   enode.ID
   287  	Target enode.ID
   288  	All    []*node
   289  	N      int
   290  }
   291  
   292  func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
   293  	t := &closeTest{
   294  		Self:   gen(enode.ID{}, rand).(enode.ID),
   295  		Target: gen(enode.ID{}, rand).(enode.ID),
   296  		N:      rand.Intn(bucketSize),
   297  	}
   298  	for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
   299  		r := new(enr.Record)
   300  		r.Set(enr.IP(genIP(rand)))
   301  		n := wrapNode(enode.SignNull(r, id))
   302  		n.livenessChecks = 1
   303  		t.All = append(t.All, n)
   304  	}
   305  	return reflect.ValueOf(t)
   306  }
   307  
   308  func TestTable_addVerifiedNode(t *testing.T) {
   309  	tab, db := newTestTable(newPingRecorder())
   310  	<-tab.initDone
   311  	defer db.Close()
   312  	defer tab.close()
   313  
   314  	// Insert two nodes.
   315  	n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
   316  	n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
   317  	tab.addSeenNode(n1)
   318  	tab.addSeenNode(n2)
   319  
   320  	// Verify bucket content:
   321  	bcontent := []*node{n1, n2}
   322  	if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
   323  		t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries)
   324  	}
   325  
   326  	// Add a changed version of n2.
   327  	newrec := n2.Record()
   328  	newrec.Set(enr.IP{99, 99, 99, 99})
   329  	newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
   330  	tab.addVerifiedNode(newn2)
   331  
   332  	// Check that bucket is updated correctly.
   333  	newBcontent := []*node{newn2, n1}
   334  	if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, newBcontent) {
   335  		t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries)
   336  	}
   337  	checkIPLimitInvariant(t, tab)
   338  }
   339  
   340  func TestTable_addSeenNode(t *testing.T) {
   341  	tab, db := newTestTable(newPingRecorder())
   342  	<-tab.initDone
   343  	defer db.Close()
   344  	defer tab.close()
   345  
   346  	// Insert two nodes.
   347  	n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
   348  	n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
   349  	tab.addSeenNode(n1)
   350  	tab.addSeenNode(n2)
   351  
   352  	// Verify bucket content:
   353  	bcontent := []*node{n1, n2}
   354  	if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
   355  		t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries)
   356  	}
   357  
   358  	// Add a changed version of n2.
   359  	newrec := n2.Record()
   360  	newrec.Set(enr.IP{99, 99, 99, 99})
   361  	newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
   362  	tab.addSeenNode(newn2)
   363  
   364  	// Check that bucket content is unchanged.
   365  	if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
   366  		t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries)
   367  	}
   368  	checkIPLimitInvariant(t, tab)
   369  }
   370  
   371  // This test checks that ENR updates happen during revalidation. If a node in the table
   372  // announces a new sequence number, the new record should be pulled.
   373  func TestTable_revalidateSyncRecord(t *testing.T) {
   374  	transport := newPingRecorder()
   375  	tab, db := newTestTable(transport)
   376  	<-tab.initDone
   377  	defer db.Close()
   378  	defer tab.close()
   379  
   380  	// Insert a node.
   381  	var r enr.Record
   382  	r.Set(enr.IP(net.IP{127, 0, 0, 1}))
   383  	id := enode.ID{1}
   384  	n1 := wrapNode(enode.SignNull(&r, id))
   385  	tab.addSeenNode(n1)
   386  
   387  	// Update the node record.
   388  	r.Set(enr.WithEntry("foo", "bar"))
   389  	n2 := enode.SignNull(&r, id)
   390  	transport.updateRecord(n2)
   391  
   392  	tab.doRevalidate(make(chan struct{}, 1))
   393  	intable := tab.getNode(id)
   394  	if !reflect.DeepEqual(intable, n2) {
   395  		t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq())
   396  	}
   397  }
   398  
   399  func TestNodesPush(t *testing.T) {
   400  	var target enode.ID
   401  	n1 := nodeAtDistance(target, 255, intIP(1))
   402  	n2 := nodeAtDistance(target, 254, intIP(2))
   403  	n3 := nodeAtDistance(target, 253, intIP(3))
   404  	perm := [][]*node{
   405  		{n3, n2, n1},
   406  		{n3, n1, n2},
   407  		{n2, n3, n1},
   408  		{n2, n1, n3},
   409  		{n1, n3, n2},
   410  		{n1, n2, n3},
   411  	}
   412  
   413  	// Insert all permutations into lists with size limit 3.
   414  	for _, nodes := range perm {
   415  		list := nodesByDistance{target: target}
   416  		for _, n := range nodes {
   417  			list.push(n, 3)
   418  		}
   419  		if !slicesEqual(list.entries, perm[0], nodeIDEqual) {
   420  			t.Fatal("not equal")
   421  		}
   422  	}
   423  
   424  	// Insert all permutations into lists with size limit 2.
   425  	for _, nodes := range perm {
   426  		list := nodesByDistance{target: target}
   427  		for _, n := range nodes {
   428  			list.push(n, 2)
   429  		}
   430  		if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) {
   431  			t.Fatal("not equal")
   432  		}
   433  	}
   434  }
   435  
   436  func nodeIDEqual(n1, n2 *node) bool {
   437  	return n1.ID() == n2.ID()
   438  }
   439  
   440  func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool {
   441  	if len(s1) != len(s2) {
   442  		return false
   443  	}
   444  	for i := range s1 {
   445  		if !check(s1[i], s2[i]) {
   446  			return false
   447  		}
   448  	}
   449  	return true
   450  }
   451  
   452  // gen wraps quick.Value so it's easier to use.
   453  // it generates a random value of the given value's type.
   454  func gen(typ interface{}, rand *rand.Rand) interface{} {
   455  	v, ok := quick.Value(reflect.TypeOf(typ), rand)
   456  	if !ok {
   457  		panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
   458  	}
   459  	return v.Interface()
   460  }
   461  
   462  func genIP(rand *rand.Rand) net.IP {
   463  	ip := make(net.IP, 4)
   464  	rand.Read(ip)
   465  	return ip
   466  }
   467  
   468  func quickcfg() *quick.Config {
   469  	return &quick.Config{
   470  		MaxCount: 5000,
   471  		Rand:     rand.New(rand.NewSource(time.Now().Unix())),
   472  	}
   473  }
   474  
   475  func newkey() *ecdsa.PrivateKey {
   476  	key, err := crypto.GenerateKey()
   477  	if err != nil {
   478  		panic("couldn't generate key: " + err.Error())
   479  	}
   480  	return key
   481  }