github.com/onflow/flow-go@v0.33.17/fvm/evm/emulator/state/stateDB.go (about)

     1  package state
     2  
     3  import (
     4  	"bytes"
     5  	stdErrors "errors"
     6  	"fmt"
     7  	"math/big"
     8  	"sort"
     9  
    10  	gethCommon "github.com/ethereum/go-ethereum/common"
    11  	gethTypes "github.com/ethereum/go-ethereum/core/types"
    12  	gethParams "github.com/ethereum/go-ethereum/params"
    13  	"github.com/onflow/atree"
    14  
    15  	"github.com/onflow/flow-go/fvm/errors"
    16  	"github.com/onflow/flow-go/fvm/evm/types"
    17  	"github.com/onflow/flow-go/model/flow"
    18  )
    19  
    20  // StateDB implements a types.StateDB interface
    21  //
    22  // stateDB interface defined by the Geth doesn't support returning errors
    23  // when state calls are happening, and requires stateDB to cache the error
    24  // and return it at a later time (when commit is called). Only the first error
    25  // is expected to be returned.
    26  // Warning: current implementation of the StateDB is considered
    27  // to be used for a single EVM transaction execution and is not
    28  // thread safe. yet the current design supports addition of concurrency in the
    29  // future if needed
    30  type StateDB struct {
    31  	ledger      atree.Ledger
    32  	root        flow.Address
    33  	baseView    types.BaseView
    34  	views       []*DeltaView
    35  	cachedError error
    36  }
    37  
    38  var _ types.StateDB = &StateDB{}
    39  
    40  // NewStateDB constructs a new StateDB
    41  func NewStateDB(ledger atree.Ledger, root flow.Address) (*StateDB, error) {
    42  	bv, err := NewBaseView(ledger, root)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	return &StateDB{
    47  		ledger:      ledger,
    48  		root:        root,
    49  		baseView:    bv,
    50  		views:       []*DeltaView{NewDeltaView(bv)},
    51  		cachedError: nil,
    52  	}, nil
    53  }
    54  
    55  // Exist returns true if the given address exists in state.
    56  //
    57  // this should also return true for self destructed accounts during the transaction execution.
    58  func (db *StateDB) Exist(addr gethCommon.Address) bool {
    59  	exist, err := db.lastestView().Exist(addr)
    60  	db.handleError(err)
    61  	return exist
    62  }
    63  
    64  // Empty returns whether the given account is empty.
    65  //
    66  // Empty is defined according to EIP161 (balance = nonce = code = 0).
    67  func (db *StateDB) Empty(addr gethCommon.Address) bool {
    68  	if !db.Exist(addr) {
    69  		return true
    70  	}
    71  	return db.GetNonce(addr) == 0 &&
    72  		db.GetBalance(addr).Sign() == 0 &&
    73  		bytes.Equal(db.GetCodeHash(addr).Bytes(), gethTypes.EmptyCodeHash.Bytes())
    74  }
    75  
    76  // CreateAccount creates a new account for the given address
    77  // it sets the nonce to zero
    78  func (db *StateDB) CreateAccount(addr gethCommon.Address) {
    79  	err := db.lastestView().CreateAccount(addr)
    80  	db.handleError(err)
    81  }
    82  
    83  // IsCreated returns true if address is recently created (context of a transaction)
    84  func (db *StateDB) IsCreated(addr gethCommon.Address) bool {
    85  	return db.lastestView().IsCreated(addr)
    86  }
    87  
    88  // SelfDestruct flags the address for deletion.
    89  //
    90  // while this address exists for the rest of transaction,
    91  // the balance of this account is return zero after the SelfDestruct call.
    92  func (db *StateDB) SelfDestruct(addr gethCommon.Address) {
    93  	err := db.lastestView().SelfDestruct(addr)
    94  	db.handleError(err)
    95  }
    96  
    97  // Selfdestruct6780 would only follow the self destruct steps if account is created
    98  func (db *StateDB) Selfdestruct6780(addr gethCommon.Address) {
    99  	if db.IsCreated(addr) {
   100  		db.SelfDestruct(addr)
   101  	}
   102  }
   103  
   104  // HasSelfDestructed returns true if address is flaged with self destruct.
   105  func (db *StateDB) HasSelfDestructed(addr gethCommon.Address) bool {
   106  	destructed, _ := db.lastestView().HasSelfDestructed(addr)
   107  	return destructed
   108  }
   109  
   110  // SubBalance substitutes the amount from the balance of the given address
   111  func (db *StateDB) SubBalance(addr gethCommon.Address, amount *big.Int) {
   112  	err := db.lastestView().SubBalance(addr, amount)
   113  	db.handleError(err)
   114  }
   115  
   116  // AddBalance adds the amount from the balance of the given address
   117  func (db *StateDB) AddBalance(addr gethCommon.Address, amount *big.Int) {
   118  	err := db.lastestView().AddBalance(addr, amount)
   119  	db.handleError(err)
   120  }
   121  
   122  // GetBalance returns the balance of the given address
   123  func (db *StateDB) GetBalance(addr gethCommon.Address) *big.Int {
   124  	bal, err := db.lastestView().GetBalance(addr)
   125  	db.handleError(err)
   126  	return bal
   127  }
   128  
   129  // GetNonce returns the nonce of the given address
   130  func (db *StateDB) GetNonce(addr gethCommon.Address) uint64 {
   131  	nonce, err := db.lastestView().GetNonce(addr)
   132  	db.handleError(err)
   133  	return nonce
   134  }
   135  
   136  // SetNonce sets the nonce value for the given address
   137  func (db *StateDB) SetNonce(addr gethCommon.Address, nonce uint64) {
   138  	err := db.lastestView().SetNonce(addr, nonce)
   139  	db.handleError(err)
   140  }
   141  
   142  // GetCodeHash returns the code hash of the given address
   143  func (db *StateDB) GetCodeHash(addr gethCommon.Address) gethCommon.Hash {
   144  	hash, err := db.lastestView().GetCodeHash(addr)
   145  	db.handleError(err)
   146  	return hash
   147  }
   148  
   149  // GetCode returns the code for the given address
   150  func (db *StateDB) GetCode(addr gethCommon.Address) []byte {
   151  	code, err := db.lastestView().GetCode(addr)
   152  	db.handleError(err)
   153  	return code
   154  }
   155  
   156  // GetCodeSize returns the size of the code for the given address
   157  func (db *StateDB) GetCodeSize(addr gethCommon.Address) int {
   158  	codeSize, err := db.lastestView().GetCodeSize(addr)
   159  	db.handleError(err)
   160  	return codeSize
   161  }
   162  
   163  // SetCode sets the code for the given address
   164  func (db *StateDB) SetCode(addr gethCommon.Address, code []byte) {
   165  	err := db.lastestView().SetCode(addr, code)
   166  	db.handleError(err)
   167  }
   168  
   169  // AddRefund adds the amount to the total (gas) refund
   170  func (db *StateDB) AddRefund(amount uint64) {
   171  	err := db.lastestView().AddRefund(amount)
   172  	db.handleError(err)
   173  }
   174  
   175  // SubRefund subtracts the amount from the total (gas) refund
   176  func (db *StateDB) SubRefund(amount uint64) {
   177  	err := db.lastestView().SubRefund(amount)
   178  	db.handleError(err)
   179  }
   180  
   181  // GetRefund returns the total (gas) refund
   182  func (db *StateDB) GetRefund() uint64 {
   183  	return db.lastestView().GetRefund()
   184  }
   185  
   186  // GetCommittedState returns the value for the given storage slot considering only the commited state and not
   187  // changes in the scope of current transaction.
   188  func (db *StateDB) GetCommittedState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   189  	value, err := db.baseView.GetState(types.SlotAddress{Address: addr, Key: key})
   190  	db.handleError(err)
   191  	return value
   192  }
   193  
   194  // GetState returns the value for the given storage slot
   195  func (db *StateDB) GetState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   196  	state, err := db.lastestView().GetState(types.SlotAddress{Address: addr, Key: key})
   197  	db.handleError(err)
   198  	return state
   199  }
   200  
   201  // SetState sets a value for the given storage slot
   202  func (db *StateDB) SetState(addr gethCommon.Address, key gethCommon.Hash, value gethCommon.Hash) {
   203  	err := db.lastestView().SetState(types.SlotAddress{Address: addr, Key: key}, value)
   204  	db.handleError(err)
   205  }
   206  
   207  // GetTransientState returns the value for the given key of the transient storage
   208  func (db *StateDB) GetTransientState(addr gethCommon.Address, key gethCommon.Hash) gethCommon.Hash {
   209  	return db.lastestView().GetTransientState(types.SlotAddress{Address: addr, Key: key})
   210  }
   211  
   212  // SetTransientState sets a value for the given key of the transient storage
   213  func (db *StateDB) SetTransientState(addr gethCommon.Address, key, value gethCommon.Hash) {
   214  	db.lastestView().SetTransientState(types.SlotAddress{Address: addr, Key: key}, value)
   215  }
   216  
   217  // AddressInAccessList checks if an address is in the access list
   218  func (db *StateDB) AddressInAccessList(addr gethCommon.Address) bool {
   219  	return db.lastestView().AddressInAccessList(addr)
   220  }
   221  
   222  // SlotInAccessList checks if the given (address,slot) is in the access list
   223  func (db *StateDB) SlotInAccessList(addr gethCommon.Address, key gethCommon.Hash) (addressOk bool, slotOk bool) {
   224  	return db.lastestView().SlotInAccessList(types.SlotAddress{Address: addr, Key: key})
   225  }
   226  
   227  // AddAddressToAccessList adds the given address to the access list.
   228  func (db *StateDB) AddAddressToAccessList(addr gethCommon.Address) {
   229  	db.lastestView().AddAddressToAccessList(addr)
   230  }
   231  
   232  // AddSlotToAccessList adds the given (address,slot) to the access list.
   233  func (db *StateDB) AddSlotToAccessList(addr gethCommon.Address, key gethCommon.Hash) {
   234  	db.lastestView().AddSlotToAccessList(types.SlotAddress{Address: addr, Key: key})
   235  }
   236  
   237  // AddLog appends a lot to the collection of logs
   238  func (db *StateDB) AddLog(log *gethTypes.Log) {
   239  	db.lastestView().AddLog(log)
   240  }
   241  
   242  // AddPreimage adds a preimage to the collection of preimages
   243  func (db *StateDB) AddPreimage(hash gethCommon.Hash, data []byte) {
   244  	db.lastestView().AddPreimage(hash, data)
   245  }
   246  
   247  // RevertToSnapshot reverts the changes until we reach the given snaptshot
   248  func (db *StateDB) RevertToSnapshot(index int) {
   249  	if index > len(db.views) {
   250  		db.cachedError = fmt.Errorf("invalid revert")
   251  		return
   252  	}
   253  	db.views = db.views[:index]
   254  }
   255  
   256  // Snapshot takes an snapshot of the state and returns an int
   257  // that can be used later for revert calls.
   258  func (db *StateDB) Snapshot() int {
   259  	newView := db.lastestView().NewChildView()
   260  	db.views = append(db.views, newView)
   261  	return len(db.views) - 1
   262  }
   263  
   264  // Logs returns the list of logs
   265  // it also update each log with the block and tx info
   266  func (db *StateDB) Logs(
   267  	blockHash gethCommon.Hash,
   268  	blockNumber uint64,
   269  	txHash gethCommon.Hash,
   270  	txIndex uint,
   271  ) []*gethTypes.Log {
   272  	allLogs := make([]*gethTypes.Log, 0)
   273  	for _, view := range db.views {
   274  		for _, log := range view.Logs() {
   275  			log.BlockNumber = blockNumber
   276  			log.BlockHash = blockHash
   277  			log.TxHash = txHash
   278  			log.TxIndex = txIndex
   279  			allLogs = append(allLogs, log)
   280  		}
   281  	}
   282  	return allLogs
   283  }
   284  
   285  // Preimages returns a set of preimages
   286  func (db *StateDB) Preimages() map[gethCommon.Hash][]byte {
   287  	preImages := make(map[gethCommon.Hash][]byte, 0)
   288  	for _, view := range db.views {
   289  		for k, v := range view.Preimages() {
   290  			preImages[k] = v
   291  		}
   292  	}
   293  	return preImages
   294  }
   295  
   296  // Commit commits state changes back to the underlying
   297  func (db *StateDB) Commit() error {
   298  	// return error if any has been acumulated
   299  	if db.cachedError != nil {
   300  		return wrapError(db.cachedError)
   301  	}
   302  
   303  	var err error
   304  
   305  	// iterate views and collect dirty addresses and slots
   306  	addresses := make(map[gethCommon.Address]struct{})
   307  	slots := make(map[types.SlotAddress]struct{})
   308  	for _, view := range db.views {
   309  		for key := range view.DirtyAddresses() {
   310  			addresses[key] = struct{}{}
   311  		}
   312  		for key := range view.DirtySlots() {
   313  			slots[key] = struct{}{}
   314  		}
   315  	}
   316  
   317  	// sort addresses
   318  	sortedAddresses := make([]gethCommon.Address, 0, len(addresses))
   319  	for addr := range addresses {
   320  		sortedAddresses = append(sortedAddresses, addr)
   321  	}
   322  
   323  	sort.Slice(sortedAddresses,
   324  		func(i, j int) bool {
   325  			return bytes.Compare(sortedAddresses[i][:], sortedAddresses[j][:]) < 0
   326  		})
   327  
   328  	// update accounts
   329  	for _, addr := range sortedAddresses {
   330  		deleted := false
   331  		// first we need to delete accounts
   332  		if db.HasSelfDestructed(addr) {
   333  			err = db.baseView.DeleteAccount(addr)
   334  			if err != nil {
   335  				return wrapError(err)
   336  			}
   337  			deleted = true
   338  		}
   339  		// then create new ones
   340  		// an account might be in a single transaction be deleted and recreated
   341  		if db.IsCreated(addr) {
   342  			err = db.baseView.CreateAccount(
   343  				addr,
   344  				db.GetBalance(addr),
   345  				db.GetNonce(addr),
   346  				db.GetCode(addr),
   347  				db.GetCodeHash(addr),
   348  			)
   349  			if err != nil {
   350  				return wrapError(err)
   351  			}
   352  			continue
   353  		}
   354  		if deleted {
   355  			continue
   356  		}
   357  		err = db.baseView.UpdateAccount(
   358  			addr,
   359  			db.GetBalance(addr),
   360  			db.GetNonce(addr),
   361  			db.GetCode(addr),
   362  			db.GetCodeHash(addr),
   363  		)
   364  		if err != nil {
   365  			return wrapError(err)
   366  		}
   367  	}
   368  
   369  	// sort slots
   370  	sortedSlots := make([]types.SlotAddress, 0, len(slots))
   371  	for slot := range slots {
   372  		sortedSlots = append(sortedSlots, slot)
   373  	}
   374  	sort.Slice(sortedSlots, func(i, j int) bool {
   375  		comp := bytes.Compare(sortedSlots[i].Address[:], sortedSlots[j].Address[:])
   376  		if comp == 0 {
   377  			return bytes.Compare(sortedSlots[i].Key[:], sortedSlots[j].Key[:]) < 0
   378  		}
   379  		return comp < 0
   380  	})
   381  
   382  	// update slots
   383  	for _, sk := range sortedSlots {
   384  		err = db.baseView.UpdateSlot(
   385  			sk,
   386  			db.GetState(sk.Address, sk.Key),
   387  		)
   388  		if err != nil {
   389  			return wrapError(err)
   390  		}
   391  	}
   392  
   393  	// don't purge views yet, people might call the logs etc
   394  	err = db.baseView.Commit()
   395  	if err != nil {
   396  		return wrapError(err)
   397  	}
   398  	return nil
   399  }
   400  
   401  // Prepare is a highlevel logic that sadly is considered to be part of the
   402  // stateDB interface and not on the layers above.
   403  // based on parameters that are passed it updates accesslists
   404  func (db *StateDB) Prepare(rules gethParams.Rules, sender, coinbase gethCommon.Address, dest *gethCommon.Address, precompiles []gethCommon.Address, txAccesses gethTypes.AccessList) {
   405  	if rules.IsBerlin {
   406  		db.AddAddressToAccessList(sender)
   407  
   408  		if dest != nil {
   409  			db.AddAddressToAccessList(*dest)
   410  			// If it's a create-tx, the destination will be added inside egethVM.create
   411  		}
   412  		for _, addr := range precompiles {
   413  			db.AddAddressToAccessList(addr)
   414  		}
   415  		for _, el := range txAccesses {
   416  			db.AddAddressToAccessList(el.Address)
   417  			for _, key := range el.StorageKeys {
   418  				db.AddSlotToAccessList(el.Address, key)
   419  			}
   420  		}
   421  		if rules.IsShanghai { // EIP-3651: warm coinbase
   422  			db.AddAddressToAccessList(coinbase)
   423  		}
   424  	}
   425  }
   426  
   427  // Error returns the memorized database failure occurred earlier.
   428  func (s *StateDB) Error() error {
   429  	return wrapError(s.cachedError)
   430  }
   431  
   432  func (db *StateDB) lastestView() *DeltaView {
   433  	return db.views[len(db.views)-1]
   434  }
   435  
   436  // set error captures the first non-nil error it is called with.
   437  func (db *StateDB) handleError(err error) {
   438  	if err == nil {
   439  		return
   440  	}
   441  	if db.cachedError == nil {
   442  		db.cachedError = err
   443  	}
   444  }
   445  
   446  func wrapError(err error) error {
   447  	if err == nil {
   448  		return nil
   449  	}
   450  
   451  	var atreeUserError *atree.UserError
   452  	// if is an atree user error
   453  	if stdErrors.As(err, &atreeUserError) {
   454  		return types.NewStateError(err)
   455  	}
   456  
   457  	var atreeFatalError *atree.FatalError
   458  	// if is a atree fatal error or
   459  	if stdErrors.As(err, &atreeFatalError) {
   460  		return types.NewFatalError(err)
   461  	}
   462  
   463  	// if is fvm fatal error
   464  	if errors.IsFailure(err) {
   465  		return types.NewFatalError(err)
   466  	}
   467  
   468  	return types.NewStateError(err)
   469  }