github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/chain/trie/trie.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  
     7  	"github.com/neatio-net/neatio/chain/log"
     8  	"github.com/neatio-net/neatio/utilities/common"
     9  	"github.com/neatio-net/neatio/utilities/crypto"
    10  )
    11  
    12  var (
    13  	emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
    14  
    15  	emptyState = crypto.Keccak256Hash(nil)
    16  )
    17  
    18  type LeafCallback func(leaf []byte, parent common.Hash) error
    19  
    20  type Trie struct {
    21  	db   *Database
    22  	root node
    23  }
    24  
    25  func (t *Trie) newFlag() nodeFlag {
    26  	return nodeFlag{dirty: true}
    27  }
    28  
    29  func New(root common.Hash, db *Database) (*Trie, error) {
    30  	if db == nil {
    31  		panic("trie.New called without a database")
    32  	}
    33  	trie := &Trie{
    34  		db: db,
    35  	}
    36  	if root != (common.Hash{}) && root != emptyRoot {
    37  		rootnode, err := trie.resolveHash(root[:], nil)
    38  		if err != nil {
    39  			return nil, err
    40  		}
    41  		trie.root = rootnode
    42  	}
    43  	return trie, nil
    44  }
    45  
    46  func (t *Trie) NodeIterator(start []byte) NodeIterator {
    47  	return newNodeIterator(t, start)
    48  }
    49  
    50  func (t *Trie) Get(key []byte) []byte {
    51  	res, err := t.TryGet(key)
    52  	if err != nil {
    53  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    54  	}
    55  	return res
    56  }
    57  
    58  func (t *Trie) TryGet(key []byte) ([]byte, error) {
    59  	key = keybytesToHex(key)
    60  	value, newroot, didResolve, err := t.tryGet(t.root, key, 0)
    61  	if err == nil && didResolve {
    62  		t.root = newroot
    63  	}
    64  	return value, err
    65  }
    66  
    67  func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
    68  	switch n := (origNode).(type) {
    69  	case nil:
    70  		return nil, nil, false, nil
    71  	case valueNode:
    72  		return n, n, false, nil
    73  	case *shortNode:
    74  		if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
    75  
    76  			return nil, n, false, nil
    77  		}
    78  		value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
    79  		if err == nil && didResolve {
    80  			n = n.copy()
    81  			n.Val = newnode
    82  		}
    83  		return value, n, didResolve, err
    84  	case *fullNode:
    85  		value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
    86  		if err == nil && didResolve {
    87  			n = n.copy()
    88  			n.Children[key[pos]] = newnode
    89  		}
    90  		return value, n, didResolve, err
    91  	case hashNode:
    92  		side, err := t.resolveHash(n, key[:pos])
    93  		if err != nil {
    94  			return nil, n, true, err
    95  		}
    96  		value, newnode, _, err := t.tryGet(side, key, pos)
    97  		return value, newnode, true, err
    98  	default:
    99  		panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
   100  	}
   101  }
   102  
   103  func (t *Trie) Update(key, value []byte) {
   104  	if err := t.TryUpdate(key, value); err != nil {
   105  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
   106  	}
   107  }
   108  
   109  func (t *Trie) TryUpdate(key, value []byte) error {
   110  	k := keybytesToHex(key)
   111  	if len(value) != 0 {
   112  		_, n, err := t.insert(t.root, nil, k, valueNode(value))
   113  		if err != nil {
   114  			return err
   115  		}
   116  		t.root = n
   117  	} else {
   118  		_, n, err := t.delete(t.root, nil, k)
   119  		if err != nil {
   120  			return err
   121  		}
   122  		t.root = n
   123  	}
   124  	return nil
   125  }
   126  
   127  func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
   128  	if len(key) == 0 {
   129  		if v, ok := n.(valueNode); ok {
   130  			return !bytes.Equal(v, value.(valueNode)), value, nil
   131  		}
   132  		return true, value, nil
   133  	}
   134  	switch n := n.(type) {
   135  	case *shortNode:
   136  		matchlen := prefixLen(key, n.Key)
   137  
   138  		if matchlen == len(n.Key) {
   139  			dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
   140  			if !dirty || err != nil {
   141  				return false, n, err
   142  			}
   143  			return true, &shortNode{n.Key, nn, t.newFlag()}, nil
   144  		}
   145  
   146  		branch := &fullNode{flags: t.newFlag()}
   147  		var err error
   148  		_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
   149  		if err != nil {
   150  			return false, nil, err
   151  		}
   152  		_, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
   153  		if err != nil {
   154  			return false, nil, err
   155  		}
   156  
   157  		if matchlen == 0 {
   158  			return true, branch, nil
   159  		}
   160  
   161  		return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
   162  
   163  	case *fullNode:
   164  		dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
   165  		if !dirty || err != nil {
   166  			return false, n, err
   167  		}
   168  		n = n.copy()
   169  		n.flags = t.newFlag()
   170  		n.Children[key[0]] = nn
   171  		return true, n, nil
   172  
   173  	case nil:
   174  		return true, &shortNode{key, value, t.newFlag()}, nil
   175  
   176  	case hashNode:
   177  
   178  		rn, err := t.resolveHash(n, prefix)
   179  		if err != nil {
   180  			return false, nil, err
   181  		}
   182  		dirty, nn, err := t.insert(rn, prefix, key, value)
   183  		if !dirty || err != nil {
   184  			return false, rn, err
   185  		}
   186  		return true, nn, nil
   187  
   188  	default:
   189  		panic(fmt.Sprintf("%T: invalid node: %v", n, n))
   190  	}
   191  }
   192  
   193  func (t *Trie) Delete(key []byte) {
   194  	if err := t.TryDelete(key); err != nil {
   195  		log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
   196  	}
   197  }
   198  
   199  func (t *Trie) TryDelete(key []byte) error {
   200  	k := keybytesToHex(key)
   201  	_, n, err := t.delete(t.root, nil, k)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	t.root = n
   206  	return nil
   207  }
   208  
   209  func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
   210  	switch n := n.(type) {
   211  	case *shortNode:
   212  		matchlen := prefixLen(key, n.Key)
   213  		if matchlen < len(n.Key) {
   214  			return false, n, nil
   215  		}
   216  		if matchlen == len(key) {
   217  			return true, nil, nil
   218  		}
   219  
   220  		dirty, side, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
   221  		if !dirty || err != nil {
   222  			return false, n, err
   223  		}
   224  		switch side := side.(type) {
   225  		case *shortNode:
   226  
   227  			return true, &shortNode{concat(n.Key, side.Key...), side.Val, t.newFlag()}, nil
   228  		default:
   229  			return true, &shortNode{n.Key, side, t.newFlag()}, nil
   230  		}
   231  
   232  	case *fullNode:
   233  		dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
   234  		if !dirty || err != nil {
   235  			return false, n, err
   236  		}
   237  		n = n.copy()
   238  		n.flags = t.newFlag()
   239  		n.Children[key[0]] = nn
   240  
   241  		pos := -1
   242  		for i, cld := range &n.Children {
   243  			if cld != nil {
   244  				if pos == -1 {
   245  					pos = i
   246  				} else {
   247  					pos = -2
   248  					break
   249  				}
   250  			}
   251  		}
   252  		if pos >= 0 {
   253  			if pos != 16 {
   254  
   255  				cnode, err := t.resolve(n.Children[pos], prefix)
   256  				if err != nil {
   257  					return false, nil, err
   258  				}
   259  				if cnode, ok := cnode.(*shortNode); ok {
   260  					k := append([]byte{byte(pos)}, cnode.Key...)
   261  					return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
   262  				}
   263  			}
   264  
   265  			return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
   266  		}
   267  
   268  		return true, n, nil
   269  
   270  	case valueNode:
   271  		return true, nil, nil
   272  
   273  	case nil:
   274  		return false, nil, nil
   275  
   276  	case hashNode:
   277  
   278  		rn, err := t.resolveHash(n, prefix)
   279  		if err != nil {
   280  			return false, nil, err
   281  		}
   282  		dirty, nn, err := t.delete(rn, prefix, key)
   283  		if !dirty || err != nil {
   284  			return false, rn, err
   285  		}
   286  		return true, nn, nil
   287  
   288  	default:
   289  		panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
   290  	}
   291  }
   292  
   293  func concat(s1 []byte, s2 ...byte) []byte {
   294  	r := make([]byte, len(s1)+len(s2))
   295  	copy(r, s1)
   296  	copy(r[len(s1):], s2)
   297  	return r
   298  }
   299  
   300  func (t *Trie) resolve(n node, prefix []byte) (node, error) {
   301  	if n, ok := n.(hashNode); ok {
   302  		return t.resolveHash(n, prefix)
   303  	}
   304  	return n, nil
   305  }
   306  
   307  func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
   308  	hash := common.BytesToHash(n)
   309  	if node := t.db.node(hash); node != nil {
   310  		return node, nil
   311  	}
   312  	return nil, &MissingNodeError{NodeHash: hash, Path: prefix}
   313  }
   314  
   315  func (t *Trie) Hash() common.Hash {
   316  	hash, cached, _ := t.hashRoot(nil, nil)
   317  	t.root = cached
   318  	return common.BytesToHash(hash.(hashNode))
   319  }
   320  
   321  func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
   322  	if t.db == nil {
   323  		panic("commit called on trie with nil database")
   324  	}
   325  	hash, cached, err := t.hashRoot(t.db, onleaf)
   326  	if err != nil {
   327  		return common.Hash{}, err
   328  	}
   329  	t.root = cached
   330  	return common.BytesToHash(hash.(hashNode)), nil
   331  }
   332  
   333  func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) {
   334  	if t.root == nil {
   335  		return hashNode(emptyRoot.Bytes()), nil, nil
   336  	}
   337  	h := newHasher(onleaf)
   338  	defer returnHasherToPool(h)
   339  	return h.hash(t.root, db, true)
   340  }