github.com/MetalBlockchain/metalgo@v1.11.9/snow/validators/test_state.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package validators
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/MetalBlockchain/metalgo/ids"
    14  )
    15  
    16  var (
    17  	errMinimumHeight   = errors.New("unexpectedly called GetMinimumHeight")
    18  	errCurrentHeight   = errors.New("unexpectedly called GetCurrentHeight")
    19  	errSubnetID        = errors.New("unexpectedly called GetSubnetID")
    20  	errGetValidatorSet = errors.New("unexpectedly called GetValidatorSet")
    21  )
    22  
    23  var _ State = (*TestState)(nil)
    24  
    25  type TestState struct {
    26  	T testing.TB
    27  
    28  	CantGetMinimumHeight,
    29  	CantGetCurrentHeight,
    30  	CantGetSubnetID,
    31  	CantGetValidatorSet bool
    32  
    33  	GetMinimumHeightF func(ctx context.Context) (uint64, error)
    34  	GetCurrentHeightF func(ctx context.Context) (uint64, error)
    35  	GetSubnetIDF      func(ctx context.Context, chainID ids.ID) (ids.ID, error)
    36  	GetValidatorSetF  func(ctx context.Context, height uint64, subnetID ids.ID) (map[ids.NodeID]*GetValidatorOutput, error)
    37  }
    38  
    39  func (vm *TestState) GetMinimumHeight(ctx context.Context) (uint64, error) {
    40  	if vm.GetMinimumHeightF != nil {
    41  		return vm.GetMinimumHeightF(ctx)
    42  	}
    43  	if vm.CantGetMinimumHeight && vm.T != nil {
    44  		require.FailNow(vm.T, errMinimumHeight.Error())
    45  	}
    46  	return 0, errMinimumHeight
    47  }
    48  
    49  func (vm *TestState) GetCurrentHeight(ctx context.Context) (uint64, error) {
    50  	if vm.GetCurrentHeightF != nil {
    51  		return vm.GetCurrentHeightF(ctx)
    52  	}
    53  	if vm.CantGetCurrentHeight && vm.T != nil {
    54  		require.FailNow(vm.T, errCurrentHeight.Error())
    55  	}
    56  	return 0, errCurrentHeight
    57  }
    58  
    59  func (vm *TestState) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) {
    60  	if vm.GetSubnetIDF != nil {
    61  		return vm.GetSubnetIDF(ctx, chainID)
    62  	}
    63  	if vm.CantGetSubnetID && vm.T != nil {
    64  		require.FailNow(vm.T, errSubnetID.Error())
    65  	}
    66  	return ids.Empty, errSubnetID
    67  }
    68  
    69  func (vm *TestState) GetValidatorSet(
    70  	ctx context.Context,
    71  	height uint64,
    72  	subnetID ids.ID,
    73  ) (map[ids.NodeID]*GetValidatorOutput, error) {
    74  	if vm.GetValidatorSetF != nil {
    75  		return vm.GetValidatorSetF(ctx, height, subnetID)
    76  	}
    77  	if vm.CantGetValidatorSet && vm.T != nil {
    78  		require.FailNow(vm.T, errGetValidatorSet.Error())
    79  	}
    80  	return nil, errGetValidatorSet
    81  }