github.com/sunjiahui/go-ethereum@v1.10.31/p2p/enr/enr_test.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package enr
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"github.com/sunjiahui/go-ethereum/rlp"
    30  )
    31  
    32  var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
    33  
    34  func randomString(strlen int) string {
    35  	b := make([]byte, strlen)
    36  	rnd.Read(b)
    37  	return string(b)
    38  }
    39  
    40  // TestGetSetID tests encoding/decoding and setting/getting of the ID key.
    41  func TestGetSetID(t *testing.T) {
    42  	id := ID("someid")
    43  	var r Record
    44  	r.Set(id)
    45  
    46  	var id2 ID
    47  	require.NoError(t, r.Load(&id2))
    48  	assert.Equal(t, id, id2)
    49  }
    50  
    51  // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key.
    52  func TestGetSetIPv4(t *testing.T) {
    53  	ip := IPv4{192, 168, 0, 3}
    54  	var r Record
    55  	r.Set(ip)
    56  
    57  	var ip2 IPv4
    58  	require.NoError(t, r.Load(&ip2))
    59  	assert.Equal(t, ip, ip2)
    60  }
    61  
    62  // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
    63  func TestGetSetIPv6(t *testing.T) {
    64  	ip := IPv6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
    65  	var r Record
    66  	r.Set(ip)
    67  
    68  	var ip2 IPv6
    69  	require.NoError(t, r.Load(&ip2))
    70  	assert.Equal(t, ip, ip2)
    71  }
    72  
    73  // TestGetSetUDP tests encoding/decoding and setting/getting of the UDP key.
    74  func TestGetSetUDP(t *testing.T) {
    75  	port := UDP(30309)
    76  	var r Record
    77  	r.Set(port)
    78  
    79  	var port2 UDP
    80  	require.NoError(t, r.Load(&port2))
    81  	assert.Equal(t, port, port2)
    82  }
    83  
    84  func TestLoadErrors(t *testing.T) {
    85  	var r Record
    86  	ip4 := IPv4{127, 0, 0, 1}
    87  	r.Set(ip4)
    88  
    89  	// Check error for missing keys.
    90  	var udp UDP
    91  	err := r.Load(&udp)
    92  	if !IsNotFound(err) {
    93  		t.Error("IsNotFound should return true for missing key")
    94  	}
    95  	assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err)
    96  
    97  	// Check error for invalid keys.
    98  	var list []uint
    99  	err = r.Load(WithEntry(ip4.ENRKey(), &list))
   100  	kerr, ok := err.(*KeyError)
   101  	if !ok {
   102  		t.Fatalf("expected KeyError, got %T", err)
   103  	}
   104  	assert.Equal(t, kerr.Key, ip4.ENRKey())
   105  	assert.Error(t, kerr.Err)
   106  	if IsNotFound(err) {
   107  		t.Error("IsNotFound should return false for decoding errors")
   108  	}
   109  }
   110  
   111  // TestSortedGetAndSet tests that Set produced a sorted pairs slice.
   112  func TestSortedGetAndSet(t *testing.T) {
   113  	type pair struct {
   114  		k string
   115  		v uint32
   116  	}
   117  
   118  	for _, tt := range []struct {
   119  		input []pair
   120  		want  []pair
   121  	}{
   122  		{
   123  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
   124  			want:  []pair{{"a", 1}, {"b", 3}, {"c", 2}},
   125  		},
   126  		{
   127  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   128  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   129  		},
   130  		{
   131  			input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   132  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   133  		},
   134  	} {
   135  		var r Record
   136  		for _, i := range tt.input {
   137  			r.Set(WithEntry(i.k, &i.v))
   138  		}
   139  		for i, w := range tt.want {
   140  			// set got's key from r.pair[i], so that we preserve order of pairs
   141  			got := pair{k: r.pairs[i].k}
   142  			assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
   143  			assert.Equal(t, w, got)
   144  		}
   145  	}
   146  }
   147  
   148  // TestDirty tests record signature removal on setting of new key/value pair in record.
   149  func TestDirty(t *testing.T) {
   150  	var r Record
   151  
   152  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   153  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   154  	}
   155  
   156  	require.NoError(t, signTest([]byte{5}, &r))
   157  	if len(r.signature) == 0 {
   158  		t.Error("record is not signed")
   159  	}
   160  	_, err := rlp.EncodeToBytes(r)
   161  	assert.NoError(t, err)
   162  
   163  	r.SetSeq(3)
   164  	if len(r.signature) != 0 {
   165  		t.Error("signature still set after modification")
   166  	}
   167  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   168  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   169  	}
   170  }
   171  
   172  func TestSize(t *testing.T) {
   173  	var r Record
   174  
   175  	// Empty record size is 3 bytes.
   176  	// Unsigned records cannot be encoded, but they could, the encoding
   177  	// would be [ 0, 0 ] -> 0xC28080.
   178  	assert.Equal(t, uint64(3), r.Size())
   179  
   180  	// Add one attribute. The size increases to 5, the encoding
   181  	// would be [ 0, 0, "k", "v" ] -> 0xC58080C26B76.
   182  	r.Set(WithEntry("k", "v"))
   183  	assert.Equal(t, uint64(5), r.Size())
   184  
   185  	// Now add a signature.
   186  	nodeid := []byte{1, 2, 3, 4, 5, 6, 7, 8}
   187  	signTest(nodeid, &r)
   188  	assert.Equal(t, uint64(45), r.Size())
   189  	enc, _ := rlp.EncodeToBytes(&r)
   190  	if r.Size() != uint64(len(enc)) {
   191  		t.Error("Size() not equal encoded length", len(enc))
   192  	}
   193  	if r.Size() != computeSize(&r) {
   194  		t.Error("Size() not equal computed size", computeSize(&r))
   195  	}
   196  }
   197  
   198  func TestSeq(t *testing.T) {
   199  	var r Record
   200  
   201  	assert.Equal(t, uint64(0), r.Seq())
   202  	r.Set(UDP(1))
   203  	assert.Equal(t, uint64(0), r.Seq())
   204  	signTest([]byte{5}, &r)
   205  	assert.Equal(t, uint64(0), r.Seq())
   206  	r.Set(UDP(2))
   207  	assert.Equal(t, uint64(1), r.Seq())
   208  }
   209  
   210  // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
   211  func TestGetSetOverwrite(t *testing.T) {
   212  	var r Record
   213  
   214  	ip := IPv4{192, 168, 0, 3}
   215  	r.Set(ip)
   216  
   217  	ip2 := IPv4{192, 168, 0, 4}
   218  	r.Set(ip2)
   219  
   220  	var ip3 IPv4
   221  	require.NoError(t, r.Load(&ip3))
   222  	assert.Equal(t, ip2, ip3)
   223  }
   224  
   225  // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
   226  func TestSignEncodeAndDecode(t *testing.T) {
   227  	var r Record
   228  	r.Set(UDP(30303))
   229  	r.Set(IPv4{127, 0, 0, 1})
   230  	require.NoError(t, signTest([]byte{5}, &r))
   231  
   232  	blob, err := rlp.EncodeToBytes(r)
   233  	require.NoError(t, err)
   234  
   235  	var r2 Record
   236  	require.NoError(t, rlp.DecodeBytes(blob, &r2))
   237  	assert.Equal(t, r, r2)
   238  
   239  	blob2, err := rlp.EncodeToBytes(r2)
   240  	require.NoError(t, err)
   241  	assert.Equal(t, blob, blob2)
   242  }
   243  
   244  // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
   245  func TestRecordTooBig(t *testing.T) {
   246  	var r Record
   247  	key := randomString(10)
   248  
   249  	// set a big value for random key, expect error
   250  	r.Set(WithEntry(key, randomString(SizeLimit)))
   251  	if err := signTest([]byte{5}, &r); err != errTooBig {
   252  		t.Fatalf("expected to get errTooBig, got %#v", err)
   253  	}
   254  
   255  	// set an acceptable value for random key, expect no error
   256  	r.Set(WithEntry(key, randomString(100)))
   257  	require.NoError(t, signTest([]byte{5}, &r))
   258  }
   259  
   260  // This checks that incomplete RLP inputs are handled correctly.
   261  func TestDecodeIncomplete(t *testing.T) {
   262  	type decTest struct {
   263  		input []byte
   264  		err   error
   265  	}
   266  	tests := []decTest{
   267  		{[]byte{0xC0}, errIncompleteList},
   268  		{[]byte{0xC1, 0x1}, errIncompleteList},
   269  		{[]byte{0xC2, 0x1, 0x2}, nil},
   270  		{[]byte{0xC3, 0x1, 0x2, 0x3}, errIncompletePair},
   271  		{[]byte{0xC4, 0x1, 0x2, 0x3, 0x4}, nil},
   272  		{[]byte{0xC5, 0x1, 0x2, 0x3, 0x4, 0x5}, errIncompletePair},
   273  	}
   274  	for _, test := range tests {
   275  		var r Record
   276  		err := rlp.DecodeBytes(test.input, &r)
   277  		if err != test.err {
   278  			t.Errorf("wrong error for %X: %v", test.input, err)
   279  		}
   280  	}
   281  }
   282  
   283  // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
   284  func TestSignEncodeAndDecodeRandom(t *testing.T) {
   285  	var r Record
   286  
   287  	// random key/value pairs for testing
   288  	pairs := map[string]uint32{}
   289  	for i := 0; i < 10; i++ {
   290  		key := randomString(7)
   291  		value := rnd.Uint32()
   292  		pairs[key] = value
   293  		r.Set(WithEntry(key, &value))
   294  	}
   295  
   296  	require.NoError(t, signTest([]byte{5}, &r))
   297  
   298  	enc, err := rlp.EncodeToBytes(r)
   299  	require.NoError(t, err)
   300  	require.Equal(t, uint64(len(enc)), r.Size())
   301  	require.Equal(t, uint64(len(enc)), computeSize(&r))
   302  
   303  	for k, v := range pairs {
   304  		desc := fmt.Sprintf("key %q", k)
   305  		var got uint32
   306  		buf := WithEntry(k, &got)
   307  		require.NoError(t, r.Load(buf), desc)
   308  		require.Equal(t, v, got, desc)
   309  	}
   310  }
   311  
   312  type testSig struct{}
   313  
   314  type testID []byte
   315  
   316  func (id testID) ENRKey() string { return "testid" }
   317  
   318  func signTest(id []byte, r *Record) error {
   319  	r.Set(ID("test"))
   320  	r.Set(testID(id))
   321  	return r.SetSig(testSig{}, makeTestSig(id, r.Seq()))
   322  }
   323  
   324  func makeTestSig(id []byte, seq uint64) []byte {
   325  	sig := make([]byte, 8, len(id)+8)
   326  	binary.BigEndian.PutUint64(sig[:8], seq)
   327  	sig = append(sig, id...)
   328  	return sig
   329  }
   330  
   331  func (testSig) Verify(r *Record, sig []byte) error {
   332  	var id []byte
   333  	if err := r.Load((*testID)(&id)); err != nil {
   334  		return err
   335  	}
   336  	if !bytes.Equal(sig, makeTestSig(id, r.Seq())) {
   337  		return ErrInvalidSig
   338  	}
   339  	return nil
   340  }
   341  
   342  func (testSig) NodeAddr(r *Record) []byte {
   343  	var id []byte
   344  	if err := r.Load((*testID)(&id)); err != nil {
   345  		return nil
   346  	}
   347  	return id
   348  }