github.com/tacshi/go-ethereum@v0.0.0-20230616113857-84a434e20921/arbitrum/recordingdb.go (about)

     1  package arbitrum
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"sync"
    10  
    11  	"github.com/tacshi/go-ethereum/common"
    12  	"github.com/tacshi/go-ethereum/consensus"
    13  	"github.com/tacshi/go-ethereum/core"
    14  	"github.com/tacshi/go-ethereum/core/rawdb"
    15  	"github.com/tacshi/go-ethereum/core/state"
    16  	"github.com/tacshi/go-ethereum/core/types"
    17  	"github.com/tacshi/go-ethereum/core/vm"
    18  	"github.com/tacshi/go-ethereum/crypto"
    19  	"github.com/tacshi/go-ethereum/ethdb"
    20  	"github.com/tacshi/go-ethereum/log"
    21  	"github.com/tacshi/go-ethereum/rlp"
    22  	"github.com/tacshi/go-ethereum/trie"
    23  )
    24  
    25  type RecordingKV struct {
    26  	inner         *trie.Database
    27  	diskDb        ethdb.KeyValueStore
    28  	readDbEntries map[common.Hash][]byte
    29  	enableBypass  bool
    30  }
    31  
    32  func newRecordingKV(inner *trie.Database, diskDb ethdb.KeyValueStore) *RecordingKV {
    33  	return &RecordingKV{inner, diskDb, make(map[common.Hash][]byte), false}
    34  }
    35  
    36  func (db *RecordingKV) Has(key []byte) (bool, error) {
    37  	return false, errors.New("recording KV doesn't support Has")
    38  }
    39  
    40  func (db *RecordingKV) Get(key []byte) ([]byte, error) {
    41  	var hash common.Hash
    42  	var res []byte
    43  	var err error
    44  	if len(key) == 32 {
    45  		copy(hash[:], key)
    46  		res, err = db.inner.Node(hash)
    47  	} else if len(key) == len(rawdb.CodePrefix)+32 && bytes.HasPrefix(key, rawdb.CodePrefix) {
    48  		// Retrieving code
    49  		copy(hash[:], key[len(rawdb.CodePrefix):])
    50  		res, err = db.diskDb.Get(key)
    51  	} else {
    52  		err = fmt.Errorf("recording KV attempted to access non-hash key %v", hex.EncodeToString(key))
    53  	}
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	if db.enableBypass {
    58  		return res, nil
    59  	}
    60  	if crypto.Keccak256Hash(res) != hash {
    61  		return nil, fmt.Errorf("recording KV attempted to access non-hash key %v", hash)
    62  	}
    63  	db.readDbEntries[hash] = res
    64  	return res, nil
    65  }
    66  
    67  func (db *RecordingKV) Put(key []byte, value []byte) error {
    68  	return errors.New("recording KV doesn't support Put")
    69  }
    70  
    71  func (db *RecordingKV) Delete(key []byte) error {
    72  	return errors.New("recording KV doesn't support Delete")
    73  }
    74  
    75  func (db *RecordingKV) NewBatch() ethdb.Batch {
    76  	if db.enableBypass {
    77  		return db.diskDb.NewBatch()
    78  	}
    79  	log.Error("recording KV: attempted to create batch when bypass not enabled")
    80  	return nil
    81  }
    82  
    83  func (db *RecordingKV) NewBatchWithSize(size int) ethdb.Batch {
    84  	if db.enableBypass {
    85  		return db.diskDb.NewBatchWithSize(size)
    86  	}
    87  	log.Error("recording KV: attempted to create batch when bypass not enabled")
    88  	return nil
    89  }
    90  
    91  func (db *RecordingKV) NewIterator(prefix []byte, start []byte) ethdb.Iterator {
    92  	if db.enableBypass {
    93  		return db.diskDb.NewIterator(prefix, start)
    94  	}
    95  	log.Error("recording KV: attempted to create iterator when bypass not enabled")
    96  	return nil
    97  }
    98  
    99  func (db *RecordingKV) NewSnapshot() (ethdb.Snapshot, error) {
   100  	// This is fine as RecordingKV doesn't support mutation
   101  	return db, nil
   102  }
   103  
   104  func (db *RecordingKV) Stat(property string) (string, error) {
   105  	return "", errors.New("recording KV doesn't support Stat")
   106  }
   107  
   108  func (db *RecordingKV) Compact(start []byte, limit []byte) error {
   109  	return nil
   110  }
   111  
   112  func (db *RecordingKV) Close() error {
   113  	return nil
   114  }
   115  
   116  func (db *RecordingKV) Release() {}
   117  
   118  func (db *RecordingKV) GetRecordedEntries() map[common.Hash][]byte {
   119  	return db.readDbEntries
   120  }
   121  func (db *RecordingKV) EnableBypass() {
   122  	db.enableBypass = true
   123  }
   124  
   125  type RecordingChainContext struct {
   126  	bc                     core.ChainContext
   127  	minBlockNumberAccessed uint64
   128  	initialBlockNumber     uint64
   129  }
   130  
   131  func newRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext {
   132  	return &RecordingChainContext{
   133  		bc:                     inner,
   134  		minBlockNumberAccessed: blocknumber,
   135  		initialBlockNumber:     blocknumber,
   136  	}
   137  }
   138  
   139  func (r *RecordingChainContext) Engine() consensus.Engine {
   140  	return r.bc.Engine()
   141  }
   142  
   143  func (r *RecordingChainContext) GetHeader(hash common.Hash, num uint64) *types.Header {
   144  	if num < r.minBlockNumberAccessed {
   145  		r.minBlockNumberAccessed = num
   146  	}
   147  	return r.bc.GetHeader(hash, num)
   148  }
   149  
   150  func (r *RecordingChainContext) GetMinBlockNumberAccessed() uint64 {
   151  	return r.minBlockNumberAccessed
   152  }
   153  
   154  type RecordingDatabase struct {
   155  	db         state.Database
   156  	bc         *core.BlockChain
   157  	mutex      sync.Mutex // protects StateFor and Dereference
   158  	references int64
   159  }
   160  
   161  func NewRecordingDatabase(ethdb ethdb.Database, blockchain *core.BlockChain) *RecordingDatabase {
   162  	return &RecordingDatabase{
   163  		db: state.NewDatabaseWithConfig(ethdb, &trie.Config{Cache: 16}), //TODO cache needed? configurable?
   164  		bc: blockchain,
   165  	}
   166  }
   167  
   168  // Normal geth state.New + Reference is not atomic vs Dereference. This one is.
   169  // This function does not recreate a state
   170  func (r *RecordingDatabase) StateFor(header *types.Header) (*state.StateDB, error) {
   171  	r.mutex.Lock()
   172  	defer r.mutex.Unlock()
   173  
   174  	sdb, err := state.NewDeterministic(header.Root, r.db)
   175  	if err == nil {
   176  		r.referenceRootLockHeld(header.Root)
   177  	}
   178  	return sdb, err
   179  }
   180  
   181  func (r *RecordingDatabase) Dereference(header *types.Header) {
   182  	if header != nil {
   183  		r.dereferenceRoot(header.Root)
   184  	}
   185  }
   186  
   187  func (r *RecordingDatabase) WriteStateToDatabase(header *types.Header) error {
   188  	if header != nil {
   189  		return r.db.TrieDB().Commit(header.Root, true)
   190  	}
   191  	return nil
   192  }
   193  
   194  // lock must be held when calling that
   195  func (r *RecordingDatabase) referenceRootLockHeld(root common.Hash) {
   196  	r.references++
   197  	r.db.TrieDB().Reference(root, common.Hash{})
   198  }
   199  
   200  func (r *RecordingDatabase) dereferenceRoot(root common.Hash) {
   201  	r.mutex.Lock()
   202  	defer r.mutex.Unlock()
   203  	r.references--
   204  	r.db.TrieDB().Dereference(root)
   205  }
   206  
   207  func (r *RecordingDatabase) addStateVerify(statedb *state.StateDB, expected common.Hash) error {
   208  	r.mutex.Lock()
   209  	defer r.mutex.Unlock()
   210  	result, err := statedb.Commit(true)
   211  	if err != nil {
   212  		return err
   213  	}
   214  	if result != expected {
   215  		return fmt.Errorf("bad root hash expected: %v got: %v", expected, result)
   216  	}
   217  	r.referenceRootLockHeld(result)
   218  	return nil
   219  }
   220  
   221  type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool)
   222  
   223  func (r *RecordingDatabase) PrepareRecording(ctx context.Context, lastBlockHeader *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, core.ChainContext, *RecordingKV, error) {
   224  	_, err := r.GetOrRecreateState(ctx, lastBlockHeader, logFunc)
   225  	if err != nil {
   226  		return nil, nil, nil, err
   227  	}
   228  	finalDereference := lastBlockHeader // dereference in case of error
   229  	defer func() { r.Dereference(finalDereference) }()
   230  	recordingKeyValue := newRecordingKV(r.db.TrieDB(), r.db.DiskDB())
   231  
   232  	recordingStateDatabase := state.NewDatabase(rawdb.NewDatabase(recordingKeyValue))
   233  	var prevRoot common.Hash
   234  	if lastBlockHeader != nil {
   235  		prevRoot = lastBlockHeader.Root
   236  	}
   237  	recordingStateDb, err := state.NewDeterministic(prevRoot, recordingStateDatabase)
   238  	if err != nil {
   239  		return nil, nil, nil, fmt.Errorf("failed to create recordingStateDb: %w", err)
   240  	}
   241  	var recordingChainContext *RecordingChainContext
   242  	if lastBlockHeader != nil {
   243  		if !lastBlockHeader.Number.IsUint64() {
   244  			return nil, nil, nil, errors.New("block number not uint64")
   245  		}
   246  		recordingChainContext = newRecordingChainContext(r.bc, lastBlockHeader.Number.Uint64())
   247  	}
   248  	finalDereference = nil
   249  	return recordingStateDb, recordingChainContext, recordingKeyValue, nil
   250  }
   251  
   252  func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) {
   253  	entries := recordingDb.GetRecordedEntries()
   254  	recordingChainContext, ok := chainContextIf.(*RecordingChainContext)
   255  	if (recordingChainContext == nil) || (!ok) {
   256  		return nil, errors.New("recordingChainContext invalid")
   257  	}
   258  
   259  	for i := recordingChainContext.GetMinBlockNumberAccessed(); i <= recordingChainContext.initialBlockNumber; i++ {
   260  		header := r.bc.GetHeaderByNumber(i)
   261  		hash := header.Hash()
   262  		bytes, err := rlp.EncodeToBytes(header)
   263  		if err != nil {
   264  			return nil, fmt.Errorf("Error RLP encoding header: %v\n", err)
   265  		}
   266  		entries[hash] = bytes
   267  	}
   268  	return entries, nil
   269  }
   270  
   271  func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) {
   272  	stateDb, err := r.StateFor(header)
   273  	if err == nil {
   274  		return stateDb, nil
   275  	}
   276  	returnedBlockNumber := header.Number.Uint64()
   277  	genesis := r.bc.Config().ArbitrumChainParams.GenesisBlockNum
   278  	currentHeader := header
   279  	var lastRoot common.Hash
   280  	for ctx.Err() == nil {
   281  		if logFunc != nil {
   282  			logFunc(header, currentHeader, false)
   283  		}
   284  		if currentHeader.Number.Uint64() <= genesis {
   285  			return nil, fmt.Errorf("moved beyond genesis looking for state looking for %d, genesis %d, err %w", returnedBlockNumber, genesis, err)
   286  		}
   287  		lastHeader := currentHeader
   288  		currentHeader = r.bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1)
   289  		if currentHeader == nil {
   290  			return nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v (expected parent hash %v)", lastHeader.Number, lastHeader.Hash(), lastHeader.ParentHash)
   291  		}
   292  		stateDb, err = r.StateFor(currentHeader)
   293  		if err == nil {
   294  			lastRoot = currentHeader.Root
   295  			break
   296  		}
   297  	}
   298  	defer func() {
   299  		if (lastRoot != common.Hash{}) {
   300  			r.dereferenceRoot(lastRoot)
   301  		}
   302  	}()
   303  	blockToRecreate := currentHeader.Number.Uint64() + 1
   304  	prevHash := currentHeader.Hash()
   305  	for ctx.Err() == nil {
   306  		block := r.bc.GetBlockByNumber(blockToRecreate)
   307  		if block == nil {
   308  			return nil, fmt.Errorf("block not found while recreating: %d", blockToRecreate)
   309  		}
   310  		if block.ParentHash() != prevHash {
   311  			return nil, fmt.Errorf("reorg detected: number %d expectedPrev: %v foundPrev: %v", blockToRecreate, prevHash, block.ParentHash())
   312  		}
   313  		prevHash = block.Hash()
   314  		if logFunc != nil {
   315  			logFunc(header, block.Header(), true)
   316  		}
   317  		_, _, _, err := r.bc.Processor().Process(block, stateDb, vm.Config{})
   318  		if err != nil {
   319  			return nil, fmt.Errorf("failed recreating state for block %d : %w", blockToRecreate, err)
   320  		}
   321  		err = r.addStateVerify(stateDb, block.Root())
   322  		if err != nil {
   323  			return nil, fmt.Errorf("failed committing state for block %d : %w", blockToRecreate, err)
   324  		}
   325  		r.dereferenceRoot(lastRoot)
   326  		lastRoot = block.Root()
   327  		if blockToRecreate >= returnedBlockNumber {
   328  			if block.Hash() != header.Hash() {
   329  				return nil, fmt.Errorf("blockHash doesn't match when recreating number: %d expected: %v got: %v", blockToRecreate, header.Hash(), block.Hash())
   330  			}
   331  			// don't dereference this one
   332  			lastRoot = common.Hash{}
   333  			return stateDb, nil
   334  		}
   335  		blockToRecreate++
   336  	}
   337  	return nil, ctx.Err()
   338  }
   339  
   340  func (r *RecordingDatabase) ReferenceCount() int64 {
   341  	return r.references
   342  }