github.com/lmittmann/w3@v0.20.0/w3vm/vm.go (about) 1 /* 2 Package w3vm provides a VM for executing EVM messages. 3 */ 4 package w3vm 5 6 import ( 7 "cmp" 8 "crypto/rand" 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "math/big" 13 "testing" 14 "time" 15 16 "github.com/ethereum/go-ethereum/accounts/abi" 17 "github.com/ethereum/go-ethereum/common" 18 "github.com/ethereum/go-ethereum/consensus/misc/eip4844" 19 "github.com/ethereum/go-ethereum/core" 20 "github.com/ethereum/go-ethereum/core/state" 21 "github.com/ethereum/go-ethereum/core/tracing" 22 "github.com/ethereum/go-ethereum/core/types" 23 "github.com/ethereum/go-ethereum/core/vm" 24 "github.com/ethereum/go-ethereum/crypto" 25 "github.com/ethereum/go-ethereum/params" 26 "github.com/holiman/uint256" 27 "github.com/lmittmann/w3" 28 "github.com/lmittmann/w3/module/eth" 29 "github.com/lmittmann/w3/w3types" 30 ) 31 32 var ( 33 pendingBlockNumber = big.NewInt(-1) 34 35 ErrFetch = errors.New("fetching failed") 36 ErrRevert = errors.New("execution reverted") 37 ) 38 39 type VM struct { 40 opts *options 41 42 txIndex uint64 43 db *state.StateDB 44 } 45 46 // New creates a new VM, that is configured with the given options. 47 func New(opts ...Option) (*VM, error) { 48 vm := &VM{opts: new(options)} 49 for _, opt := range opts { 50 if opt == nil { 51 continue 52 } 53 opt(vm) 54 } 55 56 if err := vm.opts.Init(); err != nil { 57 return nil, err 58 } 59 60 // set DB 61 db := newDB(vm.opts.fetcher) 62 if vm.db == nil { 63 vm.db, _ = state.New(w3.Hash0, db) 64 } 65 for addr, acc := range vm.opts.preState { 66 vm.db.SetNonce(addr, acc.Nonce, tracing.NonceChangeGenesis) 67 if acc.Balance != nil { 68 vm.db.SetBalance(addr, uint256.MustFromBig(acc.Balance), tracing.BalanceIncreaseGenesisBalance) 69 } 70 if acc.Code != nil { 71 vm.db.SetCode(addr, acc.Code) 72 } 73 for slot, val := range acc.Storage { 74 vm.db.SetState(addr, slot, val) 75 } 76 } 77 return vm, nil 78 } 79 80 // Apply the given message to the VM, and return its receipt. Multiple tracing hooks 81 // may be given to trace the execution of the message. 82 func (vm *VM) Apply(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { 83 return vm.apply(msg, false, joinHooks(hooks)) 84 } 85 86 // ApplyTx is like [VM.Apply], but takes a transaction instead of a message. 87 func (vm *VM) ApplyTx(tx *types.Transaction, hooks ...*tracing.Hooks) (*Receipt, error) { 88 msg, err := new(w3types.Message).SetTx(tx, vm.opts.Signer()) 89 if err != nil { 90 return nil, err 91 } 92 return vm.Apply(msg, hooks...) 93 } 94 95 func (v *VM) apply(msg *w3types.Message, isCall bool, hooks *tracing.Hooks) (*Receipt, error) { 96 if v.db.Error() != nil { 97 return nil, ErrFetch 98 } 99 100 var db vm.StateDB 101 if hooks != nil { 102 db = state.NewHookedState(v.db, hooks) 103 } else { 104 db = v.db 105 } 106 107 coreMsg, err := v.buildMessage(msg, isCall) 108 if err != nil { 109 return nil, err 110 } 111 112 var txHash common.Hash 113 binary.BigEndian.PutUint64(txHash[:], v.txIndex) 114 v.db.SetTxContext(txHash, int(v.txIndex)) 115 v.txIndex++ 116 117 gp := new(core.GasPool).AddGas(coreMsg.GasLimit) 118 evm := vm.NewEVM(*v.opts.blockCtx, db, v.opts.chainConfig, vm.Config{ 119 Tracer: hooks, 120 NoBaseFee: v.opts.noBaseFee || isCall, 121 }) 122 123 if len(v.opts.precompiles) > 0 { 124 evm.SetPrecompiles(v.opts.precompiles) 125 } 126 127 snap := v.db.Snapshot() 128 129 // apply the message to the evm 130 result, err := core.ApplyMessage(evm, coreMsg, gp) 131 if err != nil { 132 return nil, err 133 } 134 135 // build receipt 136 receipt := &Receipt{ 137 f: msg.Func, 138 GasUsed: result.UsedGas, 139 MaxGasUsed: result.MaxUsedGas, 140 Output: result.ReturnData, 141 Logs: v.db.GetLogs(txHash, 0, w3.Hash0, 0), 142 } 143 144 // zero out the log tx hashes, indices and normalize the log indices 145 for i, log := range receipt.Logs { 146 log.Index = uint(i) 147 log.TxHash = w3.Hash0 148 log.TxIndex = 0 149 } 150 151 if err := result.Err; err != nil { 152 if reason, unpackErr := abi.UnpackRevert(result.ReturnData); unpackErr != nil { 153 receipt.Err = ErrRevert 154 } else { 155 receipt.Err = fmt.Errorf("%w: %s", ErrRevert, reason) 156 } 157 } 158 if msg.To == nil { 159 contractAddr := crypto.CreateAddress(msg.From, coreMsg.Nonce) 160 receipt.ContractAddress = &contractAddr 161 } 162 163 if isCall && !result.Failed() { 164 v.db.RevertToSnapshot(snap) 165 } 166 v.db.Finalise(false) 167 168 return receipt, receipt.Err 169 } 170 171 // Call the given message on the VM, and returns its receipt. Any state changes 172 // of a call are reverted. Multiple tracing hooks may be given to trace the execution 173 // of the message. 174 func (vm *VM) Call(msg *w3types.Message, hooks ...*tracing.Hooks) (*Receipt, error) { 175 return vm.apply(msg, true, joinHooks(hooks)) 176 } 177 178 // CallFunc is a utility function for [VM.Call] that calls the given function 179 // on the given contract address with the given arguments and decodes the 180 // output into the given returns. 181 // 182 // Example: 183 // 184 // funcBalanceOf := w3.MustNewFunc("balanceOf(address)", "uint256") 185 // 186 // var balance *big.Int 187 // err := vm.CallFunc(contractAddr, funcBalanceOf, addr).Returns(&balance) 188 // if err != nil { 189 // // ... 190 // } 191 func (vm *VM) CallFunc(contract common.Address, f w3types.Func, args ...any) *CallFuncFactory { 192 receipt, err := vm.Call(&w3types.Message{ 193 To: &contract, 194 Func: f, 195 Args: args, 196 }) 197 return &CallFuncFactory{receipt, err} 198 } 199 200 type CallFuncFactory struct { 201 receipt *Receipt 202 err error 203 } 204 205 func (cff *CallFuncFactory) Returns(returns ...any) error { 206 if err := cff.err; err != nil { 207 return err 208 } 209 return cff.receipt.DecodeReturns(returns...) 210 } 211 212 // Nonce returns the nonce of the given address. 213 func (vm *VM) Nonce(addr common.Address) (uint64, error) { 214 nonce := vm.db.GetNonce(addr) 215 if vm.db.Error() != nil { 216 return 0, fmt.Errorf("%w: failed to fetch nonce of %s", ErrFetch, addr) 217 } 218 return nonce, nil 219 } 220 221 // SetNonce sets the nonce of the given address. 222 func (vm *VM) SetNonce(addr common.Address, nonce uint64) { 223 vm.db.SetNonce(addr, nonce, tracing.NonceChangeUnspecified) 224 } 225 226 // Balance returns the balance of the given address. 227 func (vm *VM) Balance(addr common.Address) (*big.Int, error) { 228 balance := vm.db.GetBalance(addr) 229 if vm.db.Error() != nil { 230 return nil, fmt.Errorf("%w: failed to fetch balance of %s", ErrFetch, addr) 231 } 232 return balance.ToBig(), nil 233 } 234 235 // SetBalance sets the balance of the given address. 236 func (vm *VM) SetBalance(addr common.Address, balance *big.Int) { 237 vm.db.SetBalance(addr, uint256.MustFromBig(balance), tracing.BalanceChangeUnspecified) 238 } 239 240 // Code returns the code of the given address. 241 func (vm *VM) Code(addr common.Address) ([]byte, error) { 242 code := vm.db.GetCode(addr) 243 if vm.db.Error() != nil { 244 return nil, fmt.Errorf("%w: failed to fetch code of %s", ErrFetch, addr) 245 } 246 return code, nil 247 } 248 249 // SetCode sets the code of the given address. 250 func (vm *VM) SetCode(addr common.Address, code []byte) { 251 vm.db.SetCode(addr, code) 252 } 253 254 // StorageAt returns the state of the given address at the give storage slot. 255 func (vm *VM) StorageAt(addr common.Address, slot common.Hash) (common.Hash, error) { 256 val := vm.db.GetState(addr, slot) 257 if vm.db.Error() != nil { 258 return w3.Hash0, fmt.Errorf("%w: failed to fetch storage of %s at %s", ErrFetch, addr, slot) 259 } 260 return val, nil 261 } 262 263 // SetStorageAt sets the state of the given address at the given storage slot. 264 func (vm *VM) SetStorageAt(addr common.Address, slot, val common.Hash) { 265 vm.db.SetState(addr, slot, val) 266 } 267 268 // Snapshot the current state of the VM. The returned state can only be rolled 269 // back to once. Use [state.StateDB.Copy] if you need to rollback multiple times. 270 func (vm *VM) Snapshot() *state.StateDB { return vm.db.Copy() } 271 272 // Rollback the state of the VM to the given snapshot. 273 func (vm *VM) Rollback(snapshot *state.StateDB) { 274 vm.db = snapshot 275 vm.txIndex = uint64(snapshot.TxIndex()) + 1 276 } 277 278 func (v *VM) buildMessage(msg *w3types.Message, skipAccChecks bool) (*core.Message, error) { 279 nonce := msg.Nonce 280 if !skipAccChecks && nonce == 0 { 281 var err error 282 nonce, err = v.Nonce(msg.From) 283 if err != nil { 284 return nil, err 285 } 286 } 287 288 gasLimit := msg.Gas 289 if maxGasLimit := v.opts.blockCtx.GasLimit; gasLimit == 0 { 290 gasLimit = maxGasLimit 291 } else if gasLimit > maxGasLimit { 292 gasLimit = maxGasLimit 293 } 294 if gasLimit == 0 { 295 gasLimit = 15_000_000 296 } 297 298 var input []byte 299 if msg.Input == nil && msg.Func != nil { 300 var err error 301 input, err = msg.Func.EncodeArgs(msg.Args...) 302 if err != nil { 303 return nil, err 304 } 305 } else { 306 input = msg.Input 307 } 308 309 var gasPrice, gasFeeCap, gasTipCap *big.Int 310 if baseFee := v.opts.blockCtx.BaseFee; baseFee == nil { 311 gasPrice = new(big.Int).Set(cmp.Or(msg.GasPrice, w3.Big0)) 312 gasFeeCap, gasTipCap = gasPrice, gasPrice 313 } else { 314 if msg.GasPrice != nil && msg.GasFeeCap == nil && msg.GasTipCap == nil { 315 gasPrice = msg.GasPrice 316 gasFeeCap, gasTipCap = gasPrice, gasPrice 317 } else { 318 gasFeeCap = new(big.Int).Set(cmp.Or(msg.GasFeeCap, w3.Big0)) 319 gasTipCap = new(big.Int).Set(cmp.Or(msg.GasTipCap, w3.Big0)) 320 gasPrice = new(big.Int).Add(baseFee, gasTipCap) 321 if gasPrice.Cmp(gasFeeCap) > 0 { 322 gasPrice = gasFeeCap 323 } 324 } 325 } 326 327 if v.opts.noBaseFee { 328 gasFeeCap.SetInt64(0) 329 gasTipCap.SetInt64(0) 330 } 331 332 value := new(big.Int).Set(cmp.Or(msg.Value, w3.Big0)) 333 334 return &core.Message{ 335 To: msg.To, 336 From: msg.From, 337 Nonce: nonce, 338 Value: value, 339 GasLimit: gasLimit, 340 GasPrice: gasPrice, 341 GasFeeCap: gasFeeCap, 342 GasTipCap: gasTipCap, 343 Data: input, 344 AccessList: msg.AccessList, 345 BlobGasFeeCap: msg.BlobGasFeeCap, 346 BlobHashes: msg.BlobHashes, 347 SetCodeAuthorizations: msg.SetCodeAuthorizations, 348 SkipNonceChecks: skipAccChecks, 349 SkipFromEOACheck: skipAccChecks, 350 }, nil 351 } 352 353 func newBlockContext(config *params.ChainConfig, h *types.Header, getHash vm.GetHashFunc) *vm.BlockContext { 354 var random *common.Hash 355 if h.Difficulty == nil || h.Difficulty.Sign() == 0 { 356 random = &h.MixDigest 357 } 358 359 blockNumber := h.Number 360 if blockNumber == nil { 361 blockNumber = new(big.Int) 362 } 363 difficulty := h.Difficulty 364 if difficulty == nil { 365 difficulty = new(big.Int) 366 } 367 baseFee := h.BaseFee 368 if baseFee == nil { 369 baseFee = new(big.Int) 370 } 371 var blobBaseFee *big.Int 372 if h.ExcessBlobGas != nil { 373 blobBaseFee = eip4844.CalcBlobFee(config, h) 374 } 375 376 return &vm.BlockContext{ 377 CanTransfer: core.CanTransfer, 378 Transfer: core.Transfer, 379 GetHash: getHash, 380 Coinbase: h.Coinbase, 381 BlockNumber: blockNumber, 382 Time: h.Time, 383 Difficulty: difficulty, 384 BaseFee: baseFee, 385 BlobBaseFee: blobBaseFee, 386 GasLimit: h.GasLimit, 387 Random: random, 388 } 389 } 390 391 func defaultBlockContext() *vm.BlockContext { 392 var coinbase common.Address 393 rand.Read(coinbase[:]) 394 395 var random common.Hash 396 rand.Read(random[:]) 397 398 return &vm.BlockContext{ 399 CanTransfer: core.CanTransfer, 400 Transfer: core.Transfer, 401 GetHash: zeroHashFunc, 402 Coinbase: coinbase, 403 BlockNumber: new(big.Int), 404 Time: uint64(time.Now().Unix()), 405 Difficulty: new(big.Int), 406 BaseFee: new(big.Int), 407 GasLimit: params.MaxGasLimit, 408 Random: &random, 409 } 410 } 411 412 //////////////////////////////////////////////////////////////////////////////////////////////////// 413 // VM Option /////////////////////////////////////////////////////////////////////////////////////// 414 //////////////////////////////////////////////////////////////////////////////////////////////////// 415 416 type options struct { 417 chainConfig *params.ChainConfig 418 preState w3types.State 419 noBaseFee bool 420 421 blockCtx *vm.BlockContext 422 header *types.Header 423 424 forkClient *w3.Client 425 forkBlockNumber *big.Int 426 fetcher Fetcher 427 tb testing.TB 428 429 precompiles vm.PrecompiledContracts 430 } 431 432 func (opt *options) Signer() types.Signer { 433 if opt.fetcher == nil { 434 return types.LatestSigner(opt.chainConfig) 435 } 436 return types.MakeSigner(opt.chainConfig, opt.header.Number, opt.header.Time) 437 } 438 439 func (opts *options) Init() error { 440 // set initial chain config 441 isChainConfigSet := opts.chainConfig != nil 442 if !isChainConfigSet { 443 opts.chainConfig = params.MergedTestChainConfig 444 } 445 446 // set fetcher 447 if opts.fetcher == nil && opts.forkClient != nil { 448 var calls []w3types.RPCCaller 449 450 latest := opts.forkBlockNumber == nil 451 if latest { 452 calls = append(calls, eth.BlockNumber().Returns(&opts.forkBlockNumber)) 453 } 454 if opts.header == nil && opts.blockCtx == nil { 455 if latest { 456 calls = append(calls, eth.HeaderByNumber(pendingBlockNumber).Returns(&opts.header)) 457 } else { 458 calls = append(calls, eth.HeaderByNumber(opts.forkBlockNumber).Returns(&opts.header)) 459 } 460 } 461 462 if err := opts.forkClient.Call(calls...); err != nil { 463 return fmt.Errorf("%w: failed to fetch header: %v", ErrFetch, err) 464 } 465 466 if latest { 467 opts.fetcher = NewRPCFetcher(opts.forkClient, opts.forkBlockNumber) 468 } else if opts.tb == nil { 469 opts.fetcher = NewRPCFetcher(opts.forkClient, new(big.Int).Sub(opts.forkBlockNumber, w3.Big1)) 470 } else { 471 opts.fetcher = NewTestingRPCFetcher(opts.tb, opts.chainConfig.ChainID.Uint64(), opts.forkClient, new(big.Int).Sub(opts.forkBlockNumber, w3.Big1)) 472 } 473 } 474 475 // potentially update chain config 476 if !isChainConfigSet && opts.fetcher != nil { 477 opts.chainConfig = params.MainnetChainConfig 478 } 479 480 if opts.blockCtx == nil { 481 if opts.header != nil { 482 opts.blockCtx = newBlockContext(opts.chainConfig, opts.header, fetcherHashFunc(opts.fetcher)) 483 } else { 484 opts.blockCtx = defaultBlockContext() 485 } 486 } 487 488 // set precompiles 489 if len(opts.precompiles) > 0 { 490 rules := opts.chainConfig.Rules(opts.blockCtx.BlockNumber, opts.blockCtx.Random != nil, opts.blockCtx.Time) 491 492 // overwrite default precompiles 493 precompiles := vm.ActivePrecompiledContracts(rules) 494 for addr, contract := range opts.precompiles { 495 precompiles[addr] = contract 496 } 497 opts.precompiles = precompiles 498 } 499 500 return nil 501 } 502 503 func fetcherHashFunc(fetcher Fetcher) vm.GetHashFunc { 504 return func(blockNumber uint64) common.Hash { 505 hash, _ := fetcher.HeaderHash(blockNumber) 506 return hash 507 } 508 } 509 510 // An Option configures a [VM]. 511 type Option func(*VM) 512 513 // WithChainConfig sets the chain config for the VM. 514 // 515 // If not provided, the chain config defaults to [params.MainnetChainConfig]. 516 func WithChainConfig(cfg *params.ChainConfig) Option { 517 return func(vm *VM) { vm.opts.chainConfig = cfg } 518 } 519 520 // WithBlockContext sets the block context for the VM. 521 func WithBlockContext(ctx *vm.BlockContext) Option { 522 return func(vm *VM) { vm.opts.blockCtx = ctx } 523 } 524 525 // WithPrecompile registers a precompile contract at the given address in the VM. 526 func WithPrecompile(addr common.Address, contract vm.PrecompiledContract) Option { 527 return func(v *VM) { 528 if v.opts.precompiles == nil { 529 v.opts.precompiles = make(vm.PrecompiledContracts) 530 } 531 v.opts.precompiles[addr] = contract 532 } 533 } 534 535 // WithState sets the pre state of the VM. 536 // 537 // WithState can be used together with [WithFork] to only set the state of some 538 // accounts, or partially overwrite the storage of an account. 539 func WithState(state w3types.State) Option { 540 return func(vm *VM) { vm.opts.preState = state } 541 } 542 543 // WithStateDB sets the state DB for the VM, that is usually a snapshot 544 // obtained from [VM.Snapshot]. 545 func WithStateDB(db *state.StateDB) Option { 546 return func(vm *VM) { 547 vm.db = db 548 vm.txIndex = uint64(db.TxIndex() + 1) 549 } 550 } 551 552 // WithNoBaseFee forces the EIP-1559 base fee to 0 for the VM. 553 func WithNoBaseFee() Option { 554 return func(vm *VM) { vm.opts.noBaseFee = true } 555 } 556 557 // WithFork sets the client and block number to fetch state from and sets the 558 // block context for the VM. If the block number is nil, the latest state is 559 // fetched and the pending block is used for constructing the block context. 560 // 561 // If used together with [WithTB], fetched state is stored in the testdata 562 // directory of the tests package. 563 func WithFork(client *w3.Client, blockNumber *big.Int) Option { 564 return func(vm *VM) { 565 vm.opts.forkClient = client 566 vm.opts.forkBlockNumber = blockNumber 567 } 568 } 569 570 // WithHeader sets the block context for the VM based on the given header. 571 func WithHeader(header *types.Header) Option { 572 return func(vm *VM) { vm.opts.header = header } 573 } 574 575 // WithFetcher sets the fetcher for the VM. 576 func WithFetcher(fetcher Fetcher) Option { 577 return func(vm *VM) { vm.opts.fetcher = fetcher } 578 } 579 580 // WithTB enables persistent state caching when used together with [WithFork]. 581 // State is stored in the testdata directory of the tests package. 582 func WithTB(tb testing.TB) Option { 583 return func(vm *VM) { vm.opts.tb = tb } 584 }