github.com/lmittmann/w3@v0.20.0/w3vm/fetcher.go (about)

     1  package w3vm
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"math/big"
     9  	"os"
    10  	"path/filepath"
    11  	"sync"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/ethereum/go-ethereum/common"
    17  	"github.com/ethereum/go-ethereum/common/hexutil"
    18  	"github.com/ethereum/go-ethereum/core/types"
    19  	"github.com/gofrs/flock"
    20  	"github.com/holiman/uint256"
    21  	"github.com/lmittmann/w3"
    22  	"github.com/lmittmann/w3/internal/crypto"
    23  	w3hexutil "github.com/lmittmann/w3/internal/hexutil"
    24  	"github.com/lmittmann/w3/internal/mod"
    25  	"github.com/lmittmann/w3/module/eth"
    26  	"github.com/lmittmann/w3/w3types"
    27  )
    28  
    29  // Fetcher is the interface to access account state of a blockchain.
    30  type Fetcher interface {
    31  	// Account fetches the account of the given address.
    32  	Account(common.Address) (*types.StateAccount, error)
    33  
    34  	// Code fetches the code of the given code hash.
    35  	Code(common.Hash) ([]byte, error)
    36  
    37  	// StorageAt fetches the state of the given address and storage slot.
    38  	StorageAt(common.Address, common.Hash) (common.Hash, error)
    39  
    40  	// HeaderHash fetches the hash of the header with the given number.
    41  	HeaderHash(uint64) (common.Hash, error)
    42  }
    43  
    44  type rpcFetcher struct {
    45  	client      *w3.Client
    46  	blockNumber *big.Int
    47  
    48  	mux          sync.RWMutex
    49  	accounts     map[common.Address]func() (*types.StateAccount, error)
    50  	contracts    map[common.Hash]func() ([]byte, error)
    51  	mux2         sync.RWMutex
    52  	storage      map[storageKey]func() (common.Hash, error)
    53  	mux3         sync.RWMutex
    54  	headerHashes map[uint64]func() (common.Hash, error)
    55  
    56  	dirty uint32 // indicates whether new state has been fetched (0=false, 1=true)
    57  
    58  	// file modification times for testdata files
    59  	stateFileModTime        time.Time
    60  	contractsFileModTime    time.Time
    61  	headerHashesFileModTime time.Time
    62  }
    63  
    64  // NewRPCFetcher returns a new [Fetcher] that fetches account state from the given
    65  // RPC client for the given block number.
    66  //
    67  // Note, that the returned state for a given block number is the state after the
    68  // execution of that block.
    69  func NewRPCFetcher(client *w3.Client, blockNumber *big.Int) Fetcher {
    70  	return newRPCFetcher(client, blockNumber)
    71  }
    72  
    73  func newRPCFetcher(client *w3.Client, blockNumber *big.Int) *rpcFetcher {
    74  	return &rpcFetcher{
    75  		client:       client,
    76  		blockNumber:  blockNumber,
    77  		accounts:     make(map[common.Address]func() (*types.StateAccount, error)),
    78  		contracts:    make(map[common.Hash]func() ([]byte, error)),
    79  		storage:      make(map[storageKey]func() (common.Hash, error)),
    80  		headerHashes: make(map[uint64]func() (common.Hash, error)),
    81  	}
    82  }
    83  
    84  func (f *rpcFetcher) Account(addr common.Address) (a *types.StateAccount, e error) {
    85  	f.mux.RLock()
    86  	acc, ok := f.accounts[addr]
    87  	f.mux.RUnlock()
    88  	if ok {
    89  		return acc()
    90  	}
    91  	atomic.StoreUint32(&f.dirty, 1)
    92  
    93  	var (
    94  		accNew      = &types.StateAccount{Balance: new(uint256.Int)}
    95  		contractNew []byte
    96  
    97  		accCh      = make(chan func() (*types.StateAccount, error), 1)
    98  		contractCh = make(chan func() ([]byte, error), 1)
    99  	)
   100  	go func() {
   101  		err := f.call(
   102  			eth.Nonce(addr, f.blockNumber).Returns(&accNew.Nonce),
   103  			ethBalance(addr, f.blockNumber).Returns(accNew.Balance),
   104  			eth.Code(addr, f.blockNumber).Returns(&contractNew),
   105  		)
   106  		if err != nil {
   107  			accCh <- func() (*types.StateAccount, error) { return nil, err }
   108  			contractCh <- func() ([]byte, error) { return nil, err }
   109  			return
   110  		}
   111  
   112  		if len(contractNew) == 0 {
   113  			accNew.CodeHash = types.EmptyCodeHash[:]
   114  		} else {
   115  			accNew.CodeHash = crypto.Keccak256(contractNew)
   116  		}
   117  		accCh <- func() (*types.StateAccount, error) { return accNew, nil }
   118  		contractCh <- func() ([]byte, error) { return contractNew, nil }
   119  	}()
   120  
   121  	f.mux.Lock()
   122  	defer f.mux.Unlock()
   123  	accOnce := sync.OnceValues(<-accCh)
   124  	f.accounts[addr] = accOnce
   125  	accRet, err := accOnce()
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	f.contracts[common.BytesToHash(accRet.CodeHash)] = sync.OnceValues(<-contractCh)
   130  	return accRet, nil
   131  }
   132  
   133  func (f *rpcFetcher) Code(codeHash common.Hash) ([]byte, error) {
   134  	f.mux.RLock()
   135  	contract, ok := f.contracts[codeHash]
   136  	f.mux.RUnlock()
   137  	if !ok {
   138  		panic("not implemented")
   139  	}
   140  	return contract()
   141  }
   142  
   143  func (f *rpcFetcher) StorageAt(addr common.Address, slot common.Hash) (common.Hash, error) {
   144  	key := storageKey{addr, slot}
   145  
   146  	f.mux2.RLock()
   147  	storage, ok := f.storage[key]
   148  	f.mux2.RUnlock()
   149  	if ok {
   150  		return storage()
   151  	}
   152  	atomic.StoreUint32(&f.dirty, 1)
   153  
   154  	var (
   155  		storageVal   common.Hash
   156  		storageValCh = make(chan func() (common.Hash, error), 1)
   157  	)
   158  	go func() {
   159  		err := f.call(eth.StorageAt(addr, slot, f.blockNumber).Returns(&storageVal))
   160  		storageValCh <- func() (common.Hash, error) { return storageVal, err }
   161  	}()
   162  
   163  	storageValOnce := sync.OnceValues(<-storageValCh)
   164  	f.mux2.Lock()
   165  	f.storage[key] = storageValOnce
   166  	f.mux2.Unlock()
   167  	return storageValOnce()
   168  }
   169  
   170  func (f *rpcFetcher) HeaderHash(blockNumber uint64) (common.Hash, error) {
   171  	f.mux3.RLock()
   172  	hash, ok := f.headerHashes[blockNumber]
   173  	f.mux3.RUnlock()
   174  	if ok {
   175  		return hash()
   176  	}
   177  	atomic.StoreUint32(&f.dirty, 1)
   178  
   179  	var (
   180  		header       header
   181  		headerHashCh = make(chan func() (common.Hash, error), 1)
   182  	)
   183  	go func() {
   184  		err := f.call(ethHeaderHash(blockNumber).Returns(&header))
   185  		headerHashCh <- func() (common.Hash, error) { return header.Hash, err }
   186  	}()
   187  
   188  	headerHashOnce := sync.OnceValues(<-headerHashCh)
   189  	f.mux3.Lock()
   190  	f.headerHashes[blockNumber] = headerHashOnce
   191  	f.mux3.Unlock()
   192  	return headerHashOnce()
   193  }
   194  
   195  func (f *rpcFetcher) call(calls ...w3types.RPCCaller) error {
   196  	return f.client.Call(calls...)
   197  }
   198  
   199  ////////////////////////////////////////////////////////////////////////////////////////////////////
   200  // TestingRPCFetcher ///////////////////////////////////////////////////////////////////////////////
   201  ////////////////////////////////////////////////////////////////////////////////////////////////////
   202  
   203  // NewTestingRPCFetcher returns a new [Fetcher] like [NewRPCFetcher], but caches
   204  // the fetched state on disk in the testdata directory of the tests package.
   205  func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, blockNumber *big.Int) Fetcher {
   206  	if mod.Root == "" {
   207  		panic("w3vm: NewTestingRPCFetcher must be used in a module test")
   208  	}
   209  
   210  	fetcher := newRPCFetcher(client, blockNumber)
   211  	if err := fetcher.loadTestdataState(chainID); err != nil {
   212  		tb.Fatalf("w3vm: failed to load state from testdata: %v", err)
   213  	}
   214  
   215  	tb.Cleanup(func() {
   216  		if err := fetcher.storeTestdataState(chainID); err != nil {
   217  			tb.Fatalf("w3vm: failed to write state to testdata: %v", err)
   218  		}
   219  	})
   220  	return fetcher
   221  }
   222  
   223  var (
   224  	testdataMutex sync.RWMutex                      // in-process synchronization
   225  	testdataLock  = flock.New(testdataPath("LOCK")) // inter-process synchronization
   226  )
   227  
   228  func (f *rpcFetcher) loadTestdataState(chainID uint64) (err error) {
   229  	// lock testdata files
   230  	testdataMutex.RLock()
   231  	defer testdataMutex.RUnlock()
   232  	testdataLock.RLock()
   233  	defer testdataLock.Unlock()
   234  
   235  	// read testdata files
   236  	stateFn := fmt.Sprintf("%d_%v.json", chainID, f.blockNumber)
   237  	var state testdataState
   238  	if f.stateFileModTime, err = readTestdata(stateFn, &state, time.Time{}); err != nil {
   239  		return err
   240  	}
   241  
   242  	var contracts testdataContracts
   243  	if f.contractsFileModTime, err = readTestdata("contracts.json", &contracts, time.Time{}); err != nil {
   244  		return err
   245  	}
   246  
   247  	headerHashesFn := fmt.Sprintf("%d_header_hashes.json", chainID)
   248  	var headerHashes testdataHeaderHashes
   249  	if f.headerHashesFileModTime, err = readTestdata(headerHashesFn, &headerHashes, time.Time{}); err != nil {
   250  		return err
   251  	}
   252  
   253  	// build fetcher state
   254  	f.mux.Lock()
   255  	f.mux2.Lock()
   256  	f.mux3.Lock()
   257  	defer f.mux.Unlock()
   258  	defer f.mux2.Unlock()
   259  	defer f.mux3.Unlock()
   260  
   261  	for addr, acc := range state {
   262  		codeHash := acc.codeHash()
   263  
   264  		f.accounts[addr] = func() (*types.StateAccount, error) {
   265  			return &types.StateAccount{
   266  				Nonce:    uint64(acc.Nonce),
   267  				Balance:  (*uint256.Int)(acc.Balance),
   268  				CodeHash: codeHash[:],
   269  			}, nil
   270  		}
   271  		if _, ok := f.contracts[codeHash]; codeHash != types.EmptyCodeHash && !ok {
   272  			f.contracts[codeHash] = func() ([]byte, error) {
   273  				return contracts[codeHash], nil
   274  			}
   275  		}
   276  		for slot, val := range acc.Storage {
   277  			f.storage[storageKey{addr, (common.Hash)(slot)}] = func() (common.Hash, error) {
   278  				return (common.Hash)(val), nil
   279  			}
   280  		}
   281  		for blockNumber, hash := range headerHashes {
   282  			f.headerHashes[uint64(blockNumber)] = func() (common.Hash, error) {
   283  				return hash, nil
   284  			}
   285  		}
   286  	}
   287  	return nil
   288  }
   289  
   290  func (f *rpcFetcher) storeTestdataState(chainID uint64) (err error) {
   291  	if atomic.LoadUint32(&f.dirty) == 0 {
   292  		return nil // if no new state was fetched, we do not need to store it
   293  	}
   294  
   295  	// read fetcher state
   296  	f.mux.RLock()
   297  	f.mux2.RLock()
   298  	f.mux3.RLock()
   299  	defer f.mux.RUnlock()
   300  	defer f.mux2.RUnlock()
   301  	defer f.mux3.RUnlock()
   302  
   303  	var (
   304  		state        = make(testdataState)
   305  		contracts    = make(testdataContracts)
   306  		headerHashes = make(testdataHeaderHashes)
   307  	)
   308  	for addr, accFunc := range f.accounts {
   309  		acc, err := accFunc()
   310  		if err != nil {
   311  			continue
   312  		}
   313  
   314  		state[addr] = &testdataAccount{
   315  			Nonce:   hexutil.Uint64(acc.Nonce),
   316  			Balance: (*hexutil.U256)(acc.Balance),
   317  		}
   318  		if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) {
   319  			codeHash := common.BytesToHash(acc.CodeHash)
   320  			state[addr].CodeHash = codeHash
   321  			contracts[codeHash], _ = f.contracts[codeHash]()
   322  		}
   323  	}
   324  
   325  	for storageKey, storageValFunc := range f.storage {
   326  		storageVal, err := storageValFunc()
   327  		if err != nil {
   328  			continue
   329  		}
   330  
   331  		if _, ok := state[storageKey.addr]; !ok {
   332  			state[storageKey.addr] = &testdataAccount{
   333  				Storage: make(map[w3hexutil.Hash]w3hexutil.Hash),
   334  			}
   335  		} else if state[storageKey.addr].Storage == nil {
   336  			state[storageKey.addr].Storage = make(map[w3hexutil.Hash]w3hexutil.Hash)
   337  		}
   338  		state[storageKey.addr].Storage[w3hexutil.Hash(storageKey.slot)] = w3hexutil.Hash(storageVal)
   339  	}
   340  
   341  	for blockNumber, hashFunc := range f.headerHashes {
   342  		hash, err := hashFunc()
   343  		if err != nil {
   344  			continue
   345  		}
   346  		headerHashes[hexutil.Uint64(blockNumber)] = hash
   347  	}
   348  
   349  	// lock testdata files
   350  	testdataMutex.Lock()
   351  	defer testdataMutex.Unlock()
   352  	testdataLock.Lock()
   353  	defer testdataLock.Unlock()
   354  
   355  	// load current testdata state
   356  	stateFn := fmt.Sprintf("%d_%v.json", chainID, f.blockNumber)
   357  	var otherState testdataState
   358  	if _, err = readTestdata(stateFn, &otherState, f.stateFileModTime); err != nil {
   359  		return err
   360  	}
   361  
   362  	var otherContracts testdataContracts
   363  	if _, err = readTestdata("contracts.json", &otherContracts, f.contractsFileModTime); err != nil {
   364  		return err
   365  	}
   366  
   367  	headerHashesFn := fmt.Sprintf("%d_header_hashes.json", chainID)
   368  	var otherHeaderHashes testdataHeaderHashes
   369  	if _, err = readTestdata(headerHashesFn, &otherHeaderHashes, f.headerHashesFileModTime); err != nil {
   370  		return err
   371  	}
   372  
   373  	// merge
   374  	if err := state.Merge(otherState); err != nil {
   375  		return fmt.Errorf("failed to merge testdata state: %w", err)
   376  	}
   377  
   378  	if err := contracts.Merge(otherContracts); err != nil {
   379  		return fmt.Errorf("failed to merge testdata contracts: %w", err)
   380  	}
   381  
   382  	if err := headerHashes.Merge(otherHeaderHashes); err != nil {
   383  		return fmt.Errorf("failed to merge testdata header hashes: %w", err)
   384  	}
   385  
   386  	// write testdata files
   387  	if err := writeTestdata(stateFn, state); err != nil {
   388  		return err
   389  	}
   390  	if err := writeTestdata("contracts.json", contracts); err != nil {
   391  		return err
   392  	}
   393  	if err := writeTestdata(headerHashesFn, headerHashes); err != nil {
   394  		return err
   395  	}
   396  
   397  	return nil
   398  }
   399  
   400  type storageKey struct {
   401  	addr common.Address
   402  	slot common.Hash
   403  }
   404  
   405  // testdataState maps accounts to their state at a specific block in a specific
   406  // chain.
   407  type testdataState map[common.Address]*testdataAccount
   408  
   409  func (s testdataState) Merge(other testdataState) error {
   410  	for addr, otherAccount := range other {
   411  		if existingAccount, ok := s[addr]; ok {
   412  			if err := existingAccount.Merge(otherAccount); err != nil {
   413  				return fmt.Errorf("account conflict for address %s: %w", addr, err)
   414  			}
   415  		} else {
   416  			s[addr] = otherAccount
   417  		}
   418  	}
   419  	return nil
   420  }
   421  
   422  // testdataAccount represents the state of a single account.
   423  type testdataAccount struct {
   424  	Nonce    hexutil.Uint64                    `json:"nonce"`
   425  	Balance  *hexutil.U256                     `json:"balance"`
   426  	CodeHash common.Hash                       `json:"codeHash,omitzero"`
   427  	Storage  map[w3hexutil.Hash]w3hexutil.Hash `json:"storage,omitempty"`
   428  }
   429  
   430  func (a *testdataAccount) codeHash() common.Hash {
   431  	if a.CodeHash == w3.Hash0 {
   432  		return types.EmptyCodeHash
   433  	}
   434  	return a.CodeHash
   435  }
   436  
   437  func (a *testdataAccount) Merge(other *testdataAccount) error {
   438  	if a.Nonce != other.Nonce {
   439  		return fmt.Errorf("nonce conflict: %d != %d", a.Nonce, other.Nonce)
   440  	}
   441  	if (*uint256.Int)(a.Balance).Cmp((*uint256.Int)(other.Balance)) != 0 {
   442  		return fmt.Errorf("balance conflict: %s != %s", a.Balance, other.Balance)
   443  	}
   444  	if a.CodeHash != other.CodeHash {
   445  		return fmt.Errorf("code hash conflict: %s != %s", a.CodeHash, other.CodeHash)
   446  	}
   447  
   448  	// Merge storage maps
   449  	if a.Storage == nil {
   450  		a.Storage = make(map[w3hexutil.Hash]w3hexutil.Hash)
   451  	}
   452  	for slot, value := range other.Storage {
   453  		if existingValue, ok := a.Storage[slot]; ok {
   454  			if existingValue != value {
   455  				return fmt.Errorf("storage conflict at slot %s: %s != %s",
   456  					(common.Hash)(slot), (common.Hash)(existingValue), (common.Hash)(value),
   457  				)
   458  			}
   459  		} else {
   460  			a.Storage[slot] = value
   461  		}
   462  	}
   463  
   464  	return nil
   465  }
   466  
   467  // testdataContracts maps code hashes to their code.
   468  type testdataContracts map[common.Hash]hexutil.Bytes
   469  
   470  func (c testdataContracts) Merge(other testdataContracts) error {
   471  	for hash, code := range other {
   472  		if existingCode, ok := c[hash]; ok {
   473  			if !bytes.Equal(existingCode, code) {
   474  				return fmt.Errorf("bytecode conflict for code hash %s", hash)
   475  			}
   476  		} else {
   477  			c[hash] = code
   478  		}
   479  	}
   480  	return nil
   481  }
   482  
   483  // testdataHeaderHashes maps block numbers to their hashes for a specific chain.
   484  type testdataHeaderHashes map[hexutil.Uint64]common.Hash
   485  
   486  func (h testdataHeaderHashes) Merge(other testdataHeaderHashes) error {
   487  	for blockNumber, hash := range other {
   488  		if existingHash, ok := h[blockNumber]; ok {
   489  			if existingHash != hash {
   490  				return fmt.Errorf("header hash conflict for block %d", blockNumber)
   491  			}
   492  		} else {
   493  			h[blockNumber] = hash
   494  		}
   495  	}
   496  	return nil
   497  }
   498  
   499  func readTestdata(filename string, data any, onlyIfModifiedAfter time.Time) (time.Time, error) {
   500  	path := testdataPath(filename)
   501  
   502  	// get file info first
   503  	info, err := os.Stat(path)
   504  	if errors.Is(err, os.ErrNotExist) {
   505  		return time.Time{}, nil
   506  	} else if err != nil {
   507  		return time.Time{}, err
   508  	}
   509  
   510  	if info.ModTime().Before(onlyIfModifiedAfter) {
   511  		return info.ModTime(), nil // file was NOT modified after "onlyIfModifiedAfter"
   512  	}
   513  
   514  	// open and read file
   515  	f, err := os.Open(path)
   516  	if err != nil {
   517  		return time.Time{}, err
   518  	}
   519  	defer f.Close()
   520  
   521  	if err := json.NewDecoder(f).Decode(data); err != nil {
   522  		return time.Time{}, fmt.Errorf("decode json %s: %w", filename, err)
   523  	}
   524  	return info.ModTime(), nil
   525  }
   526  
   527  func writeTestdata(filename string, data any) error {
   528  	path := testdataPath(filename)
   529  
   530  	// create "testdata/w3vm"-dir, if it does not exist yet
   531  	dir := filepath.Dir(path)
   532  	if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) {
   533  		if err := os.MkdirAll(dir, 0o775); err != nil {
   534  			return err
   535  		}
   536  	}
   537  
   538  	// create or open file
   539  	f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664)
   540  	if err != nil {
   541  		return err
   542  	}
   543  	defer f.Close()
   544  
   545  	enc := json.NewEncoder(f)
   546  	enc.SetIndent("", "\t")
   547  	if err := enc.Encode(data); err != nil {
   548  		return fmt.Errorf("encode json %s: %w", filename, err)
   549  	}
   550  	return nil
   551  }
   552  
   553  func testdataPath(filename string) string {
   554  	return filepath.Join(mod.Root, "testdata", "w3vm", filename)
   555  }