github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/protocol/bc/types/merkle.go (about)

     1  package types
     2  
     3  import (
     4  	"container/list"
     5  	"io"
     6  	"math"
     7  
     8  	"gopkg.in/fatih/set.v0"
     9  
    10  	"github.com/bytom/bytom/crypto/sha3pool"
    11  	"github.com/bytom/bytom/protocol/bc"
    12  )
    13  
    14  // merkleFlag represent the type of merkle tree node, it's used to generate the structure of merkle tree
    15  // Bitcoin has only two flags, which zero means the hash of assist node. And one means the hash of the related
    16  // transaction node or it's parents, which distinguish them according to the height of the tree. But in the bytom,
    17  // the height of transaction node is not fixed, so we need three flags to distinguish these nodes.
    18  const (
    19  	// FlagAssist represent assist node
    20  	FlagAssist = iota
    21  	// FlagTxParent represent the parent of transaction of node
    22  	FlagTxParent
    23  	// FlagTxLeaf represent transaction of node
    24  	FlagTxLeaf
    25  )
    26  
    27  var (
    28  	leafPrefix     = []byte{0x00}
    29  	interiorPrefix = []byte{0x01}
    30  )
    31  
    32  type merkleNode interface {
    33  	WriteTo(io.Writer) (int64, error)
    34  }
    35  
    36  func merkleRoot(nodes []merkleNode) (root bc.Hash, err error) {
    37  	switch {
    38  	case len(nodes) == 0:
    39  		return bc.EmptyStringHash, nil
    40  
    41  	case len(nodes) == 1:
    42  		root = leafMerkleHash(nodes[0])
    43  		return root, nil
    44  
    45  	default:
    46  		k := prevPowerOfTwo(len(nodes))
    47  		left, err := merkleRoot(nodes[:k])
    48  		if err != nil {
    49  			return root, err
    50  		}
    51  
    52  		right, err := merkleRoot(nodes[k:])
    53  		if err != nil {
    54  			return root, err
    55  		}
    56  
    57  		root = interiorMerkleHash(&left, &right)
    58  		return root, nil
    59  	}
    60  }
    61  
    62  func interiorMerkleHash(left merkleNode, right merkleNode) (hash bc.Hash) {
    63  	h := sha3pool.Get256()
    64  	defer sha3pool.Put256(h)
    65  	h.Write(interiorPrefix)
    66  	left.WriteTo(h)
    67  	right.WriteTo(h)
    68  	hash.ReadFrom(h)
    69  	return hash
    70  }
    71  
    72  func leafMerkleHash(node merkleNode) (hash bc.Hash) {
    73  	h := sha3pool.Get256()
    74  	defer sha3pool.Put256(h)
    75  	h.Write(leafPrefix)
    76  	node.WriteTo(h)
    77  	hash.ReadFrom(h)
    78  	return hash
    79  }
    80  
    81  type merkleTreeNode struct {
    82  	hash  bc.Hash
    83  	left  *merkleTreeNode
    84  	right *merkleTreeNode
    85  }
    86  
    87  // buildMerkleTree construct a merkle tree based on the provide node data
    88  func buildMerkleTree(rawDatas []merkleNode) *merkleTreeNode {
    89  	switch len(rawDatas) {
    90  	case 0:
    91  		return nil
    92  	case 1:
    93  		rawData := rawDatas[0]
    94  		merkleHash := leafMerkleHash(rawData)
    95  		node := newMerkleTreeNode(merkleHash, nil, nil)
    96  		return node
    97  	default:
    98  		k := prevPowerOfTwo(len(rawDatas))
    99  		left := buildMerkleTree(rawDatas[:k])
   100  		right := buildMerkleTree(rawDatas[k:])
   101  		merkleHash := interiorMerkleHash(&left.hash, &right.hash)
   102  		node := newMerkleTreeNode(merkleHash, left, right)
   103  		return node
   104  	}
   105  }
   106  
   107  func (node *merkleTreeNode) getMerkleTreeProof(merkleHashSet *set.Set) ([]*bc.Hash, []uint8) {
   108  	var hashes []*bc.Hash
   109  	var flags []uint8
   110  
   111  	if node.left == nil && node.right == nil {
   112  		if key := node.hash.String(); merkleHashSet.Has(key) {
   113  			hashes = append(hashes, &node.hash)
   114  			flags = append(flags, FlagTxLeaf)
   115  			return hashes, flags
   116  		}
   117  		return hashes, flags
   118  	}
   119  	var leftHashes, rightHashes []*bc.Hash
   120  	var leftFlags, rightFlags []uint8
   121  	if node.left != nil {
   122  		leftHashes, leftFlags = node.left.getMerkleTreeProof(merkleHashSet)
   123  	}
   124  	if node.right != nil {
   125  		rightHashes, rightFlags = node.right.getMerkleTreeProof(merkleHashSet)
   126  	}
   127  	leftFind, rightFind := len(leftHashes) > 0, len(rightHashes) > 0
   128  
   129  	if leftFind || rightFind {
   130  		flags = append(flags, FlagTxParent)
   131  	} else {
   132  		return hashes, flags
   133  	}
   134  
   135  	if leftFind {
   136  		hashes = append(hashes, leftHashes...)
   137  		flags = append(flags, leftFlags...)
   138  	} else {
   139  		hashes = append(hashes, &node.left.hash)
   140  		flags = append(flags, FlagAssist)
   141  	}
   142  
   143  	if rightFind {
   144  		hashes = append(hashes, rightHashes...)
   145  		flags = append(flags, rightFlags...)
   146  	} else {
   147  		hashes = append(hashes, &node.right.hash)
   148  		flags = append(flags, FlagAssist)
   149  	}
   150  	return hashes, flags
   151  }
   152  
   153  func getMerkleTreeProof(rawDatas []merkleNode, relatedRawDatas []merkleNode) ([]*bc.Hash, []uint8) {
   154  	merkleTree := buildMerkleTree(rawDatas)
   155  	if merkleTree == nil {
   156  		return []*bc.Hash{}, []uint8{}
   157  	}
   158  	merkleHashSet := set.New()
   159  	for _, data := range relatedRawDatas {
   160  		merkleHash := leafMerkleHash(data)
   161  		merkleHashSet.Add(merkleHash.String())
   162  	}
   163  	if merkleHashSet.Size() == 0 {
   164  		return []*bc.Hash{&merkleTree.hash}, []uint8{FlagAssist}
   165  	}
   166  	return merkleTree.getMerkleTreeProof(merkleHashSet)
   167  }
   168  
   169  func (node *merkleTreeNode) getMerkleTreeProofByFlags(flagList *list.List) []*bc.Hash {
   170  	var hashes []*bc.Hash
   171  
   172  	if flagList.Len() == 0 {
   173  		return hashes
   174  	}
   175  	flagEle := flagList.Front()
   176  	flag := flagEle.Value.(uint8)
   177  	flagList.Remove(flagEle)
   178  
   179  	if flag == FlagTxLeaf || flag == FlagAssist {
   180  		hashes = append(hashes, &node.hash)
   181  		return hashes
   182  	}
   183  	if node.left != nil {
   184  		leftHashes := node.left.getMerkleTreeProofByFlags(flagList)
   185  		hashes = append(hashes, leftHashes...)
   186  	}
   187  	if node.right != nil {
   188  		rightHashes := node.right.getMerkleTreeProofByFlags(flagList)
   189  		hashes = append(hashes, rightHashes...)
   190  	}
   191  	return hashes
   192  }
   193  
   194  func getMerkleTreeProofByFlags(rawDatas []merkleNode, flagList *list.List) []*bc.Hash {
   195  	tree := buildMerkleTree(rawDatas)
   196  	return tree.getMerkleTreeProofByFlags(flagList)
   197  }
   198  
   199  // GetTxMerkleTreeProof return a proof of merkle tree, which used to proof the transaction does
   200  // exist in the merkle tree
   201  func GetTxMerkleTreeProof(txs []*Tx, relatedTxs []*Tx) ([]*bc.Hash, []uint8) {
   202  	var rawDatas []merkleNode
   203  	var relatedRawDatas []merkleNode
   204  	for _, tx := range txs {
   205  		rawDatas = append(rawDatas, &tx.ID)
   206  	}
   207  	for _, relatedTx := range relatedTxs {
   208  		relatedRawDatas = append(relatedRawDatas, &relatedTx.ID)
   209  	}
   210  	return getMerkleTreeProof(rawDatas, relatedRawDatas)
   211  }
   212  
   213  // GetStatusMerkleTreeProof return a proof of merkle tree, which used to proof the status of transaction is valid
   214  func GetStatusMerkleTreeProof(statuses []*bc.TxVerifyResult, flags []uint8) []*bc.Hash {
   215  	var rawDatas []merkleNode
   216  	for _, status := range statuses {
   217  		rawDatas = append(rawDatas, status)
   218  	}
   219  	flagList := list.New()
   220  	for _, flag := range flags {
   221  		flagList.PushBack(flag)
   222  	}
   223  	return getMerkleTreeProofByFlags(rawDatas, flagList)
   224  }
   225  
   226  // getMerkleRootByProof caculate the merkle root hash according to the proof
   227  func getMerkleRootByProof(hashList *list.List, flagList *list.List, merkleHashes *list.List) bc.Hash {
   228  	if flagList.Len() == 0 || hashList.Len() == 0 {
   229  		return bc.EmptyStringHash
   230  	}
   231  	flagEle := flagList.Front()
   232  	flag := flagEle.Value.(uint8)
   233  	flagList.Remove(flagEle)
   234  	switch flag {
   235  	case FlagAssist:
   236  		{
   237  			hash := hashList.Front()
   238  			hashList.Remove(hash)
   239  			return hash.Value.(bc.Hash)
   240  		}
   241  	case FlagTxLeaf:
   242  		{
   243  			if merkleHashes.Len() == 0 {
   244  				return bc.EmptyStringHash
   245  			}
   246  			hashEle := hashList.Front()
   247  			hash := hashEle.Value.(bc.Hash)
   248  			relatedHashEle := merkleHashes.Front()
   249  			relatedHash := relatedHashEle.Value.(bc.Hash)
   250  			if hash == relatedHash {
   251  				hashList.Remove(hashEle)
   252  				merkleHashes.Remove(relatedHashEle)
   253  				return hash
   254  			}
   255  		}
   256  	case FlagTxParent:
   257  		{
   258  			leftHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
   259  			rightHash := getMerkleRootByProof(hashList, flagList, merkleHashes)
   260  			hash := interiorMerkleHash(&leftHash, &rightHash)
   261  			return hash
   262  		}
   263  	}
   264  	return bc.EmptyStringHash
   265  }
   266  
   267  func newMerkleTreeNode(merkleHash bc.Hash, left *merkleTreeNode, right *merkleTreeNode) *merkleTreeNode {
   268  	return &merkleTreeNode{
   269  		hash:  merkleHash,
   270  		left:  left,
   271  		right: right,
   272  	}
   273  }
   274  
   275  // ValidateMerkleTreeProof caculate the merkle root according to the hash of node and the flags
   276  // only if the merkle root by caculated equals to the specify merkle root, and the merkle tree
   277  // contains all of the related raw datas, the validate result will be true.
   278  func validateMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedNodes []merkleNode, merkleRoot bc.Hash) bool {
   279  	merkleHashes := list.New()
   280  	for _, relatedNode := range relatedNodes {
   281  		merkleHashes.PushBack(leafMerkleHash(relatedNode))
   282  	}
   283  	hashList := list.New()
   284  	for _, hash := range hashes {
   285  		hashList.PushBack(*hash)
   286  	}
   287  	flagList := list.New()
   288  	for _, flag := range flags {
   289  		flagList.PushBack(flag)
   290  	}
   291  	root := getMerkleRootByProof(hashList, flagList, merkleHashes)
   292  	return root == merkleRoot && merkleHashes.Len() == 0
   293  }
   294  
   295  // ValidateTxMerkleTreeProof validate the merkle tree of transactions
   296  func ValidateTxMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedHashes []*bc.Hash, merkleRoot bc.Hash) bool {
   297  	var relatedNodes []merkleNode
   298  	for _, hash := range relatedHashes {
   299  		relatedNodes = append(relatedNodes, hash)
   300  	}
   301  	return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
   302  }
   303  
   304  // ValidateStatusMerkleTreeProof validate the merkle tree of transaction status
   305  func ValidateStatusMerkleTreeProof(hashes []*bc.Hash, flags []uint8, relatedStatus []*bc.TxVerifyResult, merkleRoot bc.Hash) bool {
   306  	var relatedNodes []merkleNode
   307  	for _, result := range relatedStatus {
   308  		relatedNodes = append(relatedNodes, result)
   309  	}
   310  	return validateMerkleTreeProof(hashes, flags, relatedNodes, merkleRoot)
   311  }
   312  
   313  // TxStatusMerkleRoot creates a merkle tree from a slice of bc.TxVerifyResult
   314  func TxStatusMerkleRoot(tvrs []*bc.TxVerifyResult) (root bc.Hash, err error) {
   315  	nodes := []merkleNode{}
   316  	for _, tvr := range tvrs {
   317  		nodes = append(nodes, tvr)
   318  	}
   319  	return merkleRoot(nodes)
   320  }
   321  
   322  // TxMerkleRoot creates a merkle tree from a slice of transactions
   323  // and returns the root hash of the tree.
   324  func TxMerkleRoot(transactions []*bc.Tx) (root bc.Hash, err error) {
   325  	nodes := []merkleNode{}
   326  	for _, tx := range transactions {
   327  		nodes = append(nodes, &tx.ID)
   328  	}
   329  	return merkleRoot(nodes)
   330  }
   331  
   332  // prevPowerOfTwo returns the largest power of two that is smaller than a given number.
   333  // In other words, for some input n, the prevPowerOfTwo k is a power of two such that
   334  // k < n <= 2k. This is a helper function used during the calculation of a merkle tree.
   335  func prevPowerOfTwo(n int) int {
   336  	// If the number is a power of two, divide it by 2 and return.
   337  	if n&(n-1) == 0 {
   338  		return n / 2
   339  	}
   340  
   341  	// Otherwise, find the previous PoT.
   342  	exponent := uint(math.Log2(float64(n)))
   343  	return 1 << exponent // 2^exponent
   344  }