github.com/tacshi/go-ethereum@v0.0.0-20230616113857-84a434e20921/arbitrum/recordingdb.go (about) 1 package arbitrum 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 "sync" 10 11 "github.com/tacshi/go-ethereum/common" 12 "github.com/tacshi/go-ethereum/consensus" 13 "github.com/tacshi/go-ethereum/core" 14 "github.com/tacshi/go-ethereum/core/rawdb" 15 "github.com/tacshi/go-ethereum/core/state" 16 "github.com/tacshi/go-ethereum/core/types" 17 "github.com/tacshi/go-ethereum/core/vm" 18 "github.com/tacshi/go-ethereum/crypto" 19 "github.com/tacshi/go-ethereum/ethdb" 20 "github.com/tacshi/go-ethereum/log" 21 "github.com/tacshi/go-ethereum/rlp" 22 "github.com/tacshi/go-ethereum/trie" 23 ) 24 25 type RecordingKV struct { 26 inner *trie.Database 27 diskDb ethdb.KeyValueStore 28 readDbEntries map[common.Hash][]byte 29 enableBypass bool 30 } 31 32 func newRecordingKV(inner *trie.Database, diskDb ethdb.KeyValueStore) *RecordingKV { 33 return &RecordingKV{inner, diskDb, make(map[common.Hash][]byte), false} 34 } 35 36 func (db *RecordingKV) Has(key []byte) (bool, error) { 37 return false, errors.New("recording KV doesn't support Has") 38 } 39 40 func (db *RecordingKV) Get(key []byte) ([]byte, error) { 41 var hash common.Hash 42 var res []byte 43 var err error 44 if len(key) == 32 { 45 copy(hash[:], key) 46 res, err = db.inner.Node(hash) 47 } else if len(key) == len(rawdb.CodePrefix)+32 && bytes.HasPrefix(key, rawdb.CodePrefix) { 48 // Retrieving code 49 copy(hash[:], key[len(rawdb.CodePrefix):]) 50 res, err = db.diskDb.Get(key) 51 } else { 52 err = fmt.Errorf("recording KV attempted to access non-hash key %v", hex.EncodeToString(key)) 53 } 54 if err != nil { 55 return nil, err 56 } 57 if db.enableBypass { 58 return res, nil 59 } 60 if crypto.Keccak256Hash(res) != hash { 61 return nil, fmt.Errorf("recording KV attempted to access non-hash key %v", hash) 62 } 63 db.readDbEntries[hash] = res 64 return res, nil 65 } 66 67 func (db *RecordingKV) Put(key []byte, value []byte) error { 68 return errors.New("recording KV doesn't support Put") 69 } 70 71 func (db *RecordingKV) Delete(key []byte) error { 72 return errors.New("recording KV doesn't support Delete") 73 } 74 75 func (db *RecordingKV) NewBatch() ethdb.Batch { 76 if db.enableBypass { 77 return db.diskDb.NewBatch() 78 } 79 log.Error("recording KV: attempted to create batch when bypass not enabled") 80 return nil 81 } 82 83 func (db *RecordingKV) NewBatchWithSize(size int) ethdb.Batch { 84 if db.enableBypass { 85 return db.diskDb.NewBatchWithSize(size) 86 } 87 log.Error("recording KV: attempted to create batch when bypass not enabled") 88 return nil 89 } 90 91 func (db *RecordingKV) NewIterator(prefix []byte, start []byte) ethdb.Iterator { 92 if db.enableBypass { 93 return db.diskDb.NewIterator(prefix, start) 94 } 95 log.Error("recording KV: attempted to create iterator when bypass not enabled") 96 return nil 97 } 98 99 func (db *RecordingKV) NewSnapshot() (ethdb.Snapshot, error) { 100 // This is fine as RecordingKV doesn't support mutation 101 return db, nil 102 } 103 104 func (db *RecordingKV) Stat(property string) (string, error) { 105 return "", errors.New("recording KV doesn't support Stat") 106 } 107 108 func (db *RecordingKV) Compact(start []byte, limit []byte) error { 109 return nil 110 } 111 112 func (db *RecordingKV) Close() error { 113 return nil 114 } 115 116 func (db *RecordingKV) Release() {} 117 118 func (db *RecordingKV) GetRecordedEntries() map[common.Hash][]byte { 119 return db.readDbEntries 120 } 121 func (db *RecordingKV) EnableBypass() { 122 db.enableBypass = true 123 } 124 125 type RecordingChainContext struct { 126 bc core.ChainContext 127 minBlockNumberAccessed uint64 128 initialBlockNumber uint64 129 } 130 131 func newRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext { 132 return &RecordingChainContext{ 133 bc: inner, 134 minBlockNumberAccessed: blocknumber, 135 initialBlockNumber: blocknumber, 136 } 137 } 138 139 func (r *RecordingChainContext) Engine() consensus.Engine { 140 return r.bc.Engine() 141 } 142 143 func (r *RecordingChainContext) GetHeader(hash common.Hash, num uint64) *types.Header { 144 if num < r.minBlockNumberAccessed { 145 r.minBlockNumberAccessed = num 146 } 147 return r.bc.GetHeader(hash, num) 148 } 149 150 func (r *RecordingChainContext) GetMinBlockNumberAccessed() uint64 { 151 return r.minBlockNumberAccessed 152 } 153 154 type RecordingDatabase struct { 155 db state.Database 156 bc *core.BlockChain 157 mutex sync.Mutex // protects StateFor and Dereference 158 references int64 159 } 160 161 func NewRecordingDatabase(ethdb ethdb.Database, blockchain *core.BlockChain) *RecordingDatabase { 162 return &RecordingDatabase{ 163 db: state.NewDatabaseWithConfig(ethdb, &trie.Config{Cache: 16}), //TODO cache needed? configurable? 164 bc: blockchain, 165 } 166 } 167 168 // Normal geth state.New + Reference is not atomic vs Dereference. This one is. 169 // This function does not recreate a state 170 func (r *RecordingDatabase) StateFor(header *types.Header) (*state.StateDB, error) { 171 r.mutex.Lock() 172 defer r.mutex.Unlock() 173 174 sdb, err := state.NewDeterministic(header.Root, r.db) 175 if err == nil { 176 r.referenceRootLockHeld(header.Root) 177 } 178 return sdb, err 179 } 180 181 func (r *RecordingDatabase) Dereference(header *types.Header) { 182 if header != nil { 183 r.dereferenceRoot(header.Root) 184 } 185 } 186 187 func (r *RecordingDatabase) WriteStateToDatabase(header *types.Header) error { 188 if header != nil { 189 return r.db.TrieDB().Commit(header.Root, true) 190 } 191 return nil 192 } 193 194 // lock must be held when calling that 195 func (r *RecordingDatabase) referenceRootLockHeld(root common.Hash) { 196 r.references++ 197 r.db.TrieDB().Reference(root, common.Hash{}) 198 } 199 200 func (r *RecordingDatabase) dereferenceRoot(root common.Hash) { 201 r.mutex.Lock() 202 defer r.mutex.Unlock() 203 r.references-- 204 r.db.TrieDB().Dereference(root) 205 } 206 207 func (r *RecordingDatabase) addStateVerify(statedb *state.StateDB, expected common.Hash) error { 208 r.mutex.Lock() 209 defer r.mutex.Unlock() 210 result, err := statedb.Commit(true) 211 if err != nil { 212 return err 213 } 214 if result != expected { 215 return fmt.Errorf("bad root hash expected: %v got: %v", expected, result) 216 } 217 r.referenceRootLockHeld(result) 218 return nil 219 } 220 221 type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool) 222 223 func (r *RecordingDatabase) PrepareRecording(ctx context.Context, lastBlockHeader *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, core.ChainContext, *RecordingKV, error) { 224 _, err := r.GetOrRecreateState(ctx, lastBlockHeader, logFunc) 225 if err != nil { 226 return nil, nil, nil, err 227 } 228 finalDereference := lastBlockHeader // dereference in case of error 229 defer func() { r.Dereference(finalDereference) }() 230 recordingKeyValue := newRecordingKV(r.db.TrieDB(), r.db.DiskDB()) 231 232 recordingStateDatabase := state.NewDatabase(rawdb.NewDatabase(recordingKeyValue)) 233 var prevRoot common.Hash 234 if lastBlockHeader != nil { 235 prevRoot = lastBlockHeader.Root 236 } 237 recordingStateDb, err := state.NewDeterministic(prevRoot, recordingStateDatabase) 238 if err != nil { 239 return nil, nil, nil, fmt.Errorf("failed to create recordingStateDb: %w", err) 240 } 241 var recordingChainContext *RecordingChainContext 242 if lastBlockHeader != nil { 243 if !lastBlockHeader.Number.IsUint64() { 244 return nil, nil, nil, errors.New("block number not uint64") 245 } 246 recordingChainContext = newRecordingChainContext(r.bc, lastBlockHeader.Number.Uint64()) 247 } 248 finalDereference = nil 249 return recordingStateDb, recordingChainContext, recordingKeyValue, nil 250 } 251 252 func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) { 253 entries := recordingDb.GetRecordedEntries() 254 recordingChainContext, ok := chainContextIf.(*RecordingChainContext) 255 if (recordingChainContext == nil) || (!ok) { 256 return nil, errors.New("recordingChainContext invalid") 257 } 258 259 for i := recordingChainContext.GetMinBlockNumberAccessed(); i <= recordingChainContext.initialBlockNumber; i++ { 260 header := r.bc.GetHeaderByNumber(i) 261 hash := header.Hash() 262 bytes, err := rlp.EncodeToBytes(header) 263 if err != nil { 264 return nil, fmt.Errorf("Error RLP encoding header: %v\n", err) 265 } 266 entries[hash] = bytes 267 } 268 return entries, nil 269 } 270 271 func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) { 272 stateDb, err := r.StateFor(header) 273 if err == nil { 274 return stateDb, nil 275 } 276 returnedBlockNumber := header.Number.Uint64() 277 genesis := r.bc.Config().ArbitrumChainParams.GenesisBlockNum 278 currentHeader := header 279 var lastRoot common.Hash 280 for ctx.Err() == nil { 281 if logFunc != nil { 282 logFunc(header, currentHeader, false) 283 } 284 if currentHeader.Number.Uint64() <= genesis { 285 return nil, fmt.Errorf("moved beyond genesis looking for state looking for %d, genesis %d, err %w", returnedBlockNumber, genesis, err) 286 } 287 lastHeader := currentHeader 288 currentHeader = r.bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1) 289 if currentHeader == nil { 290 return nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v (expected parent hash %v)", lastHeader.Number, lastHeader.Hash(), lastHeader.ParentHash) 291 } 292 stateDb, err = r.StateFor(currentHeader) 293 if err == nil { 294 lastRoot = currentHeader.Root 295 break 296 } 297 } 298 defer func() { 299 if (lastRoot != common.Hash{}) { 300 r.dereferenceRoot(lastRoot) 301 } 302 }() 303 blockToRecreate := currentHeader.Number.Uint64() + 1 304 prevHash := currentHeader.Hash() 305 for ctx.Err() == nil { 306 block := r.bc.GetBlockByNumber(blockToRecreate) 307 if block == nil { 308 return nil, fmt.Errorf("block not found while recreating: %d", blockToRecreate) 309 } 310 if block.ParentHash() != prevHash { 311 return nil, fmt.Errorf("reorg detected: number %d expectedPrev: %v foundPrev: %v", blockToRecreate, prevHash, block.ParentHash()) 312 } 313 prevHash = block.Hash() 314 if logFunc != nil { 315 logFunc(header, block.Header(), true) 316 } 317 _, _, _, err := r.bc.Processor().Process(block, stateDb, vm.Config{}) 318 if err != nil { 319 return nil, fmt.Errorf("failed recreating state for block %d : %w", blockToRecreate, err) 320 } 321 err = r.addStateVerify(stateDb, block.Root()) 322 if err != nil { 323 return nil, fmt.Errorf("failed committing state for block %d : %w", blockToRecreate, err) 324 } 325 r.dereferenceRoot(lastRoot) 326 lastRoot = block.Root() 327 if blockToRecreate >= returnedBlockNumber { 328 if block.Hash() != header.Hash() { 329 return nil, fmt.Errorf("blockHash doesn't match when recreating number: %d expected: %v got: %v", blockToRecreate, header.Hash(), block.Hash()) 330 } 331 // don't dereference this one 332 lastRoot = common.Hash{} 333 return stateDb, nil 334 } 335 blockToRecreate++ 336 } 337 return nil, ctx.Err() 338 } 339 340 func (r *RecordingDatabase) ReferenceCount() int64 { 341 return r.references 342 }