github.com/zhiqiangxu/util@v0.0.0-20230112053021-0a7aee056cd5/mmr/mmr.go (about)

     1  package mmr
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  )
     7  
     8  // HashStore for store sequential hashes and fetch by index
     9  type HashStore interface {
    10  	Append(hashes []HashType) error
    11  	Flush() error
    12  	Close()
    13  	GetHash(offset uint64) (HashType, error)
    14  }
    15  
    16  // MMR for MerkleMoutainRange
    17  type MMR struct {
    18  	size          uint64
    19  	peaks         []HashType
    20  	rootHash      HashType
    21  	minPeakHeight int
    22  	hasher        Hasher
    23  	store         HashStore
    24  }
    25  
    26  var unknownHash HashType
    27  
    28  // NewMMR is ctor for MMR
    29  func NewMMR(size uint64, peaks []HashType, hasher Hasher, store HashStore) *MMR {
    30  	if hasher == nil {
    31  		hasher = defaultHasher
    32  	}
    33  	m := &MMR{hasher: hasher, store: store}
    34  	m.update(size, peaks)
    35  	return m
    36  }
    37  
    38  func (m *MMR) update(size uint64, peaks []HashType) {
    39  	if len(peaks) != peakCount(size) {
    40  		panic("number of peaks != peakCount")
    41  	}
    42  	m.size = size
    43  	m.peaks = peaks
    44  	m.minPeakHeight = minPeakHeight(size)
    45  	m.rootHash = unknownHash
    46  }
    47  
    48  // Size of mmr
    49  func (m *MMR) Size() uint64 {
    50  	return m.size
    51  }
    52  
    53  // Push a leaf
    54  func (m *MMR) Push(leaf []byte, wantAP bool) []HashType {
    55  	h := m.hasher.Leaf(leaf)
    56  	return m.PushHash(h, wantAP)
    57  }
    58  
    59  // Push a hash
    60  func (m *MMR) PushHash(h HashType, wantAP bool) (ap []HashType) {
    61  	psize := len(m.peaks)
    62  
    63  	if wantAP {
    64  		ap = make([]HashType, psize, psize)
    65  		// reverse
    66  		for i, v := range m.peaks {
    67  			ap[psize-i-1] = v
    68  		}
    69  	}
    70  
    71  	newHashes := []HashType{h}
    72  	m.minPeakHeight = 0
    73  	for s := m.size; s%2 == 1; s = s >> 1 {
    74  		m.minPeakHeight++
    75  		h = m.hasher.Node(m.peaks[psize-1], h)
    76  		newHashes = append(newHashes, h)
    77  		psize--
    78  	}
    79  
    80  	if m.store != nil {
    81  		m.store.Append(newHashes)
    82  		m.store.Flush()
    83  	}
    84  
    85  	m.size++
    86  	m.peaks = m.peaks[0:psize]
    87  	m.peaks = append(m.peaks, h)
    88  	m.rootHash = unknownHash
    89  
    90  	return
    91  }
    92  
    93  // Root returns root hash
    94  func (m *MMR) Root() HashType {
    95  
    96  	if m.rootHash == unknownHash {
    97  		if len(m.peaks) > 0 {
    98  			m.rootHash = bagPeaks(m.hasher, m.peaks)
    99  		} else {
   100  			m.rootHash = m.hasher.Empty()
   101  		}
   102  	}
   103  	return m.rootHash
   104  }
   105  
   106  var (
   107  	// ErrRootNotAvailableYet used by MMR
   108  	ErrRootNotAvailableYet = errors.New("not available yet")
   109  	// ErrHashStoreNotAvailable used by MMR
   110  	ErrHashStoreNotAvailable = errors.New("hash store not available")
   111  )
   112  
   113  // InclusionProof returns the audit path of ti wrt size
   114  func (m *MMR) InclusionProof(leafIdx, size uint64) (hashes []HashType, err error) {
   115  	if leafIdx >= size {
   116  		err = fmt.Errorf("wrong parameters")
   117  		return
   118  	} else if m.size < size {
   119  		err = ErrRootNotAvailableYet
   120  		return
   121  	} else if m.store == nil {
   122  		err = ErrHashStoreNotAvailable
   123  		return
   124  	}
   125  
   126  	var (
   127  		offset       uint64
   128  		leftPeakHash HashType
   129  	)
   130  	// need no proof if size is 1
   131  	//
   132  	// for size > 1, we want to:
   133  	// 1. locate the target moutain M leafIdx is in
   134  	// 2. bag the right peaks of moutain M
   135  	// 3. collect the preceding leaks of moutain M
   136  	// 4. collect the proof of leafIdx within moutain M
   137  	for size > 1 {
   138  		// if size is not 2^n, left peak of size/size-1 is the same
   139  		// if size is 2^n, left peak of size-1 decomposes to the left sub peak
   140  		//
   141  		// this trick unifies the process of finding proofs within one moutain and amoung mountains.
   142  		//
   143  		// it's based on the invariant that the graph can always be decomposed into a sub left mountain Msub and right side
   144  		//
   145  		// as long as there're no fewer than 2 leaves, whether it's completely balanced or not.
   146  		//
   147  		// if leafIdx is within Msub, we find the proof for Msub and bag it with the right side
   148  		//
   149  		// if leafIdx is out of Msub, we find the proof for the right side and bag it with the peak of Msub
   150  		lpLeaf := leftPeakLeaf(size - 1) // -1 for a proper one
   151  		if leafIdx < lpLeaf {
   152  			rightPeaks := getMoutainPeaks(size - lpLeaf)
   153  			rightHashes := make([]HashType, len(rightPeaks), len(rightPeaks))
   154  			for i := range rightPeaks {
   155  				rightPeaks[i] += offset + 2*lpLeaf - 1
   156  				rightHashes[i], err = m.store.GetHash(rightPeaks[i] - 1)
   157  				if err != nil {
   158  					return
   159  				}
   160  			}
   161  			baggedRightHash := bagPeaks(m.hasher, rightHashes)
   162  			hashes = append(hashes, baggedRightHash)
   163  			size = lpLeaf
   164  		} else {
   165  			offset += 2*lpLeaf - 1
   166  			leftPeakHash, err = m.store.GetHash(offset - 1)
   167  			if err != nil {
   168  				return
   169  			}
   170  			hashes = append(hashes, leftPeakHash)
   171  			leafIdx -= lpLeaf
   172  			size -= lpLeaf
   173  		}
   174  	}
   175  
   176  	// reverse
   177  	// https://github.com/golang/go/wiki/SliceTricks#reversing
   178  	length := len(hashes)
   179  	for i := length/2 - 1; i >= 0; i-- {
   180  		opp := length - 1 - i
   181  		hashes[i], hashes[opp] = hashes[opp], hashes[i]
   182  	}
   183  
   184  	return
   185  }
   186  
   187  func (m *MMR) VerifyInclusion(leafHash, rootHash HashType, leafIdx, size uint64, proof []HashType) (err error) {
   188  	if m.size < size {
   189  		err = ErrRootNotAvailableYet
   190  		return
   191  	}
   192  
   193  	calculatedHash := leafHash
   194  	lastNode := size - 1
   195  	idx := 0
   196  	proofLen := len(proof)
   197  
   198  	for lastNode > 0 {
   199  		if idx >= proofLen {
   200  			err = fmt.Errorf("Proof too short. expected %d, got %d", proofLength(leafIdx, size), proofLen)
   201  			return
   202  		}
   203  
   204  		if leafIdx%2 == 1 {
   205  			calculatedHash = m.hasher.Node(proof[idx], calculatedHash)
   206  			idx++
   207  		} else if leafIdx < lastNode {
   208  			calculatedHash = m.hasher.Node(calculatedHash, proof[idx])
   209  			idx++
   210  		}
   211  
   212  		leafIdx /= 2
   213  		lastNode /= 2
   214  	}
   215  
   216  	if idx < proofLen {
   217  		err = fmt.Errorf("Proof too long")
   218  		return
   219  	}
   220  
   221  	if rootHash != calculatedHash {
   222  		err = fmt.Errorf(
   223  			"Constructed root hash differs from provided root hash. Constructed: %x, Expected: %x",
   224  			calculatedHash, rootHash)
   225  		return
   226  	}
   227  	return
   228  }
   229  
   230  // FYI: https://tools.ietf.org/id/draft-ietf-trans-rfc6962-bis-27.html#rfc.section.2.1.4
   231  func (m *MMR) ConsistencyProof(l, n uint64) (hashes []HashType, err error) {
   232  	if m.store == nil {
   233  		err = ErrHashStoreNotAvailable
   234  		return
   235  	}
   236  
   237  	hashes, err = m.subproof(l, n, true)
   238  	return
   239  }
   240  
   241  func (m *MMR) subproof(l, n uint64, compeleteST bool) (hashes []HashType, err error) {
   242  
   243  	var hash HashType
   244  	offset := uint64(0)
   245  	for l < n {
   246  		k := leftPeakLeaf(n - 1)
   247  		if l <= k {
   248  			rightPeaks := getMoutainPeaks(n - k)
   249  			rightHashes := make([]HashType, len(rightPeaks), len(rightPeaks))
   250  			for i := range rightPeaks {
   251  				rightPeaks[i] = offset + 2*k - 1
   252  				rightHashes[i], err = m.store.GetHash(rightPeaks[i] - 1)
   253  				if err != nil {
   254  					return
   255  				}
   256  			}
   257  			baggedRightHash := bagPeaks(m.hasher, rightHashes)
   258  			hashes = append(hashes, baggedRightHash)
   259  			n = k
   260  		} else {
   261  			offset += k*2 - 1
   262  			hash, err = m.store.GetHash(offset - 1)
   263  			if err != nil {
   264  				return
   265  			}
   266  			hashes = append(hashes, hash)
   267  			l -= k
   268  			n -= k
   269  			compeleteST = false
   270  		}
   271  	}
   272  
   273  	if !compeleteST {
   274  		peaks := getMoutainPeaks(l)
   275  		if len(peaks) != 1 {
   276  			panic("bug in subproof")
   277  		}
   278  		hash, err = m.store.GetHash(peaks[0] + offset - 1)
   279  		if err != nil {
   280  			return
   281  		}
   282  		hashes = append(hashes, hash)
   283  	}
   284  
   285  	// reverse
   286  	// https://github.com/golang/go/wiki/SliceTricks#reversing
   287  	length := len(hashes)
   288  	for i := length/2 - 1; i >= 0; i-- {
   289  		opp := length - 1 - i
   290  		hashes[i], hashes[opp] = hashes[opp], hashes[i]
   291  	}
   292  	return
   293  }
   294  
   295  func (m *MMR) VerifyConsistency(oldTreeSize, newTreeSize uint64, oldRoot, newRoot HashType, proof []HashType) (err error) {
   296  	if oldTreeSize > newTreeSize {
   297  		err = fmt.Errorf("oldTreeSize > newTreeSize")
   298  		return
   299  	}
   300  
   301  	if oldTreeSize == newTreeSize {
   302  		return
   303  	}
   304  
   305  	if oldTreeSize == 0 {
   306  		return
   307  	}
   308  
   309  	first := oldTreeSize - 1
   310  	last := newTreeSize - 1
   311  
   312  	for first%2 == 1 {
   313  		first /= 2
   314  		last /= 2
   315  	}
   316  
   317  	lenp := len(proof)
   318  	if lenp == 0 {
   319  		err = errors.New("Wrong proof length")
   320  		return
   321  	}
   322  
   323  	pos := 0
   324  	var newHash, oldHash HashType
   325  
   326  	if first != 0 {
   327  		newHash = proof[pos]
   328  		oldHash = proof[pos]
   329  		pos += 1
   330  	} else {
   331  		newHash = oldRoot
   332  		oldHash = oldRoot
   333  	}
   334  
   335  	for first != 0 {
   336  		if first%2 == 1 {
   337  			if pos >= lenp {
   338  				err = errors.New("Wrong proof length")
   339  				return
   340  			}
   341  			// node is a right child: left sibling exists in both trees
   342  			nextNode := proof[pos]
   343  			pos += 1
   344  			oldHash = m.hasher.Node(nextNode, oldHash)
   345  			newHash = m.hasher.Node(nextNode, newHash)
   346  		} else if first < last {
   347  			if pos >= lenp {
   348  				err = errors.New("Wrong proof length")
   349  				return
   350  			}
   351  			// node is a left child: right sibling only exists in the newer tree
   352  			nextNode := proof[pos]
   353  			pos += 1
   354  			newHash = m.hasher.Node(nextNode, newHash)
   355  		}
   356  
   357  		first /= 2
   358  		last /= 2
   359  	}
   360  
   361  	for last != 0 {
   362  		if pos >= lenp {
   363  			err = errors.New("Wrong proof length")
   364  			return
   365  		}
   366  		nextNode := proof[pos]
   367  		pos += 1
   368  		newHash = m.hasher.Node(nextNode, newHash)
   369  		last /= 2
   370  	}
   371  
   372  	if newHash != newRoot {
   373  		err = errors.New(fmt.Sprintf(`Bad Merkle proof: second root hash does not match. 
   374  			Expected hash:%x, computed hash: %x`, newRoot, newHash))
   375  		return
   376  	} else if oldHash != oldRoot {
   377  		err = errors.New(fmt.Sprintf(`Inconsistency: first root hash does not match."
   378  			"Expected hash: %x, computed hash:%x`, oldRoot, oldHash))
   379  		return
   380  	}
   381  
   382  	if pos != lenp {
   383  		err = errors.New("Proof too long")
   384  		return
   385  	}
   386  
   387  	return
   388  }