github.com/aergoio/aergo@v1.3.1/state/statedb.go (about)

     1  /**
     2   *  @file
     3   *  @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package state
     7  
     8  import (
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"math/big"
    13  	"sync"
    14  
    15  	"github.com/aergoio/aergo-lib/db"
    16  	"github.com/aergoio/aergo-lib/log"
    17  	"github.com/aergoio/aergo/internal/common"
    18  	"github.com/aergoio/aergo/internal/enc"
    19  	"github.com/aergoio/aergo/pkg/trie"
    20  	"github.com/aergoio/aergo/types"
    21  )
    22  
    23  const (
    24  	stateName   = "state"
    25  	stateLatest = stateName + ".latest"
    26  )
    27  
    28  var (
    29  	stateMarker = []byte{0x54, 0x45} // marker: tail end
    30  )
    31  
    32  var (
    33  	logger = log.NewLogger(stateName)
    34  )
    35  
    36  var (
    37  	emptyHashID    = types.HashID{}
    38  	emptyBlockID   = types.BlockID{}
    39  	emptyAccountID = types.AccountID{}
    40  )
    41  
    42  var (
    43  	errSaveData = errors.New("Failed to save data: invalid key")
    44  	errLoadData = errors.New("Failed to load data: invalid key")
    45  
    46  	errLoadStateData = errors.New("Failed to load StateData: invalid HashID")
    47  	// errSaveStateData = errors.New("Failed to save StateData: invalid HashID")
    48  
    49  	// errInvalidArgs = errors.New("invalid arguments")
    50  	// errInvalidRoot = errors.New("invalid root")
    51  	// errSetRoot     = errors.New("Failed to set root: invalid root")
    52  	// errLoadRoot    = errors.New("Failed to load root: invalid root")
    53  
    54  	errGetState = errors.New("Failed to get state: invalid account id")
    55  	errPutState = errors.New("Failed to put state: invalid account id")
    56  )
    57  
    58  // StateDB manages trie of states
    59  type StateDB struct {
    60  	lock     sync.RWMutex
    61  	buffer   *stateBuffer
    62  	cache    *storageCache
    63  	trie     *trie.Trie
    64  	store    db.DB
    65  	batchtx  db.Transaction
    66  	testmode bool
    67  }
    68  
    69  // NewStateDB craete StateDB instance
    70  func NewStateDB(dbstore db.DB, root []byte, test bool) *StateDB {
    71  	sdb := StateDB{
    72  		buffer:   newStateBuffer(),
    73  		cache:    newStorageCache(),
    74  		trie:     trie.NewTrie(root, common.Hasher, dbstore),
    75  		store:    dbstore,
    76  		testmode: test,
    77  	}
    78  	return &sdb
    79  }
    80  
    81  // Clone returns a new StateDB which has same store and Root
    82  func (states *StateDB) Clone() *StateDB {
    83  	states.lock.RLock()
    84  	defer states.lock.RUnlock()
    85  
    86  	return NewStateDB(states.store, states.GetRoot(), states.testmode)
    87  }
    88  
    89  // GetRoot returns root hash of trie
    90  func (states *StateDB) GetRoot() []byte {
    91  	states.lock.RLock()
    92  	defer states.lock.RUnlock()
    93  	return states.trie.Root
    94  }
    95  
    96  // SetRoot updates root node of trie as a given root hash
    97  func (states *StateDB) SetRoot(root []byte) error {
    98  	states.lock.Lock()
    99  	defer states.lock.Unlock()
   100  	// update root node
   101  	states.trie.Root = root
   102  	// reset buffer
   103  	return states.buffer.reset()
   104  }
   105  
   106  // LoadCache reads first layer of trie given root hash
   107  // and also updates root node of trie as a given root hash
   108  func (states *StateDB) LoadCache(root []byte) error {
   109  	states.lock.Lock()
   110  	defer states.lock.Unlock()
   111  	// update root node and load cache
   112  	err := states.trie.LoadCache(root)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	// reset buffer
   117  	return states.buffer.reset()
   118  }
   119  
   120  // Revert rollbacks trie to previous root hash
   121  func (states *StateDB) Revert(root types.HashID) error {
   122  	states.lock.Lock()
   123  	defer states.lock.Unlock()
   124  	// // handle nil bytes
   125  	// targetRoot := root.Bytes()
   126  
   127  	// // revert trie
   128  	// err := states.trie.Revert(targetRoot)
   129  	// if err != nil {
   130  	// 	// when targetRoot is not contained in the cached tries.
   131  	// 	states.trie.Root = targetRoot
   132  	// }
   133  
   134  	// just update root node as targetRoot.
   135  	// revert trie consumes unnecessarily long time.
   136  	states.trie.Root = root.Bytes()
   137  
   138  	// reset buffer
   139  	return states.buffer.reset()
   140  }
   141  
   142  // PutState puts account id and its state into state buffer.
   143  func (states *StateDB) PutState(id types.AccountID, state *types.State) error {
   144  	states.lock.Lock()
   145  	defer states.lock.Unlock()
   146  	if id == emptyAccountID {
   147  		return errPutState
   148  	}
   149  	states.buffer.put(newValueEntry(types.HashID(id), state))
   150  	return nil
   151  }
   152  
   153  // GetAccountState gets state of account id from statedb.
   154  // empty state is returned when there is no state corresponding to account id.
   155  func (states *StateDB) GetAccountState(id types.AccountID) (*types.State, error) {
   156  	st, err := states.GetState(id)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	if st == nil {
   161  		if states.testmode {
   162  			amount := new(big.Int).Add(types.StakingMinimum, types.StakingMinimum)
   163  			return &types.State{Balance: amount.Bytes()}, nil
   164  		}
   165  		return &types.State{}, nil
   166  	}
   167  	return st, nil
   168  }
   169  
   170  type V struct {
   171  	sdb      *StateDB
   172  	id       []byte
   173  	aid      types.AccountID
   174  	oldV     *types.State
   175  	newV     *types.State
   176  	newOne   bool
   177  	deploy   int8
   178  	buffer   *stateBuffer
   179  }
   180  
   181  const (
   182  	deployFlag = 0x01 << iota
   183  	redeployFlag
   184  )
   185  
   186  func (v *V) ID() []byte {
   187  	if len(v.id) < types.AddressLength {
   188  		v.id = types.AddressPadding(v.id)
   189  	}
   190  	return v.id
   191  }
   192  
   193  func (v *V) AccountID() types.AccountID {
   194  	return v.aid
   195  }
   196  
   197  func (v *V) State() *types.State {
   198  	return v.newV
   199  }
   200  
   201  func (v *V) SetNonce(nonce uint64) {
   202  	v.newV.Nonce = nonce
   203  }
   204  
   205  func (v *V) Balance() *big.Int {
   206  	return new(big.Int).SetBytes(v.newV.Balance)
   207  }
   208  
   209  func (v *V) AddBalance(amount *big.Int) {
   210  	balance := new(big.Int).SetBytes(v.newV.Balance)
   211  	v.newV.Balance = new(big.Int).Add(balance, amount).Bytes()
   212  }
   213  
   214  func (v *V) SubBalance(amount *big.Int) {
   215  	balance := new(big.Int).SetBytes(v.newV.Balance)
   216  	v.newV.Balance = new(big.Int).Sub(balance, amount).Bytes()
   217  }
   218  
   219  func (v *V) RP() uint64 {
   220  	return v.newV.SqlRecoveryPoint
   221  }
   222  
   223  func (v *V) IsNew() bool {
   224  	return v.newOne
   225  }
   226  
   227  func (v *V) IsDeploy() bool {
   228  	return v.deploy & deployFlag != 0
   229  }
   230  
   231  func (v *V) SetRedeploy() {
   232  	v.deploy = deployFlag | redeployFlag
   233  }
   234  
   235  func (v *V) IsRedeploy() bool {
   236  	return v.deploy & redeployFlag != 0
   237  }
   238  
   239  func (v *V) Reset() {
   240  	*v.newV = types.State(*v.oldV)
   241  }
   242  
   243  func (v *V) PutState() error {
   244  	return v.sdb.PutState(v.aid, v.newV)
   245  }
   246  
   247  func (states *StateDB) CreateAccountStateV(id []byte) (*V, error) {
   248  	v, err := states.GetAccountStateV(id)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	if !v.newOne {
   253  		return nil, fmt.Errorf("account(%s) aleardy exists", types.EncodeAddress(v.ID()))
   254  	}
   255  	v.newV.SqlRecoveryPoint = 1
   256  	v.deploy = deployFlag
   257  	return v, nil
   258  }
   259  
   260  func (states *StateDB) GetAccountStateV(id []byte) (*V, error) {
   261  	aid := types.ToAccountID(id)
   262  	st, err := states.GetState(aid)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  	if st == nil {
   267  		if states.testmode {
   268  			amount := new(big.Int).Add(types.StakingMinimum, types.StakingMinimum)
   269  			return &V{
   270  				sdb:    states,
   271  				id:     id,
   272  				aid:    aid,
   273  				oldV:   &types.State{},
   274  				newV:   &types.State{Balance: amount.Bytes()},
   275  				newOne: true,
   276  			}, nil
   277  		}
   278  		return &V{
   279  			sdb:    states,
   280  			id:     id,
   281  			aid:    aid,
   282  			oldV:   &types.State{},
   283  			newV:   &types.State{},
   284  			newOne: true,
   285  		}, nil
   286  	}
   287  	newV := new(types.State)
   288  	*newV = types.State(*st)
   289  	return &V{
   290  		sdb:  states,
   291  		id:   id,
   292  		aid:  aid,
   293  		oldV: st,
   294  		newV: newV,
   295  	}, nil
   296  }
   297  
   298  func (states *StateDB) InitAccountStateV(id []byte, old *types.State, new *types.State) *V {
   299  	return &V{
   300  		sdb:  states,
   301  		id:   id,
   302  		oldV: old,
   303  		newV: new,
   304  	}
   305  }
   306  
   307  // GetState gets state of account id from state buffer and trie.
   308  // nil value is returned when there is no state corresponding to account id.
   309  func (states *StateDB) GetState(id types.AccountID) (*types.State, error) {
   310  	states.lock.RLock()
   311  	defer states.lock.RUnlock()
   312  	if id == emptyAccountID {
   313  		return nil, errGetState
   314  	}
   315  	return states.getState(id)
   316  }
   317  
   318  // getState returns state of account id from buffer and trie.
   319  // nil value is returned when there is no state corresponding to account id.
   320  func (states *StateDB) getState(id types.AccountID) (*types.State, error) {
   321  	// get state from buffer
   322  	if entry := states.buffer.get(types.HashID(id)); entry != nil {
   323  		return entry.Value().(*types.State), nil
   324  	}
   325  	// get state from trie
   326  	return states.getTrieState(id)
   327  }
   328  
   329  // getTrieState gets state of account id from trie.
   330  // nil value is returned when there is no state corresponding to account id.
   331  func (states *StateDB) getTrieState(id types.AccountID) (*types.State, error) {
   332  	key, err := states.trie.Get(id[:])
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  	if key == nil || len(key) == 0 {
   337  		return nil, nil
   338  	}
   339  	return states.loadStateData(key)
   340  }
   341  
   342  func (states *StateDB) TrieQuery(id []byte, root []byte, compressed bool) ([]byte, [][]byte, int, bool, []byte, []byte, error) {
   343  	var ap [][]byte
   344  	var proofKey, proofVal, bitmap []byte
   345  	var isIncluded bool
   346  	var err error
   347  	var height int
   348  	states.lock.RLock()
   349  	defer states.lock.RUnlock()
   350  
   351  	if len(root) != 0 {
   352  		if compressed {
   353  			bitmap, ap, height, isIncluded, proofKey, proofVal, err = states.trie.MerkleProofCompressedR(id, root)
   354  		} else {
   355  			// Get the state and proof of the account for a past state
   356  			ap, isIncluded, proofKey, proofVal, err = states.trie.MerkleProofR(id, root)
   357  		}
   358  	} else {
   359  		// Get the state and proof of the account at the latest trie
   360  		// The wallet should check that state hashes to proofVal and verify the audit path,
   361  		// The returned proofVal shouldn't be trusted by the wallet, it is used to proove non inclusion
   362  		if compressed {
   363  			bitmap, ap, height, isIncluded, proofKey, proofVal, err = states.trie.MerkleProofCompressed(id)
   364  		} else {
   365  			ap, isIncluded, proofKey, proofVal, err = states.trie.MerkleProof(id)
   366  		}
   367  	}
   368  	return bitmap, ap, height, isIncluded, proofKey, proofVal, err
   369  }
   370  
   371  // GetVarAndProof gets the value of a variable in the given contract trie root.
   372  func (states *StateDB) GetVarAndProof(id []byte, root []byte, compressed bool) (*types.ContractVarProof, error) {
   373  	var value []byte
   374  	bitmap, ap, height, isIncluded, proofKey, dbKey, err := states.TrieQuery(id, root, compressed)
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  	if isIncluded {
   379  		value = []byte{}
   380  		if err := loadData(states.store, dbKey, &value); err != nil {
   381  			return nil, err
   382  		}
   383  		// proofKey and proofVal are only not nil for prooving exclusion with another leaf on the path
   384  		dbKey = nil
   385  	}
   386  	contractVarProof := &types.ContractVarProof{
   387  		Value:     value,
   388  		Inclusion: isIncluded,
   389  		ProofKey:  proofKey,
   390  		ProofVal:  dbKey,
   391  		Bitmap:    bitmap,
   392  		Height:    uint32(height),
   393  		AuditPath: ap,
   394  	}
   395  	logger.Debug().Str("contract root : ", enc.ToString(root)).Msg("Get contract variable and Proof")
   396  	return contractVarProof, nil
   397  
   398  }
   399  
   400  // GetAccountAndProof gets the state and associated proof of an account
   401  // in the given trie root. If the account doesnt exist, a proof of
   402  // non existence is returned.
   403  func (states *StateDB) GetAccountAndProof(id []byte, root []byte, compressed bool) (*types.AccountProof, error) {
   404  	var state *types.State
   405  	bitmap, ap, height, isIncluded, proofKey, dbKey, err := states.TrieQuery(id, root, compressed)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  	if isIncluded {
   410  		state, err = states.loadStateData(dbKey)
   411  		if err != nil {
   412  			return nil, err
   413  		}
   414  		dbKey = nil
   415  	}
   416  	accountProof := &types.AccountProof{
   417  		State:     state,
   418  		Inclusion: isIncluded,
   419  		ProofKey:  proofKey,
   420  		ProofVal:  dbKey,
   421  		Bitmap:    bitmap,
   422  		Height:    uint32(height),
   423  		AuditPath: ap,
   424  	}
   425  	logger.Debug().Str("state root : ", enc.ToString(root)).Msg("Get Account and Proof")
   426  	return accountProof, nil
   427  }
   428  
   429  // Snapshot represents revision number of statedb
   430  type Snapshot int
   431  
   432  // Snapshot returns revision number of state buffer
   433  func (states *StateDB) Snapshot() Snapshot {
   434  	states.lock.RLock()
   435  	defer states.lock.RUnlock()
   436  	return Snapshot(states.buffer.snapshot())
   437  }
   438  
   439  // Rollback discards changes of state buffer to revision number
   440  func (states *StateDB) Rollback(revision Snapshot) error {
   441  	states.lock.Lock()
   442  	defer states.lock.Unlock()
   443  	return states.buffer.rollback(int(revision))
   444  }
   445  
   446  // Update applies changes of state buffer to trie
   447  func (states *StateDB) Update() error {
   448  	states.lock.Lock()
   449  	defer states.lock.Unlock()
   450  
   451  	if err := states.update(); err != nil {
   452  		return err
   453  	}
   454  	return nil
   455  }
   456  
   457  func (states *StateDB) update() error {
   458  	// update storage and put state with changed storage root
   459  	if err := states.updateStorage(); err != nil {
   460  		return err
   461  	}
   462  	// export buffer and update to trie
   463  	if err := states.buffer.updateTrie(states.trie); err != nil {
   464  		return err
   465  	}
   466  	return nil
   467  }
   468  
   469  func (states *StateDB) updateStorage() error {
   470  	before := states.buffer.snapshot()
   471  	for id, storage := range states.cache.storages {
   472  		// update storage
   473  		if err := storage.update(); err != nil {
   474  			states.buffer.rollback(before)
   475  			return err
   476  		}
   477  		// update state if storage root changed
   478  		if storage.isDirty() {
   479  			st, err := states.getState(id)
   480  			if err != nil {
   481  				states.buffer.rollback(before)
   482  				return err
   483  			}
   484  			if st == nil {
   485  				st = &types.State{}
   486  			}
   487  			// put state with storage root
   488  			st.StorageRoot = storage.trie.Root
   489  			states.buffer.put(newValueEntry(types.HashID(id), st))
   490  		}
   491  	}
   492  	return nil
   493  }
   494  
   495  // Commit writes state buffer and trie to db
   496  func (states *StateDB) Commit() error {
   497  	states.lock.Lock()
   498  	defer states.lock.Unlock()
   499  
   500  	bulk := states.store.NewBulk()
   501  	for _, storage := range states.cache.storages {
   502  		// stage changes
   503  		if err := storage.stage(bulk); err != nil {
   504  			bulk.DiscardLast()
   505  			return err
   506  		}
   507  	}
   508  	if err := states.stage(bulk); err != nil {
   509  		bulk.DiscardLast()
   510  		return err
   511  	}
   512  	bulk.Flush()
   513  	return nil
   514  }
   515  
   516  func (states *StateDB) stage(txn trie.DbTx) error {
   517  	// stage trie and buffer
   518  	states.trie.StageUpdates(txn)
   519  	if err := states.buffer.stage(txn); err != nil {
   520  		return err
   521  	}
   522  	// set marker
   523  	states.setMarker(txn)
   524  	// reset buffer
   525  	if err := states.buffer.reset(); err != nil {
   526  		return err
   527  	}
   528  	return nil
   529  }
   530  
   531  // setMarker store the marker that represents finalization of the state root.
   532  func (states *StateDB) setMarker(txn trie.DbTx) {
   533  	if states.trie.Root == nil {
   534  		return
   535  	}
   536  	// logger.Debug().Str("stateRoot", enc.ToString(states.trie.Root)).Msg("setMarker")
   537  	txn.Set(common.Hasher(states.trie.Root), stateMarker)
   538  }
   539  
   540  // HasMarker represents that the state root is finalized or not.
   541  func (states *StateDB) HasMarker(root []byte) bool {
   542  	if root == nil {
   543  		return false
   544  	}
   545  	marker := states.store.Get(common.Hasher(root))
   546  	if marker != nil && bytes.Equal(marker, stateMarker) {
   547  		// logger.Debug().Str("stateRoot", enc.ToString(root)).Str("marker", hex.EncodeToString(marker)).Msg("IsMarked")
   548  		return true
   549  	}
   550  	return false
   551  }