github.com/koko1123/flow-go-1@v0.29.6/cmd/dynamic_startup_test.go (about)

     1  package cmd
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/require"
     9  
    10  	"github.com/koko1123/flow-go-1/model/flow"
    11  	"github.com/koko1123/flow-go-1/state/protocol"
    12  	protocolmock "github.com/koko1123/flow-go-1/state/protocol/mock"
    13  	"github.com/koko1123/flow-go-1/utils/unittest"
    14  	"github.com/koko1123/flow-go-1/utils/unittest/mocks"
    15  )
    16  
    17  func dynamicJoinFlagsFixture() (string, string, flow.EpochPhase, uint64) {
    18  	return unittest.NetworkingPrivKeyFixture().PublicKey().String(), "access_1:9001", flow.EpochPhaseSetup, 1
    19  }
    20  
    21  func getMockSnapshot(t *testing.T, epochCounter uint64, phase flow.EpochPhase) *protocolmock.Snapshot {
    22  	currentEpoch := new(protocolmock.Epoch)
    23  	currentEpoch.On("Counter").Return(epochCounter, nil)
    24  
    25  	epochQuery := mocks.NewEpochQuery(t, epochCounter)
    26  	epochQuery.Add(currentEpoch)
    27  
    28  	snapshot := new(protocolmock.Snapshot)
    29  	snapshot.On("Epochs").Return(epochQuery)
    30  	snapshot.On("Phase").Return(phase, nil)
    31  
    32  	return snapshot
    33  }
    34  
    35  // TestValidateDynamicStartupFlags tests validation of dynamic-startup-* CLI flags
    36  func TestValidateDynamicStartupFlags(t *testing.T) {
    37  	t.Run("should return nil if all flags are valid", func(t *testing.T) {
    38  		pub, address, phase, _ := dynamicJoinFlagsFixture()
    39  		err := ValidateDynamicStartupFlags(pub, address, phase)
    40  		require.NoError(t, err)
    41  	})
    42  
    43  	t.Run("should return error if access network key is not valid ECDSA_P256 public key", func(t *testing.T) {
    44  		_, address, phase, _ := dynamicJoinFlagsFixture()
    45  		err := ValidateDynamicStartupFlags("0xKEY", address, phase)
    46  		require.Error(t, err)
    47  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-publickey")
    48  	})
    49  
    50  	t.Run("should return error if access address is empty", func(t *testing.T) {
    51  		pub, _, phase, _ := dynamicJoinFlagsFixture()
    52  		err := ValidateDynamicStartupFlags(pub, "", phase)
    53  		require.Error(t, err)
    54  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-address")
    55  	})
    56  
    57  	t.Run("should return error if startup epoch phase is invalid", func(t *testing.T) {
    58  		pub, address, _, _ := dynamicJoinFlagsFixture()
    59  		err := ValidateDynamicStartupFlags(pub, address, -1)
    60  		require.Error(t, err)
    61  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-startup-epoch-phase")
    62  	})
    63  }
    64  
    65  // TestGetSnapshotAtEpochAndPhase ensures the target start epoch and phase conditions are met before returning a snapshot
    66  // for a node to bootstrap with by asserting the expected number of warn/info log messages are output and the expected
    67  // snapshot is returned
    68  func TestGetSnapshotAtEpochAndPhase(t *testing.T) {
    69  	t.Run("should retry until a snapshot is observed with desired epoch/phase", func(t *testing.T) {
    70  		// the snapshot we will use to force GetSnapshotAtEpochAndPhase to retry
    71  		oldSnapshot := getMockSnapshot(t, 0, flow.EpochPhaseStaking)
    72  
    73  		// the snapshot that will return target counter and phase
    74  		expectedSnapshot := getMockSnapshot(t, 1, flow.EpochPhaseSetup)
    75  
    76  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
    77  		counter := 0
    78  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
    79  			if counter < 3 {
    80  				counter++
    81  				return oldSnapshot, nil
    82  			}
    83  
    84  			return expectedSnapshot, nil
    85  		}
    86  
    87  		_, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture()
    88  
    89  		// get snapshot
    90  		actualSnapshot, err := GetSnapshotAtEpochAndPhase(
    91  			context.Background(),
    92  			unittest.Logger(),
    93  			targetEpoch,
    94  			targetPhase,
    95  			time.Nanosecond,
    96  			getSnapshot,
    97  		)
    98  		require.NoError(t, err)
    99  
   100  		require.Equal(t, expectedSnapshot, actualSnapshot)
   101  	})
   102  
   103  	t.Run("should return snapshot immediately if target epoch has passed", func(t *testing.T) {
   104  		// the snapshot that will return target counter and phase
   105  		// epoch > target epoch but phase < target phase
   106  		expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking)
   107  
   108  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
   109  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
   110  			return expectedSnapshot, nil
   111  		}
   112  
   113  		_, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture()
   114  
   115  		// get snapshot
   116  		actualSnapshot, err := GetSnapshotAtEpochAndPhase(
   117  			context.Background(),
   118  			unittest.Logger(),
   119  			targetEpoch,
   120  			targetPhase,
   121  			time.Nanosecond,
   122  			getSnapshot,
   123  		)
   124  		require.NoError(t, err)
   125  		require.Equal(t, expectedSnapshot, actualSnapshot)
   126  	})
   127  
   128  	t.Run("should return snapshot after target phase is reached if target epoch is the same as current", func(t *testing.T) {
   129  		oldSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking)
   130  		expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseSetup)
   131  
   132  		counter := 0
   133  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
   134  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
   135  			if counter < 3 {
   136  				counter++
   137  				return oldSnapshot, nil
   138  			}
   139  
   140  			return expectedSnapshot, nil
   141  		}
   142  
   143  		_, _, targetPhase, _ := dynamicJoinFlagsFixture()
   144  
   145  		// get snapshot
   146  		actualSnapshot, err := GetSnapshotAtEpochAndPhase(
   147  			context.Background(),
   148  			unittest.Logger(),
   149  			5,
   150  			targetPhase,
   151  			time.Nanosecond,
   152  			getSnapshot,
   153  		)
   154  		require.NoError(t, err)
   155  		require.Equal(t, expectedSnapshot, actualSnapshot)
   156  	})
   157  }