github.com/snowblossomcoin/go-ethereum@v1.9.25/p2p/nodestate/nodestate_test.go (about)

     1  // Copyright 2020 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package nodestate
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"reflect"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/ethereum/go-ethereum/common/mclock"
    27  	"github.com/ethereum/go-ethereum/core/rawdb"
    28  	"github.com/ethereum/go-ethereum/p2p/enode"
    29  	"github.com/ethereum/go-ethereum/p2p/enr"
    30  	"github.com/ethereum/go-ethereum/rlp"
    31  )
    32  
    33  func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) {
    34  	setup := &Setup{}
    35  	flags := make([]Flags, len(flagPersist))
    36  	for i, persist := range flagPersist {
    37  		if persist {
    38  			flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i))
    39  		} else {
    40  			flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i))
    41  		}
    42  	}
    43  	fields := make([]Field, len(fieldType))
    44  	for i, ftype := range fieldType {
    45  		switch ftype {
    46  		case reflect.TypeOf(uint64(0)):
    47  			fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec)
    48  		case reflect.TypeOf(""):
    49  			fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec)
    50  		default:
    51  			fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype)
    52  		}
    53  	}
    54  	return setup, flags, fields
    55  }
    56  
    57  func testNode(b byte) *enode.Node {
    58  	r := &enr.Record{}
    59  	r.SetSig(dummyIdentity{b}, []byte{42})
    60  	n, _ := enode.New(dummyIdentity{b}, r)
    61  	return n
    62  }
    63  
    64  func TestCallback(t *testing.T) {
    65  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
    66  
    67  	s, flags, _ := testSetup([]bool{false, false, false}, nil)
    68  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
    69  
    70  	set0 := make(chan struct{}, 1)
    71  	set1 := make(chan struct{}, 1)
    72  	set2 := make(chan struct{}, 1)
    73  	ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} })
    74  	ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} })
    75  	ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} })
    76  
    77  	ns.Start()
    78  
    79  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
    80  	ns.SetState(testNode(1), flags[1], Flags{}, time.Second)
    81  	ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second)
    82  
    83  	for i := 0; i < 3; i++ {
    84  		select {
    85  		case <-set0:
    86  		case <-set1:
    87  		case <-set2:
    88  		case <-time.After(time.Second):
    89  			t.Fatalf("failed to invoke callback")
    90  		}
    91  	}
    92  }
    93  
    94  func TestPersistentFlags(t *testing.T) {
    95  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
    96  
    97  	s, flags, _ := testSetup([]bool{true, true, true, false}, nil)
    98  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
    99  
   100  	saveNode := make(chan *nodeInfo, 5)
   101  	ns.saveNodeHook = func(node *nodeInfo) {
   102  		saveNode <- node
   103  	}
   104  
   105  	ns.Start()
   106  
   107  	ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved
   108  	ns.SetState(testNode(2), flags[1], Flags{}, 0)
   109  	ns.SetState(testNode(3), flags[2], Flags{}, 0)
   110  	ns.SetState(testNode(4), flags[3], Flags{}, 0)
   111  	ns.SetState(testNode(5), flags[0], Flags{}, 0)
   112  	ns.Persist(testNode(5))
   113  	select {
   114  	case <-saveNode:
   115  	case <-time.After(time.Second):
   116  		t.Fatalf("Timeout")
   117  	}
   118  	ns.Stop()
   119  
   120  	for i := 0; i < 2; i++ {
   121  		select {
   122  		case <-saveNode:
   123  		case <-time.After(time.Second):
   124  			t.Fatalf("Timeout")
   125  		}
   126  	}
   127  	select {
   128  	case <-saveNode:
   129  		t.Fatalf("Unexpected saveNode")
   130  	case <-time.After(time.Millisecond * 100):
   131  	}
   132  }
   133  
   134  func TestSetField(t *testing.T) {
   135  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   136  
   137  	s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")})
   138  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   139  
   140  	saveNode := make(chan *nodeInfo, 1)
   141  	ns.saveNodeHook = func(node *nodeInfo) {
   142  		saveNode <- node
   143  	}
   144  
   145  	ns.Start()
   146  
   147  	// Set field before setting state
   148  	ns.SetField(testNode(1), fields[0], "hello world")
   149  	field := ns.GetField(testNode(1), fields[0])
   150  	if field == nil {
   151  		t.Fatalf("Field should be set before setting states")
   152  	}
   153  	ns.SetField(testNode(1), fields[0], nil)
   154  	field = ns.GetField(testNode(1), fields[0])
   155  	if field != nil {
   156  		t.Fatalf("Field should be unset")
   157  	}
   158  	// Set field after setting state
   159  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
   160  	ns.SetField(testNode(1), fields[0], "hello world")
   161  	field = ns.GetField(testNode(1), fields[0])
   162  	if field == nil {
   163  		t.Fatalf("Field should be set after setting states")
   164  	}
   165  	if err := ns.SetField(testNode(1), fields[0], 123); err == nil {
   166  		t.Fatalf("Invalid field should be rejected")
   167  	}
   168  	// Dirty node should be written back
   169  	ns.Stop()
   170  	select {
   171  	case <-saveNode:
   172  	case <-time.After(time.Second):
   173  		t.Fatalf("Timeout")
   174  	}
   175  }
   176  
   177  func TestSetState(t *testing.T) {
   178  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   179  
   180  	s, flags, _ := testSetup([]bool{false, false, false}, nil)
   181  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   182  
   183  	type change struct{ old, new Flags }
   184  	set := make(chan change, 1)
   185  	ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) {
   186  		set <- change{
   187  			old: oldState,
   188  			new: newState,
   189  		}
   190  	})
   191  
   192  	ns.Start()
   193  
   194  	check := func(expectOld, expectNew Flags, expectChange bool) {
   195  		if expectChange {
   196  			select {
   197  			case c := <-set:
   198  				if !c.old.Equals(expectOld) {
   199  					t.Fatalf("Old state mismatch")
   200  				}
   201  				if !c.new.Equals(expectNew) {
   202  					t.Fatalf("New state mismatch")
   203  				}
   204  			case <-time.After(time.Second):
   205  			}
   206  			return
   207  		}
   208  		select {
   209  		case <-set:
   210  			t.Fatalf("Unexpected change")
   211  		case <-time.After(time.Millisecond * 100):
   212  			return
   213  		}
   214  	}
   215  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
   216  	check(Flags{}, flags[0], true)
   217  
   218  	ns.SetState(testNode(1), flags[1], Flags{}, 0)
   219  	check(flags[0], flags[0].Or(flags[1]), true)
   220  
   221  	ns.SetState(testNode(1), flags[2], Flags{}, 0)
   222  	check(Flags{}, Flags{}, false)
   223  
   224  	ns.SetState(testNode(1), Flags{}, flags[0], 0)
   225  	check(flags[0].Or(flags[1]), flags[1], true)
   226  
   227  	ns.SetState(testNode(1), Flags{}, flags[1], 0)
   228  	check(flags[1], Flags{}, true)
   229  
   230  	ns.SetState(testNode(1), Flags{}, flags[2], 0)
   231  	check(Flags{}, Flags{}, false)
   232  
   233  	ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second)
   234  	check(Flags{}, flags[0].Or(flags[1]), true)
   235  	clock.Run(time.Second)
   236  	check(flags[0].Or(flags[1]), Flags{}, true)
   237  }
   238  
   239  func uint64FieldEnc(field interface{}) ([]byte, error) {
   240  	if u, ok := field.(uint64); ok {
   241  		enc, err := rlp.EncodeToBytes(&u)
   242  		return enc, err
   243  	} else {
   244  		return nil, errors.New("invalid field type")
   245  	}
   246  }
   247  
   248  func uint64FieldDec(enc []byte) (interface{}, error) {
   249  	var u uint64
   250  	err := rlp.DecodeBytes(enc, &u)
   251  	return u, err
   252  }
   253  
   254  func stringFieldEnc(field interface{}) ([]byte, error) {
   255  	if s, ok := field.(string); ok {
   256  		return []byte(s), nil
   257  	} else {
   258  		return nil, errors.New("invalid field type")
   259  	}
   260  }
   261  
   262  func stringFieldDec(enc []byte) (interface{}, error) {
   263  	return string(enc), nil
   264  }
   265  
   266  func TestPersistentFields(t *testing.T) {
   267  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   268  
   269  	s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")})
   270  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   271  
   272  	ns.Start()
   273  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
   274  	ns.SetField(testNode(1), fields[0], uint64(100))
   275  	ns.SetField(testNode(1), fields[1], "hello world")
   276  	ns.Stop()
   277  
   278  	ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   279  
   280  	ns2.Start()
   281  	field0 := ns2.GetField(testNode(1), fields[0])
   282  	if !reflect.DeepEqual(field0, uint64(100)) {
   283  		t.Fatalf("Field changed")
   284  	}
   285  	field1 := ns2.GetField(testNode(1), fields[1])
   286  	if !reflect.DeepEqual(field1, "hello world") {
   287  		t.Fatalf("Field changed")
   288  	}
   289  
   290  	s.Version++
   291  	ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   292  	ns3.Start()
   293  	if ns3.GetField(testNode(1), fields[0]) != nil {
   294  		t.Fatalf("Old field version should have been discarded")
   295  	}
   296  }
   297  
   298  func TestFieldSub(t *testing.T) {
   299  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   300  
   301  	s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))})
   302  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   303  
   304  	var (
   305  		lastState                  Flags
   306  		lastOldValue, lastNewValue interface{}
   307  	)
   308  	ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
   309  		lastState, lastOldValue, lastNewValue = state, oldValue, newValue
   310  	})
   311  	check := func(state Flags, oldValue, newValue interface{}) {
   312  		if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue {
   313  			t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue)
   314  		}
   315  	}
   316  	ns.Start()
   317  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
   318  	ns.SetField(testNode(1), fields[0], uint64(100))
   319  	check(flags[0], nil, uint64(100))
   320  	ns.Stop()
   321  	check(s.OfflineFlag(), uint64(100), nil)
   322  
   323  	ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   324  	ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) {
   325  		lastState, lastOldValue, lastNewValue = state, oldValue, newValue
   326  	})
   327  	ns2.Start()
   328  	check(s.OfflineFlag(), nil, uint64(100))
   329  	ns2.SetState(testNode(1), Flags{}, flags[0], 0)
   330  	ns2.SetField(testNode(1), fields[0], nil)
   331  	check(Flags{}, uint64(100), nil)
   332  	ns2.Stop()
   333  }
   334  
   335  func TestDuplicatedFlags(t *testing.T) {
   336  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   337  
   338  	s, flags, _ := testSetup([]bool{true}, nil)
   339  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   340  
   341  	type change struct{ old, new Flags }
   342  	set := make(chan change, 1)
   343  	ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
   344  		set <- change{oldState, newState}
   345  	})
   346  
   347  	ns.Start()
   348  	defer ns.Stop()
   349  
   350  	check := func(expectOld, expectNew Flags, expectChange bool) {
   351  		if expectChange {
   352  			select {
   353  			case c := <-set:
   354  				if !c.old.Equals(expectOld) {
   355  					t.Fatalf("Old state mismatch")
   356  				}
   357  				if !c.new.Equals(expectNew) {
   358  					t.Fatalf("New state mismatch")
   359  				}
   360  			case <-time.After(time.Second):
   361  			}
   362  			return
   363  		}
   364  		select {
   365  		case <-set:
   366  			t.Fatalf("Unexpected change")
   367  		case <-time.After(time.Millisecond * 100):
   368  			return
   369  		}
   370  	}
   371  	ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
   372  	check(Flags{}, flags[0], true)
   373  	ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s
   374  	check(Flags{}, flags[0], false)
   375  
   376  	clock.Run(2 * time.Second)
   377  	check(flags[0], Flags{}, true)
   378  }
   379  
   380  func TestCallbackOrder(t *testing.T) {
   381  	mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
   382  
   383  	s, flags, _ := testSetup([]bool{false, false, false, false}, nil)
   384  	ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
   385  
   386  	ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
   387  		if newState.Equals(flags[0]) {
   388  			ns.SetStateSub(n, flags[1], Flags{}, 0)
   389  			ns.SetStateSub(n, flags[2], Flags{}, 0)
   390  		}
   391  	})
   392  	ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) {
   393  		if newState.Equals(flags[1]) {
   394  			ns.SetStateSub(n, flags[3], Flags{}, 0)
   395  		}
   396  	})
   397  	lastState := Flags{}
   398  	ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) {
   399  		if !oldState.Equals(lastState) {
   400  			t.Fatalf("Wrong callback order")
   401  		}
   402  		lastState = newState
   403  	})
   404  
   405  	ns.Start()
   406  	defer ns.Stop()
   407  
   408  	ns.SetState(testNode(1), flags[0], Flags{}, 0)
   409  }