github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/enr/enr_test.go (about)

     1  package enr
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"math/rand"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/quickchainproject/quickchain/crypto"
    12  	"github.com/quickchainproject/quickchain/rlp"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  var (
    18  	privkey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
    19  	pubkey     = &privkey.PublicKey
    20  )
    21  
    22  var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))
    23  
    24  func randomString(strlen int) string {
    25  	b := make([]byte, strlen)
    26  	rnd.Read(b)
    27  	return string(b)
    28  }
    29  
    30  // TestGetSetID tests encoding/decoding and setting/getting of the ID key.
    31  func TestGetSetID(t *testing.T) {
    32  	id := ID("someid")
    33  	var r Record
    34  	r.Set(id)
    35  
    36  	var id2 ID
    37  	require.NoError(t, r.Load(&id2))
    38  	assert.Equal(t, id, id2)
    39  }
    40  
    41  // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP4 key.
    42  func TestGetSetIP4(t *testing.T) {
    43  	ip := IP4{192, 168, 0, 3}
    44  	var r Record
    45  	r.Set(ip)
    46  
    47  	var ip2 IP4
    48  	require.NoError(t, r.Load(&ip2))
    49  	assert.Equal(t, ip, ip2)
    50  }
    51  
    52  // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP6 key.
    53  func TestGetSetIP6(t *testing.T) {
    54  	ip := IP6{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}
    55  	var r Record
    56  	r.Set(ip)
    57  
    58  	var ip2 IP6
    59  	require.NoError(t, r.Load(&ip2))
    60  	assert.Equal(t, ip, ip2)
    61  }
    62  
    63  // TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key.
    64  func TestGetSetDiscPort(t *testing.T) {
    65  	port := DiscPort(30309)
    66  	var r Record
    67  	r.Set(port)
    68  
    69  	var port2 DiscPort
    70  	require.NoError(t, r.Load(&port2))
    71  	assert.Equal(t, port, port2)
    72  }
    73  
    74  // TestGetSetSecp256k1 tests encoding/decoding and setting/getting of the Secp256k1 key.
    75  func TestGetSetSecp256k1(t *testing.T) {
    76  	var r Record
    77  	if err := r.Sign(privkey); err != nil {
    78  		t.Fatal(err)
    79  	}
    80  
    81  	var pk Secp256k1
    82  	require.NoError(t, r.Load(&pk))
    83  	assert.EqualValues(t, pubkey, &pk)
    84  }
    85  
    86  func TestLoadErrors(t *testing.T) {
    87  	var r Record
    88  	ip4 := IP4{127, 0, 0, 1}
    89  	r.Set(ip4)
    90  
    91  	// Check error for missing keys.
    92  	var ip6 IP6
    93  	err := r.Load(&ip6)
    94  	if !IsNotFound(err) {
    95  		t.Error("IsNotFound should return true for missing key")
    96  	}
    97  	assert.Equal(t, &KeyError{Key: ip6.ENRKey(), Err: errNotFound}, err)
    98  
    99  	// Check error for invalid keys.
   100  	var list []uint
   101  	err = r.Load(WithEntry(ip4.ENRKey(), &list))
   102  	kerr, ok := err.(*KeyError)
   103  	if !ok {
   104  		t.Fatalf("expected KeyError, got %T", err)
   105  	}
   106  	assert.Equal(t, kerr.Key, ip4.ENRKey())
   107  	assert.Error(t, kerr.Err)
   108  	if IsNotFound(err) {
   109  		t.Error("IsNotFound should return false for decoding errors")
   110  	}
   111  }
   112  
   113  // TestSortedGetAndSet tests that Set produced a sorted pairs slice.
   114  func TestSortedGetAndSet(t *testing.T) {
   115  	type pair struct {
   116  		k string
   117  		v uint32
   118  	}
   119  
   120  	for _, tt := range []struct {
   121  		input []pair
   122  		want  []pair
   123  	}{
   124  		{
   125  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}},
   126  			want:  []pair{{"a", 1}, {"b", 3}, {"c", 2}},
   127  		},
   128  		{
   129  			input: []pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   130  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   131  		},
   132  		{
   133  			input: []pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}},
   134  			want:  []pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}},
   135  		},
   136  	} {
   137  		var r Record
   138  		for _, i := range tt.input {
   139  			r.Set(WithEntry(i.k, &i.v))
   140  		}
   141  		for i, w := range tt.want {
   142  			// set got's key from r.pair[i], so that we preserve order of pairs
   143  			got := pair{k: r.pairs[i].k}
   144  			assert.NoError(t, r.Load(WithEntry(w.k, &got.v)))
   145  			assert.Equal(t, w, got)
   146  		}
   147  	}
   148  }
   149  
   150  // TestDirty tests record signature removal on setting of new key/value pair in record.
   151  func TestDirty(t *testing.T) {
   152  	var r Record
   153  
   154  	if r.Signed() {
   155  		t.Error("Signed returned true for zero record")
   156  	}
   157  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   158  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   159  	}
   160  
   161  	require.NoError(t, r.Sign(privkey))
   162  	if !r.Signed() {
   163  		t.Error("Signed return false for signed record")
   164  	}
   165  	_, err := rlp.EncodeToBytes(r)
   166  	assert.NoError(t, err)
   167  
   168  	r.SetSeq(3)
   169  	if r.Signed() {
   170  		t.Error("Signed returned true for modified record")
   171  	}
   172  	if _, err := rlp.EncodeToBytes(r); err != errEncodeUnsigned {
   173  		t.Errorf("expected errEncodeUnsigned, got %#v", err)
   174  	}
   175  }
   176  
   177  // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
   178  func TestGetSetOverwrite(t *testing.T) {
   179  	var r Record
   180  
   181  	ip := IP4{192, 168, 0, 3}
   182  	r.Set(ip)
   183  
   184  	ip2 := IP4{192, 168, 0, 4}
   185  	r.Set(ip2)
   186  
   187  	var ip3 IP4
   188  	require.NoError(t, r.Load(&ip3))
   189  	assert.Equal(t, ip2, ip3)
   190  }
   191  
   192  // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record.
   193  func TestSignEncodeAndDecode(t *testing.T) {
   194  	var r Record
   195  	r.Set(DiscPort(36663))
   196  	r.Set(IP4{127, 0, 0, 1})
   197  	require.NoError(t, r.Sign(privkey))
   198  
   199  	blob, err := rlp.EncodeToBytes(r)
   200  	require.NoError(t, err)
   201  
   202  	var r2 Record
   203  	require.NoError(t, rlp.DecodeBytes(blob, &r2))
   204  	assert.Equal(t, r, r2)
   205  
   206  	blob2, err := rlp.EncodeToBytes(r2)
   207  	require.NoError(t, err)
   208  	assert.Equal(t, blob, blob2)
   209  }
   210  
   211  func TestNodeAddr(t *testing.T) {
   212  	var r Record
   213  	if addr := r.NodeAddr(); addr != nil {
   214  		t.Errorf("wrong address on empty record: got %v, want %v", addr, nil)
   215  	}
   216  
   217  	require.NoError(t, r.Sign(privkey))
   218  	expected := "caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726"
   219  	assert.Equal(t, expected, hex.EncodeToString(r.NodeAddr()))
   220  }
   221  
   222  var pyRecord, _ = hex.DecodeString("f896b840954dc36583c1f4b69ab59b1375f362f06ee99f3723cd77e64b6de6d211c27d7870642a79d4516997f94091325d2a7ca6215376971455fb221d34f35b277149a1018664697363763582765f82696490736563703235366b312d6b656363616b83697034847f00000189736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138")
   223  
   224  // TestPythonInterop checks that we can decode and verify a record produced by the Python
   225  // implementation.
   226  func TestPythonInterop(t *testing.T) {
   227  	var r Record
   228  	if err := rlp.DecodeBytes(pyRecord, &r); err != nil {
   229  		t.Fatalf("can't decode: %v", err)
   230  	}
   231  
   232  	var (
   233  		wantAddr, _  = hex.DecodeString("caaa1485d83b18b32ed9ad666026151bf0cae8a0a88c857ae2d4c5be2daa6726")
   234  		wantSeq      = uint64(1)
   235  		wantIP       = IP4{127, 0, 0, 1}
   236  		wantDiscport = DiscPort(36663)
   237  	)
   238  	if r.Seq() != wantSeq {
   239  		t.Errorf("wrong seq: got %d, want %d", r.Seq(), wantSeq)
   240  	}
   241  	if addr := r.NodeAddr(); !bytes.Equal(addr, wantAddr) {
   242  		t.Errorf("wrong addr: got %x, want %x", addr, wantAddr)
   243  	}
   244  	want := map[Entry]interface{}{new(IP4): &wantIP, new(DiscPort): &wantDiscport}
   245  	for k, v := range want {
   246  		desc := fmt.Sprintf("loading key %q", k.ENRKey())
   247  		if assert.NoError(t, r.Load(k), desc) {
   248  			assert.Equal(t, k, v, desc)
   249  		}
   250  	}
   251  }
   252  
   253  // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed.
   254  func TestRecordTooBig(t *testing.T) {
   255  	var r Record
   256  	key := randomString(10)
   257  
   258  	// set a big value for random key, expect error
   259  	r.Set(WithEntry(key, randomString(300)))
   260  	if err := r.Sign(privkey); err != errTooBig {
   261  		t.Fatalf("expected to get errTooBig, got %#v", err)
   262  	}
   263  
   264  	// set an acceptable value for random key, expect no error
   265  	r.Set(WithEntry(key, randomString(100)))
   266  	require.NoError(t, r.Sign(privkey))
   267  }
   268  
   269  // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value pairs.
   270  func TestSignEncodeAndDecodeRandom(t *testing.T) {
   271  	var r Record
   272  
   273  	// random key/value pairs for testing
   274  	pairs := map[string]uint32{}
   275  	for i := 0; i < 10; i++ {
   276  		key := randomString(7)
   277  		value := rnd.Uint32()
   278  		pairs[key] = value
   279  		r.Set(WithEntry(key, &value))
   280  	}
   281  
   282  	require.NoError(t, r.Sign(privkey))
   283  	_, err := rlp.EncodeToBytes(r)
   284  	require.NoError(t, err)
   285  
   286  	for k, v := range pairs {
   287  		desc := fmt.Sprintf("key %q", k)
   288  		var got uint32
   289  		buf := WithEntry(k, &got)
   290  		require.NoError(t, r.Load(buf), desc)
   291  		require.Equal(t, v, got, desc)
   292  	}
   293  }
   294  
   295  func BenchmarkDecode(b *testing.B) {
   296  	var r Record
   297  	for i := 0; i < b.N; i++ {
   298  		rlp.DecodeBytes(pyRecord, &r)
   299  	}
   300  	b.StopTimer()
   301  	r.NodeAddr()
   302  }