github.com/MikyChow/arbitrum-go-ethereum@v0.0.0-20230306102812-078da49636de/arbitrum/recordingdb.go (about)

     1  package arbitrum
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"sync"
    10  
    11  	"github.com/MikyChow/arbitrum-go-ethereum/common"
    12  	"github.com/MikyChow/arbitrum-go-ethereum/consensus"
    13  	"github.com/MikyChow/arbitrum-go-ethereum/core"
    14  	"github.com/MikyChow/arbitrum-go-ethereum/core/rawdb"
    15  	"github.com/MikyChow/arbitrum-go-ethereum/core/state"
    16  	"github.com/MikyChow/arbitrum-go-ethereum/core/types"
    17  	"github.com/MikyChow/arbitrum-go-ethereum/core/vm"
    18  	"github.com/MikyChow/arbitrum-go-ethereum/crypto"
    19  	"github.com/MikyChow/arbitrum-go-ethereum/ethdb"
    20  	"github.com/MikyChow/arbitrum-go-ethereum/log"
    21  	"github.com/MikyChow/arbitrum-go-ethereum/rlp"
    22  	"github.com/MikyChow/arbitrum-go-ethereum/trie"
    23  )
    24  
    25  type RecordingKV struct {
    26  	inner         *trie.Database
    27  	readDbEntries map[common.Hash][]byte
    28  	enableBypass  bool
    29  }
    30  
    31  func newRecordingKV(inner *trie.Database) *RecordingKV {
    32  	return &RecordingKV{inner, make(map[common.Hash][]byte), false}
    33  }
    34  
    35  func (db *RecordingKV) Has(key []byte) (bool, error) {
    36  	return false, errors.New("recording KV doesn't support Has")
    37  }
    38  
    39  func (db *RecordingKV) Get(key []byte) ([]byte, error) {
    40  	var hash common.Hash
    41  	var res []byte
    42  	var err error
    43  	if len(key) == 32 {
    44  		copy(hash[:], key)
    45  		res, err = db.inner.Node(hash)
    46  	} else if len(key) == len(rawdb.CodePrefix)+32 && bytes.HasPrefix(key, rawdb.CodePrefix) {
    47  		// Retrieving code
    48  		copy(hash[:], key[len(rawdb.CodePrefix):])
    49  		res, err = db.inner.DiskDB().Get(key)
    50  	} else {
    51  		err = fmt.Errorf("recording KV attempted to access non-hash key %v", hex.EncodeToString(key))
    52  	}
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	if db.enableBypass {
    57  		return res, nil
    58  	}
    59  	if crypto.Keccak256Hash(res) != hash {
    60  		return nil, fmt.Errorf("recording KV attempted to access non-hash key %v", hash)
    61  	}
    62  	db.readDbEntries[hash] = res
    63  	return res, nil
    64  }
    65  
    66  func (db *RecordingKV) Put(key []byte, value []byte) error {
    67  	return errors.New("recording KV doesn't support Put")
    68  }
    69  
    70  func (db *RecordingKV) Delete(key []byte) error {
    71  	return errors.New("recording KV doesn't support Delete")
    72  }
    73  
    74  func (db *RecordingKV) NewBatch() ethdb.Batch {
    75  	if db.enableBypass {
    76  		return db.inner.DiskDB().NewBatch()
    77  	}
    78  	log.Error("recording KV: attempted to create batch when bypass not enabled")
    79  	return nil
    80  }
    81  
    82  func (db *RecordingKV) NewBatchWithSize(size int) ethdb.Batch {
    83  	if db.enableBypass {
    84  		return db.inner.DiskDB().NewBatchWithSize(size)
    85  	}
    86  	log.Error("recording KV: attempted to create batch when bypass not enabled")
    87  	return nil
    88  }
    89  
    90  func (db *RecordingKV) NewIterator(prefix []byte, start []byte) ethdb.Iterator {
    91  	if db.enableBypass {
    92  		return db.inner.DiskDB().NewIterator(prefix, start)
    93  	}
    94  	log.Error("recording KV: attempted to create iterator when bypass not enabled")
    95  	return nil
    96  }
    97  
    98  func (db *RecordingKV) NewSnapshot() (ethdb.Snapshot, error) {
    99  	// This is fine as RecordingKV doesn't support mutation
   100  	return db, nil
   101  }
   102  
   103  func (db *RecordingKV) Stat(property string) (string, error) {
   104  	return "", errors.New("recording KV doesn't support Stat")
   105  }
   106  
   107  func (db *RecordingKV) Compact(start []byte, limit []byte) error {
   108  	return nil
   109  }
   110  
   111  func (db *RecordingKV) Close() error {
   112  	return nil
   113  }
   114  
   115  func (db *RecordingKV) Release() {
   116  	return
   117  }
   118  
   119  func (db *RecordingKV) GetRecordedEntries() map[common.Hash][]byte {
   120  	return db.readDbEntries
   121  }
   122  func (db *RecordingKV) EnableBypass() {
   123  	db.enableBypass = true
   124  }
   125  
   126  type RecordingChainContext struct {
   127  	bc                     core.ChainContext
   128  	minBlockNumberAccessed uint64
   129  	initialBlockNumber     uint64
   130  }
   131  
   132  func newRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext {
   133  	return &RecordingChainContext{
   134  		bc:                     inner,
   135  		minBlockNumberAccessed: blocknumber,
   136  		initialBlockNumber:     blocknumber,
   137  	}
   138  }
   139  
   140  func (r *RecordingChainContext) Engine() consensus.Engine {
   141  	return r.bc.Engine()
   142  }
   143  
   144  func (r *RecordingChainContext) GetHeader(hash common.Hash, num uint64) *types.Header {
   145  	if num < r.minBlockNumberAccessed {
   146  		r.minBlockNumberAccessed = num
   147  	}
   148  	return r.bc.GetHeader(hash, num)
   149  }
   150  
   151  func (r *RecordingChainContext) GetMinBlockNumberAccessed() uint64 {
   152  	return r.minBlockNumberAccessed
   153  }
   154  
   155  type RecordingDatabase struct {
   156  	db         state.Database
   157  	bc         *core.BlockChain
   158  	mutex      sync.Mutex // protects StateFor and Dereference
   159  	references int64
   160  }
   161  
   162  func NewRecordingDatabase(ethdb ethdb.Database, blockchain *core.BlockChain) *RecordingDatabase {
   163  	return &RecordingDatabase{
   164  		db: state.NewDatabaseWithConfig(ethdb, &trie.Config{Cache: 16}), //TODO cache needed? configurable?
   165  		bc: blockchain,
   166  	}
   167  }
   168  
   169  // Normal geth state.New + Reference is not atomic vs Dereference. This one is.
   170  // This function does not recreate a state
   171  func (r *RecordingDatabase) StateFor(header *types.Header) (*state.StateDB, error) {
   172  	r.mutex.Lock()
   173  	defer r.mutex.Unlock()
   174  
   175  	sdb, err := state.NewDeterministic(header.Root, r.db)
   176  	if err == nil {
   177  		r.referenceRootLockHeld(header.Root)
   178  	}
   179  	return sdb, err
   180  }
   181  
   182  func (r *RecordingDatabase) Dereference(header *types.Header) {
   183  	if header != nil {
   184  		r.dereferenceRoot(header.Root)
   185  	}
   186  }
   187  
   188  func (r *RecordingDatabase) WriteStateToDatabase(header *types.Header) error {
   189  	if header != nil {
   190  		return r.db.TrieDB().Commit(header.Root, true, nil)
   191  	}
   192  	return nil
   193  }
   194  
   195  // lock must be held when calling that
   196  func (r *RecordingDatabase) referenceRootLockHeld(root common.Hash) {
   197  	r.references++
   198  	r.db.TrieDB().Reference(root, common.Hash{})
   199  }
   200  
   201  func (r *RecordingDatabase) dereferenceRoot(root common.Hash) {
   202  	r.mutex.Lock()
   203  	defer r.mutex.Unlock()
   204  	r.references--
   205  	r.db.TrieDB().Dereference(root)
   206  }
   207  
   208  func (r *RecordingDatabase) addStateVerify(statedb *state.StateDB, expected common.Hash) error {
   209  	r.mutex.Lock()
   210  	defer r.mutex.Unlock()
   211  	result, err := statedb.Commit(true)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	if result != expected {
   216  		return fmt.Errorf("bad root hash expected: %v got: %v", expected, result)
   217  	}
   218  	r.referenceRootLockHeld(result)
   219  	return nil
   220  }
   221  
   222  type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool)
   223  
   224  func (r *RecordingDatabase) PrepareRecording(ctx context.Context, lastBlockHeader *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, core.ChainContext, *RecordingKV, error) {
   225  	_, err := r.GetOrRecreateState(ctx, lastBlockHeader, logFunc)
   226  	if err != nil {
   227  		return nil, nil, nil, err
   228  	}
   229  	finalDereference := lastBlockHeader // dereference in case of error
   230  	defer func() { r.Dereference(finalDereference) }()
   231  	recordingKeyValue := newRecordingKV(r.db.TrieDB())
   232  
   233  	recordingStateDatabase := state.NewDatabase(rawdb.NewDatabase(recordingKeyValue))
   234  	var prevRoot common.Hash
   235  	if lastBlockHeader != nil {
   236  		prevRoot = lastBlockHeader.Root
   237  	}
   238  	recordingStateDb, err := state.NewDeterministic(prevRoot, recordingStateDatabase)
   239  	if err != nil {
   240  		return nil, nil, nil, fmt.Errorf("failed to create recordingStateDb: %w", err)
   241  	}
   242  	var recordingChainContext *RecordingChainContext
   243  	if lastBlockHeader != nil {
   244  		if !lastBlockHeader.Number.IsUint64() {
   245  			return nil, nil, nil, errors.New("block number not uint64")
   246  		}
   247  		recordingChainContext = newRecordingChainContext(r.bc, lastBlockHeader.Number.Uint64())
   248  	}
   249  	finalDereference = nil
   250  	return recordingStateDb, recordingChainContext, recordingKeyValue, nil
   251  }
   252  
   253  func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) {
   254  	entries := recordingDb.GetRecordedEntries()
   255  	recordingChainContext, ok := chainContextIf.(*RecordingChainContext)
   256  	if (recordingChainContext == nil) || (!ok) {
   257  		return nil, errors.New("recordingChainContext invalid")
   258  	}
   259  
   260  	for i := recordingChainContext.GetMinBlockNumberAccessed(); i <= recordingChainContext.initialBlockNumber; i++ {
   261  		header := r.bc.GetHeaderByNumber(i)
   262  		hash := header.Hash()
   263  		bytes, err := rlp.EncodeToBytes(header)
   264  		if err != nil {
   265  			return nil, fmt.Errorf("Error RLP encoding header: %v\n", err)
   266  		}
   267  		entries[hash] = bytes
   268  	}
   269  	return entries, nil
   270  }
   271  
   272  func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) {
   273  	stateDb, err := r.StateFor(header)
   274  	if err == nil {
   275  		return stateDb, nil
   276  	}
   277  	returnedBlockNumber := header.Number.Uint64()
   278  	genesis := r.bc.Config().ArbitrumChainParams.GenesisBlockNum
   279  	currentHeader := header
   280  	var lastRoot common.Hash
   281  	for ctx.Err() == nil {
   282  		if logFunc != nil {
   283  			logFunc(header, currentHeader, false)
   284  		}
   285  		if currentHeader.Number.Uint64() <= genesis {
   286  			return nil, fmt.Errorf("moved beyond genesis looking for state looking for %d, genesis %d, err %w", returnedBlockNumber, genesis, err)
   287  		}
   288  		lastHeader := currentHeader
   289  		currentHeader = r.bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1)
   290  		if currentHeader == nil {
   291  			return nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v (expected parent hash %v)", lastHeader.Number, lastHeader.Hash(), lastHeader.ParentHash)
   292  		}
   293  		stateDb, err = r.StateFor(currentHeader)
   294  		if err == nil {
   295  			lastRoot = currentHeader.Root
   296  			break
   297  		}
   298  	}
   299  	defer func() {
   300  		if (lastRoot != common.Hash{}) {
   301  			r.dereferenceRoot(lastRoot)
   302  		}
   303  	}()
   304  	blockToRecreate := currentHeader.Number.Uint64() + 1
   305  	prevHash := currentHeader.Hash()
   306  	for ctx.Err() == nil {
   307  		block := r.bc.GetBlockByNumber(blockToRecreate)
   308  		if block == nil {
   309  			return nil, fmt.Errorf("block not found while recreating: %d", blockToRecreate)
   310  		}
   311  		if block.ParentHash() != prevHash {
   312  			return nil, fmt.Errorf("reorg detected: number %d expectedPrev: %v foundPrev: %v", blockToRecreate, prevHash, block.ParentHash())
   313  		}
   314  		prevHash = block.Hash()
   315  		if logFunc != nil {
   316  			logFunc(header, block.Header(), true)
   317  		}
   318  		_, _, _, err := r.bc.Processor().Process(block, stateDb, vm.Config{})
   319  		if err != nil {
   320  			return nil, fmt.Errorf("failed recreating state for block %d : %w", blockToRecreate, err)
   321  		}
   322  		err = r.addStateVerify(stateDb, block.Root())
   323  		if err != nil {
   324  			return nil, fmt.Errorf("failed commiting state for block %d : %w", blockToRecreate, err)
   325  		}
   326  		r.dereferenceRoot(lastRoot)
   327  		lastRoot = block.Root()
   328  		if blockToRecreate >= returnedBlockNumber {
   329  			if block.Hash() != header.Hash() {
   330  				return nil, fmt.Errorf("blockHash doesn't match when recreating number: %d expected: %v got: %v", blockToRecreate, header.Hash(), block.Hash())
   331  			}
   332  			// don't dereference this one
   333  			lastRoot = common.Hash{}
   334  			return stateDb, nil
   335  		}
   336  		blockToRecreate++
   337  	}
   338  	return nil, ctx.Err()
   339  }
   340  
   341  func (r *RecordingDatabase) ReferenceCount() int64 {
   342  	return r.references
   343  }