github.com/aquanetwork/aquachain@v1.7.8/p2p/enr/enr_test.go (about)

     1  // Copyright 2017 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package enr
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/hex"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"gitlab.com/aquachain/aquachain/crypto"
    30  	"gitlab.com/aquachain/aquachain/rlp"
    31  )
    32  
    33  var (
    34  	privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
    35  	pubkey     = &privkey.PublicKey
    36  )
    37  
    38  var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
    39  
    40  func randomString(strlen int) string {
    41  	b := make([]byte, strlen)
    42  	rnd.Read(b)
    43  	return string(b)
    44  }
    45  
    46  // TestGetSetID tests encoding/decoding and setting/getting of the ID key.
    47  func TestGetSetID(t *testing.T) {
    48  	id := ID("someid")
    49  	var r Record
    50  	r.Set(id)
    51  
    52  	var id2 ID
    53  	require.NoError(t, r.Load(&id2))
    54  	assert.Equal(t, id, id2)
    55  }
    56  
    57  // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key.
    58  func TestGetSetIP4(t *testing.T) {
    59  	ip := IP4{192, 168, 0, 3}
    60  	var r Record
    61  	r.Set(ip)
    62  
    63  	var ip2 IP4
    64  	require.NoError(t, r.Load(&ip2))
    65  	assert.Equal(t, ip, ip2)
    66  }
    67  
    68  // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
    69  func TestGetSetIP6(t *testing.T) {
    70  	ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
    71  	var r Record
    72  	r.Set(ip)
    73  
    74  	var ip2 IP6
    75  	require.NoError(t, r.Load(&ip2))
    76  	assert.Equal(t, ip, ip2)
    77  }
    78  
    79  // TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key.
    80  func TestGetSetDiscPort(t *testing.T) {
    81  	port := DiscPort(30309)
    82  	var r Record
    83  	r.Set(port)
    84  
    85  	var port2 DiscPort
    86  	require.NoError(t, r.Load(&port2))
    87  	assert.Equal(t, port, port2)
    88  }
    89  
    90  // TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key.
    91  func TestGetSetSecp256k1(t *testing.T) {
    92  	var r Record
    93  	if err := r.Sign(privkey); err != nil {
    94  		t.Fatal(err)
    95  	}
    96  
    97  	var pk Secp256k1
    98  	require.NoError(t, r.Load(&pk))
    99  	assert.EqualValues(t, pubkey, &pk)
   100  }
   101  
   102  func TestLoadErrors(t *testing.T) {
   103  	var r Record
   104  	ip4 := IP4{127, 0, 0, 1}
   105  	r.Set(ip4)
   106  
   107  	// Check error for missing keys.
   108  	var ip6 IP6
   109  	err := r.Load(&ip6)
   110  	if !IsNotFound(err) {
   111  		t.Error("IsNotFound should return true for missing key")
   112  	}
   113  	assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err)
   114  
   115  	// Check error for invalid keys.
   116  	var list []uint
   117  	err = r.Load(WithEntry(ip4.ENRKey(), &list))
   118  	kerr, ok := err.(*KeyError)
   119  	if !ok {
   120  		t.Fatalf("expected KeyError, got %T", err)
   121  	}
   122  	assert.Equal(t, kerr.Key, ip4.ENRKey())
   123  	assert.Error(t, kerr.Err)
   124  	if IsNotFound(err) {
   125  		t.Error("IsNotFound should return false for decoding errors")
   126  	}
   127  }
   128  
   129  // TestSortedGetAndSet tests that Set produced a sorted pairs slice.
   130  func TestSortedGetAndSet(t *testing.T) {
   131  	type pair struct {
   132  		k string
   133  		v uint32
   134  	}
   135  
   136  	for _, tt := range []struct {
   137  		input []pair
   138  		want  []pair
   139  	}{
   140  		{
   141  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
   142  			want:  []pair{{"a", 1}, {"b", 3}, {"c", 2}},
   143  		},
   144  		{
   145  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   146  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   147  		},
   148  		{
   149  			input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   150  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   151  		},
   152  	} {
   153  		var r Record
   154  		for _, i := range tt.input {
   155  			r.Set(WithEntry(i.k, &i.v))
   156  		}
   157  		for i, w := range tt.want {
   158  			// set got's key from r.pair[i], so that we preserve order of pairs
   159  			got := pair{k: r.pairs[i].k}
   160  			assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
   161  			assert.Equal(t, w, got)
   162  		}
   163  	}
   164  }
   165  
   166  // TestDirty tests record signature removal on setting of new key/value pair in record.
   167  func TestDirty(t *testing.T) {
   168  	var r Record
   169  
   170  	if r.Signed() {
   171  		t.Error("Signed returned true for zero record")
   172  	}
   173  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   174  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   175  	}
   176  
   177  	require.NoError(t, r.Sign(privkey))
   178  	if !r.Signed() {
   179  		t.Error("Signed return false for signed record")
   180  	}
   181  	_, err := rlp.EncodeToBytes(r)
   182  	assert.NoError(t, err)
   183  
   184  	r.SetSeq(3)
   185  	if r.Signed() {
   186  		t.Error("Signed returned true for modified record")
   187  	}
   188  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   189  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   190  	}
   191  }
   192  
   193  // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
   194  func TestGetSetOverwrite(t *testing.T) {
   195  	var r Record
   196  
   197  	ip := IP4{192, 168, 0, 3}
   198  	r.Set(ip)
   199  
   200  	ip2 := IP4{192, 168, 0, 4}
   201  	r.Set(ip2)
   202  
   203  	var ip3 IP4
   204  	require.NoError(t, r.Load(&ip3))
   205  	assert.Equal(t, ip2, ip3)
   206  }
   207  
   208  // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
   209  func TestSignEncodeAndDecode(t *testing.T) {
   210  	var r Record
   211  	r.Set(DiscPort(30303))
   212  	r.Set(IP4{127, 0, 0, 1})
   213  	require.NoError(t, r.Sign(privkey))
   214  
   215  	blob, err := rlp.EncodeToBytes(r)
   216  	require.NoError(t, err)
   217  
   218  	var r2 Record
   219  	require.NoError(t, rlp.DecodeBytes(blob, &r2))
   220  	assert.Equal(t, r, r2)
   221  
   222  	blob2, err := rlp.EncodeToBytes(r2)
   223  	require.NoError(t, err)
   224  	assert.Equal(t, blob, blob2)
   225  }
   226  
   227  func TestNodeAddr(t *testing.T) {
   228  	var r Record
   229  	if addr := r.NodeAddr(); addr != nil {
   230  		t.Errorf("wrong address on empty record: got %v, want %v", addr, nil)
   231  	}
   232  
   233  	require.NoError(t, r.Sign(privkey))
   234  	expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726"
   235  	assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr()))
   236  }
   237  
   238  var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138")
   239  
   240  // TestPythonInterop checks that we can decode and verify a record produced by the Python
   241  // implementation.
   242  func TestPythonInterop(t *testing.T) {
   243  	var r Record
   244  	if err := rlp.DecodeBytes(pyRecord, &r); err != nil {
   245  		t.Fatalf("can't decode: %v", err)
   246  	}
   247  
   248  	var (
   249  		wantAddr, _  = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726")
   250  		wantSeq      = uint64(1)
   251  		wantIP       = IP4{127, 0, 0, 1}
   252  		wantDiscport = DiscPort(30303)
   253  	)
   254  	if r.Seq() != wantSeq {
   255  		t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq)
   256  	}
   257  	if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) {
   258  		t.Errorf("wrong addr: got %x, want %x", addr, wantAddr)
   259  	}
   260  	want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport}
   261  	for k, v := range want {
   262  		desc := fmt.Sprintf("loading key %q", k.ENRKey())
   263  		if assert.NoError(t, r.Load(k), desc) {
   264  			assert.Equal(t, k, v, desc)
   265  		}
   266  	}
   267  }
   268  
   269  // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
   270  func TestRecordTooBig(t *testing.T) {
   271  	var r Record
   272  	key := randomString(10)
   273  
   274  	// set a big value for random key, expect error
   275  	r.Set(WithEntry(key, randomString(300)))
   276  	if err := r.Sign(privkey); err != errTooBig {
   277  		t.Fatalf("expected to get errTooBig, got %#v", err)
   278  	}
   279  
   280  	// set an acceptable value for random key, expect no error
   281  	r.Set(WithEntry(key, randomString(100)))
   282  	require.NoError(t, r.Sign(privkey))
   283  }
   284  
   285  // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
   286  func TestSignEncodeAndDecodeRandom(t *testing.T) {
   287  	var r Record
   288  
   289  	// random key/value pairs for testing
   290  	pairs := map[string]uint32{}
   291  	for i := 0; i < 10; i++ {
   292  		key := randomString(7)
   293  		value := rnd.Uint32()
   294  		pairs[key] = value
   295  		r.Set(WithEntry(key, &value))
   296  	}
   297  
   298  	require.NoError(t, r.Sign(privkey))
   299  	_, err := rlp.EncodeToBytes(r)
   300  	require.NoError(t, err)
   301  
   302  	for k, v := range pairs {
   303  		desc := fmt.Sprintf("key %q", k)
   304  		var got uint32
   305  		buf := WithEntry(k, &got)
   306  		require.NoError(t, r.Load(buf), desc)
   307  		require.Equal(t, v, got, desc)
   308  	}
   309  }
   310  
   311  func BenchmarkDecode(b *testing.B) {
   312  	var r Record
   313  	for i := 0; i < b.N; i++ {
   314  		rlp.DecodeBytes(pyRecord, &r)
   315  	}
   316  	b.StopTimer()
   317  	r.NodeAddr()
   318  }