github.com/koko1123/flow-go-1@v0.29.6/cmd/dynamic_startup_test.go (about) 1 package cmd 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/stretchr/testify/require" 9 10 "github.com/koko1123/flow-go-1/model/flow" 11 "github.com/koko1123/flow-go-1/state/protocol" 12 protocolmock "github.com/koko1123/flow-go-1/state/protocol/mock" 13 "github.com/koko1123/flow-go-1/utils/unittest" 14 "github.com/koko1123/flow-go-1/utils/unittest/mocks" 15 ) 16 17 func dynamicJoinFlagsFixture() (string, string, flow.EpochPhase, uint64) { 18 return unittest.NetworkingPrivKeyFixture().PublicKey().String(), "access_1:9001", flow.EpochPhaseSetup, 1 19 } 20 21 func getMockSnapshot(t *testing.T, epochCounter uint64, phase flow.EpochPhase) *protocolmock.Snapshot { 22 currentEpoch := new(protocolmock.Epoch) 23 currentEpoch.On("Counter").Return(epochCounter, nil) 24 25 epochQuery := mocks.NewEpochQuery(t, epochCounter) 26 epochQuery.Add(currentEpoch) 27 28 snapshot := new(protocolmock.Snapshot) 29 snapshot.On("Epochs").Return(epochQuery) 30 snapshot.On("Phase").Return(phase, nil) 31 32 return snapshot 33 } 34 35 // TestValidateDynamicStartupFlags tests validation of dynamic-startup-* CLI flags 36 func TestValidateDynamicStartupFlags(t *testing.T) { 37 t.Run("should return nil if all flags are valid", func(t *testing.T) { 38 pub, address, phase, _ := dynamicJoinFlagsFixture() 39 err := ValidateDynamicStartupFlags(pub, address, phase) 40 require.NoError(t, err) 41 }) 42 43 t.Run("should return error if access network key is not valid ECDSA_P256 public key", func(t *testing.T) { 44 _, address, phase, _ := dynamicJoinFlagsFixture() 45 err := ValidateDynamicStartupFlags("0xKEY", address, phase) 46 require.Error(t, err) 47 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-publickey") 48 }) 49 50 t.Run("should return error if access address is empty", func(t *testing.T) { 51 pub, _, phase, _ := dynamicJoinFlagsFixture() 52 err := ValidateDynamicStartupFlags(pub, "", phase) 53 require.Error(t, err) 54 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-address") 55 }) 56 57 t.Run("should return error if startup epoch phase is invalid", func(t *testing.T) { 58 pub, address, _, _ := dynamicJoinFlagsFixture() 59 err := ValidateDynamicStartupFlags(pub, address, -1) 60 require.Error(t, err) 61 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-startup-epoch-phase") 62 }) 63 } 64 65 // TestGetSnapshotAtEpochAndPhase ensures the target start epoch and phase conditions are met before returning a snapshot 66 // for a node to bootstrap with by asserting the expected number of warn/info log messages are output and the expected 67 // snapshot is returned 68 func TestGetSnapshotAtEpochAndPhase(t *testing.T) { 69 t.Run("should retry until a snapshot is observed with desired epoch/phase", func(t *testing.T) { 70 // the snapshot we will use to force GetSnapshotAtEpochAndPhase to retry 71 oldSnapshot := getMockSnapshot(t, 0, flow.EpochPhaseStaking) 72 73 // the snapshot that will return target counter and phase 74 expectedSnapshot := getMockSnapshot(t, 1, flow.EpochPhaseSetup) 75 76 // setup mock get snapshot func that will return expected snapshot after 3 invocations 77 counter := 0 78 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 79 if counter < 3 { 80 counter++ 81 return oldSnapshot, nil 82 } 83 84 return expectedSnapshot, nil 85 } 86 87 _, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture() 88 89 // get snapshot 90 actualSnapshot, err := GetSnapshotAtEpochAndPhase( 91 context.Background(), 92 unittest.Logger(), 93 targetEpoch, 94 targetPhase, 95 time.Nanosecond, 96 getSnapshot, 97 ) 98 require.NoError(t, err) 99 100 require.Equal(t, expectedSnapshot, actualSnapshot) 101 }) 102 103 t.Run("should return snapshot immediately if target epoch has passed", func(t *testing.T) { 104 // the snapshot that will return target counter and phase 105 // epoch > target epoch but phase < target phase 106 expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking) 107 108 // setup mock get snapshot func that will return expected snapshot after 3 invocations 109 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 110 return expectedSnapshot, nil 111 } 112 113 _, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture() 114 115 // get snapshot 116 actualSnapshot, err := GetSnapshotAtEpochAndPhase( 117 context.Background(), 118 unittest.Logger(), 119 targetEpoch, 120 targetPhase, 121 time.Nanosecond, 122 getSnapshot, 123 ) 124 require.NoError(t, err) 125 require.Equal(t, expectedSnapshot, actualSnapshot) 126 }) 127 128 t.Run("should return snapshot after target phase is reached if target epoch is the same as current", func(t *testing.T) { 129 oldSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking) 130 expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseSetup) 131 132 counter := 0 133 // setup mock get snapshot func that will return expected snapshot after 3 invocations 134 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 135 if counter < 3 { 136 counter++ 137 return oldSnapshot, nil 138 } 139 140 return expectedSnapshot, nil 141 } 142 143 _, _, targetPhase, _ := dynamicJoinFlagsFixture() 144 145 // get snapshot 146 actualSnapshot, err := GetSnapshotAtEpochAndPhase( 147 context.Background(), 148 unittest.Logger(), 149 5, 150 targetPhase, 151 time.Nanosecond, 152 getSnapshot, 153 ) 154 require.NoError(t, err) 155 require.Equal(t, expectedSnapshot, actualSnapshot) 156 }) 157 }