github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/iavl/node.go (about)

     1  package iavl
     2  
     3  // NOTE: This file favors int64 as opposed to int for size/counts.
     4  // The Tree on the other hand favors int.  This is intentional.
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  
    11  	"github.com/gnolang/gno/tm2/pkg/amino"
    12  	"github.com/gnolang/gno/tm2/pkg/crypto/tmhash"
    13  	"github.com/gnolang/gno/tm2/pkg/errors"
    14  )
    15  
    16  // Node represents a node in a Tree.
    17  type Node struct {
    18  	key       []byte
    19  	value     []byte
    20  	version   int64
    21  	height    int8
    22  	size      int64
    23  	hash      []byte
    24  	leftHash  []byte
    25  	leftNode  *Node
    26  	rightHash []byte
    27  	rightNode *Node
    28  	persisted bool
    29  }
    30  
    31  // NewNode returns a new node from a key, value and version.
    32  func NewNode(key []byte, value []byte, version int64) *Node {
    33  	return &Node{
    34  		key:     key,
    35  		value:   value,
    36  		height:  0,
    37  		size:    1,
    38  		version: version,
    39  	}
    40  }
    41  
    42  // MakeNode constructs an *Node from an encoded byte slice.
    43  //
    44  // The new node doesn't have its hash saved or set. The caller must set it
    45  // afterwards.
    46  func MakeNode(buf []byte) (*Node, error) {
    47  	// Read node header (height, size, version, key).
    48  	height, n, cause := amino.DecodeVarint8(buf)
    49  	if cause != nil {
    50  		return nil, errors.Wrap(cause, "decoding node.height")
    51  	}
    52  	buf = buf[n:]
    53  
    54  	size, n, cause := amino.DecodeVarint(buf)
    55  	if cause != nil {
    56  		return nil, errors.Wrap(cause, "decoding node.size")
    57  	}
    58  	buf = buf[n:]
    59  
    60  	ver, n, cause := amino.DecodeVarint(buf)
    61  	if cause != nil {
    62  		return nil, errors.Wrap(cause, "decoding node.version")
    63  	}
    64  	buf = buf[n:]
    65  
    66  	key, n, cause := amino.DecodeByteSlice(buf)
    67  	if cause != nil {
    68  		return nil, errors.Wrap(cause, "decoding node.key")
    69  	}
    70  	buf = buf[n:]
    71  
    72  	node := &Node{
    73  		height:  height,
    74  		size:    size,
    75  		version: ver,
    76  		key:     key,
    77  	}
    78  
    79  	// Read node body.
    80  
    81  	if node.isLeaf() {
    82  		val, _, cause := amino.DecodeByteSlice(buf)
    83  		if cause != nil {
    84  			return nil, errors.Wrap(cause, "decoding node.value")
    85  		}
    86  		node.value = val
    87  	} else { // Read children.
    88  		leftHash, n, cause := amino.DecodeByteSlice(buf)
    89  		if cause != nil {
    90  			return nil, errors.Wrap(cause, "decoding node.leftHash")
    91  		}
    92  		buf = buf[n:]
    93  
    94  		rightHash, _, cause := amino.DecodeByteSlice(buf)
    95  		if cause != nil {
    96  			return nil, errors.Wrap(cause, "decoding node.rightHash")
    97  		}
    98  		node.leftHash = leftHash
    99  		node.rightHash = rightHash
   100  	}
   101  	return node, nil
   102  }
   103  
   104  // String returns a string representation of the node.
   105  func (node *Node) String() string {
   106  	hashstr := "<no hash>"
   107  	if len(node.hash) > 0 {
   108  		hashstr = fmt.Sprintf("%X", node.hash)
   109  	}
   110  	return fmt.Sprintf("Node{%s:%s@%d %X;%X}#%s",
   111  		ColoredBytes(node.key, Green, Blue),
   112  		ColoredBytes(node.value, Cyan, Blue),
   113  		node.version,
   114  		node.leftHash, node.rightHash,
   115  		hashstr)
   116  }
   117  
   118  // clone creates a shallow copy of a node with its hash set to nil.
   119  func (node *Node) clone(version int64) *Node {
   120  	if node.isLeaf() {
   121  		panic("Attempt to copy a leaf node")
   122  	}
   123  	return &Node{
   124  		key:       node.key,
   125  		height:    node.height,
   126  		version:   version,
   127  		size:      node.size,
   128  		hash:      nil,
   129  		leftHash:  node.leftHash,
   130  		leftNode:  node.leftNode,
   131  		rightHash: node.rightHash,
   132  		rightNode: node.rightNode,
   133  		persisted: false,
   134  	}
   135  }
   136  
   137  func (node *Node) isLeaf() bool {
   138  	return node.height == 0
   139  }
   140  
   141  // Check if the node has a descendant with the given key.
   142  func (node *Node) has(t *ImmutableTree, key []byte) (has bool) {
   143  	if bytes.Equal(node.key, key) {
   144  		return true
   145  	}
   146  	if node.isLeaf() {
   147  		return false
   148  	}
   149  	if bytes.Compare(key, node.key) < 0 {
   150  		return node.getLeftNode(t).has(t, key)
   151  	}
   152  	return node.getRightNode(t).has(t, key)
   153  }
   154  
   155  // Get a key under the node.
   156  func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte) {
   157  	if node.isLeaf() {
   158  		switch bytes.Compare(node.key, key) {
   159  		case -1:
   160  			return 1, nil
   161  		case 1:
   162  			return 0, nil
   163  		default:
   164  			return 0, node.value
   165  		}
   166  	}
   167  
   168  	if bytes.Compare(key, node.key) < 0 {
   169  		return node.getLeftNode(t).get(t, key)
   170  	}
   171  	rightNode := node.getRightNode(t)
   172  	index, value = rightNode.get(t, key)
   173  	index += node.size - rightNode.size
   174  	return index, value
   175  }
   176  
   177  func (node *Node) getByIndex(t *ImmutableTree, index int64) (key []byte, value []byte) {
   178  	if node.isLeaf() {
   179  		if index == 0 {
   180  			return node.key, node.value
   181  		}
   182  		return nil, nil
   183  	}
   184  	// TODO: could improve this by storing the
   185  	// sizes as well as left/right hash.
   186  	leftNode := node.getLeftNode(t)
   187  
   188  	if index < leftNode.size {
   189  		return leftNode.getByIndex(t, index)
   190  	}
   191  	return node.getRightNode(t).getByIndex(t, index-leftNode.size)
   192  }
   193  
   194  // Computes the hash of the node without computing its descendants. Must be
   195  // called on nodes which have descendant node hashes already computed.
   196  func (node *Node) _hash() []byte {
   197  	if node.hash != nil {
   198  		return node.hash
   199  	}
   200  
   201  	h := tmhash.New()
   202  	buf := new(bytes.Buffer)
   203  	if err := node.writeHashBytes(buf); err != nil {
   204  		panic(err)
   205  	}
   206  	h.Write(buf.Bytes())
   207  	node.hash = h.Sum(nil)
   208  
   209  	return node.hash
   210  }
   211  
   212  // Hash the node and its descendants recursively. This usually mutates all
   213  // descendant nodes. Returns the node hash and number of nodes hashed.
   214  func (node *Node) hashWithCount() ([]byte, int64) {
   215  	if node.hash != nil {
   216  		return node.hash, 0
   217  	}
   218  
   219  	h := tmhash.New()
   220  	buf := new(bytes.Buffer)
   221  	hashCount, err := node.writeHashBytesRecursively(buf)
   222  	if err != nil {
   223  		panic(err)
   224  	}
   225  	h.Write(buf.Bytes())
   226  	node.hash = h.Sum(nil)
   227  
   228  	return node.hash, hashCount + 1
   229  }
   230  
   231  // Writes the node's hash to the given io.Writer. This function expects
   232  // child hashes to be already set.
   233  func (node *Node) writeHashBytes(w io.Writer) error {
   234  	err := amino.EncodeVarint8(w, node.height)
   235  	if err != nil {
   236  		return errors.Wrap(err, "writing height")
   237  	}
   238  	err = amino.EncodeVarint(w, node.size)
   239  	if err != nil {
   240  		return errors.Wrap(err, "writing size")
   241  	}
   242  	err = amino.EncodeVarint(w, node.version)
   243  	if err != nil {
   244  		return errors.Wrap(err, "writing version")
   245  	}
   246  
   247  	// Key is not written for inner nodes, unlike writeBytes.
   248  
   249  	if node.isLeaf() {
   250  		err = amino.EncodeByteSlice(w, node.key)
   251  		if err != nil {
   252  			return errors.Wrap(err, "writing key")
   253  		}
   254  		// Indirection needed to provide proofs without values.
   255  		// (e.g. proofLeafNode.ValueHash)
   256  		valueHash := tmhash.Sum(node.value)
   257  		err = amino.EncodeByteSlice(w, valueHash)
   258  		if err != nil {
   259  			return errors.Wrap(err, "writing value")
   260  		}
   261  	} else {
   262  		if node.leftHash == nil || node.rightHash == nil {
   263  			panic("Found an empty child hash")
   264  		}
   265  		err = amino.EncodeByteSlice(w, node.leftHash)
   266  		if err != nil {
   267  			return errors.Wrap(err, "writing left hash")
   268  		}
   269  		err = amino.EncodeByteSlice(w, node.rightHash)
   270  		if err != nil {
   271  			return errors.Wrap(err, "writing right hash")
   272  		}
   273  	}
   274  
   275  	return nil
   276  }
   277  
   278  // Writes the node's hash to the given io.Writer.
   279  // This function has the side-effect of calling hashWithCount.
   280  func (node *Node) writeHashBytesRecursively(w io.Writer) (hashCount int64, err error) {
   281  	if node.leftNode != nil {
   282  		leftHash, leftCount := node.leftNode.hashWithCount()
   283  		node.leftHash = leftHash
   284  		hashCount += leftCount
   285  	}
   286  	if node.rightNode != nil {
   287  		rightHash, rightCount := node.rightNode.hashWithCount()
   288  		node.rightHash = rightHash
   289  		hashCount += rightCount
   290  	}
   291  	err = node.writeHashBytes(w)
   292  
   293  	return
   294  }
   295  
   296  // Writes the node as a serialized byte slice to the supplied io.Writer.
   297  func (node *Node) writeBytes(w io.Writer) error {
   298  	var cause error
   299  	cause = amino.EncodeVarint8(w, node.height)
   300  	if cause != nil {
   301  		return errors.Wrap(cause, "writing height")
   302  	}
   303  	cause = amino.EncodeVarint(w, node.size)
   304  	if cause != nil {
   305  		return errors.Wrap(cause, "writing size")
   306  	}
   307  	cause = amino.EncodeVarint(w, node.version)
   308  	if cause != nil {
   309  		return errors.Wrap(cause, "writing version")
   310  	}
   311  
   312  	// Unlike writeHashBytes, key is written for inner nodes.
   313  	cause = amino.EncodeByteSlice(w, node.key)
   314  	if cause != nil {
   315  		return errors.Wrap(cause, "writing key")
   316  	}
   317  
   318  	if node.isLeaf() {
   319  		cause = amino.EncodeByteSlice(w, node.value)
   320  		if cause != nil {
   321  			return errors.Wrap(cause, "writing value")
   322  		}
   323  	} else {
   324  		if node.leftHash == nil {
   325  			panic("node.leftHash was nil in writeBytes")
   326  		}
   327  		cause = amino.EncodeByteSlice(w, node.leftHash)
   328  		if cause != nil {
   329  			return errors.Wrap(cause, "writing left hash")
   330  		}
   331  
   332  		if node.rightHash == nil {
   333  			panic("node.rightHash was nil in writeBytes")
   334  		}
   335  		cause = amino.EncodeByteSlice(w, node.rightHash)
   336  		if cause != nil {
   337  			return errors.Wrap(cause, "writing right hash")
   338  		}
   339  	}
   340  	return nil
   341  }
   342  
   343  func (node *Node) getLeftNode(t *ImmutableTree) *Node {
   344  	if node.leftNode != nil {
   345  		return node.leftNode
   346  	}
   347  	return t.ndb.GetNode(node.leftHash)
   348  }
   349  
   350  func (node *Node) getRightNode(t *ImmutableTree) *Node {
   351  	if node.rightNode != nil {
   352  		return node.rightNode
   353  	}
   354  	return t.ndb.GetNode(node.rightHash)
   355  }
   356  
   357  // NOTE: mutates height and size
   358  func (node *Node) calcHeightAndSize(t *ImmutableTree) {
   359  	node.height = maxInt8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1
   360  	node.size = node.getLeftNode(t).size + node.getRightNode(t).size
   361  }
   362  
   363  func (node *Node) calcBalance(t *ImmutableTree) int {
   364  	return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height)
   365  }
   366  
   367  // traverse is a wrapper over traverseInRange when we want the whole tree
   368  func (node *Node) traverse(t *ImmutableTree, ascending bool, cb func(*Node) bool) bool {
   369  	return node.traverseInRange(t, nil, nil, ascending, false, 0, func(node *Node, depth uint8) bool {
   370  		return cb(node)
   371  	})
   372  }
   373  
   374  func (node *Node) traverseInRange(t *ImmutableTree, start, end []byte, ascending bool, inclusive bool, depth uint8, cb func(*Node, uint8) bool) bool {
   375  	afterStart := start == nil || bytes.Compare(start, node.key) < 0
   376  	startOrAfter := start == nil || bytes.Compare(start, node.key) <= 0
   377  	beforeEnd := end == nil || bytes.Compare(node.key, end) < 0
   378  	if inclusive {
   379  		beforeEnd = end == nil || bytes.Compare(node.key, end) <= 0
   380  	}
   381  
   382  	// Run callback per inner/leaf node.
   383  	stop := false
   384  	if !node.isLeaf() || (startOrAfter && beforeEnd) {
   385  		stop = cb(node, depth)
   386  		if stop {
   387  			return stop
   388  		}
   389  	}
   390  	if node.isLeaf() {
   391  		return stop
   392  	}
   393  
   394  	if ascending {
   395  		// check lower nodes, then higher
   396  		if afterStart {
   397  			stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb)
   398  		}
   399  		if stop {
   400  			return stop
   401  		}
   402  		if beforeEnd {
   403  			stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb)
   404  		}
   405  	} else {
   406  		// check the higher nodes first
   407  		if beforeEnd {
   408  			stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb)
   409  		}
   410  		if stop {
   411  			return stop
   412  		}
   413  		if afterStart {
   414  			stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb)
   415  		}
   416  	}
   417  
   418  	return stop
   419  }
   420  
   421  // Only used in testing...
   422  func (node *Node) lmd(t *ImmutableTree) *Node {
   423  	if node.isLeaf() {
   424  		return node
   425  	}
   426  	return node.getLeftNode(t).lmd(t)
   427  }