github.com/MetalBlockchain/metalgo@v1.11.9/snow/engine/snowman/block/test_vm.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package block
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/metalgo/snow/consensus/snowman"
    14  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    15  )
    16  
    17  var (
    18  	errBuildBlock         = errors.New("unexpectedly called BuildBlock")
    19  	errParseBlock         = errors.New("unexpectedly called ParseBlock")
    20  	errGetBlock           = errors.New("unexpectedly called GetBlock")
    21  	errLastAccepted       = errors.New("unexpectedly called LastAccepted")
    22  	errGetBlockIDAtHeight = errors.New("unexpectedly called GetBlockIDAtHeight")
    23  
    24  	_ ChainVM = (*TestVM)(nil)
    25  )
    26  
    27  // TestVM is a ChainVM that is useful for testing.
    28  type TestVM struct {
    29  	common.TestVM
    30  
    31  	CantBuildBlock,
    32  	CantParseBlock,
    33  	CantGetBlock,
    34  	CantSetPreference,
    35  	CantLastAccepted,
    36  	CantGetBlockIDAtHeight bool
    37  
    38  	BuildBlockF         func(context.Context) (snowman.Block, error)
    39  	ParseBlockF         func(context.Context, []byte) (snowman.Block, error)
    40  	GetBlockF           func(context.Context, ids.ID) (snowman.Block, error)
    41  	SetPreferenceF      func(context.Context, ids.ID) error
    42  	LastAcceptedF       func(context.Context) (ids.ID, error)
    43  	GetBlockIDAtHeightF func(ctx context.Context, height uint64) (ids.ID, error)
    44  }
    45  
    46  func (vm *TestVM) Default(cant bool) {
    47  	vm.TestVM.Default(cant)
    48  
    49  	vm.CantBuildBlock = cant
    50  	vm.CantParseBlock = cant
    51  	vm.CantGetBlock = cant
    52  	vm.CantSetPreference = cant
    53  	vm.CantLastAccepted = cant
    54  }
    55  
    56  func (vm *TestVM) BuildBlock(ctx context.Context) (snowman.Block, error) {
    57  	if vm.BuildBlockF != nil {
    58  		return vm.BuildBlockF(ctx)
    59  	}
    60  	if vm.CantBuildBlock && vm.T != nil {
    61  		require.FailNow(vm.T, errBuildBlock.Error())
    62  	}
    63  	return nil, errBuildBlock
    64  }
    65  
    66  func (vm *TestVM) ParseBlock(ctx context.Context, b []byte) (snowman.Block, error) {
    67  	if vm.ParseBlockF != nil {
    68  		return vm.ParseBlockF(ctx, b)
    69  	}
    70  	if vm.CantParseBlock && vm.T != nil {
    71  		require.FailNow(vm.T, errParseBlock.Error())
    72  	}
    73  	return nil, errParseBlock
    74  }
    75  
    76  func (vm *TestVM) GetBlock(ctx context.Context, id ids.ID) (snowman.Block, error) {
    77  	if vm.GetBlockF != nil {
    78  		return vm.GetBlockF(ctx, id)
    79  	}
    80  	if vm.CantGetBlock && vm.T != nil {
    81  		require.FailNow(vm.T, errGetBlock.Error())
    82  	}
    83  	return nil, errGetBlock
    84  }
    85  
    86  func (vm *TestVM) SetPreference(ctx context.Context, id ids.ID) error {
    87  	if vm.SetPreferenceF != nil {
    88  		return vm.SetPreferenceF(ctx, id)
    89  	}
    90  	if vm.CantSetPreference && vm.T != nil {
    91  		require.FailNow(vm.T, "Unexpectedly called SetPreference")
    92  	}
    93  	return nil
    94  }
    95  
    96  func (vm *TestVM) LastAccepted(ctx context.Context) (ids.ID, error) {
    97  	if vm.LastAcceptedF != nil {
    98  		return vm.LastAcceptedF(ctx)
    99  	}
   100  	if vm.CantLastAccepted && vm.T != nil {
   101  		require.FailNow(vm.T, errLastAccepted.Error())
   102  	}
   103  	return ids.Empty, errLastAccepted
   104  }
   105  
   106  func (vm *TestVM) GetBlockIDAtHeight(ctx context.Context, height uint64) (ids.ID, error) {
   107  	if vm.GetBlockIDAtHeightF != nil {
   108  		return vm.GetBlockIDAtHeightF(ctx, height)
   109  	}
   110  	if vm.CantGetBlockIDAtHeight && vm.T != nil {
   111  		require.FailNow(vm.T, errGetAncestor.Error())
   112  	}
   113  	return ids.Empty, errGetBlockIDAtHeight
   114  }