github.com/leovct/zkevm-bridge-service@v0.4.4/bridgectrl/merkletree.go (about)

     1  package bridgectrl
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/0xPolygonHermez/zkevm-bridge-service/etherman"
     8  	"github.com/0xPolygonHermez/zkevm-bridge-service/log"
     9  	"github.com/0xPolygonHermez/zkevm-bridge-service/utils/gerror"
    10  	"github.com/ethereum/go-ethereum/common"
    11  	"github.com/jackc/pgx/v4"
    12  )
    13  
    14  // zeroHashes is the pre-calculated zero hash array
    15  var zeroHashes [][KeyLen]byte
    16  
    17  // MerkleTree struct
    18  type MerkleTree struct {
    19  	// store is the database storage to store all node data
    20  	store   merkleTreeStore
    21  	network uint
    22  	// height is the depth of the merkle tree
    23  	height uint8
    24  	// count is the number of deposit
    25  	count uint
    26  	// siblings is the array of sibling of the last leaf added
    27  	siblings [][KeyLen]byte
    28  }
    29  
    30  func init() {
    31  	/*
    32  	* We set 64 levels because the height is not known yet. Also it is initialized here to avoid run this
    33  	* function twice (one for mainnetExitTree and another for RollupExitTree).
    34  	* If we receive a height of 32, we would need to use only the first 32 values of the array.
    35  	* If we need more level than 64 for the mt we need to edit this value here and set for example 128.
    36  	 */
    37  	zeroHashes = generateZeroHashes(64) // nolint
    38  }
    39  
    40  // NewMerkleTree creates new MerkleTree.
    41  func NewMerkleTree(ctx context.Context, store merkleTreeStore, height uint8, network uint) (*MerkleTree, error) {
    42  	depositCnt, err := store.GetLastDepositCount(ctx, network, nil)
    43  	if err != nil {
    44  		if err != gerror.ErrStorageNotFound {
    45  			return nil, err
    46  		}
    47  		depositCnt = 0
    48  	} else {
    49  		depositCnt++
    50  	}
    51  
    52  	mt := &MerkleTree{
    53  		store:   store,
    54  		network: network,
    55  		height:  height,
    56  		count:   depositCnt,
    57  	}
    58  	mt.siblings, err = mt.initSiblings(ctx, nil)
    59  
    60  	return mt, err
    61  }
    62  
    63  // initSiblings returns the siblings of the node at the given index.
    64  // it is used to initialize the siblings array in the beginning.
    65  func (mt *MerkleTree) initSiblings(ctx context.Context, dbTx pgx.Tx) ([][KeyLen]byte, error) {
    66  	var (
    67  		left     [KeyLen]byte
    68  		siblings [][KeyLen]byte
    69  	)
    70  
    71  	if mt.count == 0 {
    72  		for h := 0; h < int(mt.height); h++ {
    73  			copy(left[:], zeroHashes[h][:])
    74  			siblings = append(siblings, left)
    75  		}
    76  		return siblings, nil
    77  	}
    78  
    79  	root, err := mt.getRoot(ctx, dbTx)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	// index is the index of the last node
    84  	index := mt.count - 1
    85  	cur := root
    86  
    87  	// It starts in height-1 because 0 is the level of the leafs
    88  	for h := int(mt.height - 1); h >= 0; h-- {
    89  		value, err := mt.store.Get(ctx, cur, dbTx)
    90  		if err != nil {
    91  			return nil, fmt.Errorf("height: %d, cur: %v, error: %v", h, cur, err)
    92  		}
    93  
    94  		copy(left[:], value[0])
    95  		// we will keep the left sibling of the last node
    96  		siblings = append(siblings, left)
    97  
    98  		if index&(1<<h) > 0 {
    99  			cur = value[1]
   100  		} else {
   101  			cur = value[0]
   102  		}
   103  	}
   104  
   105  	// We need to invert the siblings to go from leafs to the top
   106  	for st, en := 0, len(siblings)-1; st < en; st, en = st+1, en-1 {
   107  		siblings[st], siblings[en] = siblings[en], siblings[st]
   108  	}
   109  
   110  	return siblings, nil
   111  }
   112  
   113  func (mt *MerkleTree) addLeaf(ctx context.Context, depositID uint64, leaf [KeyLen]byte, index uint, dbTx pgx.Tx) error {
   114  	if index != mt.count {
   115  		return fmt.Errorf("mismatched deposit count: %d, expected: %d", index, mt.count)
   116  	}
   117  	cur := leaf
   118  	isFilledSubTree := true
   119  
   120  	var leaves [][][]byte
   121  	for h := uint8(0); h < mt.height; h++ {
   122  		if index&(1<<h) > 0 {
   123  			var child [KeyLen]byte
   124  			copy(child[:], cur[:])
   125  			parent := Hash(mt.siblings[h], child)
   126  			cur = parent
   127  			leaves = append(leaves, [][]byte{parent[:], mt.siblings[h][:], child[:]})
   128  		} else {
   129  			if isFilledSubTree {
   130  				// we will update the sibling when the sub tree is complete
   131  				copy(mt.siblings[h][:], cur[:])
   132  				// we have a left child in this layer, it means the right child is empty so the sub tree is not completed
   133  				isFilledSubTree = false
   134  			}
   135  			var child [KeyLen]byte
   136  			copy(child[:], cur[:])
   137  			parent := Hash(child, zeroHashes[h])
   138  			cur = parent
   139  			// the sibling of 0 bit should be the zero hash, since we are in the last node of the tree
   140  			leaves = append(leaves, [][]byte{parent[:], child[:], zeroHashes[h][:]})
   141  		}
   142  	}
   143  
   144  	err := mt.store.SetRoot(ctx, cur[:], depositID, mt.network, dbTx)
   145  	if err != nil {
   146  		return err
   147  	}
   148  	var nodes [][]interface{}
   149  	for _, leaf := range leaves {
   150  		nodes = append(nodes, []interface{}{leaf[0], [][]byte{leaf[1], leaf[2]}, depositID})
   151  	}
   152  	if err := mt.store.BulkSet(ctx, nodes, dbTx); err != nil {
   153  		return err
   154  	}
   155  
   156  	mt.count++
   157  	return nil
   158  }
   159  
   160  func (mt *MerkleTree) resetLeaf(ctx context.Context, depositCount uint, dbTx pgx.Tx) error {
   161  	var err error
   162  	mt.count = depositCount
   163  	mt.siblings, err = mt.initSiblings(ctx, dbTx)
   164  	return err
   165  }
   166  
   167  // this function is used to get the current root of the merkle tree
   168  func (mt *MerkleTree) getRoot(ctx context.Context, dbTx pgx.Tx) ([]byte, error) {
   169  	if mt.count == 0 {
   170  		return zeroHashes[mt.height][:], nil
   171  	}
   172  	return mt.store.GetRoot(ctx, mt.count-1, mt.network, dbTx)
   173  }
   174  
   175  func buildIntermediate(leaves [][KeyLen]byte) ([][][]byte, [][32]byte) {
   176  	var (
   177  		nodes  [][][]byte
   178  		hashes [][KeyLen]byte
   179  	)
   180  	for i := 0; i < len(leaves); i += 2 {
   181  		var left, right int = i, i + 1
   182  		hash := Hash(leaves[left], leaves[right])
   183  		nodes = append(nodes, [][]byte{hash[:], leaves[left][:], leaves[right][:]})
   184  		hashes = append(hashes, hash)
   185  	}
   186  	return nodes, hashes
   187  }
   188  
   189  func (mt *MerkleTree) updateLeaf(ctx context.Context, depositID uint64, leaves [][KeyLen]byte, dbTx pgx.Tx) error {
   190  	var (
   191  		nodes [][][][]byte
   192  		ns    [][][]byte
   193  	)
   194  	initLeavesCount := uint(len(leaves))
   195  	if len(leaves) == 0 {
   196  		leaves = append(leaves, zeroHashes[0])
   197  	}
   198  
   199  	for h := uint8(0); h < mt.height; h++ {
   200  		if len(leaves)%2 == 1 {
   201  			leaves = append(leaves, zeroHashes[h])
   202  		}
   203  		ns, leaves = buildIntermediate(leaves)
   204  		nodes = append(nodes, ns)
   205  	}
   206  	if len(ns) != 1 {
   207  		return fmt.Errorf("error: more than one root detected: %+v", nodes)
   208  	}
   209  	log.Debug("Root calculated: ", common.Bytes2Hex(ns[0][0]))
   210  	err := mt.store.SetRoot(ctx, ns[0][0], depositID, mt.network, dbTx)
   211  	if err != nil {
   212  		return err
   213  	}
   214  	var nodesToStore [][]interface{}
   215  	for _, leaves := range nodes {
   216  		for _, leaf := range leaves {
   217  			nodesToStore = append(nodesToStore, []interface{}{leaf[0], [][]byte{leaf[1], leaf[2]}, depositID})
   218  		}
   219  	}
   220  	if err := mt.store.BulkSet(ctx, nodesToStore, dbTx); err != nil {
   221  		return err
   222  	}
   223  	mt.count = initLeavesCount
   224  	return nil
   225  }
   226  
   227  func (mt *MerkleTree) getLeaves(ctx context.Context, dbTx pgx.Tx) ([][KeyLen]byte, error) {
   228  	root, err := mt.getRoot(ctx, dbTx)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	cur := [][]byte{root}
   233  	// It starts in height-1 because 0 is the level of the leafs
   234  	for h := int(mt.height - 1); h >= 0; h-- {
   235  		var levelLeaves [][]byte
   236  		for _, c := range cur {
   237  			leaves, err := mt.store.Get(ctx, c, dbTx)
   238  			if err != nil {
   239  				var isZero bool
   240  				curHash := common.BytesToHash(c)
   241  				for _, h := range zeroHashes {
   242  					if common.BytesToHash(h[:]) == curHash {
   243  						isZero = true
   244  					}
   245  				}
   246  				if !isZero {
   247  					return nil, fmt.Errorf("height: %d, cur: %v, error: %v", h, cur, err)
   248  				}
   249  			}
   250  			levelLeaves = append(levelLeaves, leaves...)
   251  		}
   252  		cur = levelLeaves
   253  	}
   254  	var result [][KeyLen]byte
   255  	for _, l := range cur {
   256  		var aux [KeyLen]byte
   257  		copy(aux[:], l)
   258  		result = append(result, aux)
   259  	}
   260  	return result, nil
   261  }
   262  
   263  func (mt *MerkleTree) buildMTRoot(leaves [][KeyLen]byte) (common.Hash, error) {
   264  	var (
   265  		nodes [][][][]byte
   266  		ns    [][][]byte
   267  	)
   268  	if len(leaves) == 0 {
   269  		leaves = append(leaves, zeroHashes[0])
   270  	}
   271  
   272  	for h := uint8(0); h < mt.height; h++ {
   273  		if len(leaves)%2 == 1 {
   274  			leaves = append(leaves, zeroHashes[h])
   275  		}
   276  		ns, leaves = buildIntermediate(leaves)
   277  		nodes = append(nodes, ns)
   278  	}
   279  	if len(ns) != 1 {
   280  		return common.Hash{}, fmt.Errorf("error: more than one root detected: %+v", nodes)
   281  	}
   282  	log.Debug("Root calculated: ", common.Bytes2Hex(ns[0][0]))
   283  
   284  	return common.BytesToHash(ns[0][0]), nil
   285  }
   286  
   287  func (mt MerkleTree) storeLeaves(ctx context.Context, leaves [][KeyLen]byte, blockID uint64, dbTx pgx.Tx) error {
   288  	root, err := mt.buildMTRoot(leaves)
   289  	if err != nil {
   290  		return err
   291  	}
   292  	// Check if root is already stored. If so, don't save the leaves because they are already stored on the db.
   293  	exist, err := mt.store.IsRollupExitRoot(ctx, root, dbTx)
   294  	if err != nil {
   295  		return err
   296  	}
   297  	if !exist {
   298  		var inserts [][]interface{}
   299  		for i := range leaves {
   300  			inserts = append(inserts, []interface{}{leaves[i][:], i + 1, root.Bytes(), blockID})
   301  		}
   302  		if err := mt.store.AddRollupExitLeaves(ctx, inserts, dbTx); err != nil {
   303  			return err
   304  		}
   305  	}
   306  	return nil
   307  }
   308  
   309  // func (mt MerkleTree) getLatestRollupExitLeaves(ctx context.Context, dbTx pgx.Tx) ([]etherman.RollupExitLeaf, error) {
   310  // 	return mt.store.GetLatestRollupExitLeaves(ctx, dbTx)
   311  // }
   312  
   313  func (mt MerkleTree) addRollupExitLeaf(ctx context.Context, rollupLeaf etherman.RollupExitLeaf, dbTx pgx.Tx) error {
   314  	storedRollupLeaves, err := mt.store.GetLatestRollupExitLeaves(ctx, dbTx)
   315  	if err != nil {
   316  		log.Error("error getting latest rollup exit leaves. Error: ", err)
   317  		return err
   318  	}
   319  	// If rollupLeaf.RollupId is lower or equal than len(storedRollupLeaves), we can add it in the proper position of the array
   320  	// if rollupLeaf.RollupId <= uint64(len(storedRollupLeaves)) {
   321  	// 	if storedRollupLeaves[rollupLeaf.RollupId-1].RollupId == rollupLeaf.RollupId {
   322  	// 		storedRollupLeaves[rollupLeaf.RollupId-1] = rollupLeaf
   323  	// 	} else {
   324  	// 		return fmt.Errorf("error: RollupId doesn't match")
   325  	// 	}
   326  	// } else {
   327  
   328  	// If rollupLeaf.RollupId is higher than len(storedRollupLeaves), We have to add empty rollups until the new rollupID
   329  	for i := len(storedRollupLeaves); i < int(rollupLeaf.RollupId); i++ {
   330  		storedRollupLeaves = append(storedRollupLeaves, etherman.RollupExitLeaf{
   331  			BlockID:  rollupLeaf.BlockID,
   332  			RollupId: uint(i + 1),
   333  		})
   334  	}
   335  	if storedRollupLeaves[rollupLeaf.RollupId-1].RollupId == rollupLeaf.RollupId {
   336  		storedRollupLeaves[rollupLeaf.RollupId-1] = rollupLeaf
   337  	} else {
   338  		return fmt.Errorf("error: RollupId doesn't match")
   339  	}
   340  	// }
   341  	var leaves [][KeyLen]byte
   342  	for _, l := range storedRollupLeaves {
   343  		var aux [KeyLen]byte
   344  		copy(aux[:], l.Leaf[:])
   345  		leaves = append(leaves, aux)
   346  	}
   347  	err = mt.storeLeaves(ctx, leaves, rollupLeaf.BlockID, dbTx)
   348  	if err != nil {
   349  		log.Error("error storing leaves. Error: ", err)
   350  		return err
   351  	}
   352  	return nil
   353  }
   354  
   355  func ComputeSiblings(rollupIndex uint, leaves [][KeyLen]byte, height uint8) ([][KeyLen]byte, common.Hash, error) {
   356  	var ns [][][]byte
   357  	if len(leaves) == 0 {
   358  		leaves = append(leaves, zeroHashes[0])
   359  	}
   360  	var siblings [][KeyLen]byte
   361  	index := rollupIndex
   362  	for h := uint8(0); h < height; h++ {
   363  		if len(leaves)%2 == 1 {
   364  			leaves = append(leaves, zeroHashes[h])
   365  		}
   366  		if index%2 == 1 { //If it is odd
   367  			siblings = append(siblings, leaves[index-1])
   368  		} else { // It is even
   369  			if len(leaves) > 1 {
   370  				siblings = append(siblings, leaves[index+1])
   371  			}
   372  		}
   373  		var (
   374  			nsi    [][][]byte
   375  			hashes [][KeyLen]byte
   376  		)
   377  		for i := 0; i < len(leaves); i += 2 {
   378  			var left, right int = i, i + 1
   379  			hash := Hash(leaves[left], leaves[right])
   380  			nsi = append(nsi, [][]byte{hash[:], leaves[left][:], leaves[right][:]})
   381  			hashes = append(hashes, hash)
   382  		}
   383  		// Find the index of the leave in the next level of the tree.
   384  		// Divide the index by 2 to find the position in the upper level
   385  		index = uint(float64(index) / 2) //nolint:gomnd
   386  		ns = nsi
   387  		leaves = hashes
   388  	}
   389  	if len(ns) != 1 {
   390  		return nil, common.Hash{}, fmt.Errorf("error: more than one root detected: %+v", ns)
   391  	}
   392  
   393  	return siblings, common.BytesToHash(ns[0][0]), nil
   394  }
   395  
   396  func calculateRoot(leafHash common.Hash, smtProof [][KeyLen]byte, index uint, height uint8) common.Hash {
   397  	var node [KeyLen]byte
   398  	copy(node[:], leafHash[:])
   399  
   400  	// Check merkle proof
   401  	var h uint8
   402  	for h = 0; h < height; h++ {
   403  		if ((index >> h) & 1) == 1 {
   404  			node = Hash(smtProof[h], node)
   405  		} else {
   406  			node = Hash(node, smtProof[h])
   407  		}
   408  	}
   409  	return common.BytesToHash(node[:])
   410  }