github.com/pranksteess/go-ethereum@v1.9.7/p2p/enr/enr_test.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package enr
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/rlp"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
    33  
    34  func randomString(strlen int) string {
    35  	b := make([]byte, strlen)
    36  	rnd.Read(b)
    37  	return string(b)
    38  }
    39  
    40  // TestGetSetID tests encoding/decoding and setting/getting of the ID key.
    41  func TestGetSetID(t *testing.T) {
    42  	id := ID("someid")
    43  	var r Record
    44  	r.Set(id)
    45  
    46  	var id2 ID
    47  	require.NoError(t, r.Load(&id2))
    48  	assert.Equal(t, id, id2)
    49  }
    50  
    51  // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key.
    52  func TestGetSetIPv4(t *testing.T) {
    53  	ip := IPv4{192, 168, 0, 3}
    54  	var r Record
    55  	r.Set(ip)
    56  
    57  	var ip2 IPv4
    58  	require.NoError(t, r.Load(&ip2))
    59  	assert.Equal(t, ip, ip2)
    60  }
    61  
    62  // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
    63  func TestGetSetIPv6(t *testing.T) {
    64  	ip := IPv6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
    65  	var r Record
    66  	r.Set(ip)
    67  
    68  	var ip2 IPv6
    69  	require.NoError(t, r.Load(&ip2))
    70  	assert.Equal(t, ip, ip2)
    71  }
    72  
    73  // TestGetSetUDP tests encoding/decoding and setting/getting of the UDP key.
    74  func TestGetSetUDP(t *testing.T) {
    75  	port := UDP(30309)
    76  	var r Record
    77  	r.Set(port)
    78  
    79  	var port2 UDP
    80  	require.NoError(t, r.Load(&port2))
    81  	assert.Equal(t, port, port2)
    82  }
    83  
    84  func TestLoadErrors(t *testing.T) {
    85  	var r Record
    86  	ip4 := IPv4{127, 0, 0, 1}
    87  	r.Set(ip4)
    88  
    89  	// Check error for missing keys.
    90  	var udp UDP
    91  	err := r.Load(&udp)
    92  	if !IsNotFound(err) {
    93  		t.Error("IsNotFound should return true for missing key")
    94  	}
    95  	assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err)
    96  
    97  	// Check error for invalid keys.
    98  	var list []uint
    99  	err = r.Load(WithEntry(ip4.ENRKey(), &list))
   100  	kerr, ok := err.(*KeyError)
   101  	if !ok {
   102  		t.Fatalf("expected KeyError, got %T", err)
   103  	}
   104  	assert.Equal(t, kerr.Key, ip4.ENRKey())
   105  	assert.Error(t, kerr.Err)
   106  	if IsNotFound(err) {
   107  		t.Error("IsNotFound should return false for decoding errors")
   108  	}
   109  }
   110  
   111  // TestSortedGetAndSet tests that Set produced a sorted pairs slice.
   112  func TestSortedGetAndSet(t *testing.T) {
   113  	type pair struct {
   114  		k string
   115  		v uint32
   116  	}
   117  
   118  	for _, tt := range []struct {
   119  		input []pair
   120  		want  []pair
   121  	}{
   122  		{
   123  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
   124  			want:  []pair{{"a", 1}, {"b", 3}, {"c", 2}},
   125  		},
   126  		{
   127  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   128  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   129  		},
   130  		{
   131  			input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   132  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   133  		},
   134  	} {
   135  		var r Record
   136  		for _, i := range tt.input {
   137  			r.Set(WithEntry(i.k, &i.v))
   138  		}
   139  		for i, w := range tt.want {
   140  			// set got's key from r.pair[i], so that we preserve order of pairs
   141  			got := pair{k: r.pairs[i].k}
   142  			assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
   143  			assert.Equal(t, w, got)
   144  		}
   145  	}
   146  }
   147  
   148  // TestDirty tests record signature removal on setting of new key/value pair in record.
   149  func TestDirty(t *testing.T) {
   150  	var r Record
   151  
   152  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   153  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   154  	}
   155  
   156  	require.NoError(t, signTest([]byte{5}, &r))
   157  	if len(r.signature) == 0 {
   158  		t.Error("record is not signed")
   159  	}
   160  	_, err := rlp.EncodeToBytes(r)
   161  	assert.NoError(t, err)
   162  
   163  	r.SetSeq(3)
   164  	if len(r.signature) != 0 {
   165  		t.Error("signature still set after modification")
   166  	}
   167  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   168  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   169  	}
   170  }
   171  
   172  func TestSeq(t *testing.T) {
   173  	var r Record
   174  
   175  	assert.Equal(t, uint64(0), r.Seq())
   176  	r.Set(UDP(1))
   177  	assert.Equal(t, uint64(0), r.Seq())
   178  	signTest([]byte{5}, &r)
   179  	assert.Equal(t, uint64(0), r.Seq())
   180  	r.Set(UDP(2))
   181  	assert.Equal(t, uint64(1), r.Seq())
   182  }
   183  
   184  // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
   185  func TestGetSetOverwrite(t *testing.T) {
   186  	var r Record
   187  
   188  	ip := IPv4{192, 168, 0, 3}
   189  	r.Set(ip)
   190  
   191  	ip2 := IPv4{192, 168, 0, 4}
   192  	r.Set(ip2)
   193  
   194  	var ip3 IPv4
   195  	require.NoError(t, r.Load(&ip3))
   196  	assert.Equal(t, ip2, ip3)
   197  }
   198  
   199  // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
   200  func TestSignEncodeAndDecode(t *testing.T) {
   201  	var r Record
   202  	r.Set(UDP(30303))
   203  	r.Set(IPv4{127, 0, 0, 1})
   204  	require.NoError(t, signTest([]byte{5}, &r))
   205  
   206  	blob, err := rlp.EncodeToBytes(r)
   207  	require.NoError(t, err)
   208  
   209  	var r2 Record
   210  	require.NoError(t, rlp.DecodeBytes(blob, &r2))
   211  	assert.Equal(t, r, r2)
   212  
   213  	blob2, err := rlp.EncodeToBytes(r2)
   214  	require.NoError(t, err)
   215  	assert.Equal(t, blob, blob2)
   216  }
   217  
   218  // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
   219  func TestRecordTooBig(t *testing.T) {
   220  	var r Record
   221  	key := randomString(10)
   222  
   223  	// set a big value for random key, expect error
   224  	r.Set(WithEntry(key, randomString(SizeLimit)))
   225  	if err := signTest([]byte{5}, &r); err != errTooBig {
   226  		t.Fatalf("expected to get errTooBig, got %#v", err)
   227  	}
   228  
   229  	// set an acceptable value for random key, expect no error
   230  	r.Set(WithEntry(key, randomString(100)))
   231  	require.NoError(t, signTest([]byte{5}, &r))
   232  }
   233  
   234  // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
   235  func TestSignEncodeAndDecodeRandom(t *testing.T) {
   236  	var r Record
   237  
   238  	// random key/value pairs for testing
   239  	pairs := map[string]uint32{}
   240  	for i := 0; i < 10; i++ {
   241  		key := randomString(7)
   242  		value := rnd.Uint32()
   243  		pairs[key] = value
   244  		r.Set(WithEntry(key, &value))
   245  	}
   246  
   247  	require.NoError(t, signTest([]byte{5}, &r))
   248  	_, err := rlp.EncodeToBytes(r)
   249  	require.NoError(t, err)
   250  
   251  	for k, v := range pairs {
   252  		desc := fmt.Sprintf("key %q", k)
   253  		var got uint32
   254  		buf := WithEntry(k, &got)
   255  		require.NoError(t, r.Load(buf), desc)
   256  		require.Equal(t, v, got, desc)
   257  	}
   258  }
   259  
   260  type testSig struct{}
   261  
   262  type testID []byte
   263  
   264  func (id testID) ENRKey() string { return "testid" }
   265  
   266  func signTest(id []byte, r *Record) error {
   267  	r.Set(ID("test"))
   268  	r.Set(testID(id))
   269  	return r.SetSig(testSig{}, makeTestSig(id, r.Seq()))
   270  }
   271  
   272  func makeTestSig(id []byte, seq uint64) []byte {
   273  	sig := make([]byte, 8, len(id)+8)
   274  	binary.BigEndian.PutUint64(sig[:8], seq)
   275  	sig = append(sig, id...)
   276  	return sig
   277  }
   278  
   279  func (testSig) Verify(r *Record, sig []byte) error {
   280  	var id []byte
   281  	if err := r.Load((*testID)(&id)); err != nil {
   282  		return err
   283  	}
   284  	if !bytes.Equal(sig, makeTestSig(id, r.Seq())) {
   285  		return ErrInvalidSig
   286  	}
   287  	return nil
   288  }
   289  
   290  func (testSig) NodeAddr(r *Record) []byte {
   291  	var id []byte
   292  	if err := r.Load((*testID)(&id)); err != nil {
   293  		return nil
   294  	}
   295  	return id
   296  }