github.com/lmittmann/w3@v0.20.0/w3vm/vm.go (about)

     1  /*
     2  Package w3vm provides a VM for executing EVM messages.
     3  */
     4  package w3vm
     5  
     6  import (
     7  	"cmp"
     8  	"crypto/rand"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"math/big"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/ethereum/go-ethereum/accounts/abi"
    17  	"github.com/ethereum/go-ethereum/common"
    18  	"github.com/ethereum/go-ethereum/consensus/misc/eip4844"
    19  	"github.com/ethereum/go-ethereum/core"
    20  	"github.com/ethereum/go-ethereum/core/state"
    21  	"github.com/ethereum/go-ethereum/core/tracing"
    22  	"github.com/ethereum/go-ethereum/core/types"
    23  	"github.com/ethereum/go-ethereum/core/vm"
    24  	"github.com/ethereum/go-ethereum/crypto"
    25  	"github.com/ethereum/go-ethereum/params"
    26  	"github.com/holiman/uint256"
    27  	"github.com/lmittmann/w3"
    28  	"github.com/lmittmann/w3/module/eth"
    29  	"github.com/lmittmann/w3/w3types"
    30  )
    31  
    32  var (
    33  	pendingBlockNumber = big.NewInt(-1)
    34  
    35  	ErrFetch  = errors.New("fetching failed")
    36  	ErrRevert = errors.New("execution reverted")
    37  )
    38  
    39  type VM struct {
    40  	opts *options
    41  
    42  	txIndex uint64
    43  	db      *state.StateDB
    44  }
    45  
    46  // New creates a new VM, that is configured with the given options.
    47  func New(opts ...Option) (*VM, error) {
    48  	vm := &VM{opts: new(options)}
    49  	for _, opt := range opts {
    50  		if opt == nil {
    51  			continue
    52  		}
    53  		opt(vm)
    54  	}
    55  
    56  	if err := vm.opts.Init(); err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	// set DB
    61  	db := newDB(vm.opts.fetcher)
    62  	if vm.db == nil {
    63  		vm.db, _ = state.New(w3.Hash0, db)
    64  	}
    65  	for addr, acc := range vm.opts.preState {
    66  		vm.db.SetNonce(addr, acc.Nonce, tracing.NonceChangeGenesis)
    67  		if acc.Balance != nil {
    68  			vm.db.SetBalance(addr, uint256.MustFromBig(acc.Balance), tracing.BalanceIncreaseGenesisBalance)
    69  		}
    70  		if acc.Code != nil {
    71  			vm.db.SetCode(addr, acc.Code)
    72  		}
    73  		for slot, val := range acc.Storage {
    74  			vm.db.SetState(addr, slot, val)
    75  		}
    76  	}
    77  	return vm, nil
    78  }
    79  
    80  // Apply the given message to the VM, and return its receipt. Multiple tracing hooks
    81  // may be given to trace the execution of the message.
    82  func (vm *VM) Apply(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) {
    83  	return vm.apply(msg, false, joinHooks(hooks))
    84  }
    85  
    86  // ApplyTx is like [VM.Apply], but takes a transaction instead of a message.
    87  func (vm *VM) ApplyTx(tx *types.Transaction, hooks ...*tracing.Hooks) (*Receipt, error) {
    88  	msg, err := new(w3types.Message).SetTx(tx, vm.opts.Signer())
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	return vm.Apply(msg, hooks...)
    93  }
    94  
    95  func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Receipt, error) {
    96  	if v.db.Error() != nil {
    97  		return nil, ErrFetch
    98  	}
    99  
   100  	var db vm.StateDB
   101  	if hooks != nil {
   102  		db = state.NewHookedState(v.db, hooks)
   103  	} else {
   104  		db = v.db
   105  	}
   106  
   107  	coreMsg, err := v.buildMessage(msg, isCall)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	var txHash common.Hash
   113  	binary.BigEndian.PutUint64(txHash[:], v.txIndex)
   114  	v.db.SetTxContext(txHash, int(v.txIndex))
   115  	v.txIndex++
   116  
   117  	gp := new(core.GasPool).AddGas(coreMsg.GasLimit)
   118  	evm := vm.NewEVM(*v.opts.blockCtx, db, v.opts.chainConfig, vm.Config{
   119  		Tracer:    hooks,
   120  		NoBaseFee: v.opts.noBaseFee || isCall,
   121  	})
   122  
   123  	if len(v.opts.precompiles) > 0 {
   124  		evm.SetPrecompiles(v.opts.precompiles)
   125  	}
   126  
   127  	snap := v.db.Snapshot()
   128  
   129  	// apply the message to the evm
   130  	result, err := core.ApplyMessage(evm, coreMsg, gp)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	// build receipt
   136  	receipt := &Receipt{
   137  		f:          msg.Func,
   138  		GasUsed:    result.UsedGas,
   139  		MaxGasUsed: result.MaxUsedGas,
   140  		Output:     result.ReturnData,
   141  		Logs:       v.db.GetLogs(txHash, 0, w3.Hash0, 0),
   142  	}
   143  
   144  	// zero out the log tx hashes, indices and normalize the log indices
   145  	for i, log := range receipt.Logs {
   146  		log.Index = uint(i)
   147  		log.TxHash = w3.Hash0
   148  		log.TxIndex = 0
   149  	}
   150  
   151  	if err := result.Err; err != nil {
   152  		if reason, unpackErr := abi.UnpackRevert(result.ReturnData); unpackErr != nil {
   153  			receipt.Err = ErrRevert
   154  		} else {
   155  			receipt.Err = fmt.Errorf("%w: %s", ErrRevert, reason)
   156  		}
   157  	}
   158  	if msg.To == nil {
   159  		contractAddr := crypto.CreateAddress(msg.From, coreMsg.Nonce)
   160  		receipt.ContractAddress = &contractAddr
   161  	}
   162  
   163  	if isCall && !result.Failed() {
   164  		v.db.RevertToSnapshot(snap)
   165  	}
   166  	v.db.Finalise(false)
   167  
   168  	return receipt, receipt.Err
   169  }
   170  
   171  // Call the given message on the VM, and returns its receipt. Any state changes
   172  // of a call are reverted. Multiple tracing hooks may be given to trace the execution
   173  // of the message.
   174  func (vm *VM) Call(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) {
   175  	return vm.apply(msg, true, joinHooks(hooks))
   176  }
   177  
   178  // CallFunc is a utility function for [VM.Call] that calls the given function
   179  // on the given contract address with the given arguments and decodes the
   180  // output into the given returns.
   181  //
   182  // Example:
   183  //
   184  //	funcBalanceOf := w3.MustNewFunc("balanceOf(address)", "uint256")
   185  //
   186  //	var balance *big.Int
   187  //	err := vm.CallFunc(contractAddr, funcBalanceOf, addr).Returns(&balance)
   188  //	if err != nil {
   189  //		// ...
   190  //	}
   191  func (vm *VM) CallFunc(contract common.Address, f w3types.Func, args ...any) *CallFuncFactory {
   192  	receipt, err := vm.Call(&w3types.Message{
   193  		To:   &contract,
   194  		Func: f,
   195  		Args: args,
   196  	})
   197  	return &CallFuncFactory{receipt, err}
   198  }
   199  
   200  type CallFuncFactory struct {
   201  	receipt *Receipt
   202  	err     error
   203  }
   204  
   205  func (cff *CallFuncFactory) Returns(returns ...any) error {
   206  	if err := cff.err; err != nil {
   207  		return err
   208  	}
   209  	return cff.receipt.DecodeReturns(returns...)
   210  }
   211  
   212  // Nonce returns the nonce of the given address.
   213  func (vm *VM) Nonce(addr common.Address) (uint64, error) {
   214  	nonce := vm.db.GetNonce(addr)
   215  	if vm.db.Error() != nil {
   216  		return 0, fmt.Errorf("%w: failed to fetch nonce of %s", ErrFetch, addr)
   217  	}
   218  	return nonce, nil
   219  }
   220  
   221  // SetNonce sets the nonce of the given address.
   222  func (vm *VM) SetNonce(addr common.Address, nonce uint64) {
   223  	vm.db.SetNonce(addr, nonce, tracing.NonceChangeUnspecified)
   224  }
   225  
   226  // Balance returns the balance of the given address.
   227  func (vm *VM) Balance(addr common.Address) (*big.Int, error) {
   228  	balance := vm.db.GetBalance(addr)
   229  	if vm.db.Error() != nil {
   230  		return nil, fmt.Errorf("%w: failed to fetch balance of %s", ErrFetch, addr)
   231  	}
   232  	return balance.ToBig(), nil
   233  }
   234  
   235  // SetBalance sets the balance of the given address.
   236  func (vm *VM) SetBalance(addr common.Address, balance *big.Int) {
   237  	vm.db.SetBalance(addr, uint256.MustFromBig(balance), tracing.BalanceChangeUnspecified)
   238  }
   239  
   240  // Code returns the code of the given address.
   241  func (vm *VM) Code(addr common.Address) ([]byte, error) {
   242  	code := vm.db.GetCode(addr)
   243  	if vm.db.Error() != nil {
   244  		return nil, fmt.Errorf("%w: failed to fetch code of %s", ErrFetch, addr)
   245  	}
   246  	return code, nil
   247  }
   248  
   249  // SetCode sets the code of the given address.
   250  func (vm *VM) SetCode(addr common.Address, code []byte) {
   251  	vm.db.SetCode(addr, code)
   252  }
   253  
   254  // StorageAt returns the state of the given address at the give storage slot.
   255  func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, error) {
   256  	val := vm.db.GetState(addr, slot)
   257  	if vm.db.Error() != nil {
   258  		return w3.Hash0, fmt.Errorf("%w: failed to fetch storage of %s at %s", ErrFetch, addr, slot)
   259  	}
   260  	return val, nil
   261  }
   262  
   263  // SetStorageAt sets the state of the given address at the given storage slot.
   264  func (vm *VM) SetStorageAt(addr common.Address, slot, val common.Hash) {
   265  	vm.db.SetState(addr, slot, val)
   266  }
   267  
   268  // Snapshot the current state of the VM. The returned state can only be rolled
   269  // back to once. Use [state.StateDB.Copy] if you need to rollback multiple times.
   270  func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() }
   271  
   272  // Rollback the state of the VM to the given snapshot.
   273  func (vm *VM) Rollback(snapshot *state.StateDB) {
   274  	vm.db = snapshot
   275  	vm.txIndex = uint64(snapshot.TxIndex()) + 1
   276  }
   277  
   278  func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, error) {
   279  	nonce := msg.Nonce
   280  	if !skipAccChecks && nonce == 0 {
   281  		var err error
   282  		nonce, err = v.Nonce(msg.From)
   283  		if err != nil {
   284  			return nil, err
   285  		}
   286  	}
   287  
   288  	gasLimit := msg.Gas
   289  	if maxGasLimit := v.opts.blockCtx.GasLimit; gasLimit == 0 {
   290  		gasLimit = maxGasLimit
   291  	} else if gasLimit > maxGasLimit {
   292  		gasLimit = maxGasLimit
   293  	}
   294  	if gasLimit == 0 {
   295  		gasLimit = 15_000_000
   296  	}
   297  
   298  	var input []byte
   299  	if msg.Input == nil && msg.Func != nil {
   300  		var err error
   301  		input, err = msg.Func.EncodeArgs(msg.Args...)
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  	} else {
   306  		input = msg.Input
   307  	}
   308  
   309  	var gasPrice, gasFeeCap, gasTipCap *big.Int
   310  	if baseFee := v.opts.blockCtx.BaseFee; baseFee == nil {
   311  		gasPrice = new(big.Int).Set(cmp.Or(msg.GasPrice, w3.Big0))
   312  		gasFeeCap, gasTipCap = gasPrice, gasPrice
   313  	} else {
   314  		if msg.GasPrice != nil && msg.GasFeeCap == nil && msg.GasTipCap == nil {
   315  			gasPrice = msg.GasPrice
   316  			gasFeeCap, gasTipCap = gasPrice, gasPrice
   317  		} else {
   318  			gasFeeCap = new(big.Int).Set(cmp.Or(msg.GasFeeCap, w3.Big0))
   319  			gasTipCap = new(big.Int).Set(cmp.Or(msg.GasTipCap, w3.Big0))
   320  			gasPrice = new(big.Int).Add(baseFee, gasTipCap)
   321  			if gasPrice.Cmp(gasFeeCap) > 0 {
   322  				gasPrice = gasFeeCap
   323  			}
   324  		}
   325  	}
   326  
   327  	if v.opts.noBaseFee {
   328  		gasFeeCap.SetInt64(0)
   329  		gasTipCap.SetInt64(0)
   330  	}
   331  
   332  	value := new(big.Int).Set(cmp.Or(msg.Value, w3.Big0))
   333  
   334  	return &core.Message{
   335  		To:                    msg.To,
   336  		From:                  msg.From,
   337  		Nonce:                 nonce,
   338  		Value:                 value,
   339  		GasLimit:              gasLimit,
   340  		GasPrice:              gasPrice,
   341  		GasFeeCap:             gasFeeCap,
   342  		GasTipCap:             gasTipCap,
   343  		Data:                  input,
   344  		AccessList:            msg.AccessList,
   345  		BlobGasFeeCap:         msg.BlobGasFeeCap,
   346  		BlobHashes:            msg.BlobHashes,
   347  		SetCodeAuthorizations: msg.SetCodeAuthorizations,
   348  		SkipNonceChecks:       skipAccChecks,
   349  		SkipFromEOACheck:      skipAccChecks,
   350  	}, nil
   351  }
   352  
   353  func newBlockContext(config *params.ChainConfig, h *types.Header, getHash vm.GetHashFunc) *vm.BlockContext {
   354  	var random *common.Hash
   355  	if h.Difficulty == nil || h.Difficulty.Sign() == 0 {
   356  		random = &h.MixDigest
   357  	}
   358  
   359  	blockNumber := h.Number
   360  	if blockNumber == nil {
   361  		blockNumber = new(big.Int)
   362  	}
   363  	difficulty := h.Difficulty
   364  	if difficulty == nil {
   365  		difficulty = new(big.Int)
   366  	}
   367  	baseFee := h.BaseFee
   368  	if baseFee == nil {
   369  		baseFee = new(big.Int)
   370  	}
   371  	var blobBaseFee *big.Int
   372  	if h.ExcessBlobGas != nil {
   373  		blobBaseFee = eip4844.CalcBlobFee(config, h)
   374  	}
   375  
   376  	return &vm.BlockContext{
   377  		CanTransfer: core.CanTransfer,
   378  		Transfer:    core.Transfer,
   379  		GetHash:     getHash,
   380  		Coinbase:    h.Coinbase,
   381  		BlockNumber: blockNumber,
   382  		Time:        h.Time,
   383  		Difficulty:  difficulty,
   384  		BaseFee:     baseFee,
   385  		BlobBaseFee: blobBaseFee,
   386  		GasLimit:    h.GasLimit,
   387  		Random:      random,
   388  	}
   389  }
   390  
   391  func defaultBlockContext() *vm.BlockContext {
   392  	var coinbase common.Address
   393  	rand.Read(coinbase[:])
   394  
   395  	var random common.Hash
   396  	rand.Read(random[:])
   397  
   398  	return &vm.BlockContext{
   399  		CanTransfer: core.CanTransfer,
   400  		Transfer:    core.Transfer,
   401  		GetHash:     zeroHashFunc,
   402  		Coinbase:    coinbase,
   403  		BlockNumber: new(big.Int),
   404  		Time:        uint64(time.Now().Unix()),
   405  		Difficulty:  new(big.Int),
   406  		BaseFee:     new(big.Int),
   407  		GasLimit:    params.MaxGasLimit,
   408  		Random:      &random,
   409  	}
   410  }
   411  
   412  ////////////////////////////////////////////////////////////////////////////////////////////////////
   413  // VM Option ///////////////////////////////////////////////////////////////////////////////////////
   414  ////////////////////////////////////////////////////////////////////////////////////////////////////
   415  
   416  type options struct {
   417  	chainConfig *params.ChainConfig
   418  	preState    w3types.State
   419  	noBaseFee   bool
   420  
   421  	blockCtx *vm.BlockContext
   422  	header   *types.Header
   423  
   424  	forkClient      *w3.Client
   425  	forkBlockNumber *big.Int
   426  	fetcher         Fetcher
   427  	tb              testing.TB
   428  
   429  	precompiles vm.PrecompiledContracts
   430  }
   431  
   432  func (opt *options) Signer() types.Signer {
   433  	if opt.fetcher == nil {
   434  		return types.LatestSigner(opt.chainConfig)
   435  	}
   436  	return types.MakeSigner(opt.chainConfig, opt.header.Number, opt.header.Time)
   437  }
   438  
   439  func (opts *options) Init() error {
   440  	// set initial chain config
   441  	isChainConfigSet := opts.chainConfig != nil
   442  	if !isChainConfigSet {
   443  		opts.chainConfig = params.MergedTestChainConfig
   444  	}
   445  
   446  	// set fetcher
   447  	if opts.fetcher == nil && opts.forkClient != nil {
   448  		var calls []w3types.RPCCaller
   449  
   450  		latest := opts.forkBlockNumber == nil
   451  		if latest {
   452  			calls = append(calls, eth.BlockNumber().Returns(&opts.forkBlockNumber))
   453  		}
   454  		if opts.header == nil && opts.blockCtx == nil {
   455  			if latest {
   456  				calls = append(calls, eth.HeaderByNumber(pendingBlockNumber).Returns(&opts.header))
   457  			} else {
   458  				calls = append(calls, eth.HeaderByNumber(opts.forkBlockNumber).Returns(&opts.header))
   459  			}
   460  		}
   461  
   462  		if err := opts.forkClient.Call(calls...); err != nil {
   463  			return fmt.Errorf("%w: failed to fetch header: %v", ErrFetch, err)
   464  		}
   465  
   466  		if latest {
   467  			opts.fetcher = NewRPCFetcher(opts.forkClient, opts.forkBlockNumber)
   468  		} else if opts.tb == nil {
   469  			opts.fetcher = NewRPCFetcher(opts.forkClient, new(big.Int).Sub(opts.forkBlockNumber, w3.Big1))
   470  		} else {
   471  			opts.fetcher = NewTestingRPCFetcher(opts.tb, opts.chainConfig.ChainID.Uint64(), opts.forkClient, new(big.Int).Sub(opts.forkBlockNumber, w3.Big1))
   472  		}
   473  	}
   474  
   475  	// potentially update chain config
   476  	if !isChainConfigSet && opts.fetcher != nil {
   477  		opts.chainConfig = params.MainnetChainConfig
   478  	}
   479  
   480  	if opts.blockCtx == nil {
   481  		if opts.header != nil {
   482  			opts.blockCtx = newBlockContext(opts.chainConfig, opts.header, fetcherHashFunc(opts.fetcher))
   483  		} else {
   484  			opts.blockCtx = defaultBlockContext()
   485  		}
   486  	}
   487  
   488  	// set precompiles
   489  	if len(opts.precompiles) > 0 {
   490  		rules := opts.chainConfig.Rules(opts.blockCtx.BlockNumber, opts.blockCtx.Random != nil, opts.blockCtx.Time)
   491  
   492  		// overwrite default precompiles
   493  		precompiles := vm.ActivePrecompiledContracts(rules)
   494  		for addr, contract := range opts.precompiles {
   495  			precompiles[addr] = contract
   496  		}
   497  		opts.precompiles = precompiles
   498  	}
   499  
   500  	return nil
   501  }
   502  
   503  func fetcherHashFunc(fetcher Fetcher) vm.GetHashFunc {
   504  	return func(blockNumber uint64) common.Hash {
   505  		hash, _ := fetcher.HeaderHash(blockNumber)
   506  		return hash
   507  	}
   508  }
   509  
   510  // An Option configures a [VM].
   511  type Option func(*VM)
   512  
   513  // WithChainConfig sets the chain config for the VM.
   514  //
   515  // If not provided, the chain config defaults to [params.MainnetChainConfig].
   516  func WithChainConfig(cfg *params.ChainConfig) Option {
   517  	return func(vm *VM) { vm.opts.chainConfig = cfg }
   518  }
   519  
   520  // WithBlockContext sets the block context for the VM.
   521  func WithBlockContext(ctx *vm.BlockContext) Option {
   522  	return func(vm *VM) { vm.opts.blockCtx = ctx }
   523  }
   524  
   525  // WithPrecompile registers a precompile contract at the given address in the VM.
   526  func WithPrecompile(addr common.Address, contract vm.PrecompiledContract) Option {
   527  	return func(v *VM) {
   528  		if v.opts.precompiles == nil {
   529  			v.opts.precompiles = make(vm.PrecompiledContracts)
   530  		}
   531  		v.opts.precompiles[addr] = contract
   532  	}
   533  }
   534  
   535  // WithState sets the pre state of the VM.
   536  //
   537  // WithState can be used together with [WithFork] to only set the state of some
   538  // accounts, or partially overwrite the storage of an account.
   539  func WithState(state w3types.State) Option {
   540  	return func(vm *VM) { vm.opts.preState = state }
   541  }
   542  
   543  // WithStateDB sets the state DB for the VM, that is usually a snapshot
   544  // obtained from [VM.Snapshot].
   545  func WithStateDB(db *state.StateDB) Option {
   546  	return func(vm *VM) {
   547  		vm.db = db
   548  		vm.txIndex = uint64(db.TxIndex() + 1)
   549  	}
   550  }
   551  
   552  // WithNoBaseFee forces the EIP-1559 base fee to 0 for the VM.
   553  func WithNoBaseFee() Option {
   554  	return func(vm *VM) { vm.opts.noBaseFee = true }
   555  }
   556  
   557  // WithFork sets the client and block number to fetch state from and sets the
   558  // block context for the VM. If the block number is nil, the latest state is
   559  // fetched and the pending block is used for constructing the block context.
   560  //
   561  // If used together with [WithTB], fetched state is stored in the testdata
   562  // directory of the tests package.
   563  func WithFork(client *w3.Client, blockNumber *big.Int) Option {
   564  	return func(vm *VM) {
   565  		vm.opts.forkClient = client
   566  		vm.opts.forkBlockNumber = blockNumber
   567  	}
   568  }
   569  
   570  // WithHeader sets the block context for the VM based on the given header.
   571  func WithHeader(header *types.Header) Option {
   572  	return func(vm *VM) { vm.opts.header = header }
   573  }
   574  
   575  // WithFetcher sets the fetcher for the VM.
   576  func WithFetcher(fetcher Fetcher) Option {
   577  	return func(vm *VM) { vm.opts.fetcher = fetcher }
   578  }
   579  
   580  // WithTB enables persistent state caching when used together with [WithFork].
   581  // State is stored in the testdata directory of the tests package.
   582  func WithTB(tb testing.TB) Option {
   583  	return func(vm *VM) { vm.opts.tb = tb }
   584  }