github.com/MetalBlockchain/metalgo@v1.11.9/snow/engine/common/test_vm.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package common
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net/http"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/MetalBlockchain/metalgo/database"
    16  	"github.com/MetalBlockchain/metalgo/ids"
    17  	"github.com/MetalBlockchain/metalgo/snow"
    18  	"github.com/MetalBlockchain/metalgo/version"
    19  )
    20  
    21  var (
    22  	errInitialize                 = errors.New("unexpectedly called Initialize")
    23  	errSetState                   = errors.New("unexpectedly called SetState")
    24  	errShutdown                   = errors.New("unexpectedly called Shutdown")
    25  	errCreateHandlers             = errors.New("unexpectedly called CreateHandlers")
    26  	errHealthCheck                = errors.New("unexpectedly called HealthCheck")
    27  	errConnected                  = errors.New("unexpectedly called Connected")
    28  	errDisconnected               = errors.New("unexpectedly called Disconnected")
    29  	errVersion                    = errors.New("unexpectedly called Version")
    30  	errAppRequest                 = errors.New("unexpectedly called AppRequest")
    31  	errAppResponse                = errors.New("unexpectedly called AppResponse")
    32  	errAppRequestFailed           = errors.New("unexpectedly called AppRequestFailed")
    33  	errAppGossip                  = errors.New("unexpectedly called AppGossip")
    34  	errCrossChainAppRequest       = errors.New("unexpectedly called CrossChainAppRequest")
    35  	errCrossChainAppResponse      = errors.New("unexpectedly called CrossChainAppResponse")
    36  	errCrossChainAppRequestFailed = errors.New("unexpectedly called CrossChainAppRequestFailed")
    37  
    38  	_ VM = (*TestVM)(nil)
    39  )
    40  
    41  // TestVM is a test vm
    42  type TestVM struct {
    43  	T *testing.T
    44  
    45  	CantInitialize, CantSetState,
    46  	CantShutdown, CantCreateHandlers,
    47  	CantHealthCheck, CantConnected, CantDisconnected, CantVersion,
    48  	CantAppRequest, CantAppResponse, CantAppGossip, CantAppRequestFailed,
    49  	CantCrossChainAppRequest, CantCrossChainAppResponse, CantCrossChainAppRequestFailed bool
    50  
    51  	InitializeF                 func(ctx context.Context, chainCtx *snow.Context, db database.Database, genesisBytes []byte, upgradeBytes []byte, configBytes []byte, msgChan chan<- Message, fxs []*Fx, appSender AppSender) error
    52  	SetStateF                   func(ctx context.Context, state snow.State) error
    53  	ShutdownF                   func(context.Context) error
    54  	CreateHandlersF             func(context.Context) (map[string]http.Handler, error)
    55  	ConnectedF                  func(ctx context.Context, nodeID ids.NodeID, nodeVersion *version.Application) error
    56  	DisconnectedF               func(ctx context.Context, nodeID ids.NodeID) error
    57  	HealthCheckF                func(context.Context) (interface{}, error)
    58  	AppRequestF                 func(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, msg []byte) error
    59  	AppResponseF                func(ctx context.Context, nodeID ids.NodeID, requestID uint32, msg []byte) error
    60  	AppGossipF                  func(ctx context.Context, nodeID ids.NodeID, msg []byte) error
    61  	AppRequestFailedF           func(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *AppError) error
    62  	VersionF                    func(context.Context) (string, error)
    63  	CrossChainAppRequestF       func(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, msg []byte) error
    64  	CrossChainAppResponseF      func(ctx context.Context, chainID ids.ID, requestID uint32, msg []byte) error
    65  	CrossChainAppRequestFailedF func(ctx context.Context, chainID ids.ID, requestID uint32, appErr *AppError) error
    66  }
    67  
    68  func (vm *TestVM) Default(cant bool) {
    69  	vm.CantInitialize = cant
    70  	vm.CantSetState = cant
    71  	vm.CantShutdown = cant
    72  	vm.CantCreateHandlers = cant
    73  	vm.CantHealthCheck = cant
    74  	vm.CantAppRequest = cant
    75  	vm.CantAppRequestFailed = cant
    76  	vm.CantAppResponse = cant
    77  	vm.CantAppGossip = cant
    78  	vm.CantVersion = cant
    79  	vm.CantConnected = cant
    80  	vm.CantDisconnected = cant
    81  	vm.CantCrossChainAppRequest = cant
    82  	vm.CantCrossChainAppRequestFailed = cant
    83  	vm.CantCrossChainAppResponse = cant
    84  }
    85  
    86  func (vm *TestVM) Initialize(
    87  	ctx context.Context,
    88  	chainCtx *snow.Context,
    89  	db database.Database,
    90  	genesisBytes,
    91  	upgradeBytes,
    92  	configBytes []byte,
    93  	msgChan chan<- Message,
    94  	fxs []*Fx,
    95  	appSender AppSender,
    96  ) error {
    97  	if vm.InitializeF != nil {
    98  		return vm.InitializeF(
    99  			ctx,
   100  			chainCtx,
   101  			db,
   102  			genesisBytes,
   103  			upgradeBytes,
   104  			configBytes,
   105  			msgChan,
   106  			fxs,
   107  			appSender,
   108  		)
   109  	}
   110  	if vm.CantInitialize && vm.T != nil {
   111  		require.FailNow(vm.T, errInitialize.Error())
   112  	}
   113  	return errInitialize
   114  }
   115  
   116  func (vm *TestVM) SetState(ctx context.Context, state snow.State) error {
   117  	if vm.SetStateF != nil {
   118  		return vm.SetStateF(ctx, state)
   119  	}
   120  	if vm.CantSetState {
   121  		if vm.T != nil {
   122  			require.FailNow(vm.T, errSetState.Error())
   123  		}
   124  		return errSetState
   125  	}
   126  	return nil
   127  }
   128  
   129  func (vm *TestVM) Shutdown(ctx context.Context) error {
   130  	if vm.ShutdownF != nil {
   131  		return vm.ShutdownF(ctx)
   132  	}
   133  	if vm.CantShutdown {
   134  		if vm.T != nil {
   135  			require.FailNow(vm.T, errShutdown.Error())
   136  		}
   137  		return errShutdown
   138  	}
   139  	return nil
   140  }
   141  
   142  func (vm *TestVM) CreateHandlers(ctx context.Context) (map[string]http.Handler, error) {
   143  	if vm.CreateHandlersF != nil {
   144  		return vm.CreateHandlersF(ctx)
   145  	}
   146  	if vm.CantCreateHandlers && vm.T != nil {
   147  		require.FailNow(vm.T, errCreateHandlers.Error())
   148  	}
   149  	return nil, nil
   150  }
   151  
   152  func (vm *TestVM) HealthCheck(ctx context.Context) (interface{}, error) {
   153  	if vm.HealthCheckF != nil {
   154  		return vm.HealthCheckF(ctx)
   155  	}
   156  	if vm.CantHealthCheck && vm.T != nil {
   157  		require.FailNow(vm.T, errHealthCheck.Error())
   158  	}
   159  	return nil, errHealthCheck
   160  }
   161  
   162  func (vm *TestVM) AppRequest(ctx context.Context, nodeID ids.NodeID, requestID uint32, deadline time.Time, request []byte) error {
   163  	if vm.AppRequestF != nil {
   164  		return vm.AppRequestF(ctx, nodeID, requestID, deadline, request)
   165  	}
   166  	if !vm.CantAppRequest {
   167  		return nil
   168  	}
   169  	if vm.T != nil {
   170  		require.FailNow(vm.T, errAppRequest.Error())
   171  	}
   172  	return errAppRequest
   173  }
   174  
   175  func (vm *TestVM) AppRequestFailed(ctx context.Context, nodeID ids.NodeID, requestID uint32, appErr *AppError) error {
   176  	if vm.AppRequestFailedF != nil {
   177  		return vm.AppRequestFailedF(ctx, nodeID, requestID, appErr)
   178  	}
   179  	if !vm.CantAppRequestFailed {
   180  		return nil
   181  	}
   182  	if vm.T != nil {
   183  		require.FailNow(vm.T, errAppRequestFailed.Error())
   184  	}
   185  	return errAppRequestFailed
   186  }
   187  
   188  func (vm *TestVM) AppResponse(ctx context.Context, nodeID ids.NodeID, requestID uint32, response []byte) error {
   189  	if vm.AppResponseF != nil {
   190  		return vm.AppResponseF(ctx, nodeID, requestID, response)
   191  	}
   192  	if !vm.CantAppResponse {
   193  		return nil
   194  	}
   195  	if vm.T != nil {
   196  		require.FailNow(vm.T, errAppResponse.Error())
   197  	}
   198  	return errAppResponse
   199  }
   200  
   201  func (vm *TestVM) AppGossip(ctx context.Context, nodeID ids.NodeID, msg []byte) error {
   202  	if vm.AppGossipF != nil {
   203  		return vm.AppGossipF(ctx, nodeID, msg)
   204  	}
   205  	if !vm.CantAppGossip {
   206  		return nil
   207  	}
   208  	if vm.T != nil {
   209  		require.FailNow(vm.T, errAppGossip.Error())
   210  	}
   211  	return errAppGossip
   212  }
   213  
   214  func (vm *TestVM) CrossChainAppRequest(ctx context.Context, chainID ids.ID, requestID uint32, deadline time.Time, request []byte) error {
   215  	if vm.CrossChainAppRequestF != nil {
   216  		return vm.CrossChainAppRequestF(ctx, chainID, requestID, deadline, request)
   217  	}
   218  	if !vm.CantCrossChainAppRequest {
   219  		return nil
   220  	}
   221  	if vm.T != nil {
   222  		require.FailNow(vm.T, errCrossChainAppRequest.Error())
   223  	}
   224  	return errCrossChainAppRequest
   225  }
   226  
   227  func (vm *TestVM) CrossChainAppRequestFailed(ctx context.Context, chainID ids.ID, requestID uint32, appErr *AppError) error {
   228  	if vm.CrossChainAppRequestFailedF != nil {
   229  		return vm.CrossChainAppRequestFailedF(ctx, chainID, requestID, appErr)
   230  	}
   231  	if !vm.CantCrossChainAppRequestFailed {
   232  		return nil
   233  	}
   234  	if vm.T != nil {
   235  		require.FailNow(vm.T, errCrossChainAppRequestFailed.Error())
   236  	}
   237  	return errCrossChainAppRequestFailed
   238  }
   239  
   240  func (vm *TestVM) CrossChainAppResponse(ctx context.Context, chainID ids.ID, requestID uint32, response []byte) error {
   241  	if vm.CrossChainAppResponseF != nil {
   242  		return vm.CrossChainAppResponseF(ctx, chainID, requestID, response)
   243  	}
   244  	if !vm.CantCrossChainAppResponse {
   245  		return nil
   246  	}
   247  	if vm.T != nil {
   248  		require.FailNow(vm.T, errCrossChainAppResponse.Error())
   249  	}
   250  	return errCrossChainAppResponse
   251  }
   252  
   253  func (vm *TestVM) Connected(ctx context.Context, id ids.NodeID, nodeVersion *version.Application) error {
   254  	if vm.ConnectedF != nil {
   255  		return vm.ConnectedF(ctx, id, nodeVersion)
   256  	}
   257  	if vm.CantConnected && vm.T != nil {
   258  		require.FailNow(vm.T, errConnected.Error())
   259  	}
   260  	return nil
   261  }
   262  
   263  func (vm *TestVM) Disconnected(ctx context.Context, id ids.NodeID) error {
   264  	if vm.DisconnectedF != nil {
   265  		return vm.DisconnectedF(ctx, id)
   266  	}
   267  	if vm.CantDisconnected && vm.T != nil {
   268  		require.FailNow(vm.T, errDisconnected.Error())
   269  	}
   270  	return nil
   271  }
   272  
   273  func (vm *TestVM) Version(ctx context.Context) (string, error) {
   274  	if vm.VersionF != nil {
   275  		return vm.VersionF(ctx)
   276  	}
   277  	if vm.CantVersion && vm.T != nil {
   278  		require.FailNow(vm.T, errVersion.Error())
   279  	}
   280  	return "", nil
   281  }