github.com/codysnider/go-ethereum@v1.10.18-0.20220420071915-14f4ae99222a/p2p/nodestate/nodestate_test.go (about) 1 // Copyright 2020 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package nodestate 18 19 import ( 20 "errors" 21 "fmt" 22 "reflect" 23 "testing" 24 "time" 25 26 "github.com/ethereum/go-ethereum/common/mclock" 27 "github.com/ethereum/go-ethereum/core/rawdb" 28 "github.com/ethereum/go-ethereum/p2p/enode" 29 "github.com/ethereum/go-ethereum/p2p/enr" 30 "github.com/ethereum/go-ethereum/rlp" 31 ) 32 33 func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) { 34 setup := &Setup{} 35 flags := make([]Flags, len(flagPersist)) 36 for i, persist := range flagPersist { 37 if persist { 38 flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i)) 39 } else { 40 flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i)) 41 } 42 } 43 fields := make([]Field, len(fieldType)) 44 for i, ftype := range fieldType { 45 switch ftype { 46 case reflect.TypeOf(uint64(0)): 47 fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec) 48 case reflect.TypeOf(""): 49 fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec) 50 default: 51 fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype) 52 } 53 } 54 return setup, flags, fields 55 } 56 57 func testNode(b byte) *enode.Node { 58 r := &enr.Record{} 59 r.SetSig(dummyIdentity{b}, []byte{42}) 60 n, _ := enode.New(dummyIdentity{b}, r) 61 return n 62 } 63 64 func TestCallback(t *testing.T) { 65 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 66 67 s, flags, _ := testSetup([]bool{false, false, false}, nil) 68 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 69 70 set0 := make(chan struct{}, 1) 71 set1 := make(chan struct{}, 1) 72 set2 := make(chan struct{}, 1) 73 ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) 74 ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) 75 ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) 76 77 ns.Start() 78 79 ns.SetState(testNode(1), flags[0], Flags{}, 0) 80 ns.SetState(testNode(1), flags[1], Flags{}, time.Second) 81 ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) 82 83 for i := 0; i < 3; i++ { 84 select { 85 case <-set0: 86 case <-set1: 87 case <-set2: 88 case <-time.After(time.Second): 89 t.Fatalf("failed to invoke callback") 90 } 91 } 92 } 93 94 func TestPersistentFlags(t *testing.T) { 95 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 96 97 s, flags, _ := testSetup([]bool{true, true, true, false}, nil) 98 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 99 100 saveNode := make(chan *nodeInfo, 5) 101 ns.saveNodeHook = func(node *nodeInfo) { 102 saveNode <- node 103 } 104 105 ns.Start() 106 107 ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved 108 ns.SetState(testNode(2), flags[1], Flags{}, 0) 109 ns.SetState(testNode(3), flags[2], Flags{}, 0) 110 ns.SetState(testNode(4), flags[3], Flags{}, 0) 111 ns.SetState(testNode(5), flags[0], Flags{}, 0) 112 ns.Persist(testNode(5)) 113 select { 114 case <-saveNode: 115 case <-time.After(time.Second): 116 t.Fatalf("Timeout") 117 } 118 ns.Stop() 119 120 for i := 0; i < 2; i++ { 121 select { 122 case <-saveNode: 123 case <-time.After(time.Second): 124 t.Fatalf("Timeout") 125 } 126 } 127 select { 128 case <-saveNode: 129 t.Fatalf("Unexpected saveNode") 130 case <-time.After(time.Millisecond * 100): 131 } 132 } 133 134 func TestSetField(t *testing.T) { 135 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 136 137 s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")}) 138 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 139 140 saveNode := make(chan *nodeInfo, 1) 141 ns.saveNodeHook = func(node *nodeInfo) { 142 saveNode <- node 143 } 144 145 ns.Start() 146 147 // Set field before setting state 148 ns.SetField(testNode(1), fields[0], "hello world") 149 field := ns.GetField(testNode(1), fields[0]) 150 if field == nil { 151 t.Fatalf("Field should be set before setting states") 152 } 153 ns.SetField(testNode(1), fields[0], nil) 154 field = ns.GetField(testNode(1), fields[0]) 155 if field != nil { 156 t.Fatalf("Field should be unset") 157 } 158 // Set field after setting state 159 ns.SetState(testNode(1), flags[0], Flags{}, 0) 160 ns.SetField(testNode(1), fields[0], "hello world") 161 field = ns.GetField(testNode(1), fields[0]) 162 if field == nil { 163 t.Fatalf("Field should be set after setting states") 164 } 165 if err := ns.SetField(testNode(1), fields[0], 123); err == nil { 166 t.Fatalf("Invalid field should be rejected") 167 } 168 // Dirty node should be written back 169 ns.Stop() 170 select { 171 case <-saveNode: 172 case <-time.After(time.Second): 173 t.Fatalf("Timeout") 174 } 175 } 176 177 func TestSetState(t *testing.T) { 178 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 179 180 s, flags, _ := testSetup([]bool{false, false, false}, nil) 181 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 182 183 type change struct{ old, new Flags } 184 set := make(chan change, 1) 185 ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { 186 set <- change{ 187 old: oldState, 188 new: newState, 189 } 190 }) 191 192 ns.Start() 193 194 check := func(expectOld, expectNew Flags, expectChange bool) { 195 if expectChange { 196 select { 197 case c := <-set: 198 if !c.old.Equals(expectOld) { 199 t.Fatalf("Old state mismatch") 200 } 201 if !c.new.Equals(expectNew) { 202 t.Fatalf("New state mismatch") 203 } 204 case <-time.After(time.Second): 205 } 206 return 207 } 208 select { 209 case <-set: 210 t.Fatalf("Unexpected change") 211 case <-time.After(time.Millisecond * 100): 212 return 213 } 214 } 215 ns.SetState(testNode(1), flags[0], Flags{}, 0) 216 check(Flags{}, flags[0], true) 217 218 ns.SetState(testNode(1), flags[1], Flags{}, 0) 219 check(flags[0], flags[0].Or(flags[1]), true) 220 221 ns.SetState(testNode(1), flags[2], Flags{}, 0) 222 check(Flags{}, Flags{}, false) 223 224 ns.SetState(testNode(1), Flags{}, flags[0], 0) 225 check(flags[0].Or(flags[1]), flags[1], true) 226 227 ns.SetState(testNode(1), Flags{}, flags[1], 0) 228 check(flags[1], Flags{}, true) 229 230 ns.SetState(testNode(1), Flags{}, flags[2], 0) 231 check(Flags{}, Flags{}, false) 232 233 ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) 234 check(Flags{}, flags[0].Or(flags[1]), true) 235 clock.Run(time.Second) 236 check(flags[0].Or(flags[1]), Flags{}, true) 237 } 238 239 func uint64FieldEnc(field interface{}) ([]byte, error) { 240 if u, ok := field.(uint64); ok { 241 enc, err := rlp.EncodeToBytes(&u) 242 return enc, err 243 } 244 return nil, errors.New("invalid field type") 245 } 246 247 func uint64FieldDec(enc []byte) (interface{}, error) { 248 var u uint64 249 err := rlp.DecodeBytes(enc, &u) 250 return u, err 251 } 252 253 func stringFieldEnc(field interface{}) ([]byte, error) { 254 if s, ok := field.(string); ok { 255 return []byte(s), nil 256 } 257 return nil, errors.New("invalid field type") 258 } 259 260 func stringFieldDec(enc []byte) (interface{}, error) { 261 return string(enc), nil 262 } 263 264 func TestPersistentFields(t *testing.T) { 265 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 266 267 s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")}) 268 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 269 270 ns.Start() 271 ns.SetState(testNode(1), flags[0], Flags{}, 0) 272 ns.SetField(testNode(1), fields[0], uint64(100)) 273 ns.SetField(testNode(1), fields[1], "hello world") 274 ns.Stop() 275 276 ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 277 278 ns2.Start() 279 field0 := ns2.GetField(testNode(1), fields[0]) 280 if !reflect.DeepEqual(field0, uint64(100)) { 281 t.Fatalf("Field changed") 282 } 283 field1 := ns2.GetField(testNode(1), fields[1]) 284 if !reflect.DeepEqual(field1, "hello world") { 285 t.Fatalf("Field changed") 286 } 287 288 s.Version++ 289 ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 290 ns3.Start() 291 if ns3.GetField(testNode(1), fields[0]) != nil { 292 t.Fatalf("Old field version should have been discarded") 293 } 294 } 295 296 func TestFieldSub(t *testing.T) { 297 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 298 299 s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))}) 300 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 301 302 var ( 303 lastState Flags 304 lastOldValue, lastNewValue interface{} 305 ) 306 ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { 307 lastState, lastOldValue, lastNewValue = state, oldValue, newValue 308 }) 309 check := func(state Flags, oldValue, newValue interface{}) { 310 if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue { 311 t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue) 312 } 313 } 314 ns.Start() 315 ns.SetState(testNode(1), flags[0], Flags{}, 0) 316 ns.SetField(testNode(1), fields[0], uint64(100)) 317 check(flags[0], nil, uint64(100)) 318 ns.Stop() 319 check(s.OfflineFlag(), uint64(100), nil) 320 321 ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 322 ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { 323 lastState, lastOldValue, lastNewValue = state, oldValue, newValue 324 }) 325 ns2.Start() 326 check(s.OfflineFlag(), nil, uint64(100)) 327 ns2.SetState(testNode(1), Flags{}, flags[0], 0) 328 ns2.SetField(testNode(1), fields[0], nil) 329 check(Flags{}, uint64(100), nil) 330 ns2.Stop() 331 } 332 333 func TestDuplicatedFlags(t *testing.T) { 334 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 335 336 s, flags, _ := testSetup([]bool{true}, nil) 337 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 338 339 type change struct{ old, new Flags } 340 set := make(chan change, 1) 341 ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { 342 set <- change{oldState, newState} 343 }) 344 345 ns.Start() 346 defer ns.Stop() 347 348 check := func(expectOld, expectNew Flags, expectChange bool) { 349 if expectChange { 350 select { 351 case c := <-set: 352 if !c.old.Equals(expectOld) { 353 t.Fatalf("Old state mismatch") 354 } 355 if !c.new.Equals(expectNew) { 356 t.Fatalf("New state mismatch") 357 } 358 case <-time.After(time.Second): 359 } 360 return 361 } 362 select { 363 case <-set: 364 t.Fatalf("Unexpected change") 365 case <-time.After(time.Millisecond * 100): 366 return 367 } 368 } 369 ns.SetState(testNode(1), flags[0], Flags{}, time.Second) 370 check(Flags{}, flags[0], true) 371 ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s 372 check(Flags{}, flags[0], false) 373 374 clock.Run(2 * time.Second) 375 check(flags[0], Flags{}, true) 376 } 377 378 func TestCallbackOrder(t *testing.T) { 379 mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} 380 381 s, flags, _ := testSetup([]bool{false, false, false, false}, nil) 382 ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) 383 384 ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { 385 if newState.Equals(flags[0]) { 386 ns.SetStateSub(n, flags[1], Flags{}, 0) 387 ns.SetStateSub(n, flags[2], Flags{}, 0) 388 } 389 }) 390 ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { 391 if newState.Equals(flags[1]) { 392 ns.SetStateSub(n, flags[3], Flags{}, 0) 393 } 394 }) 395 lastState := Flags{} 396 ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { 397 if !oldState.Equals(lastState) { 398 t.Fatalf("Wrong callback order") 399 } 400 lastState = newState 401 }) 402 403 ns.Start() 404 defer ns.Stop() 405 406 ns.SetState(testNode(1), flags[0], Flags{}, 0) 407 }