github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/secure_trie.go (about)

     1  package trie
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/neatlab/neatio/chain/log"
     7  	"github.com/neatlab/neatio/utilities/common"
     8  )
     9  
    10  type SecureTrie struct {
    11  	trie             Trie
    12  	hashKeyBuf       [common.HashLength]byte
    13  	secKeyCache      map[string][]byte
    14  	secKeyCacheOwner *SecureTrie
    15  }
    16  
    17  func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
    18  	if db == nil {
    19  		panic("trie.NewSecure called without a database")
    20  	}
    21  	trie, err := New(root, db)
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  	return &SecureTrie{trie: *trie}, nil
    26  }
    27  
    28  func (t *SecureTrie) Get(key []byte) []byte {
    29  	res, err := t.TryGet(key)
    30  	if err != nil {
    31  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    32  	}
    33  	return res
    34  }
    35  
    36  func (t *SecureTrie) TryGet(key []byte) ([]byte, error) {
    37  	return t.trie.TryGet(t.hashKey(key))
    38  }
    39  
    40  func (t *SecureTrie) Update(key, value []byte) {
    41  	if err := t.TryUpdate(key, value); err != nil {
    42  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    43  	}
    44  }
    45  
    46  func (t *SecureTrie) TryUpdate(key, value []byte) error {
    47  	hk := t.hashKey(key)
    48  	err := t.trie.TryUpdate(hk, value)
    49  	if err != nil {
    50  		return err
    51  	}
    52  	t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
    53  	return nil
    54  }
    55  
    56  func (t *SecureTrie) Delete(key []byte) {
    57  	if err := t.TryDelete(key); err != nil {
    58  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    59  	}
    60  }
    61  
    62  func (t *SecureTrie) TryDelete(key []byte) error {
    63  	hk := t.hashKey(key)
    64  	delete(t.getSecKeyCache(), string(hk))
    65  	return t.trie.TryDelete(hk)
    66  }
    67  
    68  func (t *SecureTrie) GetKey(shaKey []byte) []byte {
    69  	if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
    70  		return key
    71  	}
    72  	key, _ := t.trie.db.preimage(common.BytesToHash(shaKey))
    73  	return key
    74  }
    75  
    76  func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
    77  
    78  	if len(t.getSecKeyCache()) > 0 {
    79  		t.trie.db.lock.Lock()
    80  		for hk, key := range t.secKeyCache {
    81  			t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key)
    82  		}
    83  		t.trie.db.lock.Unlock()
    84  
    85  		t.secKeyCache = make(map[string][]byte)
    86  	}
    87  
    88  	return t.trie.Commit(onleaf)
    89  }
    90  
    91  func (t *SecureTrie) Hash() common.Hash {
    92  	return t.trie.Hash()
    93  }
    94  
    95  func (t *SecureTrie) Copy() *SecureTrie {
    96  	cpy := *t
    97  	return &cpy
    98  }
    99  
   100  func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
   101  	return t.trie.NodeIterator(start)
   102  }
   103  
   104  func (t *SecureTrie) hashKey(key []byte) []byte {
   105  	h := newHasher(nil)
   106  	h.sha.Reset()
   107  	h.sha.Write(key)
   108  	buf := h.sha.Sum(t.hashKeyBuf[:0])
   109  	returnHasherToPool(h)
   110  	return buf
   111  }
   112  
   113  func (t *SecureTrie) getSecKeyCache() map[string][]byte {
   114  	if t != t.secKeyCacheOwner {
   115  		t.secKeyCacheOwner = t
   116  		t.secKeyCache = make(map[string][]byte)
   117  	}
   118  	return t.secKeyCache
   119  }