github.com/aergoio/aergo@v1.3.1/pkg/trie/trie_tools.go (about)

     1  /**
     2   *  @file
     3   *  @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package trie
     7  
     8  import (
     9  	"bytes"
    10  	"fmt"
    11  
    12  	"github.com/aergoio/aergo-lib/db"
    13  )
    14  
    15  // LoadCache loads the first layers of the merkle tree given a root
    16  // This is called after a node restarts so that it doesnt become slow with db reads
    17  // LoadCache also updates the Root with the given root.
    18  func (s *Trie) LoadCache(root []byte) error {
    19  	if s.db.Store == nil {
    20  		return fmt.Errorf("DB not connected to trie")
    21  	}
    22  	s.db.liveCache = make(map[Hash][][]byte)
    23  	ch := make(chan error, 1)
    24  	s.loadCache(root, nil, 0, s.TrieHeight, ch)
    25  	s.Root = root
    26  	return <-ch
    27  }
    28  
    29  // loadCache loads the first layers of the merkle tree given a root
    30  func (s *Trie) loadCache(root []byte, batch [][]byte, iBatch, height int, ch chan<- (error)) {
    31  	if height < s.CacheHeightLimit || len(root) == 0 {
    32  		ch <- nil
    33  		return
    34  	}
    35  	if height%4 == 0 {
    36  		// Load the node from db
    37  		s.db.lock.Lock()
    38  		dbval := s.db.Store.Get(root[:HashLength])
    39  		s.db.lock.Unlock()
    40  		if len(dbval) == 0 {
    41  			ch <- fmt.Errorf("the trie node %x is unavailable in the disk db, db may be corrupted", root)
    42  			return
    43  		}
    44  		//Store node in cache.
    45  		var node Hash
    46  		copy(node[:], root)
    47  		batch = s.parseBatch(dbval)
    48  		s.db.liveMux.Lock()
    49  		s.db.liveCache[node] = batch
    50  		s.db.liveMux.Unlock()
    51  		iBatch = 0
    52  		if batch[0][0] == 1 {
    53  			// if height == 0 this will also return
    54  			ch <- nil
    55  			return
    56  		}
    57  	}
    58  	if iBatch != 0 && batch[iBatch][HashLength] == 1 {
    59  		// Check if node is a leaf node
    60  		ch <- nil
    61  	} else {
    62  		// Load subtree
    63  		lnode, rnode := batch[2*iBatch+1], batch[2*iBatch+2]
    64  
    65  		lch := make(chan error, 1)
    66  		rch := make(chan error, 1)
    67  		go s.loadCache(lnode, batch, 2*iBatch+1, height-1, lch)
    68  		go s.loadCache(rnode, batch, 2*iBatch+2, height-1, rch)
    69  		if err := <-lch; err != nil {
    70  			ch <- err
    71  			return
    72  		}
    73  		if err := <-rch; err != nil {
    74  			ch <- err
    75  			return
    76  		}
    77  		ch <- nil
    78  	}
    79  }
    80  
    81  // Get fetches the value of a key by going down the current trie root.
    82  func (s *Trie) Get(key []byte) ([]byte, error) {
    83  	s.lock.RLock()
    84  	defer s.lock.RUnlock()
    85  	s.atomicUpdate = false
    86  	return s.get(s.Root, key, nil, 0, s.TrieHeight)
    87  }
    88  
    89  // get fetches the value of a key given a trie root
    90  func (s *Trie) get(root, key []byte, batch [][]byte, iBatch, height int) ([]byte, error) {
    91  	if len(root) == 0 {
    92  		// the trie does not contain the key
    93  		return nil, nil
    94  	}
    95  	// Fetch the children of the node
    96  	batch, iBatch, lnode, rnode, isShortcut, err := s.loadChildren(root, height, iBatch, batch)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	if isShortcut {
   101  		if bytes.Equal(lnode[:HashLength], key) {
   102  			return rnode[:HashLength], nil
   103  		}
   104  		// also returns nil if height 0 is not a shortcut
   105  		return nil, nil
   106  	}
   107  	if bitIsSet(key, s.TrieHeight-height) {
   108  		return s.get(rnode, key, batch, 2*iBatch+2, height-1)
   109  	}
   110  	return s.get(lnode, key, batch, 2*iBatch+1, height-1)
   111  }
   112  
   113  // TrieRootExists returns true if the root exists in Database.
   114  func (s *Trie) TrieRootExists(root []byte) bool {
   115  	s.db.lock.RLock()
   116  	dbval := s.db.Store.Get(root)
   117  	s.db.lock.RUnlock()
   118  	if len(dbval) != 0 {
   119  		return true
   120  	}
   121  	return false
   122  }
   123  
   124  // Commit stores the updated nodes to disk.
   125  // Commit should be called for every block otherwise past tries
   126  // are not recorded and it is not possible to revert to them
   127  // (except if AtomicUpdate is used, which records every state).
   128  func (s *Trie) Commit() error {
   129  	if s.db.Store == nil {
   130  		return fmt.Errorf("DB not connected to trie")
   131  	}
   132  	// NOTE The tx interface doesnt handle ErrTxnTooBig
   133  	txn := s.db.Store.NewTx().(DbTx)
   134  	s.StageUpdates(txn)
   135  	txn.(db.Transaction).Commit()
   136  	return nil
   137  }
   138  
   139  // StageUpdates requires a database transaction as input
   140  // Unlike Commit(), it doesnt commit the transaction
   141  // the database transaction MUST be commited otherwise the
   142  // state ROOT will not exist.
   143  func (s *Trie) StageUpdates(txn DbTx) {
   144  	s.lock.Lock()
   145  	defer s.lock.Unlock()
   146  	// Commit the new nodes to database, clear updatedNodes and store the Root in pastTries for reverts.
   147  	if !s.atomicUpdate {
   148  		// if previously AtomicUpdate was called, then past tries is already updated
   149  		s.updatePastTries()
   150  	}
   151  	s.db.commit(&txn)
   152  
   153  	s.db.updatedNodes = make(map[Hash][][]byte)
   154  	s.prevRoot = s.Root
   155  }
   156  
   157  // Stash rolls back the changes made by previous updates
   158  // and loads the cache from before the rollback.
   159  func (s *Trie) Stash(rollbackCache bool) error {
   160  	s.lock.Lock()
   161  	defer s.lock.Unlock()
   162  	s.Root = s.prevRoot
   163  	if rollbackCache {
   164  		// Making a temporary liveCache requires it to be copied, so it's quicker
   165  		// to just load the cache from DB if a block state root was incorrect.
   166  		s.db.liveCache = make(map[Hash][][]byte)
   167  		ch := make(chan error, 1)
   168  		s.loadCache(s.Root, nil, 0, s.TrieHeight, ch)
   169  		err := <-ch
   170  		if err != nil {
   171  			return err
   172  		}
   173  	} else {
   174  		s.db.liveCache = make(map[Hash][][]byte)
   175  	}
   176  	s.db.updatedNodes = make(map[Hash][][]byte)
   177  	// also stash past tries created by Atomic update
   178  	for i := len(s.pastTries) - 1; i >= 0; i-- {
   179  		if bytes.Equal(s.pastTries[i], s.Root) {
   180  			break
   181  		} else {
   182  			// remove from past tries
   183  			s.pastTries = s.pastTries[:len(s.pastTries)-1]
   184  		}
   185  	}
   186  	return nil
   187  }