github.com/MikyChow/arbitrum-go-ethereum@v0.0.0-20230306102812-078da49636de/arbitrum/recordingdb.go (about) 1 package arbitrum 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 "sync" 10 11 "github.com/MikyChow/arbitrum-go-ethereum/common" 12 "github.com/MikyChow/arbitrum-go-ethereum/consensus" 13 "github.com/MikyChow/arbitrum-go-ethereum/core" 14 "github.com/MikyChow/arbitrum-go-ethereum/core/rawdb" 15 "github.com/MikyChow/arbitrum-go-ethereum/core/state" 16 "github.com/MikyChow/arbitrum-go-ethereum/core/types" 17 "github.com/MikyChow/arbitrum-go-ethereum/core/vm" 18 "github.com/MikyChow/arbitrum-go-ethereum/crypto" 19 "github.com/MikyChow/arbitrum-go-ethereum/ethdb" 20 "github.com/MikyChow/arbitrum-go-ethereum/log" 21 "github.com/MikyChow/arbitrum-go-ethereum/rlp" 22 "github.com/MikyChow/arbitrum-go-ethereum/trie" 23 ) 24 25 type RecordingKV struct { 26 inner *trie.Database 27 readDbEntries map[common.Hash][]byte 28 enableBypass bool 29 } 30 31 func newRecordingKV(inner *trie.Database) *RecordingKV { 32 return &RecordingKV{inner, make(map[common.Hash][]byte), false} 33 } 34 35 func (db *RecordingKV) Has(key []byte) (bool, error) { 36 return false, errors.New("recording KV doesn't support Has") 37 } 38 39 func (db *RecordingKV) Get(key []byte) ([]byte, error) { 40 var hash common.Hash 41 var res []byte 42 var err error 43 if len(key) == 32 { 44 copy(hash[:], key) 45 res, err = db.inner.Node(hash) 46 } else if len(key) == len(rawdb.CodePrefix)+32 && bytes.HasPrefix(key, rawdb.CodePrefix) { 47 // Retrieving code 48 copy(hash[:], key[len(rawdb.CodePrefix):]) 49 res, err = db.inner.DiskDB().Get(key) 50 } else { 51 err = fmt.Errorf("recording KV attempted to access non-hash key %v", hex.EncodeToString(key)) 52 } 53 if err != nil { 54 return nil, err 55 } 56 if db.enableBypass { 57 return res, nil 58 } 59 if crypto.Keccak256Hash(res) != hash { 60 return nil, fmt.Errorf("recording KV attempted to access non-hash key %v", hash) 61 } 62 db.readDbEntries[hash] = res 63 return res, nil 64 } 65 66 func (db *RecordingKV) Put(key []byte, value []byte) error { 67 return errors.New("recording KV doesn't support Put") 68 } 69 70 func (db *RecordingKV) Delete(key []byte) error { 71 return errors.New("recording KV doesn't support Delete") 72 } 73 74 func (db *RecordingKV) NewBatch() ethdb.Batch { 75 if db.enableBypass { 76 return db.inner.DiskDB().NewBatch() 77 } 78 log.Error("recording KV: attempted to create batch when bypass not enabled") 79 return nil 80 } 81 82 func (db *RecordingKV) NewBatchWithSize(size int) ethdb.Batch { 83 if db.enableBypass { 84 return db.inner.DiskDB().NewBatchWithSize(size) 85 } 86 log.Error("recording KV: attempted to create batch when bypass not enabled") 87 return nil 88 } 89 90 func (db *RecordingKV) NewIterator(prefix []byte, start []byte) ethdb.Iterator { 91 if db.enableBypass { 92 return db.inner.DiskDB().NewIterator(prefix, start) 93 } 94 log.Error("recording KV: attempted to create iterator when bypass not enabled") 95 return nil 96 } 97 98 func (db *RecordingKV) NewSnapshot() (ethdb.Snapshot, error) { 99 // This is fine as RecordingKV doesn't support mutation 100 return db, nil 101 } 102 103 func (db *RecordingKV) Stat(property string) (string, error) { 104 return "", errors.New("recording KV doesn't support Stat") 105 } 106 107 func (db *RecordingKV) Compact(start []byte, limit []byte) error { 108 return nil 109 } 110 111 func (db *RecordingKV) Close() error { 112 return nil 113 } 114 115 func (db *RecordingKV) Release() { 116 return 117 } 118 119 func (db *RecordingKV) GetRecordedEntries() map[common.Hash][]byte { 120 return db.readDbEntries 121 } 122 func (db *RecordingKV) EnableBypass() { 123 db.enableBypass = true 124 } 125 126 type RecordingChainContext struct { 127 bc core.ChainContext 128 minBlockNumberAccessed uint64 129 initialBlockNumber uint64 130 } 131 132 func newRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext { 133 return &RecordingChainContext{ 134 bc: inner, 135 minBlockNumberAccessed: blocknumber, 136 initialBlockNumber: blocknumber, 137 } 138 } 139 140 func (r *RecordingChainContext) Engine() consensus.Engine { 141 return r.bc.Engine() 142 } 143 144 func (r *RecordingChainContext) GetHeader(hash common.Hash, num uint64) *types.Header { 145 if num < r.minBlockNumberAccessed { 146 r.minBlockNumberAccessed = num 147 } 148 return r.bc.GetHeader(hash, num) 149 } 150 151 func (r *RecordingChainContext) GetMinBlockNumberAccessed() uint64 { 152 return r.minBlockNumberAccessed 153 } 154 155 type RecordingDatabase struct { 156 db state.Database 157 bc *core.BlockChain 158 mutex sync.Mutex // protects StateFor and Dereference 159 references int64 160 } 161 162 func NewRecordingDatabase(ethdb ethdb.Database, blockchain *core.BlockChain) *RecordingDatabase { 163 return &RecordingDatabase{ 164 db: state.NewDatabaseWithConfig(ethdb, &trie.Config{Cache: 16}), //TODO cache needed? configurable? 165 bc: blockchain, 166 } 167 } 168 169 // Normal geth state.New + Reference is not atomic vs Dereference. This one is. 170 // This function does not recreate a state 171 func (r *RecordingDatabase) StateFor(header *types.Header) (*state.StateDB, error) { 172 r.mutex.Lock() 173 defer r.mutex.Unlock() 174 175 sdb, err := state.NewDeterministic(header.Root, r.db) 176 if err == nil { 177 r.referenceRootLockHeld(header.Root) 178 } 179 return sdb, err 180 } 181 182 func (r *RecordingDatabase) Dereference(header *types.Header) { 183 if header != nil { 184 r.dereferenceRoot(header.Root) 185 } 186 } 187 188 func (r *RecordingDatabase) WriteStateToDatabase(header *types.Header) error { 189 if header != nil { 190 return r.db.TrieDB().Commit(header.Root, true, nil) 191 } 192 return nil 193 } 194 195 // lock must be held when calling that 196 func (r *RecordingDatabase) referenceRootLockHeld(root common.Hash) { 197 r.references++ 198 r.db.TrieDB().Reference(root, common.Hash{}) 199 } 200 201 func (r *RecordingDatabase) dereferenceRoot(root common.Hash) { 202 r.mutex.Lock() 203 defer r.mutex.Unlock() 204 r.references-- 205 r.db.TrieDB().Dereference(root) 206 } 207 208 func (r *RecordingDatabase) addStateVerify(statedb *state.StateDB, expected common.Hash) error { 209 r.mutex.Lock() 210 defer r.mutex.Unlock() 211 result, err := statedb.Commit(true) 212 if err != nil { 213 return err 214 } 215 if result != expected { 216 return fmt.Errorf("bad root hash expected: %v got: %v", expected, result) 217 } 218 r.referenceRootLockHeld(result) 219 return nil 220 } 221 222 type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool) 223 224 func (r *RecordingDatabase) PrepareRecording(ctx context.Context, lastBlockHeader *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, core.ChainContext, *RecordingKV, error) { 225 _, err := r.GetOrRecreateState(ctx, lastBlockHeader, logFunc) 226 if err != nil { 227 return nil, nil, nil, err 228 } 229 finalDereference := lastBlockHeader // dereference in case of error 230 defer func() { r.Dereference(finalDereference) }() 231 recordingKeyValue := newRecordingKV(r.db.TrieDB()) 232 233 recordingStateDatabase := state.NewDatabase(rawdb.NewDatabase(recordingKeyValue)) 234 var prevRoot common.Hash 235 if lastBlockHeader != nil { 236 prevRoot = lastBlockHeader.Root 237 } 238 recordingStateDb, err := state.NewDeterministic(prevRoot, recordingStateDatabase) 239 if err != nil { 240 return nil, nil, nil, fmt.Errorf("failed to create recordingStateDb: %w", err) 241 } 242 var recordingChainContext *RecordingChainContext 243 if lastBlockHeader != nil { 244 if !lastBlockHeader.Number.IsUint64() { 245 return nil, nil, nil, errors.New("block number not uint64") 246 } 247 recordingChainContext = newRecordingChainContext(r.bc, lastBlockHeader.Number.Uint64()) 248 } 249 finalDereference = nil 250 return recordingStateDb, recordingChainContext, recordingKeyValue, nil 251 } 252 253 func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) { 254 entries := recordingDb.GetRecordedEntries() 255 recordingChainContext, ok := chainContextIf.(*RecordingChainContext) 256 if (recordingChainContext == nil) || (!ok) { 257 return nil, errors.New("recordingChainContext invalid") 258 } 259 260 for i := recordingChainContext.GetMinBlockNumberAccessed(); i <= recordingChainContext.initialBlockNumber; i++ { 261 header := r.bc.GetHeaderByNumber(i) 262 hash := header.Hash() 263 bytes, err := rlp.EncodeToBytes(header) 264 if err != nil { 265 return nil, fmt.Errorf("Error RLP encoding header: %v\n", err) 266 } 267 entries[hash] = bytes 268 } 269 return entries, nil 270 } 271 272 func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) { 273 stateDb, err := r.StateFor(header) 274 if err == nil { 275 return stateDb, nil 276 } 277 returnedBlockNumber := header.Number.Uint64() 278 genesis := r.bc.Config().ArbitrumChainParams.GenesisBlockNum 279 currentHeader := header 280 var lastRoot common.Hash 281 for ctx.Err() == nil { 282 if logFunc != nil { 283 logFunc(header, currentHeader, false) 284 } 285 if currentHeader.Number.Uint64() <= genesis { 286 return nil, fmt.Errorf("moved beyond genesis looking for state looking for %d, genesis %d, err %w", returnedBlockNumber, genesis, err) 287 } 288 lastHeader := currentHeader 289 currentHeader = r.bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1) 290 if currentHeader == nil { 291 return nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v (expected parent hash %v)", lastHeader.Number, lastHeader.Hash(), lastHeader.ParentHash) 292 } 293 stateDb, err = r.StateFor(currentHeader) 294 if err == nil { 295 lastRoot = currentHeader.Root 296 break 297 } 298 } 299 defer func() { 300 if (lastRoot != common.Hash{}) { 301 r.dereferenceRoot(lastRoot) 302 } 303 }() 304 blockToRecreate := currentHeader.Number.Uint64() + 1 305 prevHash := currentHeader.Hash() 306 for ctx.Err() == nil { 307 block := r.bc.GetBlockByNumber(blockToRecreate) 308 if block == nil { 309 return nil, fmt.Errorf("block not found while recreating: %d", blockToRecreate) 310 } 311 if block.ParentHash() != prevHash { 312 return nil, fmt.Errorf("reorg detected: number %d expectedPrev: %v foundPrev: %v", blockToRecreate, prevHash, block.ParentHash()) 313 } 314 prevHash = block.Hash() 315 if logFunc != nil { 316 logFunc(header, block.Header(), true) 317 } 318 _, _, _, err := r.bc.Processor().Process(block, stateDb, vm.Config{}) 319 if err != nil { 320 return nil, fmt.Errorf("failed recreating state for block %d : %w", blockToRecreate, err) 321 } 322 err = r.addStateVerify(stateDb, block.Root()) 323 if err != nil { 324 return nil, fmt.Errorf("failed commiting state for block %d : %w", blockToRecreate, err) 325 } 326 r.dereferenceRoot(lastRoot) 327 lastRoot = block.Root() 328 if blockToRecreate >= returnedBlockNumber { 329 if block.Hash() != header.Hash() { 330 return nil, fmt.Errorf("blockHash doesn't match when recreating number: %d expected: %v got: %v", blockToRecreate, header.Hash(), block.Hash()) 331 } 332 // don't dereference this one 333 lastRoot = common.Hash{} 334 return stateDb, nil 335 } 336 blockToRecreate++ 337 } 338 return nil, ctx.Err() 339 } 340 341 func (r *RecordingDatabase) ReferenceCount() int64 { 342 return r.references 343 }