github.com/lmittmann/w3@v0.20.0/w3vm/fetcher.go (about) 1 package w3vm 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "math/big" 9 "os" 10 "path/filepath" 11 "sync" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "github.com/ethereum/go-ethereum/common" 17 "github.com/ethereum/go-ethereum/common/hexutil" 18 "github.com/ethereum/go-ethereum/core/types" 19 "github.com/gofrs/flock" 20 "github.com/holiman/uint256" 21 "github.com/lmittmann/w3" 22 "github.com/lmittmann/w3/internal/crypto" 23 w3hexutil "github.com/lmittmann/w3/internal/hexutil" 24 "github.com/lmittmann/w3/internal/mod" 25 "github.com/lmittmann/w3/module/eth" 26 "github.com/lmittmann/w3/w3types" 27 ) 28 29 // Fetcher is the interface to access account state of a blockchain. 30 type Fetcher interface { 31 // Account fetches the account of the given address. 32 Account(common.Address) (*types.StateAccount, error) 33 34 // Code fetches the code of the given code hash. 35 Code(common.Hash) ([]byte, error) 36 37 // StorageAt fetches the state of the given address and storage slot. 38 StorageAt(common.Address, common.Hash) (common.Hash, error) 39 40 // HeaderHash fetches the hash of the header with the given number. 41 HeaderHash(uint64) (common.Hash, error) 42 } 43 44 type rpcFetcher struct { 45 client *w3.Client 46 blockNumber *big.Int 47 48 mux sync.RWMutex 49 accounts map[common.Address]func() (*types.StateAccount, error) 50 contracts map[common.Hash]func() ([]byte, error) 51 mux2 sync.RWMutex 52 storage map[storageKey]func() (common.Hash, error) 53 mux3 sync.RWMutex 54 headerHashes map[uint64]func() (common.Hash, error) 55 56 dirty uint32 // indicates whether new state has been fetched (0=false, 1=true) 57 58 // file modification times for testdata files 59 stateFileModTime time.Time 60 contractsFileModTime time.Time 61 headerHashesFileModTime time.Time 62 } 63 64 // NewRPCFetcher returns a new [Fetcher] that fetches account state from the given 65 // RPC client for the given block number. 66 // 67 // Note, that the returned state for a given block number is the state after the 68 // execution of that block. 69 func NewRPCFetcher(client *w3.Client, blockNumber *big.Int) Fetcher { 70 return newRPCFetcher(client, blockNumber) 71 } 72 73 func newRPCFetcher(client *w3.Client, blockNumber *big.Int) *rpcFetcher { 74 return &rpcFetcher{ 75 client: client, 76 blockNumber: blockNumber, 77 accounts: make(map[common.Address]func() (*types.StateAccount, error)), 78 contracts: make(map[common.Hash]func() ([]byte, error)), 79 storage: make(map[storageKey]func() (common.Hash, error)), 80 headerHashes: make(map[uint64]func() (common.Hash, error)), 81 } 82 } 83 84 func (f *rpcFetcher) Account(addr common.Address) (a *types.StateAccount, e error) { 85 f.mux.RLock() 86 acc, ok := f.accounts[addr] 87 f.mux.RUnlock() 88 if ok { 89 return acc() 90 } 91 atomic.StoreUint32(&f.dirty, 1) 92 93 var ( 94 accNew = &types.StateAccount{Balance: new(uint256.Int)} 95 contractNew []byte 96 97 accCh = make(chan func() (*types.StateAccount, error), 1) 98 contractCh = make(chan func() ([]byte, error), 1) 99 ) 100 go func() { 101 err := f.call( 102 eth.Nonce(addr, f.blockNumber).Returns(&accNew.Nonce), 103 ethBalance(addr, f.blockNumber).Returns(accNew.Balance), 104 eth.Code(addr, f.blockNumber).Returns(&contractNew), 105 ) 106 if err != nil { 107 accCh <- func() (*types.StateAccount, error) { return nil, err } 108 contractCh <- func() ([]byte, error) { return nil, err } 109 return 110 } 111 112 if len(contractNew) == 0 { 113 accNew.CodeHash = types.EmptyCodeHash[:] 114 } else { 115 accNew.CodeHash = crypto.Keccak256(contractNew) 116 } 117 accCh <- func() (*types.StateAccount, error) { return accNew, nil } 118 contractCh <- func() ([]byte, error) { return contractNew, nil } 119 }() 120 121 f.mux.Lock() 122 defer f.mux.Unlock() 123 accOnce := sync.OnceValues(<-accCh) 124 f.accounts[addr] = accOnce 125 accRet, err := accOnce() 126 if err != nil { 127 return nil, err 128 } 129 f.contracts[common.BytesToHash(accRet.CodeHash)] = sync.OnceValues(<-contractCh) 130 return accRet, nil 131 } 132 133 func (f *rpcFetcher) Code(codeHash common.Hash) ([]byte, error) { 134 f.mux.RLock() 135 contract, ok := f.contracts[codeHash] 136 f.mux.RUnlock() 137 if !ok { 138 panic("not implemented") 139 } 140 return contract() 141 } 142 143 func (f *rpcFetcher) StorageAt(addr common.Address, slot common.Hash) (common.Hash, error) { 144 key := storageKey{addr, slot} 145 146 f.mux2.RLock() 147 storage, ok := f.storage[key] 148 f.mux2.RUnlock() 149 if ok { 150 return storage() 151 } 152 atomic.StoreUint32(&f.dirty, 1) 153 154 var ( 155 storageVal common.Hash 156 storageValCh = make(chan func() (common.Hash, error), 1) 157 ) 158 go func() { 159 err := f.call(eth.StorageAt(addr, slot, f.blockNumber).Returns(&storageVal)) 160 storageValCh <- func() (common.Hash, error) { return storageVal, err } 161 }() 162 163 storageValOnce := sync.OnceValues(<-storageValCh) 164 f.mux2.Lock() 165 f.storage[key] = storageValOnce 166 f.mux2.Unlock() 167 return storageValOnce() 168 } 169 170 func (f *rpcFetcher) HeaderHash(blockNumber uint64) (common.Hash, error) { 171 f.mux3.RLock() 172 hash, ok := f.headerHashes[blockNumber] 173 f.mux3.RUnlock() 174 if ok { 175 return hash() 176 } 177 atomic.StoreUint32(&f.dirty, 1) 178 179 var ( 180 header header 181 headerHashCh = make(chan func() (common.Hash, error), 1) 182 ) 183 go func() { 184 err := f.call(ethHeaderHash(blockNumber).Returns(&header)) 185 headerHashCh <- func() (common.Hash, error) { return header.Hash, err } 186 }() 187 188 headerHashOnce := sync.OnceValues(<-headerHashCh) 189 f.mux3.Lock() 190 f.headerHashes[blockNumber] = headerHashOnce 191 f.mux3.Unlock() 192 return headerHashOnce() 193 } 194 195 func (f *rpcFetcher) call(calls ...w3types.RPCCaller) error { 196 return f.client.Call(calls...) 197 } 198 199 //////////////////////////////////////////////////////////////////////////////////////////////////// 200 // TestingRPCFetcher /////////////////////////////////////////////////////////////////////////////// 201 //////////////////////////////////////////////////////////////////////////////////////////////////// 202 203 // NewTestingRPCFetcher returns a new [Fetcher] like [NewRPCFetcher], but caches 204 // the fetched state on disk in the testdata directory of the tests package. 205 func NewTestingRPCFetcher(tb testing.TB, chainID uint64, client *w3.Client, blockNumber *big.Int) Fetcher { 206 if mod.Root == "" { 207 panic("w3vm: NewTestingRPCFetcher must be used in a module test") 208 } 209 210 fetcher := newRPCFetcher(client, blockNumber) 211 if err := fetcher.loadTestdataState(chainID); err != nil { 212 tb.Fatalf("w3vm: failed to load state from testdata: %v", err) 213 } 214 215 tb.Cleanup(func() { 216 if err := fetcher.storeTestdataState(chainID); err != nil { 217 tb.Fatalf("w3vm: failed to write state to testdata: %v", err) 218 } 219 }) 220 return fetcher 221 } 222 223 var ( 224 testdataMutex sync.RWMutex // in-process synchronization 225 testdataLock = flock.New(testdataPath("LOCK")) // inter-process synchronization 226 ) 227 228 func (f *rpcFetcher) loadTestdataState(chainID uint64) (err error) { 229 // lock testdata files 230 testdataMutex.RLock() 231 defer testdataMutex.RUnlock() 232 testdataLock.RLock() 233 defer testdataLock.Unlock() 234 235 // read testdata files 236 stateFn := fmt.Sprintf("%d_%v.json", chainID, f.blockNumber) 237 var state testdataState 238 if f.stateFileModTime, err = readTestdata(stateFn, &state, time.Time{}); err != nil { 239 return err 240 } 241 242 var contracts testdataContracts 243 if f.contractsFileModTime, err = readTestdata("contracts.json", &contracts, time.Time{}); err != nil { 244 return err 245 } 246 247 headerHashesFn := fmt.Sprintf("%d_header_hashes.json", chainID) 248 var headerHashes testdataHeaderHashes 249 if f.headerHashesFileModTime, err = readTestdata(headerHashesFn, &headerHashes, time.Time{}); err != nil { 250 return err 251 } 252 253 // build fetcher state 254 f.mux.Lock() 255 f.mux2.Lock() 256 f.mux3.Lock() 257 defer f.mux.Unlock() 258 defer f.mux2.Unlock() 259 defer f.mux3.Unlock() 260 261 for addr, acc := range state { 262 codeHash := acc.codeHash() 263 264 f.accounts[addr] = func() (*types.StateAccount, error) { 265 return &types.StateAccount{ 266 Nonce: uint64(acc.Nonce), 267 Balance: (*uint256.Int)(acc.Balance), 268 CodeHash: codeHash[:], 269 }, nil 270 } 271 if _, ok := f.contracts[codeHash]; codeHash != types.EmptyCodeHash && !ok { 272 f.contracts[codeHash] = func() ([]byte, error) { 273 return contracts[codeHash], nil 274 } 275 } 276 for slot, val := range acc.Storage { 277 f.storage[storageKey{addr, (common.Hash)(slot)}] = func() (common.Hash, error) { 278 return (common.Hash)(val), nil 279 } 280 } 281 for blockNumber, hash := range headerHashes { 282 f.headerHashes[uint64(blockNumber)] = func() (common.Hash, error) { 283 return hash, nil 284 } 285 } 286 } 287 return nil 288 } 289 290 func (f *rpcFetcher) storeTestdataState(chainID uint64) (err error) { 291 if atomic.LoadUint32(&f.dirty) == 0 { 292 return nil // if no new state was fetched, we do not need to store it 293 } 294 295 // read fetcher state 296 f.mux.RLock() 297 f.mux2.RLock() 298 f.mux3.RLock() 299 defer f.mux.RUnlock() 300 defer f.mux2.RUnlock() 301 defer f.mux3.RUnlock() 302 303 var ( 304 state = make(testdataState) 305 contracts = make(testdataContracts) 306 headerHashes = make(testdataHeaderHashes) 307 ) 308 for addr, accFunc := range f.accounts { 309 acc, err := accFunc() 310 if err != nil { 311 continue 312 } 313 314 state[addr] = &testdataAccount{ 315 Nonce: hexutil.Uint64(acc.Nonce), 316 Balance: (*hexutil.U256)(acc.Balance), 317 } 318 if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) { 319 codeHash := common.BytesToHash(acc.CodeHash) 320 state[addr].CodeHash = codeHash 321 contracts[codeHash], _ = f.contracts[codeHash]() 322 } 323 } 324 325 for storageKey, storageValFunc := range f.storage { 326 storageVal, err := storageValFunc() 327 if err != nil { 328 continue 329 } 330 331 if _, ok := state[storageKey.addr]; !ok { 332 state[storageKey.addr] = &testdataAccount{ 333 Storage: make(map[w3hexutil.Hash]w3hexutil.Hash), 334 } 335 } else if state[storageKey.addr].Storage == nil { 336 state[storageKey.addr].Storage = make(map[w3hexutil.Hash]w3hexutil.Hash) 337 } 338 state[storageKey.addr].Storage[w3hexutil.Hash(storageKey.slot)] = w3hexutil.Hash(storageVal) 339 } 340 341 for blockNumber, hashFunc := range f.headerHashes { 342 hash, err := hashFunc() 343 if err != nil { 344 continue 345 } 346 headerHashes[hexutil.Uint64(blockNumber)] = hash 347 } 348 349 // lock testdata files 350 testdataMutex.Lock() 351 defer testdataMutex.Unlock() 352 testdataLock.Lock() 353 defer testdataLock.Unlock() 354 355 // load current testdata state 356 stateFn := fmt.Sprintf("%d_%v.json", chainID, f.blockNumber) 357 var otherState testdataState 358 if _, err = readTestdata(stateFn, &otherState, f.stateFileModTime); err != nil { 359 return err 360 } 361 362 var otherContracts testdataContracts 363 if _, err = readTestdata("contracts.json", &otherContracts, f.contractsFileModTime); err != nil { 364 return err 365 } 366 367 headerHashesFn := fmt.Sprintf("%d_header_hashes.json", chainID) 368 var otherHeaderHashes testdataHeaderHashes 369 if _, err = readTestdata(headerHashesFn, &otherHeaderHashes, f.headerHashesFileModTime); err != nil { 370 return err 371 } 372 373 // merge 374 if err := state.Merge(otherState); err != nil { 375 return fmt.Errorf("failed to merge testdata state: %w", err) 376 } 377 378 if err := contracts.Merge(otherContracts); err != nil { 379 return fmt.Errorf("failed to merge testdata contracts: %w", err) 380 } 381 382 if err := headerHashes.Merge(otherHeaderHashes); err != nil { 383 return fmt.Errorf("failed to merge testdata header hashes: %w", err) 384 } 385 386 // write testdata files 387 if err := writeTestdata(stateFn, state); err != nil { 388 return err 389 } 390 if err := writeTestdata("contracts.json", contracts); err != nil { 391 return err 392 } 393 if err := writeTestdata(headerHashesFn, headerHashes); err != nil { 394 return err 395 } 396 397 return nil 398 } 399 400 type storageKey struct { 401 addr common.Address 402 slot common.Hash 403 } 404 405 // testdataState maps accounts to their state at a specific block in a specific 406 // chain. 407 type testdataState map[common.Address]*testdataAccount 408 409 func (s testdataState) Merge(other testdataState) error { 410 for addr, otherAccount := range other { 411 if existingAccount, ok := s[addr]; ok { 412 if err := existingAccount.Merge(otherAccount); err != nil { 413 return fmt.Errorf("account conflict for address %s: %w", addr, err) 414 } 415 } else { 416 s[addr] = otherAccount 417 } 418 } 419 return nil 420 } 421 422 // testdataAccount represents the state of a single account. 423 type testdataAccount struct { 424 Nonce hexutil.Uint64 `json:"nonce"` 425 Balance *hexutil.U256 `json:"balance"` 426 CodeHash common.Hash `json:"codeHash,omitzero"` 427 Storage map[w3hexutil.Hash]w3hexutil.Hash `json:"storage,omitempty"` 428 } 429 430 func (a *testdataAccount) codeHash() common.Hash { 431 if a.CodeHash == w3.Hash0 { 432 return types.EmptyCodeHash 433 } 434 return a.CodeHash 435 } 436 437 func (a *testdataAccount) Merge(other *testdataAccount) error { 438 if a.Nonce != other.Nonce { 439 return fmt.Errorf("nonce conflict: %d != %d", a.Nonce, other.Nonce) 440 } 441 if (*uint256.Int)(a.Balance).Cmp((*uint256.Int)(other.Balance)) != 0 { 442 return fmt.Errorf("balance conflict: %s != %s", a.Balance, other.Balance) 443 } 444 if a.CodeHash != other.CodeHash { 445 return fmt.Errorf("code hash conflict: %s != %s", a.CodeHash, other.CodeHash) 446 } 447 448 // Merge storage maps 449 if a.Storage == nil { 450 a.Storage = make(map[w3hexutil.Hash]w3hexutil.Hash) 451 } 452 for slot, value := range other.Storage { 453 if existingValue, ok := a.Storage[slot]; ok { 454 if existingValue != value { 455 return fmt.Errorf("storage conflict at slot %s: %s != %s", 456 (common.Hash)(slot), (common.Hash)(existingValue), (common.Hash)(value), 457 ) 458 } 459 } else { 460 a.Storage[slot] = value 461 } 462 } 463 464 return nil 465 } 466 467 // testdataContracts maps code hashes to their code. 468 type testdataContracts map[common.Hash]hexutil.Bytes 469 470 func (c testdataContracts) Merge(other testdataContracts) error { 471 for hash, code := range other { 472 if existingCode, ok := c[hash]; ok { 473 if !bytes.Equal(existingCode, code) { 474 return fmt.Errorf("bytecode conflict for code hash %s", hash) 475 } 476 } else { 477 c[hash] = code 478 } 479 } 480 return nil 481 } 482 483 // testdataHeaderHashes maps block numbers to their hashes for a specific chain. 484 type testdataHeaderHashes map[hexutil.Uint64]common.Hash 485 486 func (h testdataHeaderHashes) Merge(other testdataHeaderHashes) error { 487 for blockNumber, hash := range other { 488 if existingHash, ok := h[blockNumber]; ok { 489 if existingHash != hash { 490 return fmt.Errorf("header hash conflict for block %d", blockNumber) 491 } 492 } else { 493 h[blockNumber] = hash 494 } 495 } 496 return nil 497 } 498 499 func readTestdata(filename string, data any, onlyIfModifiedAfter time.Time) (time.Time, error) { 500 path := testdataPath(filename) 501 502 // get file info first 503 info, err := os.Stat(path) 504 if errors.Is(err, os.ErrNotExist) { 505 return time.Time{}, nil 506 } else if err != nil { 507 return time.Time{}, err 508 } 509 510 if info.ModTime().Before(onlyIfModifiedAfter) { 511 return info.ModTime(), nil // file was NOT modified after "onlyIfModifiedAfter" 512 } 513 514 // open and read file 515 f, err := os.Open(path) 516 if err != nil { 517 return time.Time{}, err 518 } 519 defer f.Close() 520 521 if err := json.NewDecoder(f).Decode(data); err != nil { 522 return time.Time{}, fmt.Errorf("decode json %s: %w", filename, err) 523 } 524 return info.ModTime(), nil 525 } 526 527 func writeTestdata(filename string, data any) error { 528 path := testdataPath(filename) 529 530 // create "testdata/w3vm"-dir, if it does not exist yet 531 dir := filepath.Dir(path) 532 if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) { 533 if err := os.MkdirAll(dir, 0o775); err != nil { 534 return err 535 } 536 } 537 538 // create or open file 539 f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o664) 540 if err != nil { 541 return err 542 } 543 defer f.Close() 544 545 enc := json.NewEncoder(f) 546 enc.SetIndent("", "\t") 547 if err := enc.Encode(data); err != nil { 548 return fmt.Errorf("encode json %s: %w", filename, err) 549 } 550 return nil 551 } 552 553 func testdataPath(filename string) string { 554 return filepath.Join(mod.Root, "testdata", "w3vm", filename) 555 }