github.com/lmittmann/w3@v0.20.0/w3vm/util.go (about)

     1  package w3vm
     2  
     3  import (
     4  	"crypto/rand"
     5  	"math/big"
     6  
     7  	"github.com/ethereum/go-ethereum/common"
     8  	"github.com/ethereum/go-ethereum/common/hexutil"
     9  	"github.com/ethereum/go-ethereum/core/tracing"
    10  	"github.com/ethereum/go-ethereum/core/types"
    11  	"github.com/holiman/uint256"
    12  	"github.com/lmittmann/w3"
    13  	"github.com/lmittmann/w3/internal/crypto"
    14  	"github.com/lmittmann/w3/internal/module"
    15  	"github.com/lmittmann/w3/w3types"
    16  )
    17  
    18  // RandA returns a random address.
    19  func RandA() (addr common.Address) {
    20  	rand.Read(addr[:])
    21  	return addr
    22  }
    23  
    24  var (
    25  	weth9BalancePos   = common.BytesToHash([]byte{3})
    26  	weth9AllowancePos = common.BytesToHash([]byte{4})
    27  )
    28  
    29  // WETHBalanceSlot returns the storage slot that stores the WETH balance of
    30  // the given addr.
    31  func WETHBalanceSlot(addr common.Address) common.Hash {
    32  	return SoliditySlot(weth9BalancePos, common.BytesToHash(addr[:]))
    33  }
    34  
    35  // WETHAllowanceSlot returns the storage slot that stores the WETH allowance
    36  // of the given owner to the spender.
    37  func WETHAllowanceSlot(owner, spender common.Address) common.Hash {
    38  	return SoliditySlot2(weth9AllowancePos, common.BytesToHash(owner[:]), common.BytesToHash(spender[:]))
    39  }
    40  
    41  // SoliditySlot returns the storage slot of a mapping with the given position and key.
    42  //
    43  //	mapping(bytes32 => bytes32)
    44  func SoliditySlot(pos, key common.Hash) common.Hash {
    45  	return crypto.Keccak256Hash(key[:], pos[:])
    46  }
    47  
    48  // SoliditySlot2 returns the storage slot of a double mapping with the given position
    49  // and keys.
    50  //
    51  //	mapping(bytes32 => mapping(bytes32 => bytes32))
    52  func SoliditySlot2(pos, key0, key1 common.Hash) common.Hash {
    53  	return crypto.Keccak256Hash(
    54  		key1[:],
    55  		crypto.Keccak256(key0[:], pos[:]),
    56  	)
    57  }
    58  
    59  // SoliditySlot3 returns the storage slot of a triple mapping with the given position
    60  // and keys.
    61  //
    62  //	mapping(bytes32 => mapping(bytes32 => mapping(bytes32 => bytes32)))
    63  func SoliditySlot3(pos, key0, key1, key2 common.Hash) common.Hash {
    64  	return crypto.Keccak256Hash(
    65  		key2[:],
    66  		crypto.Keccak256(
    67  			key1[:],
    68  			crypto.Keccak256(key0[:], pos[:]),
    69  		),
    70  	)
    71  }
    72  
    73  // VyperSlot returns the storage slot of a mapping with the given position and key.
    74  //
    75  //	HashMap[bytes32, bytes32]
    76  func VyperSlot(pos, key common.Hash) common.Hash {
    77  	return crypto.Keccak256Hash(pos[:], key[:])
    78  }
    79  
    80  // VyperSlot2 returns the storage slot of a double mapping with the given position
    81  // and keys.
    82  //
    83  //	HashMap[bytes32, HashMap[bytes32, bytes32]]
    84  func VyperSlot2(pos, key0, key1 common.Hash) common.Hash {
    85  	return crypto.Keccak256Hash(
    86  		crypto.Keccak256(pos[:], key0[:]),
    87  		key1[:],
    88  	)
    89  }
    90  
    91  // VyperSlot3 returns the storage slot of a triple mapping with the given position
    92  // and keys.
    93  //
    94  //	HashMap[bytes32, HashMap[bytes32, HashMap[bytes32, bytes32]]]
    95  func VyperSlot3(pos, key0, key1, key2 common.Hash) common.Hash {
    96  	return crypto.Keccak256Hash(
    97  		crypto.Keccak256(
    98  			crypto.Keccak256(pos[:], key0[:]),
    99  			key1[:],
   100  		),
   101  		key2[:],
   102  	)
   103  }
   104  
   105  // Slot returns the storage slot of a mapping with the given position and key.
   106  //
   107  // Slot follows the Solidity storage layout for:
   108  //
   109  //	mapping(bytes32 => bytes32)
   110  //
   111  // Deprecated: Use SoliditySlot instead.
   112  func Slot(pos, key common.Hash) common.Hash {
   113  	return SoliditySlot(pos, key)
   114  }
   115  
   116  // Slot2 returns the storage slot of a double mapping with the given position
   117  // and keys.
   118  //
   119  // Slot2 follows the Solidity storage layout for:
   120  //
   121  //	mapping(bytes32 => mapping(bytes32 => bytes32))
   122  //
   123  // Deprecated: Use SoliditySlot2 instead.
   124  func Slot2(pos, key0, key1 common.Hash) common.Hash {
   125  	return SoliditySlot2(pos, key0, key1)
   126  }
   127  
   128  // Slot3 returns the storage slot of a triple mapping with the given position
   129  // and keys.
   130  //
   131  // Slot3 follows the Solidity storage layout for:
   132  //
   133  //	mapping(bytes32 => mapping(bytes32 => mapping(bytes32 => bytes32)))
   134  //
   135  // Deprecated: Use SoliditySlot3 instead.
   136  func Slot3(pos, key0, key1, key2 common.Hash) common.Hash {
   137  	return SoliditySlot3(pos, key0, key1, key2)
   138  }
   139  
   140  // zeroHashFunc implements a [vm.GetHashFunc] that always returns the zero hash.
   141  func zeroHashFunc(uint64) common.Hash {
   142  	return w3.Hash0
   143  }
   144  
   145  ////////////////////////////////////////////////////////////////////////////////////////////////////
   146  // w3types.RPCCaller's /////////////////////////////////////////////////////////////////////////////
   147  ////////////////////////////////////////////////////////////////////////////////////////////////////
   148  
   149  // ethBalance is like [eth.Balance], but returns the balance as [uint256.Int].
   150  func ethBalance(addr common.Address, blockNumber *big.Int) w3types.RPCCallerFactory[uint256.Int] {
   151  	return module.NewFactory(
   152  		"eth_getBalance",
   153  		[]any{addr, module.BlockNumberArg(blockNumber)},
   154  		module.WithRetWrapper(func(ret *uint256.Int) any { return (*hexutil.U256)(ret) }),
   155  	)
   156  }
   157  
   158  // ethHeaderHash is like [eth.Header], but only parses the header hash.
   159  func ethHeaderHash(blockNumber uint64) w3types.RPCCallerFactory[header] {
   160  	return module.NewFactory[header](
   161  		"eth_getBlockByNumber",
   162  		[]any{hexutil.Uint64(blockNumber), false},
   163  	)
   164  }
   165  
   166  type header struct {
   167  	Hash common.Hash `json:"hash"`
   168  }
   169  
   170  ////////////////////////////////////////////////////////////////////////////////////////////////////
   171  // tracing.Hook's //////////////////////////////////////////////////////////////////////////////////
   172  ////////////////////////////////////////////////////////////////////////////////////////////////////
   173  
   174  // joinHooks joins multiple hooks into one.
   175  func joinHooks(hooks []*tracing.Hooks) *tracing.Hooks {
   176  	// hot path
   177  	switch len(hooks) {
   178  	case 0:
   179  		return nil
   180  	case 1:
   181  		return hooks[0]
   182  	}
   183  
   184  	// vm hooks
   185  	var onEnters []tracing.EnterHook
   186  	var onExits []tracing.ExitHook
   187  	var onOpcodes []tracing.OpcodeHook
   188  	var onFaults []tracing.FaultHook
   189  	var onGasChanges []tracing.GasChangeHook
   190  	// state hooks
   191  	var onBalanceChanges []tracing.BalanceChangeHook
   192  	var onNonceChanges []tracing.NonceChangeHook
   193  	var onCodeChanges []tracing.CodeChangeHook
   194  	var onStorageChanges []tracing.StorageChangeHook
   195  	var onLogs []tracing.LogHook
   196  
   197  	for _, h := range hooks {
   198  		if h == nil {
   199  			continue
   200  		}
   201  		// vm hooks
   202  		if h.OnEnter != nil {
   203  			onEnters = append(onEnters, h.OnEnter)
   204  		}
   205  		if h.OnExit != nil {
   206  			onExits = append(onExits, h.OnExit)
   207  		}
   208  		if h.OnOpcode != nil {
   209  			onOpcodes = append(onOpcodes, h.OnOpcode)
   210  		}
   211  		if h.OnFault != nil {
   212  			onFaults = append(onFaults, h.OnFault)
   213  		}
   214  		if h.OnGasChange != nil {
   215  			onGasChanges = append(onGasChanges, h.OnGasChange)
   216  		}
   217  		// state hooks
   218  		if h.OnBalanceChange != nil {
   219  			onBalanceChanges = append(onBalanceChanges, h.OnBalanceChange)
   220  		}
   221  		if h.OnNonceChange != nil {
   222  			onNonceChanges = append(onNonceChanges, h.OnNonceChange)
   223  		}
   224  		if h.OnCodeChange != nil {
   225  			onCodeChanges = append(onCodeChanges, h.OnCodeChange)
   226  		}
   227  		if h.OnStorageChange != nil {
   228  			onStorageChanges = append(onStorageChanges, h.OnStorageChange)
   229  		}
   230  		if h.OnLog != nil {
   231  			onLogs = append(onLogs, h.OnLog)
   232  		}
   233  	}
   234  
   235  	hook := new(tracing.Hooks)
   236  	// vm hooks
   237  	if len(onEnters) > 0 {
   238  		hook.OnEnter = func(depth int, typ byte, from, to common.Address, input []byte, gas uint64, value *big.Int) {
   239  			for _, h := range onEnters {
   240  				h(depth, typ, from, to, input, gas, value)
   241  			}
   242  		}
   243  	}
   244  	if len(onExits) > 0 {
   245  		hook.OnExit = func(depth int, output []byte, gasUsed uint64, err error, reverted bool) {
   246  			for _, h := range onExits {
   247  				h(depth, output, gasUsed, err, reverted)
   248  			}
   249  		}
   250  	}
   251  	if len(onOpcodes) > 0 {
   252  		hook.OnOpcode = func(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, rData []byte, depth int, err error) {
   253  			for _, h := range onOpcodes {
   254  				h(pc, op, gas, cost, scope, rData, depth, err)
   255  			}
   256  		}
   257  	}
   258  	if len(onFaults) > 0 {
   259  		hook.OnFault = func(pc uint64, op byte, gas, cost uint64, scope tracing.OpContext, depth int, err error) {
   260  			for _, h := range onFaults {
   261  				h(pc, op, gas, cost, scope, depth, err)
   262  			}
   263  		}
   264  	}
   265  	if len(onGasChanges) > 0 {
   266  		hook.OnGasChange = func(old, new uint64, reason tracing.GasChangeReason) {
   267  			for _, h := range onGasChanges {
   268  				h(old, new, reason)
   269  			}
   270  		}
   271  	}
   272  	// state hooks
   273  	if len(onBalanceChanges) > 0 {
   274  		hook.OnBalanceChange = func(addr common.Address, prev, new *big.Int, reason tracing.BalanceChangeReason) {
   275  			for _, h := range onBalanceChanges {
   276  				h(addr, prev, new, reason)
   277  			}
   278  		}
   279  	}
   280  	if len(onNonceChanges) > 0 {
   281  		hook.OnNonceChange = func(addr common.Address, prev, new uint64) {
   282  			for _, h := range onNonceChanges {
   283  				h(addr, prev, new)
   284  			}
   285  		}
   286  	}
   287  	if len(onCodeChanges) > 0 {
   288  		hook.OnCodeChange = func(addr common.Address, prevCodeHash common.Hash, prevCode []byte, codeHash common.Hash, code []byte) {
   289  			for _, h := range onCodeChanges {
   290  				h(addr, prevCodeHash, prevCode, codeHash, code)
   291  			}
   292  		}
   293  	}
   294  	if len(onStorageChanges) > 0 {
   295  		hook.OnStorageChange = func(addr common.Address, slot, prev, new common.Hash) {
   296  			for _, h := range onStorageChanges {
   297  				h(addr, slot, prev, new)
   298  			}
   299  		}
   300  	}
   301  	if len(onLogs) > 0 {
   302  		hook.OnLog = func(log *types.Log) {
   303  			for _, h := range onLogs {
   304  				h(log)
   305  			}
   306  		}
   307  	}
   308  	return hook
   309  }