github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/iavl/node.go (about) 1 package iavl 2 3 // NOTE: This file favors int64 as opposed to int for size/counts. 4 // The Tree on the other hand favors int. This is intentional. 5 6 import ( 7 "bytes" 8 "fmt" 9 "io" 10 11 "github.com/gnolang/gno/tm2/pkg/amino" 12 "github.com/gnolang/gno/tm2/pkg/crypto/tmhash" 13 "github.com/gnolang/gno/tm2/pkg/errors" 14 ) 15 16 // Node represents a node in a Tree. 17 type Node struct { 18 key []byte 19 value []byte 20 version int64 21 height int8 22 size int64 23 hash []byte 24 leftHash []byte 25 leftNode *Node 26 rightHash []byte 27 rightNode *Node 28 persisted bool 29 } 30 31 // NewNode returns a new node from a key, value and version. 32 func NewNode(key []byte, value []byte, version int64) *Node { 33 return &Node{ 34 key: key, 35 value: value, 36 height: 0, 37 size: 1, 38 version: version, 39 } 40 } 41 42 // MakeNode constructs an *Node from an encoded byte slice. 43 // 44 // The new node doesn't have its hash saved or set. The caller must set it 45 // afterwards. 46 func MakeNode(buf []byte) (*Node, error) { 47 // Read node header (height, size, version, key). 48 height, n, cause := amino.DecodeVarint8(buf) 49 if cause != nil { 50 return nil, errors.Wrap(cause, "decoding node.height") 51 } 52 buf = buf[n:] 53 54 size, n, cause := amino.DecodeVarint(buf) 55 if cause != nil { 56 return nil, errors.Wrap(cause, "decoding node.size") 57 } 58 buf = buf[n:] 59 60 ver, n, cause := amino.DecodeVarint(buf) 61 if cause != nil { 62 return nil, errors.Wrap(cause, "decoding node.version") 63 } 64 buf = buf[n:] 65 66 key, n, cause := amino.DecodeByteSlice(buf) 67 if cause != nil { 68 return nil, errors.Wrap(cause, "decoding node.key") 69 } 70 buf = buf[n:] 71 72 node := &Node{ 73 height: height, 74 size: size, 75 version: ver, 76 key: key, 77 } 78 79 // Read node body. 80 81 if node.isLeaf() { 82 val, _, cause := amino.DecodeByteSlice(buf) 83 if cause != nil { 84 return nil, errors.Wrap(cause, "decoding node.value") 85 } 86 node.value = val 87 } else { // Read children. 88 leftHash, n, cause := amino.DecodeByteSlice(buf) 89 if cause != nil { 90 return nil, errors.Wrap(cause, "decoding node.leftHash") 91 } 92 buf = buf[n:] 93 94 rightHash, _, cause := amino.DecodeByteSlice(buf) 95 if cause != nil { 96 return nil, errors.Wrap(cause, "decoding node.rightHash") 97 } 98 node.leftHash = leftHash 99 node.rightHash = rightHash 100 } 101 return node, nil 102 } 103 104 // String returns a string representation of the node. 105 func (node *Node) String() string { 106 hashstr := "<no hash>" 107 if len(node.hash) > 0 { 108 hashstr = fmt.Sprintf("%X", node.hash) 109 } 110 return fmt.Sprintf("Node{%s:%s@%d %X;%X}#%s", 111 ColoredBytes(node.key, Green, Blue), 112 ColoredBytes(node.value, Cyan, Blue), 113 node.version, 114 node.leftHash, node.rightHash, 115 hashstr) 116 } 117 118 // clone creates a shallow copy of a node with its hash set to nil. 119 func (node *Node) clone(version int64) *Node { 120 if node.isLeaf() { 121 panic("Attempt to copy a leaf node") 122 } 123 return &Node{ 124 key: node.key, 125 height: node.height, 126 version: version, 127 size: node.size, 128 hash: nil, 129 leftHash: node.leftHash, 130 leftNode: node.leftNode, 131 rightHash: node.rightHash, 132 rightNode: node.rightNode, 133 persisted: false, 134 } 135 } 136 137 func (node *Node) isLeaf() bool { 138 return node.height == 0 139 } 140 141 // Check if the node has a descendant with the given key. 142 func (node *Node) has(t *ImmutableTree, key []byte) (has bool) { 143 if bytes.Equal(node.key, key) { 144 return true 145 } 146 if node.isLeaf() { 147 return false 148 } 149 if bytes.Compare(key, node.key) < 0 { 150 return node.getLeftNode(t).has(t, key) 151 } 152 return node.getRightNode(t).has(t, key) 153 } 154 155 // Get a key under the node. 156 func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte) { 157 if node.isLeaf() { 158 switch bytes.Compare(node.key, key) { 159 case -1: 160 return 1, nil 161 case 1: 162 return 0, nil 163 default: 164 return 0, node.value 165 } 166 } 167 168 if bytes.Compare(key, node.key) < 0 { 169 return node.getLeftNode(t).get(t, key) 170 } 171 rightNode := node.getRightNode(t) 172 index, value = rightNode.get(t, key) 173 index += node.size - rightNode.size 174 return index, value 175 } 176 177 func (node *Node) getByIndex(t *ImmutableTree, index int64) (key []byte, value []byte) { 178 if node.isLeaf() { 179 if index == 0 { 180 return node.key, node.value 181 } 182 return nil, nil 183 } 184 // TODO: could improve this by storing the 185 // sizes as well as left/right hash. 186 leftNode := node.getLeftNode(t) 187 188 if index < leftNode.size { 189 return leftNode.getByIndex(t, index) 190 } 191 return node.getRightNode(t).getByIndex(t, index-leftNode.size) 192 } 193 194 // Computes the hash of the node without computing its descendants. Must be 195 // called on nodes which have descendant node hashes already computed. 196 func (node *Node) _hash() []byte { 197 if node.hash != nil { 198 return node.hash 199 } 200 201 h := tmhash.New() 202 buf := new(bytes.Buffer) 203 if err := node.writeHashBytes(buf); err != nil { 204 panic(err) 205 } 206 h.Write(buf.Bytes()) 207 node.hash = h.Sum(nil) 208 209 return node.hash 210 } 211 212 // Hash the node and its descendants recursively. This usually mutates all 213 // descendant nodes. Returns the node hash and number of nodes hashed. 214 func (node *Node) hashWithCount() ([]byte, int64) { 215 if node.hash != nil { 216 return node.hash, 0 217 } 218 219 h := tmhash.New() 220 buf := new(bytes.Buffer) 221 hashCount, err := node.writeHashBytesRecursively(buf) 222 if err != nil { 223 panic(err) 224 } 225 h.Write(buf.Bytes()) 226 node.hash = h.Sum(nil) 227 228 return node.hash, hashCount + 1 229 } 230 231 // Writes the node's hash to the given io.Writer. This function expects 232 // child hashes to be already set. 233 func (node *Node) writeHashBytes(w io.Writer) error { 234 err := amino.EncodeVarint8(w, node.height) 235 if err != nil { 236 return errors.Wrap(err, "writing height") 237 } 238 err = amino.EncodeVarint(w, node.size) 239 if err != nil { 240 return errors.Wrap(err, "writing size") 241 } 242 err = amino.EncodeVarint(w, node.version) 243 if err != nil { 244 return errors.Wrap(err, "writing version") 245 } 246 247 // Key is not written for inner nodes, unlike writeBytes. 248 249 if node.isLeaf() { 250 err = amino.EncodeByteSlice(w, node.key) 251 if err != nil { 252 return errors.Wrap(err, "writing key") 253 } 254 // Indirection needed to provide proofs without values. 255 // (e.g. proofLeafNode.ValueHash) 256 valueHash := tmhash.Sum(node.value) 257 err = amino.EncodeByteSlice(w, valueHash) 258 if err != nil { 259 return errors.Wrap(err, "writing value") 260 } 261 } else { 262 if node.leftHash == nil || node.rightHash == nil { 263 panic("Found an empty child hash") 264 } 265 err = amino.EncodeByteSlice(w, node.leftHash) 266 if err != nil { 267 return errors.Wrap(err, "writing left hash") 268 } 269 err = amino.EncodeByteSlice(w, node.rightHash) 270 if err != nil { 271 return errors.Wrap(err, "writing right hash") 272 } 273 } 274 275 return nil 276 } 277 278 // Writes the node's hash to the given io.Writer. 279 // This function has the side-effect of calling hashWithCount. 280 func (node *Node) writeHashBytesRecursively(w io.Writer) (hashCount int64, err error) { 281 if node.leftNode != nil { 282 leftHash, leftCount := node.leftNode.hashWithCount() 283 node.leftHash = leftHash 284 hashCount += leftCount 285 } 286 if node.rightNode != nil { 287 rightHash, rightCount := node.rightNode.hashWithCount() 288 node.rightHash = rightHash 289 hashCount += rightCount 290 } 291 err = node.writeHashBytes(w) 292 293 return 294 } 295 296 // Writes the node as a serialized byte slice to the supplied io.Writer. 297 func (node *Node) writeBytes(w io.Writer) error { 298 var cause error 299 cause = amino.EncodeVarint8(w, node.height) 300 if cause != nil { 301 return errors.Wrap(cause, "writing height") 302 } 303 cause = amino.EncodeVarint(w, node.size) 304 if cause != nil { 305 return errors.Wrap(cause, "writing size") 306 } 307 cause = amino.EncodeVarint(w, node.version) 308 if cause != nil { 309 return errors.Wrap(cause, "writing version") 310 } 311 312 // Unlike writeHashBytes, key is written for inner nodes. 313 cause = amino.EncodeByteSlice(w, node.key) 314 if cause != nil { 315 return errors.Wrap(cause, "writing key") 316 } 317 318 if node.isLeaf() { 319 cause = amino.EncodeByteSlice(w, node.value) 320 if cause != nil { 321 return errors.Wrap(cause, "writing value") 322 } 323 } else { 324 if node.leftHash == nil { 325 panic("node.leftHash was nil in writeBytes") 326 } 327 cause = amino.EncodeByteSlice(w, node.leftHash) 328 if cause != nil { 329 return errors.Wrap(cause, "writing left hash") 330 } 331 332 if node.rightHash == nil { 333 panic("node.rightHash was nil in writeBytes") 334 } 335 cause = amino.EncodeByteSlice(w, node.rightHash) 336 if cause != nil { 337 return errors.Wrap(cause, "writing right hash") 338 } 339 } 340 return nil 341 } 342 343 func (node *Node) getLeftNode(t *ImmutableTree) *Node { 344 if node.leftNode != nil { 345 return node.leftNode 346 } 347 return t.ndb.GetNode(node.leftHash) 348 } 349 350 func (node *Node) getRightNode(t *ImmutableTree) *Node { 351 if node.rightNode != nil { 352 return node.rightNode 353 } 354 return t.ndb.GetNode(node.rightHash) 355 } 356 357 // NOTE: mutates height and size 358 func (node *Node) calcHeightAndSize(t *ImmutableTree) { 359 node.height = maxInt8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1 360 node.size = node.getLeftNode(t).size + node.getRightNode(t).size 361 } 362 363 func (node *Node) calcBalance(t *ImmutableTree) int { 364 return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height) 365 } 366 367 // traverse is a wrapper over traverseInRange when we want the whole tree 368 func (node *Node) traverse(t *ImmutableTree, ascending bool, cb func(*Node) bool) bool { 369 return node.traverseInRange(t, nil, nil, ascending, false, 0, func(node *Node, depth uint8) bool { 370 return cb(node) 371 }) 372 } 373 374 func (node *Node) traverseInRange(t *ImmutableTree, start, end []byte, ascending bool, inclusive bool, depth uint8, cb func(*Node, uint8) bool) bool { 375 afterStart := start == nil || bytes.Compare(start, node.key) < 0 376 startOrAfter := start == nil || bytes.Compare(start, node.key) <= 0 377 beforeEnd := end == nil || bytes.Compare(node.key, end) < 0 378 if inclusive { 379 beforeEnd = end == nil || bytes.Compare(node.key, end) <= 0 380 } 381 382 // Run callback per inner/leaf node. 383 stop := false 384 if !node.isLeaf() || (startOrAfter && beforeEnd) { 385 stop = cb(node, depth) 386 if stop { 387 return stop 388 } 389 } 390 if node.isLeaf() { 391 return stop 392 } 393 394 if ascending { 395 // check lower nodes, then higher 396 if afterStart { 397 stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb) 398 } 399 if stop { 400 return stop 401 } 402 if beforeEnd { 403 stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb) 404 } 405 } else { 406 // check the higher nodes first 407 if beforeEnd { 408 stop = node.getRightNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb) 409 } 410 if stop { 411 return stop 412 } 413 if afterStart { 414 stop = node.getLeftNode(t).traverseInRange(t, start, end, ascending, inclusive, depth+1, cb) 415 } 416 } 417 418 return stop 419 } 420 421 // Only used in testing... 422 func (node *Node) lmd(t *ImmutableTree) *Node { 423 if node.isLeaf() { 424 return node 425 } 426 return node.getLeftNode(t).lmd(t) 427 }