github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/cmd/dynamic_startup_test.go (about)

     1  package cmd
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/stretchr/testify/require"
     9  
    10  	"github.com/onflow/flow-go/cmd/util/cmd/common"
    11  	"github.com/onflow/flow-go/model/flow"
    12  	"github.com/onflow/flow-go/state/protocol"
    13  	protocolmock "github.com/onflow/flow-go/state/protocol/mock"
    14  	"github.com/onflow/flow-go/utils/unittest"
    15  	"github.com/onflow/flow-go/utils/unittest/mocks"
    16  )
    17  
    18  func dynamicJoinFlagsFixture() (string, string, flow.EpochPhase, uint64) {
    19  	return unittest.NetworkingPrivKeyFixture().PublicKey().String(), "access_1:9001", flow.EpochPhaseSetup, 1
    20  }
    21  
    22  func getMockSnapshot(t *testing.T, epochCounter uint64, phase flow.EpochPhase) *protocolmock.Snapshot {
    23  	currentEpoch := new(protocolmock.Epoch)
    24  	currentEpoch.On("Counter").Return(epochCounter, nil)
    25  
    26  	epochQuery := mocks.NewEpochQuery(t, epochCounter)
    27  	epochQuery.Add(currentEpoch)
    28  
    29  	snapshot := new(protocolmock.Snapshot)
    30  	snapshot.On("Epochs").Return(epochQuery)
    31  	snapshot.On("Phase").Return(phase, nil)
    32  
    33  	return snapshot
    34  }
    35  
    36  // TestValidateDynamicStartupFlags tests validation of dynamic-startup-* CLI flags
    37  func TestValidateDynamicStartupFlags(t *testing.T) {
    38  	t.Run("should return nil if all flags are valid", func(t *testing.T) {
    39  		pub, address, phase, _ := dynamicJoinFlagsFixture()
    40  		err := ValidateDynamicStartupFlags(pub, address, phase)
    41  		require.NoError(t, err)
    42  	})
    43  
    44  	t.Run("should return error if access network key is not valid ECDSA_P256 public key", func(t *testing.T) {
    45  		_, address, phase, _ := dynamicJoinFlagsFixture()
    46  		err := ValidateDynamicStartupFlags("0xKEY", address, phase)
    47  		require.Error(t, err)
    48  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-publickey")
    49  	})
    50  
    51  	t.Run("should return error if access address is empty", func(t *testing.T) {
    52  		pub, _, phase, _ := dynamicJoinFlagsFixture()
    53  		err := ValidateDynamicStartupFlags(pub, "", phase)
    54  		require.Error(t, err)
    55  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-access-address")
    56  	})
    57  
    58  	t.Run("should return error if startup epoch phase is invalid", func(t *testing.T) {
    59  		pub, address, _, _ := dynamicJoinFlagsFixture()
    60  		err := ValidateDynamicStartupFlags(pub, address, -1)
    61  		require.Error(t, err)
    62  		require.Contains(t, err.Error(), "invalid flag --dynamic-startup-startup-epoch-phase")
    63  	})
    64  }
    65  
    66  // TestGetSnapshotAtEpochAndPhase ensures the target start epoch and phase conditions are met before returning a snapshot
    67  // for a node to bootstrap with by asserting the expected number of warn/info log messages are output and the expected
    68  // snapshot is returned
    69  func TestGetSnapshotAtEpochAndPhase(t *testing.T) {
    70  	t.Run("should retry until a snapshot is observed with desired epoch/phase", func(t *testing.T) {
    71  		// the snapshot we will use to force GetSnapshotAtEpochAndPhase to retry
    72  		oldSnapshot := getMockSnapshot(t, 0, flow.EpochPhaseStaking)
    73  
    74  		// the snapshot that will return target counter and phase
    75  		expectedSnapshot := getMockSnapshot(t, 1, flow.EpochPhaseSetup)
    76  
    77  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
    78  		counter := 0
    79  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
    80  			if counter < 3 {
    81  				counter++
    82  				return oldSnapshot, nil
    83  			}
    84  
    85  			return expectedSnapshot, nil
    86  		}
    87  
    88  		_, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture()
    89  
    90  		// get snapshot
    91  		actualSnapshot, err := common.GetSnapshotAtEpochAndPhase(
    92  			context.Background(),
    93  			unittest.Logger(),
    94  			targetEpoch,
    95  			targetPhase,
    96  			time.Nanosecond,
    97  			getSnapshot,
    98  		)
    99  		require.NoError(t, err)
   100  
   101  		require.Equal(t, expectedSnapshot, actualSnapshot)
   102  	})
   103  
   104  	t.Run("should return snapshot immediately if target epoch has passed", func(t *testing.T) {
   105  		// the snapshot that will return target counter and phase
   106  		// epoch > target epoch but phase < target phase
   107  		expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking)
   108  
   109  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
   110  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
   111  			return expectedSnapshot, nil
   112  		}
   113  
   114  		_, _, targetPhase, targetEpoch := dynamicJoinFlagsFixture()
   115  
   116  		// get snapshot
   117  		actualSnapshot, err := common.GetSnapshotAtEpochAndPhase(
   118  			context.Background(),
   119  			unittest.Logger(),
   120  			targetEpoch,
   121  			targetPhase,
   122  			time.Nanosecond,
   123  			getSnapshot,
   124  		)
   125  		require.NoError(t, err)
   126  		require.Equal(t, expectedSnapshot, actualSnapshot)
   127  	})
   128  
   129  	t.Run("should return snapshot after target phase is reached if target epoch is the same as current", func(t *testing.T) {
   130  		oldSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseStaking)
   131  		expectedSnapshot := getMockSnapshot(t, 5, flow.EpochPhaseSetup)
   132  
   133  		counter := 0
   134  		// setup mock get snapshot func that will return expected snapshot after 3 invocations
   135  		getSnapshot := func(_ context.Context) (protocol.Snapshot, error) {
   136  			if counter < 3 {
   137  				counter++
   138  				return oldSnapshot, nil
   139  			}
   140  
   141  			return expectedSnapshot, nil
   142  		}
   143  
   144  		_, _, targetPhase, _ := dynamicJoinFlagsFixture()
   145  
   146  		// get snapshot
   147  		actualSnapshot, err := common.GetSnapshotAtEpochAndPhase(
   148  			context.Background(),
   149  			unittest.Logger(),
   150  			5,
   151  			targetPhase,
   152  			time.Nanosecond,
   153  			getSnapshot,
   154  		)
   155  		require.NoError(t, err)
   156  		require.Equal(t, expectedSnapshot, actualSnapshot)
   157  	})
   158  }