github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/fvm/evm/emulator/state/stateDB.go (about)

     1  package state
     2  
     3  import (
     4  	"bytes"
     5  	stdErrors "errors"
     6  	"fmt"
     7  	"math/big"
     8  	"sort"
     9  
    10  	"github.com/onflow/atree"
    11  	gethCommon "github.com/onflow/go-ethereum/common"
    12  	gethTypes "github.com/onflow/go-ethereum/core/types"
    13  	gethParams "github.com/onflow/go-ethereum/params"
    14  
    15  	"github.com/onflow/flow-go/fvm/evm/types"
    16  	"github.com/onflow/flow-go/model/flow"
    17  )
    18  
    19  // StateDB implements a types.StateDB interface
    20  //
    21  // stateDB interface defined by the Geth doesn't support returning errors
    22  // when state calls are happening, and requires stateDB to cache the error
    23  // and return it at a later time (when commit is called). Only the first error
    24  // is expected to be returned.
    25  // Warning: current implementation of the StateDB is considered
    26  // to be used for a single EVM transaction execution and is not
    27  // thread safe. yet the current design supports addition of concurrency in the
    28  // future if needed
    29  type StateDB struct {
    30  	ledger      atree.Ledger
    31  	root        flow.Address
    32  	baseView    types.BaseView
    33  	views       []*DeltaView
    34  	cachedError error
    35  }
    36  
    37  var _ types.StateDB = &StateDB{}
    38  
    39  // NewStateDB constructs a new StateDB
    40  func NewStateDB(ledger atree.Ledger, root flow.Address) (*StateDB, error) {
    41  	bv, err := NewBaseView(ledger, root)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	return &StateDB{
    46  		ledger:      ledger,
    47  		root:        root,
    48  		baseView:    bv,
    49  		views:       []*DeltaView{NewDeltaView(bv)},
    50  		cachedError: nil,
    51  	}, nil
    52  }
    53  
    54  // Exist returns true if the given address exists in state.
    55  //
    56  // this should also return true for self destructed accounts during the transaction execution.
    57  func (db *StateDB) Exist(addr gethCommon.Address) bool {
    58  	exist, err := db.lastestView().Exist(addr)
    59  	db.handleError(err)
    60  	return exist
    61  }
    62  
    63  // Empty returns whether the given account is empty.
    64  //
    65  // Empty is defined according to EIP161 (balance = nonce = code = 0).
    66  func (db *StateDB) Empty(addr gethCommon.Address) bool {
    67  	if !db.Exist(addr) {
    68  		return true
    69  	}
    70  	return db.GetNonce(addr) == 0 &&
    71  		db.GetBalance(addr).Sign() == 0 &&
    72  		bytes.Equal(db.GetCodeHash(addr).Bytes(), gethTypes.EmptyCodeHash.Bytes())
    73  }
    74  
    75  // CreateAccount creates a new account for the given address
    76  // it sets the nonce to zero
    77  func (db *StateDB) CreateAccount(addr gethCommon.Address) {
    78  	err := db.lastestView().CreateAccount(addr)
    79  	db.handleError(err)
    80  }
    81  
    82  // IsCreated returns true if address is recently created (context of a transaction)
    83  func (db *StateDB) IsCreated(addr gethCommon.Address) bool {
    84  	return db.lastestView().IsCreated(addr)
    85  }
    86  
    87  // SelfDestruct flags the address for deletion.
    88  //
    89  // while this address exists for the rest of transaction,
    90  // the balance of this account is return zero after the SelfDestruct call.
    91  func (db *StateDB) SelfDestruct(addr gethCommon.Address) {
    92  	err := db.lastestView().SelfDestruct(addr)
    93  	db.handleError(err)
    94  }
    95  
    96  // Selfdestruct6780 would only follow the self destruct steps if account is created
    97  func (db *StateDB) Selfdestruct6780(addr gethCommon.Address) {
    98  	if db.IsCreated(addr) {
    99  		db.SelfDestruct(addr)
   100  	}
   101  }
   102  
   103  // HasSelfDestructed returns true if address is flaged with self destruct.
   104  func (db *StateDB) HasSelfDestructed(addr gethCommon.Address) bool {
   105  	destructed, _ := db.lastestView().HasSelfDestructed(addr)
   106  	return destructed
   107  }
   108  
   109  // SubBalance substitutes the amount from the balance of the given address
   110  func (db *StateDB) SubBalance(addr gethCommon.Address, amount *big.Int) {
   111  	err := db.lastestView().SubBalance(addr, amount)
   112  	db.handleError(err)
   113  }
   114  
   115  // AddBalance adds the amount from the balance of the given address
   116  func (db *StateDB) AddBalance(addr gethCommon.Address, amount *big.Int) {
   117  	err := db.lastestView().AddBalance(addr, amount)
   118  	db.handleError(err)
   119  }
   120  
   121  // GetBalance returns the balance of the given address
   122  func (db *StateDB) GetBalance(addr gethCommon.Address) *big.Int {
   123  	bal, err := db.lastestView().GetBalance(addr)
   124  	db.handleError(err)
   125  	return bal
   126  }
   127  
   128  // GetNonce returns the nonce of the given address
   129  func (db *StateDB) GetNonce(addr gethCommon.Address) uint64 {
   130  	nonce, err := db.lastestView().GetNonce(addr)
   131  	db.handleError(err)
   132  	return nonce
   133  }
   134  
   135  // SetNonce sets the nonce value for the given address
   136  func (db *StateDB) SetNonce(addr gethCommon.Address, nonce uint64) {
   137  	err := db.lastestView().SetNonce(addr, nonce)
   138  	db.handleError(err)
   139  }
   140  
   141  // GetCodeHash returns the code hash of the given address
   142  func (db *StateDB) GetCodeHash(addr gethCommon.Address) gethCommon.Hash {
   143  	hash, err := db.lastestView().GetCodeHash(addr)
   144  	db.handleError(err)
   145  	return hash
   146  }
   147  
   148  // GetCode returns the code for the given address
   149  func (db *StateDB) GetCode(addr gethCommon.Address) []byte {
   150  	code, err := db.lastestView().GetCode(addr)
   151  	db.handleError(err)
   152  	return code
   153  }
   154  
   155  // GetCodeSize returns the size of the code for the given address
   156  func (db *StateDB) GetCodeSize(addr gethCommon.Address) int {
   157  	codeSize, err := db.lastestView().GetCodeSize(addr)
   158  	db.handleError(err)
   159  	return codeSize
   160  }
   161  
   162  // SetCode sets the code for the given address
   163  func (db *StateDB) SetCode(addr gethCommon.Address, code []byte) {
   164  	err := db.lastestView().SetCode(addr, code)
   165  	db.handleError(err)
   166  }
   167  
   168  // AddRefund adds the amount to the total (gas) refund
   169  func (db *StateDB) AddRefund(amount uint64) {
   170  	err := db.lastestView().AddRefund(amount)
   171  	db.handleError(err)
   172  }
   173  
   174  // SubRefund subtracts the amount from the total (gas) refund
   175  func (db *StateDB) SubRefund(amount uint64) {
   176  	err := db.lastestView().SubRefund(amount)
   177  	db.handleError(err)
   178  }
   179  
   180  // GetRefund returns the total (gas) refund
   181  func (db *StateDB) GetRefund() uint64 {
   182  	return db.lastestView().GetRefund()
   183  }
   184  
   185  // GetCommittedState returns the value for the given storage slot considering only the commited state and not
   186  // changes in the scope of current transaction.
   187  func (db *StateDB) GetCommittedState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   188  	value, err := db.baseView.GetState(types.SlotAddress{Address: addr, Key: key})
   189  	db.handleError(err)
   190  	return value
   191  }
   192  
   193  // GetState returns the value for the given storage slot
   194  func (db *StateDB) GetState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   195  	state, err := db.lastestView().GetState(types.SlotAddress{Address: addr, Key: key})
   196  	db.handleError(err)
   197  	return state
   198  }
   199  
   200  // SetState sets a value for the given storage slot
   201  func (db *StateDB) SetState(addr gethCommon.Address, key gethCommon.Hash, value gethCommon.Hash) {
   202  	err := db.lastestView().SetState(types.SlotAddress{Address: addr, Key: key}, value)
   203  	db.handleError(err)
   204  }
   205  
   206  // GetTransientState returns the value for the given key of the transient storage
   207  func (db *StateDB) GetTransientState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   208  	return db.lastestView().GetTransientState(types.SlotAddress{Address: addr, Key: key})
   209  }
   210  
   211  // SetTransientState sets a value for the given key of the transient storage
   212  func (db *StateDB) SetTransientState(addr gethCommon.Address, key, value gethCommon.Hash) {
   213  	db.lastestView().SetTransientState(types.SlotAddress{Address: addr, Key: key}, value)
   214  }
   215  
   216  // AddressInAccessList checks if an address is in the access list
   217  func (db *StateDB) AddressInAccessList(addr gethCommon.Address) bool {
   218  	return db.lastestView().AddressInAccessList(addr)
   219  }
   220  
   221  // SlotInAccessList checks if the given (address,slot) is in the access list
   222  func (db *StateDB) SlotInAccessList(addr gethCommon.Address, key gethCommon.Hash) (addressOk bool, slotOk bool) {
   223  	return db.lastestView().SlotInAccessList(types.SlotAddress{Address: addr, Key: key})
   224  }
   225  
   226  // AddAddressToAccessList adds the given address to the access list.
   227  func (db *StateDB) AddAddressToAccessList(addr gethCommon.Address) {
   228  	db.lastestView().AddAddressToAccessList(addr)
   229  }
   230  
   231  // AddSlotToAccessList adds the given (address,slot) to the access list.
   232  func (db *StateDB) AddSlotToAccessList(addr gethCommon.Address, key gethCommon.Hash) {
   233  	db.lastestView().AddSlotToAccessList(types.SlotAddress{Address: addr, Key: key})
   234  }
   235  
   236  // AddLog appends a lot to the collection of logs
   237  func (db *StateDB) AddLog(log *gethTypes.Log) {
   238  	db.lastestView().AddLog(log)
   239  }
   240  
   241  // AddPreimage adds a preimage to the collection of preimages
   242  func (db *StateDB) AddPreimage(hash gethCommon.Hash, data []byte) {
   243  	db.lastestView().AddPreimage(hash, data)
   244  }
   245  
   246  // RevertToSnapshot reverts the changes until we reach the given snaptshot
   247  func (db *StateDB) RevertToSnapshot(index int) {
   248  	if index > len(db.views) {
   249  		db.cachedError = fmt.Errorf("invalid revert")
   250  		return
   251  	}
   252  	db.views = db.views[:index]
   253  }
   254  
   255  // Snapshot takes an snapshot of the state and returns an int
   256  // that can be used later for revert calls.
   257  func (db *StateDB) Snapshot() int {
   258  	newView := db.lastestView().NewChildView()
   259  	db.views = append(db.views, newView)
   260  	return len(db.views) - 1
   261  }
   262  
   263  // Logs returns the list of logs
   264  // it also update each log with the block and tx info
   265  func (db *StateDB) Logs(
   266  	blockNumber uint64,
   267  	txHash gethCommon.Hash,
   268  	txIndex uint,
   269  ) []*gethTypes.Log {
   270  	allLogs := make([]*gethTypes.Log, 0)
   271  	for _, view := range db.views {
   272  		for _, log := range view.Logs() {
   273  			log.BlockNumber = blockNumber
   274  			log.TxHash = txHash
   275  			log.TxIndex = txIndex
   276  			allLogs = append(allLogs, log)
   277  		}
   278  	}
   279  	return allLogs
   280  }
   281  
   282  // Preimages returns a set of preimages
   283  func (db *StateDB) Preimages() map[gethCommon.Hash][]byte {
   284  	preImages := make(map[gethCommon.Hash][]byte, 0)
   285  	for _, view := range db.views {
   286  		for k, v := range view.Preimages() {
   287  			preImages[k] = v
   288  		}
   289  	}
   290  	return preImages
   291  }
   292  
   293  // Commit commits state changes back to the underlying
   294  func (db *StateDB) Commit(finalize bool) error {
   295  	// return error if any has been acumulated
   296  	if db.cachedError != nil {
   297  		return wrapError(db.cachedError)
   298  	}
   299  
   300  	var err error
   301  
   302  	// iterate views and collect dirty addresses and slots
   303  	addresses := make(map[gethCommon.Address]struct{})
   304  	slots := make(map[types.SlotAddress]struct{})
   305  	for _, view := range db.views {
   306  		for key := range view.DirtyAddresses() {
   307  			addresses[key] = struct{}{}
   308  		}
   309  		for key := range view.DirtySlots() {
   310  			slots[key] = struct{}{}
   311  		}
   312  	}
   313  
   314  	// sort addresses
   315  	sortedAddresses := make([]gethCommon.Address, 0, len(addresses))
   316  	for addr := range addresses {
   317  		sortedAddresses = append(sortedAddresses, addr)
   318  	}
   319  
   320  	sort.Slice(sortedAddresses,
   321  		func(i, j int) bool {
   322  			return bytes.Compare(sortedAddresses[i][:], sortedAddresses[j][:]) < 0
   323  		})
   324  
   325  	// update accounts
   326  	for _, addr := range sortedAddresses {
   327  		deleted := false
   328  		// first we need to delete accounts
   329  		if db.HasSelfDestructed(addr) {
   330  			err = db.baseView.DeleteAccount(addr)
   331  			if err != nil {
   332  				return wrapError(err)
   333  			}
   334  			deleted = true
   335  		}
   336  		// then create new ones
   337  		// an account might be in a single transaction be deleted and recreated
   338  		if db.IsCreated(addr) {
   339  			err = db.baseView.CreateAccount(
   340  				addr,
   341  				db.GetBalance(addr),
   342  				db.GetNonce(addr),
   343  				db.GetCode(addr),
   344  				db.GetCodeHash(addr),
   345  			)
   346  			if err != nil {
   347  				return wrapError(err)
   348  			}
   349  			continue
   350  		}
   351  		if deleted {
   352  			continue
   353  		}
   354  		err = db.baseView.UpdateAccount(
   355  			addr,
   356  			db.GetBalance(addr),
   357  			db.GetNonce(addr),
   358  			db.GetCode(addr),
   359  			db.GetCodeHash(addr),
   360  		)
   361  		if err != nil {
   362  			return wrapError(err)
   363  		}
   364  	}
   365  
   366  	// sort slots
   367  	sortedSlots := make([]types.SlotAddress, 0, len(slots))
   368  	for slot := range slots {
   369  		sortedSlots = append(sortedSlots, slot)
   370  	}
   371  	sort.Slice(sortedSlots, func(i, j int) bool {
   372  		comp := bytes.Compare(sortedSlots[i].Address[:], sortedSlots[j].Address[:])
   373  		if comp == 0 {
   374  			return bytes.Compare(sortedSlots[i].Key[:], sortedSlots[j].Key[:]) < 0
   375  		}
   376  		return comp < 0
   377  	})
   378  
   379  	// update slots
   380  	for _, sk := range sortedSlots {
   381  		err = db.baseView.UpdateSlot(
   382  			sk,
   383  			db.GetState(sk.Address, sk.Key),
   384  		)
   385  		if err != nil {
   386  			return wrapError(err)
   387  		}
   388  	}
   389  
   390  	// don't purge views yet, people might call the logs etc
   391  	if finalize {
   392  		return db.Finalize()
   393  	}
   394  	return nil
   395  }
   396  
   397  // Finalize flushes all the changes
   398  // to the permanent storage
   399  func (db *StateDB) Finalize() error {
   400  	err := db.baseView.Commit()
   401  	return wrapError(err)
   402  }
   403  
   404  // Prepare is a highlevel logic that sadly is considered to be part of the
   405  // stateDB interface and not on the layers above.
   406  // based on parameters that are passed it updates accesslists
   407  func (db *StateDB) Prepare(rules gethParams.Rules, sender, coinbase gethCommon.Address, dest *gethCommon.Address, precompiles []gethCommon.Address, txAccesses gethTypes.AccessList) {
   408  	if rules.IsBerlin {
   409  		db.AddAddressToAccessList(sender)
   410  
   411  		if dest != nil {
   412  			db.AddAddressToAccessList(*dest)
   413  			// If it's a create-tx, the destination will be added inside egethVM.create
   414  		}
   415  		for _, addr := range precompiles {
   416  			db.AddAddressToAccessList(addr)
   417  		}
   418  		for _, el := range txAccesses {
   419  			db.AddAddressToAccessList(el.Address)
   420  			for _, key := range el.StorageKeys {
   421  				db.AddSlotToAccessList(el.Address, key)
   422  			}
   423  		}
   424  		if rules.IsShanghai { // EIP-3651: warm coinbase
   425  			db.AddAddressToAccessList(coinbase)
   426  		}
   427  	}
   428  }
   429  
   430  // Reset resets uncommitted changes and transient artifacts such as error, logs,
   431  // preimages, access lists, ...
   432  // The method is often called between execution of different transactions
   433  func (db *StateDB) Reset() {
   434  	db.views = []*DeltaView{NewDeltaView(db.baseView)}
   435  	db.cachedError = nil
   436  }
   437  
   438  // Error returns the memorized database failure occurred earlier.
   439  func (s *StateDB) Error() error {
   440  	return wrapError(s.cachedError)
   441  }
   442  
   443  func (db *StateDB) lastestView() *DeltaView {
   444  	return db.views[len(db.views)-1]
   445  }
   446  
   447  // set error captures the first non-nil error it is called with.
   448  func (db *StateDB) handleError(err error) {
   449  	if err == nil {
   450  		return
   451  	}
   452  	if db.cachedError == nil {
   453  		db.cachedError = err
   454  	}
   455  }
   456  
   457  func wrapError(err error) error {
   458  	if err == nil {
   459  		return nil
   460  	}
   461  
   462  	var atreeUserError *atree.UserError
   463  	// if is an atree user error
   464  	if stdErrors.As(err, &atreeUserError) {
   465  		return types.NewStateError(err)
   466  	}
   467  
   468  	var atreeFatalError *atree.FatalError
   469  	// if is a atree fatal error or
   470  	if stdErrors.As(err, &atreeFatalError) {
   471  		return types.NewFatalError(err)
   472  	}
   473  
   474  	// if is a fatal error
   475  	if types.IsAFatalError(err) {
   476  		return err
   477  	}
   478  
   479  	return types.NewStateError(err)
   480  }