git.gammaspectra.live/P2Pool/consensus/v3@v3.8.0/monero/crypto/merkle.go (about)

     1  package crypto
     2  
     3  import (
     4  	"git.gammaspectra.live/P2Pool/consensus/v3/types"
     5  	"git.gammaspectra.live/P2Pool/consensus/v3/utils"
     6  	"git.gammaspectra.live/P2Pool/sha3"
     7  )
     8  
     9  type BinaryTreeHash []types.Hash
    10  
    11  func leafHash(data []types.Hash, hasher *sha3.HasherState) (rootHash types.Hash) {
    12  	switch len(data) {
    13  	case 0:
    14  		panic("unsupported length")
    15  	case 1:
    16  		return data[0]
    17  	default:
    18  		//only hash the next two items
    19  		hasher.Reset()
    20  		_, _ = hasher.Write(data[0][:])
    21  		_, _ = hasher.Write(data[1][:])
    22  		HashFastSum(hasher, rootHash[:])
    23  		return rootHash
    24  	}
    25  }
    26  
    27  // RootHash Calculates the Merkle root hash of the tree
    28  func (t BinaryTreeHash) RootHash() (rootHash types.Hash) {
    29  	hasher := GetKeccak256Hasher()
    30  	defer PutKeccak256Hasher(hasher)
    31  
    32  	count := len(t)
    33  	if count <= 2 {
    34  		return leafHash(t, hasher)
    35  	}
    36  
    37  	pow2cnt := utils.PreviousPowerOfTwo(uint64(count))
    38  	offset := pow2cnt*2 - count
    39  
    40  	temporaryTree := make(BinaryTreeHash, pow2cnt)
    41  	copy(temporaryTree, t[:offset])
    42  
    43  	//TODO: maybe can be done zero-alloc
    44  	//temporaryTree := t[:max(pow2cnt, offset)]
    45  
    46  	offsetTree := temporaryTree[offset:]
    47  	for i := range offsetTree {
    48  		offsetTree[i] = leafHash(t[offset+i*2:], hasher)
    49  	}
    50  
    51  	for pow2cnt >>= 1; pow2cnt > 1; pow2cnt >>= 1 {
    52  		for i := range temporaryTree[:pow2cnt] {
    53  			temporaryTree[i] = leafHash(temporaryTree[i*2:], hasher)
    54  		}
    55  	}
    56  
    57  	rootHash = leafHash(temporaryTree, hasher)
    58  
    59  	return
    60  }
    61  
    62  func (t BinaryTreeHash) MainBranch() (mainBranch []types.Hash) {
    63  	count := len(t)
    64  	if count <= 2 {
    65  		return nil
    66  	}
    67  
    68  	hasher := GetKeccak256Hasher()
    69  	defer PutKeccak256Hasher(hasher)
    70  
    71  	pow2cnt := utils.PreviousPowerOfTwo(uint64(count))
    72  	offset := pow2cnt*2 - count
    73  
    74  	temporaryTree := make(BinaryTreeHash, pow2cnt)
    75  	copy(temporaryTree, t[:offset])
    76  
    77  	offsetTree := temporaryTree[offset:]
    78  
    79  	for i := range offsetTree {
    80  		if (offset + i*2) == 0 {
    81  			mainBranch = append(mainBranch, t[1])
    82  		}
    83  		offsetTree[i] = leafHash(t[offset+i*2:], hasher)
    84  	}
    85  
    86  	for pow2cnt >>= 1; pow2cnt > 1; pow2cnt >>= 1 {
    87  		for i := range temporaryTree[:pow2cnt] {
    88  			if i == 0 {
    89  				mainBranch = append(mainBranch, temporaryTree[1])
    90  			}
    91  
    92  			temporaryTree[i] = leafHash(temporaryTree[i*2:], hasher)
    93  		}
    94  	}
    95  
    96  	mainBranch = append(mainBranch, temporaryTree[1])
    97  
    98  	return
    99  }
   100  
   101  type MerkleProof []types.Hash
   102  
   103  func (proof MerkleProof) Verify(h types.Hash, index, count int, rootHash types.Hash) bool {
   104  	return proof.GetRoot(h, index, count) == rootHash
   105  }
   106  
   107  func pairHash(index int, h, p types.Hash, hasher *sha3.HasherState) (out types.Hash) {
   108  	hasher.Reset()
   109  
   110  	if index&1 > 0 {
   111  		_, _ = hasher.Write(p[:])
   112  		_, _ = hasher.Write(h[:])
   113  	} else {
   114  		_, _ = hasher.Write(h[:])
   115  		_, _ = hasher.Write(p[:])
   116  	}
   117  
   118  	HashFastSum(hasher, out[:])
   119  	return out
   120  }
   121  
   122  func (proof MerkleProof) GetRoot(h types.Hash, index, count int) types.Hash {
   123  	if count == 1 {
   124  		return h
   125  	}
   126  
   127  	if index >= count {
   128  		return types.ZeroHash
   129  	}
   130  
   131  	hasher := GetKeccak256Hasher()
   132  	defer PutKeccak256Hasher(hasher)
   133  
   134  	if count == 2 {
   135  		if len(proof) == 0 {
   136  			return types.ZeroHash
   137  		}
   138  
   139  		h = pairHash(index, h, proof[0], hasher)
   140  	} else {
   141  		pow2cnt := utils.PreviousPowerOfTwo(uint64(count))
   142  		k := pow2cnt*2 - count
   143  
   144  		var proofIndex int
   145  
   146  		if index >= k {
   147  			index -= k
   148  
   149  			if len(proof) == 0 {
   150  				return types.ZeroHash
   151  			}
   152  
   153  			h = pairHash(index, h, proof[0], hasher)
   154  
   155  			index = (index >> 1) + k
   156  			proofIndex = 1
   157  
   158  		}
   159  
   160  		for ; pow2cnt >= 2; proofIndex, index, pow2cnt = proofIndex+1, index>>1, pow2cnt>>1 {
   161  			if proofIndex >= len(proof) {
   162  				return types.ZeroHash
   163  			}
   164  
   165  			h = pairHash(index, h, proof[proofIndex], hasher)
   166  		}
   167  	}
   168  
   169  	return h
   170  }