github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/mpt/batch.go (about) 1 package mpt 2 3 import ( 4 "bytes" 5 "sort" 6 ) 7 8 // Batch is a batch of storage changes. 9 // It stores key-value pairs in a sorted state. 10 type Batch struct { 11 kv []keyValue 12 } 13 14 type keyValue struct { 15 key []byte 16 value []byte 17 } 18 19 // MapToMPTBatch makes a Batch from an unordered set of storage changes. 20 func MapToMPTBatch(m map[string][]byte) Batch { 21 var b Batch 22 23 b.kv = make([]keyValue, 0, len(m)) 24 25 for k, v := range m { 26 b.kv = append(b.kv, keyValue{strToNibbles(k), v}) // Strip storage prefix. 27 } 28 sort.Slice(b.kv, func(i, j int) bool { 29 return bytes.Compare(b.kv[i].key, b.kv[j].key) < 0 30 }) 31 return b 32 } 33 34 // PutBatch puts a batch to a trie. 35 // It is not atomic (and probably cannot be without substantial slow-down) 36 // and returns the number of elements processed. 37 // If an error is returned, the trie may be in the inconsistent state in case of storage failures. 38 // This is due to the fact that we can remove multiple children from the branch node simultaneously 39 // and won't strip the resulting branch node. 40 // However, it is used mostly after block processing to update MPT, and error is not expected. 41 func (t *Trie) PutBatch(b Batch) (int, error) { 42 if len(b.kv) == 0 { 43 return 0, nil 44 } 45 r, n, err := t.putBatch(b.kv) 46 t.root = r 47 return n, err 48 } 49 50 func (t *Trie) putBatch(kv []keyValue) (Node, int, error) { 51 return t.putBatchIntoNode(t.root, kv) 52 } 53 54 func (t *Trie) putBatchIntoNode(curr Node, kv []keyValue) (Node, int, error) { 55 switch n := curr.(type) { 56 case *LeafNode: 57 return t.putBatchIntoLeaf(n, kv) 58 case *BranchNode: 59 return t.putBatchIntoBranch(n, kv) 60 case *ExtensionNode: 61 return t.putBatchIntoExtension(n, kv) 62 case *HashNode: 63 return t.putBatchIntoHash(n, kv) 64 case EmptyNode: 65 return t.putBatchIntoEmpty(kv) 66 default: 67 panic("invalid MPT node type") 68 } 69 } 70 71 func (t *Trie) putBatchIntoLeaf(curr *LeafNode, kv []keyValue) (Node, int, error) { 72 t.removeRef(curr.Hash(), curr.Bytes()) 73 return t.newSubTrieMany(nil, kv, curr.value) 74 } 75 76 func (t *Trie) putBatchIntoBranch(curr *BranchNode, kv []keyValue) (Node, int, error) { 77 return t.addToBranch(curr, kv, true) 78 } 79 80 func (t *Trie) mergeExtension(prefix []byte, sub Node) (Node, error) { 81 switch sn := sub.(type) { 82 case *ExtensionNode: 83 t.removeRef(sn.Hash(), sn.bytes) 84 sn.key = append(prefix, sn.key...) 85 sn.invalidateCache() 86 t.addRef(sn.Hash(), sn.bytes) 87 return sn, nil 88 case EmptyNode: 89 return sn, nil 90 case *HashNode: 91 n, err := t.getFromStore(sn.Hash()) 92 if err != nil { 93 return sn, err 94 } 95 return t.mergeExtension(prefix, n) 96 default: 97 if len(prefix) != 0 { 98 e := NewExtensionNode(prefix, sub) 99 t.addRef(e.Hash(), e.bytes) 100 return e, nil 101 } 102 return sub, nil 103 } 104 } 105 106 func (t *Trie) putBatchIntoExtension(curr *ExtensionNode, kv []keyValue) (Node, int, error) { 107 t.removeRef(curr.Hash(), curr.bytes) 108 109 common := lcpMany(kv) 110 pref := lcp(common, curr.key) 111 if len(pref) == len(curr.key) { 112 // Extension must be split into new nodes. 113 stripPrefix(len(curr.key), kv) 114 sub, n, err := t.putBatchIntoNode(curr.next, kv) 115 if err == nil { 116 sub, err = t.mergeExtension(pref, sub) 117 } 118 return sub, n, err 119 } 120 121 if len(pref) != 0 { 122 stripPrefix(len(pref), kv) 123 sub, n, err := t.putBatchIntoExtensionNoPrefix(curr.key[len(pref):], curr.next, kv) 124 if err == nil { 125 sub, err = t.mergeExtension(pref, sub) 126 } 127 return sub, n, err 128 } 129 return t.putBatchIntoExtensionNoPrefix(curr.key, curr.next, kv) 130 } 131 132 func (t *Trie) putBatchIntoExtensionNoPrefix(key []byte, next Node, kv []keyValue) (Node, int, error) { 133 b := NewBranchNode() 134 if len(key) > 1 { 135 b.Children[key[0]] = t.newSubTrie(key[1:], next, false) 136 } else { 137 b.Children[key[0]] = next 138 } 139 return t.addToBranch(b, kv, false) 140 } 141 142 func isEmpty(n Node) bool { 143 _, ok := n.(EmptyNode) 144 return ok 145 } 146 147 // addToBranch puts items into the branch node assuming b is not yet in trie. 148 func (t *Trie) addToBranch(b *BranchNode, kv []keyValue, inTrie bool) (Node, int, error) { 149 if inTrie { 150 t.removeRef(b.Hash(), b.bytes) 151 } 152 153 // An error during iterate means some storage failure (i.e. some hash node cannot be 154 // retrieved from storage). This can leave the trie in an inconsistent state because 155 // it can be impossible to strip the branch node after it has been changed. 156 // Consider a branch with 10 children, first 9 of which are deleted and the remaining one 157 // is a leaf node replaced by a hash node missing from the storage. 158 // This can't be fixed easily because we need to _revert_ changes in the reference counts 159 // for children which have been updated successfully. But storage access errors means we are 160 // in a bad state anyway. 161 n, err := t.iterateBatch(kv, func(c byte, kv []keyValue) (int, error) { 162 child, n, err := t.putBatchIntoNode(b.Children[c], kv) 163 b.Children[c] = child 164 return n, err 165 }) 166 if inTrie && n != 0 { 167 b.invalidateCache() 168 } 169 170 // Even if some of the children can't be put, we need to try to strip the branch 171 // and possibly update the refcounts. 172 nd, bErr := t.stripBranch(b) 173 if err == nil { 174 err = bErr 175 } 176 return nd, n, err 177 } 178 179 // stripsBranch strips the branch node after incomplete batch put. 180 // It assumes there is no reference to b in the trie. 181 func (t *Trie) stripBranch(b *BranchNode) (Node, error) { 182 var n int 183 var lastIndex byte 184 for i := range b.Children { 185 if !isEmpty(b.Children[i]) { 186 n++ 187 lastIndex = byte(i) 188 } 189 } 190 switch { 191 case n == 0: 192 return EmptyNode{}, nil 193 case n == 1: 194 if lastIndex != lastChild { 195 return t.mergeExtension([]byte{lastIndex}, b.Children[lastIndex]) 196 } 197 return b.Children[lastIndex], nil 198 default: 199 t.addRef(b.Hash(), b.bytes) 200 return b, nil 201 } 202 } 203 204 func (t *Trie) iterateBatch(kv []keyValue, f func(c byte, kv []keyValue) (int, error)) (int, error) { 205 var n int 206 for len(kv) != 0 { 207 c, i := getLastIndex(kv) 208 if c != lastChild { 209 stripPrefix(1, kv[:i]) 210 } 211 sub, err := f(c, kv[:i]) 212 n += sub 213 if err != nil { 214 return n, err 215 } 216 kv = kv[i:] 217 } 218 return n, nil 219 } 220 221 func (t *Trie) putBatchIntoEmpty(kv []keyValue) (Node, int, error) { 222 common := lcpMany(kv) 223 stripPrefix(len(common), kv) 224 return t.newSubTrieMany(common, kv, nil) 225 } 226 227 func (t *Trie) putBatchIntoHash(curr *HashNode, kv []keyValue) (Node, int, error) { 228 result, err := t.getFromStore(curr.hash) 229 if err != nil { 230 return curr, 0, err 231 } 232 return t.putBatchIntoNode(result, kv) 233 } 234 235 // Creates a new subtrie from the provided key-value pairs. 236 // Items in kv must have no common prefix. 237 // If there are any deletions in kv, error is returned. 238 // kv is not empty. 239 // kv is sorted by key. 240 // value is the current value stored by prefix. 241 func (t *Trie) newSubTrieMany(prefix []byte, kv []keyValue, value []byte) (Node, int, error) { 242 if len(kv[0].key) == 0 { 243 if kv[0].value == nil { 244 if len(kv) == 1 { 245 return EmptyNode{}, 1, nil 246 } 247 node, n, err := t.newSubTrieMany(prefix, kv[1:], nil) 248 return node, n + 1, err 249 } 250 if len(kv) == 1 { 251 return t.newSubTrie(prefix, NewLeafNode(kv[0].value), true), 1, nil 252 } 253 value = kv[0].value 254 } 255 256 // Prefix is empty and we have at least 2 children. 257 b := NewBranchNode() 258 if value != nil { 259 // Empty key is always first. 260 leaf := NewLeafNode(value) 261 t.addRef(leaf.Hash(), leaf.bytes) 262 b.Children[lastChild] = leaf 263 } 264 nd, n, err := t.addToBranch(b, kv, false) 265 if err == nil { 266 nd, err = t.mergeExtension(prefix, nd) 267 } 268 return nd, n, err 269 } 270 271 func stripPrefix(n int, kv []keyValue) { 272 for i := range kv { 273 kv[i].key = kv[i].key[n:] 274 } 275 } 276 277 func getLastIndex(kv []keyValue) (byte, int) { 278 if len(kv[0].key) == 0 { 279 return lastChild, 1 280 } 281 c := kv[0].key[0] 282 for i := range kv[1:] { 283 if kv[i+1].key[0] != c { 284 return c, i + 1 285 } 286 } 287 return c, len(kv) 288 }