github.com/annchain/OG@v0.0.9/p2p/enr/enr_test.go (about) 1 // Copyright 2017 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package enr 18 19 import ( 20 "bytes" 21 "encoding/binary" 22 "fmt" 23 "github.com/annchain/OG/types/msg" 24 "math/rand" 25 "testing" 26 "time" 27 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 ) 31 32 var rnd = rand.New(rand.NewSource(time.Now().UnixNano())) 33 34 func randomString(strlen int) string { 35 b := make([]byte, strlen) 36 rnd.Read(b) 37 return string(b) 38 } 39 40 // TestGetSetID tests encoding/decoding and setting/getting of the ID key. 41 func TestGetSetID(t *testing.T) { 42 id := ID("someid") 43 var r Record 44 r.Set(&id) 45 46 var id2 ID 47 require.NoError(t, r.Load(&id2)) 48 assert.Equal(t, id, id2) 49 } 50 51 // TestGetSetIP4 tests encoding/decoding and setting/getting of the IP key. 52 func TestGetSetIP4(t *testing.T) { 53 ip := IP{192, 168, 0, 3} 54 var r Record 55 r.Set(&ip) 56 57 var ip2 IP 58 require.NoError(t, r.Load(&ip2)) 59 assert.Equal(t, ip, ip2) 60 } 61 62 // TestGetSetIP6 tests encoding/decoding and setting/getting of the IP key. 63 func TestGetSetIP6(t *testing.T) { 64 ip := IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68} 65 var r Record 66 r.Set(&ip) 67 68 var ip2 IP 69 require.NoError(t, r.Load(&ip2)) 70 assert.Equal(t, ip, ip2) 71 } 72 73 // TestGetSetDiscPort tests encoding/decoding and setting/getting of the DiscPort key. 74 func TestGetSetUDP(t *testing.T) { 75 port := UDP(30309) 76 var r Record 77 r.Set(&port) 78 79 var port2 UDP 80 require.NoError(t, r.Load(&port2)) 81 assert.Equal(t, port, port2) 82 } 83 84 func TestLoadErrors(t *testing.T) { 85 var r Record 86 ip4 := IP{127, 0, 0, 1} 87 r.Set(&ip4) 88 89 // Check error for missing keys. 90 var udp UDP 91 err := r.Load(&udp) 92 if !IsNotFound(err) { 93 t.Error("IsNotFound should return true for missing key") 94 } 95 assert.Equal(t, &KeyError{Key: udp.ENRKey(), Err: errNotFound}, err) 96 97 // Check error for invalid keys. 98 var list msg.Uints 99 err = r.Load(WithEntry(ip4.ENRKey(), &list)) 100 kerr, ok := err.(*KeyError) 101 if !ok { 102 t.Fatalf("expected KeyError, got %T", err) 103 } 104 assert.Equal(t, kerr.Key, ip4.ENRKey()) 105 assert.Error(t, kerr.Err) 106 if IsNotFound(err) { 107 t.Error("IsNotFound should return false for decoding errors") 108 } 109 } 110 111 // TestSortedGetAndSet tests that Set produced a sorted Pairs slice. 112 func TestSortedGetAndSet(t *testing.T) { 113 type Pair struct { 114 K string 115 V uint32 116 } 117 118 for _, tt := range []struct { 119 input []Pair 120 want []Pair 121 }{ 122 { 123 input: []Pair{{"a", 1}, {"c", 2}, {"b", 3}}, 124 want: []Pair{{"a", 1}, {"b", 3}, {"c", 2}}, 125 }, 126 { 127 input: []Pair{{"a", 1}, {"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, 128 want: []Pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, 129 }, 130 { 131 input: []Pair{{"c", 2}, {"b", 3}, {"d", 4}, {"a", 5}, {"bb", 6}}, 132 want: []Pair{{"a", 5}, {"b", 3}, {"bb", 6}, {"c", 2}, {"d", 4}}, 133 }, 134 } { 135 var r Record 136 for _, i := range tt.input { 137 val := msg.Uint(i.V) 138 r.Set(WithEntry(i.K, &val)) 139 } 140 for i, w := range tt.want { 141 // set got's key from r.Pair[i], so that we preserve order of Pairs 142 got := Pair{K: r.Pairs[i].K} 143 val := msg.Uint(got.V) 144 assert.NoError(t, r.Load(WithEntry(w.K, &val))) 145 got.V = uint32(val) 146 assert.Equal(t, w, got) 147 } 148 } 149 } 150 151 // TestDirty tests record Signature removal on setting of new key/value Pair in record. 152 func TestDirty(t *testing.T) { 153 var r Record 154 155 if _, err := r.Encode(nil); err != errEncodeUnsigned { 156 t.Errorf("expected errEncodeUnsigned, got %#v", err) 157 } 158 159 require.NoError(t, signTest([]byte{5}, &r)) 160 if len(r.Signature) == 0 { 161 t.Error("record is not signed") 162 } 163 _, err := r.Encode(nil) 164 assert.NoError(t, err) 165 166 r.SetSeq(3) 167 if len(r.Signature) != 0 { 168 t.Error("Signature still set after modification") 169 } 170 if _, err := r.Encode(nil); err != errEncodeUnsigned { 171 t.Errorf("expected errEncodeUnsigned, got %#v", err) 172 } 173 } 174 175 func TestSeq(t *testing.T) { 176 var r Record 177 178 assert.Equal(t, uint64(0), r.GetSeq()) 179 u := UDP(1) 180 r.Set(&u) 181 assert.Equal(t, uint64(0), r.GetSeq()) 182 signTest([]byte{5}, &r) 183 assert.Equal(t, uint64(0), r.GetSeq()) 184 u2 := UDP(2) 185 r.Set(&u2) 186 assert.Equal(t, uint64(1), r.GetSeq()) 187 } 188 189 // TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record. 190 func TestGetSetOverwrite(t *testing.T) { 191 var r Record 192 193 ip := IP{192, 168, 0, 3} 194 r.Set(&ip) 195 196 ip2 := IP{192, 168, 0, 4} 197 r.Set(&ip2) 198 199 var ip3 IP 200 require.NoError(t, r.Load(&ip3)) 201 assert.Equal(t, ip2, ip3) 202 } 203 204 // TestSignEncodeAndDecode tests signing, RLP encoding and RLP decoding of a record. 205 func TestSignEncodeAndDecode(t *testing.T) { 206 var r Record 207 u := UDP(30303) 208 r.Set(&u) 209 ip := IP{127, 0, 0, 1} 210 r.Set(&ip) 211 require.NoError(t, signTest([]byte{5}, &r)) 212 213 blob, err := r.Encode(nil) 214 require.NoError(t, err) 215 216 var r2 Record 217 _, err = r2.Decode(blob) 218 require.NoError(t, err) 219 assert.Equal(t, r, r2) 220 221 blob2, err := r2.Encode(nil) 222 require.NoError(t, err) 223 assert.Equal(t, blob, blob2) 224 } 225 226 // TestRecordTooBig tests that records bigger than SizeLimit bytes cannot be signed. 227 func TestRecordTooBig(t *testing.T) { 228 var r Record 229 key := randomString(10) 230 231 // set a big value for random key, expect error 232 str := msg.String(randomString(SizeLimit)) 233 r.Set(WithEntry(key, &str)) 234 if err := signTest([]byte{5}, &r); err != errTooBig { 235 t.Fatalf("expected to get errTooBig, got %#v", err) 236 } 237 str2 := msg.String(randomString(100)) 238 // set an acceptable value for random key, expect no error 239 r.Set(WithEntry(key, &str2)) 240 require.NoError(t, signTest([]byte{5}, &r)) 241 } 242 243 // TestSignEncodeAndDecodeRandom tests encoding/decoding of records containing random key/value Pairs. 244 func TestSignEncodeAndDecodeRandom(t *testing.T) { 245 var r Record 246 247 // random key/value Pairs for testing 248 Pairs := map[string]uint32{} 249 for i := 0; i < 10; i++ { 250 key := randomString(7) 251 value := rnd.Uint32() 252 Pairs[string(key)] = value 253 v := msg.Uint(value) 254 r.Set(WithEntry(string(key), &v)) 255 } 256 257 require.NoError(t, signTest([]byte{5}, &r)) 258 _, err := r.MarshalMsg(nil) 259 require.NoError(t, err) 260 261 for k, v := range Pairs { 262 desc := fmt.Sprintf("key %q", k) 263 var got msg.Uint32 264 buf := WithEntry(k, &got) 265 require.NoError(t, r.Load(buf), desc) 266 require.Equal(t, v, uint32(got), desc) 267 } 268 } 269 270 type testSig struct{} 271 272 type testID struct { 273 msg.Bytes 274 } 275 276 func newTestId(b []byte) *testID { 277 return &testID{b} 278 } 279 280 func (id testID) ENRKey() string { return "testid" } 281 282 func signTest(id []byte, r *Record) error { 283 i := ID("test") 284 r.Set(&i) 285 v := newTestId(id) 286 r.Set(v) 287 return r.SetSig(testSig{}, makeTestSig(id, r.GetSeq())) 288 } 289 290 func makeTestSig(id []byte, seq uint64) []byte { 291 sig := make([]byte, 8, len(id)+8) 292 binary.BigEndian.PutUint64(sig[:8], seq) 293 sig = append(sig, id...) 294 return sig 295 } 296 297 func (testSig) Verify(r *Record, sig []byte) error { 298 var id testID 299 if err := r.Load(&id); err != nil { 300 return err 301 } 302 if !bytes.Equal(sig, makeTestSig(id.Bytes, r.GetSeq())) { 303 return ErrInvalidSig 304 } 305 return nil 306 } 307 308 func (testSig) NodeAddr(r *Record) []byte { 309 var id testID 310 if err := r.Load(&id); err != nil { 311 return nil 312 } 313 return id.Bytes 314 }