github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/cmd/dynamic_startup_test.go (about) 1 package cmd 2 3 import ( 4 "context" 5 "testing" 6 "time" 7 8 "github.com/stretchr/testify/require" 9 10 "github.com/onflow/flow-go/cmd/util/cmd/common" 11 "github.com/onflow/flow-go/model/flow" 12 "github.com/onflow/flow-go/state/protocol" 13 protocolmock "github.com/onflow/flow-go/state/protocol/mock" 14 "github.com/onflow/flow-go/utils/unittest" 15 "github.com/onflow/flow-go/utils/unittest/mocks" 16 ) 17 18 func dynamicJoinFlagsFixture() (string, string, flow.EpochPhase, uint64) { 19 return unittest.NetworkingPrivKeyFixture().PublicKey().String(), "access_1:9001", flow.EpochPhaseSetup, 1 20 } 21 22 func getMockSnapshot(t *testing.T, epochCounter uint64, phase flow.EpochPhase) *protocolmock.Snapshot { 23 currentEpoch := new(protocolmock.Epoch) 24 currentEpoch.On("Counter").Return(epochCounter, nil) 25 26 epochQuery := mocks.NewEpochQuery(t, epochCounter) 27 epochQuery.Add(currentEpoch) 28 29 snapshot := new(protocolmock.Snapshot) 30 snapshot.On("Epochs").Return(epochQuery) 31 snapshot.On("Phase").Return(phase, nil) 32 33 return snapshot 34 } 35 36 // TestValidateDynamicStartupFlags tests validation of dynamic-startup-* CLI flags 37 func TestValidateDynamicStartupFlags(t *testing.T) { 38 t.Run("should return nil if all flags are valid", func(t *testing.T) { 39 pub, address, phase, _ := dynamicJoinFlagsFixture() 40 err := ValidateDynamicStartupFlags(pub, address, phase) 41 require.NoError(t, err) 42 }) 43 44 t.Run("should return error if access network key is not valid ECDSA_P256 public key", func(t *testing.T) { 45 _, address, phase, _ := dynamicJoinFlagsFixture() 46 err := ValidateDynamicStartupFlags("0xKEY", address, phase) 47 require.Error(t, err) 48 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-publickey") 49 }) 50 51 t.Run("should return error if access address is empty", func(t *testing.T) { 52 pub, _, phase, _ := dynamicJoinFlagsFixture() 53 err := ValidateDynamicStartupFlags(pub, "", phase) 54 require.Error(t, err) 55 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-address") 56 }) 57 58 t.Run("should return error if startup epoch phase is invalid", func(t *testing.T) { 59 pub, address, _, _ := dynamicJoinFlagsFixture() 60 err := ValidateDynamicStartupFlags(pub, address, -1) 61 require.Error(t, err) 62 require.Contains(t, err.Error(), "invalid flag --dynamic-startup-startup-epoch-phase") 63 }) 64 } 65 66 // TestGetSnapshotAtEpochAndPhase ensures the target start epoch and phase conditions are met before returning a snapshot 67 // for a node to bootstrap with by asserting the expected number of warn/info log messages are output and the expected 68 // snapshot is returned 69 func TestGetSnapshotAtEpochAndPhase(t *testing.T) { 70 t.Run("should retry until a snapshot is observed with desired epoch/phase", func(t *testing.T) { 71 // the snapshot we will use to force GetSnapshotAtEpochAndPhase to retry 72 oldSnapshot := getMockSnapshot(t, 0, flow.EpochPhaseStaking) 73 74 // the snapshot that will return target counter and phase 75 expectedSnapshot := getMockSnapshot(t, 1, flow.EpochPhaseSetup) 76 77 // setup mock get snapshot func that will return expected snapshot after 3 invocations 78 counter := 0 79 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 80 if counter < 3 { 81 counter++ 82 return oldSnapshot, nil 83 } 84 85 return expectedSnapshot, nil 86 } 87 88 _, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture() 89 90 // get snapshot 91 actualSnapshot, err := common.GetSnapshotAtEpochAndPhase( 92 context.Background(), 93 unittest.Logger(), 94 targetEpoch, 95 targetPhase, 96 time.Nanosecond, 97 getSnapshot, 98 ) 99 require.NoError(t, err) 100 101 require.Equal(t, expectedSnapshot, actualSnapshot) 102 }) 103 104 t.Run("should return snapshot immediately if target epoch has passed", func(t *testing.T) { 105 // the snapshot that will return target counter and phase 106 // epoch > target epoch but phase < target phase 107 expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking) 108 109 // setup mock get snapshot func that will return expected snapshot after 3 invocations 110 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 111 return expectedSnapshot, nil 112 } 113 114 _, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture() 115 116 // get snapshot 117 actualSnapshot, err := common.GetSnapshotAtEpochAndPhase( 118 context.Background(), 119 unittest.Logger(), 120 targetEpoch, 121 targetPhase, 122 time.Nanosecond, 123 getSnapshot, 124 ) 125 require.NoError(t, err) 126 require.Equal(t, expectedSnapshot, actualSnapshot) 127 }) 128 129 t.Run("should return snapshot after target phase is reached if target epoch is the same as current", func(t *testing.T) { 130 oldSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking) 131 expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseSetup) 132 133 counter := 0 134 // setup mock get snapshot func that will return expected snapshot after 3 invocations 135 getSnapshot := func(_ context.Context) (protocol.Snapshot, error) { 136 if counter < 3 { 137 counter++ 138 return oldSnapshot, nil 139 } 140 141 return expectedSnapshot, nil 142 } 143 144 _, _, targetPhase, _ := dynamicJoinFlagsFixture() 145 146 // get snapshot 147 actualSnapshot, err := common.GetSnapshotAtEpochAndPhase( 148 context.Background(), 149 unittest.Logger(), 150 5, 151 targetPhase, 152 time.Nanosecond, 153 getSnapshot, 154 ) 155 require.NoError(t, err) 156 require.Equal(t, expectedSnapshot, actualSnapshot) 157 }) 158 }