github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/test_state.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package validators 5 6 import ( 7 "context" 8 "errors" 9 "testing" 10 11 "github.com/stretchr/testify/require" 12 13 "github.com/MetalBlockchain/metalgo/ids" 14 ) 15 16 var ( 17 errMinimumHeight = errors.New("unexpectedly called GetMinimumHeight") 18 errCurrentHeight = errors.New("unexpectedly called GetCurrentHeight") 19 errSubnetID = errors.New("unexpectedly called GetSubnetID") 20 errGetValidatorSet = errors.New("unexpectedly called GetValidatorSet") 21 ) 22 23 var _ State = (*TestState)(nil) 24 25 type TestState struct { 26 T testing.TB 27 28 CantGetMinimumHeight, 29 CantGetCurrentHeight, 30 CantGetSubnetID, 31 CantGetValidatorSet bool 32 33 GetMinimumHeightF func(ctx context.Context) (uint64, error) 34 GetCurrentHeightF func(ctx context.Context) (uint64, error) 35 GetSubnetIDF func(ctx context.Context, chainID ids.ID) (ids.ID, error) 36 GetValidatorSetF func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*GetValidatorOutput, error) 37 } 38 39 func (vm *TestState) GetMinimumHeight(ctx context.Context) (uint64, error) { 40 if vm.GetMinimumHeightF != nil { 41 return vm.GetMinimumHeightF(ctx) 42 } 43 if vm.CantGetMinimumHeight && vm.T != nil { 44 require.FailNow(vm.T, errMinimumHeight.Error()) 45 } 46 return 0, errMinimumHeight 47 } 48 49 func (vm *TestState) GetCurrentHeight(ctx context.Context) (uint64, error) { 50 if vm.GetCurrentHeightF != nil { 51 return vm.GetCurrentHeightF(ctx) 52 } 53 if vm.CantGetCurrentHeight && vm.T != nil { 54 require.FailNow(vm.T, errCurrentHeight.Error()) 55 } 56 return 0, errCurrentHeight 57 } 58 59 func (vm *TestState) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { 60 if vm.GetSubnetIDF != nil { 61 return vm.GetSubnetIDF(ctx, chainID) 62 } 63 if vm.CantGetSubnetID && vm.T != nil { 64 require.FailNow(vm.T, errSubnetID.Error()) 65 } 66 return ids.Empty, errSubnetID 67 } 68 69 func (vm *TestState) GetValidatorSet( 70 ctx context.Context, 71 height uint64, 72 subnetID ids.ID, 73 ) (map[ids.NodeID]*GetValidatorOutput, error) { 74 if vm.GetValidatorSetF != nil { 75 return vm.GetValidatorSetF(ctx, height, subnetID) 76 } 77 if vm.CantGetValidatorSet && vm.T != nil { 78 require.FailNow(vm.T, errGetValidatorSet.Error()) 79 } 80 return nil, errGetValidatorSet 81 }