github.com/leovct/zkevm-bridge-service@v0.4.4/bridgectrl/merkletree.go (about) 1 package bridgectrl 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/0xPolygonHermez/zkevm-bridge-service/etherman" 8 "github.com/0xPolygonHermez/zkevm-bridge-service/log" 9 "github.com/0xPolygonHermez/zkevm-bridge-service/utils/gerror" 10 "github.com/ethereum/go-ethereum/common" 11 "github.com/jackc/pgx/v4" 12 ) 13 14 // zeroHashes is the pre-calculated zero hash array 15 var zeroHashes [][KeyLen]byte 16 17 // MerkleTree struct 18 type MerkleTree struct { 19 // store is the database storage to store all node data 20 store merkleTreeStore 21 network uint 22 // height is the depth of the merkle tree 23 height uint8 24 // count is the number of deposit 25 count uint 26 // siblings is the array of sibling of the last leaf added 27 siblings [][KeyLen]byte 28 } 29 30 func init() { 31 /* 32 * We set 64 levels because the height is not known yet. Also it is initialized here to avoid run this 33 * function twice (one for mainnetExitTree and another for RollupExitTree). 34 * If we receive a height of 32, we would need to use only the first 32 values of the array. 35 * If we need more level than 64 for the mt we need to edit this value here and set for example 128. 36 */ 37 zeroHashes = generateZeroHashes(64) // nolint 38 } 39 40 // NewMerkleTree creates new MerkleTree. 41 func NewMerkleTree(ctx context.Context, store merkleTreeStore, height uint8, network uint) (*MerkleTree, error) { 42 depositCnt, err := store.GetLastDepositCount(ctx, network, nil) 43 if err != nil { 44 if err != gerror.ErrStorageNotFound { 45 return nil, err 46 } 47 depositCnt = 0 48 } else { 49 depositCnt++ 50 } 51 52 mt := &MerkleTree{ 53 store: store, 54 network: network, 55 height: height, 56 count: depositCnt, 57 } 58 mt.siblings, err = mt.initSiblings(ctx, nil) 59 60 return mt, err 61 } 62 63 // initSiblings returns the siblings of the node at the given index. 64 // it is used to initialize the siblings array in the beginning. 65 func (mt *MerkleTree) initSiblings(ctx context.Context, dbTx pgx.Tx) ([][KeyLen]byte, error) { 66 var ( 67 left [KeyLen]byte 68 siblings [][KeyLen]byte 69 ) 70 71 if mt.count == 0 { 72 for h := 0; h < int(mt.height); h++ { 73 copy(left[:], zeroHashes[h][:]) 74 siblings = append(siblings, left) 75 } 76 return siblings, nil 77 } 78 79 root, err := mt.getRoot(ctx, dbTx) 80 if err != nil { 81 return nil, err 82 } 83 // index is the index of the last node 84 index := mt.count - 1 85 cur := root 86 87 // It starts in height-1 because 0 is the level of the leafs 88 for h := int(mt.height - 1); h >= 0; h-- { 89 value, err := mt.store.Get(ctx, cur, dbTx) 90 if err != nil { 91 return nil, fmt.Errorf("height: %d, cur: %v, error: %v", h, cur, err) 92 } 93 94 copy(left[:], value[0]) 95 // we will keep the left sibling of the last node 96 siblings = append(siblings, left) 97 98 if index&(1<<h) > 0 { 99 cur = value[1] 100 } else { 101 cur = value[0] 102 } 103 } 104 105 // We need to invert the siblings to go from leafs to the top 106 for st, en := 0, len(siblings)-1; st < en; st, en = st+1, en-1 { 107 siblings[st], siblings[en] = siblings[en], siblings[st] 108 } 109 110 return siblings, nil 111 } 112 113 func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLen]byte, index uint, dbTx pgx.Tx) error { 114 if index != mt.count { 115 return fmt.Errorf("mismatched deposit count: %d, expected: %d", index, mt.count) 116 } 117 cur := leaf 118 isFilledSubTree := true 119 120 var leaves [][][]byte 121 for h := uint8(0); h < mt.height; h++ { 122 if index&(1<<h) > 0 { 123 var child [KeyLen]byte 124 copy(child[:], cur[:]) 125 parent := Hash(mt.siblings[h], child) 126 cur = parent 127 leaves = append(leaves, [][]byte{parent[:], mt.siblings[h][:], child[:]}) 128 } else { 129 if isFilledSubTree { 130 // we will update the sibling when the sub tree is complete 131 copy(mt.siblings[h][:], cur[:]) 132 // we have a left child in this layer, it means the right child is empty so the sub tree is not completed 133 isFilledSubTree = false 134 } 135 var child [KeyLen]byte 136 copy(child[:], cur[:]) 137 parent := Hash(child, zeroHashes[h]) 138 cur = parent 139 // the sibling of 0 bit should be the zero hash, since we are in the last node of the tree 140 leaves = append(leaves, [][]byte{parent[:], child[:], zeroHashes[h][:]}) 141 } 142 } 143 144 err := mt.store.SetRoot(ctx, cur[:], depositID, mt.network, dbTx) 145 if err != nil { 146 return err 147 } 148 var nodes [][]interface{} 149 for _, leaf := range leaves { 150 nodes = append(nodes, []interface{}{leaf[0], [][]byte{leaf[1], leaf[2]}, depositID}) 151 } 152 if err := mt.store.BulkSet(ctx, nodes, dbTx); err != nil { 153 return err 154 } 155 156 mt.count++ 157 return nil 158 } 159 160 func (mt *MerkleTree) resetLeaf(ctx context.Context, depositCount uint, dbTx pgx.Tx) error { 161 var err error 162 mt.count = depositCount 163 mt.siblings, err = mt.initSiblings(ctx, dbTx) 164 return err 165 } 166 167 // this function is used to get the current root of the merkle tree 168 func (mt *MerkleTree) getRoot(ctx context.Context, dbTx pgx.Tx) ([]byte, error) { 169 if mt.count == 0 { 170 return zeroHashes[mt.height][:], nil 171 } 172 return mt.store.GetRoot(ctx, mt.count-1, mt.network, dbTx) 173 } 174 175 func buildIntermediate(leaves [][KeyLen]byte) ([][][]byte, [][32]byte) { 176 var ( 177 nodes [][][]byte 178 hashes [][KeyLen]byte 179 ) 180 for i := 0; i < len(leaves); i += 2 { 181 var left, right int = i, i + 1 182 hash := Hash(leaves[left], leaves[right]) 183 nodes = append(nodes, [][]byte{hash[:], leaves[left][:], leaves[right][:]}) 184 hashes = append(hashes, hash) 185 } 186 return nodes, hashes 187 } 188 189 func (mt *MerkleTree) updateLeaf(ctx context.Context, depositID uint64, leaves [][KeyLen]byte, dbTx pgx.Tx) error { 190 var ( 191 nodes [][][][]byte 192 ns [][][]byte 193 ) 194 initLeavesCount := uint(len(leaves)) 195 if len(leaves) == 0 { 196 leaves = append(leaves, zeroHashes[0]) 197 } 198 199 for h := uint8(0); h < mt.height; h++ { 200 if len(leaves)%2 == 1 { 201 leaves = append(leaves, zeroHashes[h]) 202 } 203 ns, leaves = buildIntermediate(leaves) 204 nodes = append(nodes, ns) 205 } 206 if len(ns) != 1 { 207 return fmt.Errorf("error: more than one root detected: %+v", nodes) 208 } 209 log.Debug("Root calculated: ", common.Bytes2Hex(ns[0][0])) 210 err := mt.store.SetRoot(ctx, ns[0][0], depositID, mt.network, dbTx) 211 if err != nil { 212 return err 213 } 214 var nodesToStore [][]interface{} 215 for _, leaves := range nodes { 216 for _, leaf := range leaves { 217 nodesToStore = append(nodesToStore, []interface{}{leaf[0], [][]byte{leaf[1], leaf[2]}, depositID}) 218 } 219 } 220 if err := mt.store.BulkSet(ctx, nodesToStore, dbTx); err != nil { 221 return err 222 } 223 mt.count = initLeavesCount 224 return nil 225 } 226 227 func (mt *MerkleTree) getLeaves(ctx context.Context, dbTx pgx.Tx) ([][KeyLen]byte, error) { 228 root, err := mt.getRoot(ctx, dbTx) 229 if err != nil { 230 return nil, err 231 } 232 cur := [][]byte{root} 233 // It starts in height-1 because 0 is the level of the leafs 234 for h := int(mt.height - 1); h >= 0; h-- { 235 var levelLeaves [][]byte 236 for _, c := range cur { 237 leaves, err := mt.store.Get(ctx, c, dbTx) 238 if err != nil { 239 var isZero bool 240 curHash := common.BytesToHash(c) 241 for _, h := range zeroHashes { 242 if common.BytesToHash(h[:]) == curHash { 243 isZero = true 244 } 245 } 246 if !isZero { 247 return nil, fmt.Errorf("height: %d, cur: %v, error: %v", h, cur, err) 248 } 249 } 250 levelLeaves = append(levelLeaves, leaves...) 251 } 252 cur = levelLeaves 253 } 254 var result [][KeyLen]byte 255 for _, l := range cur { 256 var aux [KeyLen]byte 257 copy(aux[:], l) 258 result = append(result, aux) 259 } 260 return result, nil 261 } 262 263 func (mt *MerkleTree) buildMTRoot(leaves [][KeyLen]byte) (common.Hash, error) { 264 var ( 265 nodes [][][][]byte 266 ns [][][]byte 267 ) 268 if len(leaves) == 0 { 269 leaves = append(leaves, zeroHashes[0]) 270 } 271 272 for h := uint8(0); h < mt.height; h++ { 273 if len(leaves)%2 == 1 { 274 leaves = append(leaves, zeroHashes[h]) 275 } 276 ns, leaves = buildIntermediate(leaves) 277 nodes = append(nodes, ns) 278 } 279 if len(ns) != 1 { 280 return common.Hash{}, fmt.Errorf("error: more than one root detected: %+v", nodes) 281 } 282 log.Debug("Root calculated: ", common.Bytes2Hex(ns[0][0])) 283 284 return common.BytesToHash(ns[0][0]), nil 285 } 286 287 func (mt MerkleTree) storeLeaves(ctx context.Context, leaves [][KeyLen]byte, blockID uint64, dbTx pgx.Tx) error { 288 root, err := mt.buildMTRoot(leaves) 289 if err != nil { 290 return err 291 } 292 // Check if root is already stored. If so, don't save the leaves because they are already stored on the db. 293 exist, err := mt.store.IsRollupExitRoot(ctx, root, dbTx) 294 if err != nil { 295 return err 296 } 297 if !exist { 298 var inserts [][]interface{} 299 for i := range leaves { 300 inserts = append(inserts, []interface{}{leaves[i][:], i + 1, root.Bytes(), blockID}) 301 } 302 if err := mt.store.AddRollupExitLeaves(ctx, inserts, dbTx); err != nil { 303 return err 304 } 305 } 306 return nil 307 } 308 309 // func (mt MerkleTree) getLatestRollupExitLeaves(ctx context.Context, dbTx pgx.Tx) ([]etherman.RollupExitLeaf, error) { 310 // return mt.store.GetLatestRollupExitLeaves(ctx, dbTx) 311 // } 312 313 func (mt MerkleTree) addRollupExitLeaf(ctx context.Context, rollupLeaf etherman.RollupExitLeaf, dbTx pgx.Tx) error { 314 storedRollupLeaves, err := mt.store.GetLatestRollupExitLeaves(ctx, dbTx) 315 if err != nil { 316 log.Error("error getting latest rollup exit leaves. Error: ", err) 317 return err 318 } 319 // If rollupLeaf.RollupId is lower or equal than len(storedRollupLeaves), we can add it in the proper position of the array 320 // if rollupLeaf.RollupId <= uint64(len(storedRollupLeaves)) { 321 // if storedRollupLeaves[rollupLeaf.RollupId-1].RollupId == rollupLeaf.RollupId { 322 // storedRollupLeaves[rollupLeaf.RollupId-1] = rollupLeaf 323 // } else { 324 // return fmt.Errorf("error: RollupId doesn't match") 325 // } 326 // } else { 327 328 // If rollupLeaf.RollupId is higher than len(storedRollupLeaves), We have to add empty rollups until the new rollupID 329 for i := len(storedRollupLeaves); i < int(rollupLeaf.RollupId); i++ { 330 storedRollupLeaves = append(storedRollupLeaves, etherman.RollupExitLeaf{ 331 BlockID: rollupLeaf.BlockID, 332 RollupId: uint(i + 1), 333 }) 334 } 335 if storedRollupLeaves[rollupLeaf.RollupId-1].RollupId == rollupLeaf.RollupId { 336 storedRollupLeaves[rollupLeaf.RollupId-1] = rollupLeaf 337 } else { 338 return fmt.Errorf("error: RollupId doesn't match") 339 } 340 // } 341 var leaves [][KeyLen]byte 342 for _, l := range storedRollupLeaves { 343 var aux [KeyLen]byte 344 copy(aux[:], l.Leaf[:]) 345 leaves = append(leaves, aux) 346 } 347 err = mt.storeLeaves(ctx, leaves, rollupLeaf.BlockID, dbTx) 348 if err != nil { 349 log.Error("error storing leaves. Error: ", err) 350 return err 351 } 352 return nil 353 } 354 355 func ComputeSiblings(rollupIndex uint, leaves [][KeyLen]byte, height uint8) ([][KeyLen]byte, common.Hash, error) { 356 var ns [][][]byte 357 if len(leaves) == 0 { 358 leaves = append(leaves, zeroHashes[0]) 359 } 360 var siblings [][KeyLen]byte 361 index := rollupIndex 362 for h := uint8(0); h < height; h++ { 363 if len(leaves)%2 == 1 { 364 leaves = append(leaves, zeroHashes[h]) 365 } 366 if index%2 == 1 { //If it is odd 367 siblings = append(siblings, leaves[index-1]) 368 } else { // It is even 369 if len(leaves) > 1 { 370 siblings = append(siblings, leaves[index+1]) 371 } 372 } 373 var ( 374 nsi [][][]byte 375 hashes [][KeyLen]byte 376 ) 377 for i := 0; i < len(leaves); i += 2 { 378 var left, right int = i, i + 1 379 hash := Hash(leaves[left], leaves[right]) 380 nsi = append(nsi, [][]byte{hash[:], leaves[left][:], leaves[right][:]}) 381 hashes = append(hashes, hash) 382 } 383 // Find the index of the leave in the next level of the tree. 384 // Divide the index by 2 to find the position in the upper level 385 index = uint(float64(index) / 2) //nolint:gomnd 386 ns = nsi 387 leaves = hashes 388 } 389 if len(ns) != 1 { 390 return nil, common.Hash{}, fmt.Errorf("error: more than one root detected: %+v", ns) 391 } 392 393 return siblings, common.BytesToHash(ns[0][0]), nil 394 } 395 396 func calculateRoot(leafHash common.Hash, smtProof [][KeyLen]byte, index uint, height uint8) common.Hash { 397 var node [KeyLen]byte 398 copy(node[:], leafHash[:]) 399 400 // Check merkle proof 401 var h uint8 402 for h = 0; h < height; h++ { 403 if ((index >> h) & 1) == 1 { 404 node = Hash(smtProof[h], node) 405 } else { 406 node = Hash(node, smtProof[h]) 407 } 408 } 409 return common.BytesToHash(node[:]) 410 }