github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/merkletree2/proof_verifier.go (about)

     1  package merkletree2
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/keybase/client/go/libkb"
     7  	"github.com/keybase/client/go/logger"
     8  )
     9  
    10  type MerkleProofVerifier struct {
    11  	cfg Config
    12  }
    13  
    14  func NewMerkleProofVerifier(c Config) MerkleProofVerifier {
    15  	return MerkleProofVerifier{cfg: c}
    16  }
    17  
    18  func (m *MerkleProofVerifier) VerifyInclusionProof(ctx logger.ContextInterface, kvp KeyValuePair, proof *MerkleInclusionProof, expRootHash Hash) (err error) {
    19  	if kvp.Value == nil {
    20  		return NewProofVerificationFailedError(fmt.Errorf("Keys cannot have nil values in the tree"))
    21  	}
    22  	return m.verifyInclusionOrExclusionProof(ctx, kvp, proof, expRootHash)
    23  }
    24  
    25  // VerifyExclusionProof uses a MerkleInclusionProof to assert that a specific key is not part of the tree
    26  func (m *MerkleProofVerifier) VerifyExclusionProof(ctx logger.ContextInterface, k Key, proof *MerkleInclusionProof, expRootHash Hash) (err error) {
    27  	return m.verifyInclusionOrExclusionProof(ctx, KeyValuePair{Key: k}, proof, expRootHash)
    28  }
    29  
    30  // if kvp.Value == nil, this functions checks that kvp.Key is not included in the tree. Otherwise, it checks that kvp is included in the tree.
    31  func (m *MerkleProofVerifier) verifyInclusionOrExclusionProof(ctx logger.ContextInterface, kvp KeyValuePair, proof *MerkleInclusionProof, expRootHash Hash) (err error) {
    32  	if proof == nil {
    33  		return NewProofVerificationFailedError(fmt.Errorf("nil proof"))
    34  	}
    35  
    36  	var kvpHash Hash
    37  	// Hash the key value pair if necessary
    38  	if kvp.Value != nil {
    39  		kvpHash, err = m.cfg.Encoder.HashKeyValuePairWithKeySpecificSecret(kvp, proof.KeySpecificSecret)
    40  		if err != nil {
    41  			return NewProofVerificationFailedError(err)
    42  		}
    43  	}
    44  
    45  	if proof.RootMetadataNoHash.RootVersion != RootVersionV1 {
    46  		return NewProofVerificationFailedError(libkb.NewAppOutdatedError(fmt.Errorf("RootVersion %v is not supported (this client can only handle V1)", proof.RootMetadataNoHash.RootVersion)))
    47  	}
    48  
    49  	if len(kvp.Key) != m.cfg.KeysByteLength {
    50  		return NewProofVerificationFailedError(fmt.Errorf("Key has wrong length for this tree: %v (expected %v)", len(kvp.Key), m.cfg.KeysByteLength))
    51  	}
    52  
    53  	// inclusion proofs for existing values can have at most MaxValuesPerLeaf - 1
    54  	// other pairs in the leaf, while exclusion proofs can have at most
    55  	// MaxValuesPerLeaf.
    56  	if (kvp.Value != nil && len(proof.OtherPairsInLeaf)+1 > m.cfg.MaxValuesPerLeaf) || (kvp.Value == nil && len(proof.OtherPairsInLeaf) > m.cfg.MaxValuesPerLeaf) {
    57  		return NewProofVerificationFailedError(fmt.Errorf("Too many keys in leaf: %v > %v", len(proof.OtherPairsInLeaf)+1, m.cfg.MaxValuesPerLeaf))
    58  	}
    59  
    60  	// Reconstruct the leaf node if necessary
    61  	var nodeHash Hash
    62  	if kvp.Value != nil || proof.OtherPairsInLeaf != nil {
    63  		valueToInsert := false
    64  		leafHashesLength := len(proof.OtherPairsInLeaf)
    65  		if kvp.Value != nil {
    66  			leafHashesLength++
    67  			valueToInsert = true
    68  		}
    69  		leaf := Node{LeafHashes: make([]KeyHashPair, leafHashesLength)}
    70  
    71  		// LeafHashes is obtained by adding kvp into OtherPairsInLeaf while maintaining sorted order
    72  		for i, j := 0, 0; i < leafHashesLength; i++ {
    73  			if (j < len(proof.OtherPairsInLeaf) && valueToInsert && proof.OtherPairsInLeaf[j].Key.Cmp(kvp.Key) > 0) || j >= len(proof.OtherPairsInLeaf) {
    74  				leaf.LeafHashes[i] = KeyHashPair{Key: kvp.Key, Hash: kvpHash}
    75  				valueToInsert = false
    76  			} else {
    77  				leaf.LeafHashes[i] = proof.OtherPairsInLeaf[j]
    78  				j++
    79  			}
    80  
    81  			// Ensure all the KeyHashPairs in the leaf node are different
    82  			if i > 0 && leaf.LeafHashes[i-1].Key.Cmp(leaf.LeafHashes[i].Key) >= 0 {
    83  				return NewProofVerificationFailedError(fmt.Errorf("Error in Leaf Key ordering or duplicated key: %v >= %v", leaf.LeafHashes[i-1].Key, leaf.LeafHashes[i].Key))
    84  			}
    85  		}
    86  
    87  		// Recompute the hashes on the nodes on the path from the leaf to the root.
    88  		_, nodeHash, err = m.cfg.Encoder.EncodeAndHashGeneric(leaf)
    89  		if err != nil {
    90  			return NewProofVerificationFailedError(err)
    91  		}
    92  	}
    93  
    94  	sibH := proof.SiblingHashesOnPath
    95  	if len(sibH)%(m.cfg.ChildrenPerNode-1) != 0 {
    96  		return NewProofVerificationFailedError(fmt.Errorf("Invalid number of SiblingHashes %v", len(sibH)))
    97  	}
    98  	keyAsPos, err := m.cfg.getDeepestPositionForKey(kvp.Key)
    99  	if err != nil {
   100  		return NewProofVerificationFailedError(err)
   101  	}
   102  	leafPosition := m.cfg.getParentAtLevel(keyAsPos, uint(len(sibH)/(m.cfg.ChildrenPerNode-1)))
   103  
   104  	// recompute the hash of the root node by recreating all the internal nodes
   105  	// on the path from the leaf to the root.
   106  	i := 0
   107  	for _, childIndex := range m.cfg.positionToChildIndexPath(leafPosition) {
   108  		sibHAtLevel := sibH[i : i+m.cfg.ChildrenPerNode-1]
   109  
   110  		node := Node{INodes: make([]Hash, m.cfg.ChildrenPerNode)}
   111  		copy(node.INodes, sibHAtLevel[:int(childIndex)])
   112  		node.INodes[int(childIndex)] = nodeHash
   113  		copy(node.INodes[int(childIndex)+1:], sibHAtLevel[int(childIndex):])
   114  
   115  		i += m.cfg.ChildrenPerNode - 1
   116  		_, nodeHash, err = m.cfg.Encoder.EncodeAndHashGeneric(node)
   117  		if err != nil {
   118  			return NewProofVerificationFailedError(err)
   119  		}
   120  	}
   121  
   122  	// Compute the hash of the RootMetadata by filling in the BareRootHash
   123  	// with the value computed above.
   124  	rootMetadata := proof.RootMetadataNoHash
   125  	rootMetadata.BareRootHash = nodeHash
   126  	_, rootHash, err := m.cfg.Encoder.EncodeAndHashGeneric(rootMetadata)
   127  	if err != nil {
   128  		return NewProofVerificationFailedError(err)
   129  	}
   130  
   131  	// Check the rootHash computed matches the expected value.
   132  	if !rootHash.Equal(expRootHash) {
   133  		return NewProofVerificationFailedError(fmt.Errorf("expected rootHash does not match the computed one (for key: %X, value: %v): expected %X but got %X", kvp.Key, kvp.Value, expRootHash, rootHash))
   134  	}
   135  
   136  	// Success!
   137  	return nil
   138  }
   139  
   140  func (m *MerkleProofVerifier) computeSkipsHashForSeqno(s Seqno, skipsMap map[Seqno]Hash) (Hash, error) {
   141  	skipSeqnos := SkipPointersForSeqno(s)
   142  	skips := make([]Hash, len(skipSeqnos))
   143  	for i, s := range skipSeqnos {
   144  		skip, found := skipsMap[s]
   145  		if !found {
   146  			return nil, fmt.Errorf("the skipsMap in the proof does not contain necessary hash of seqno %v ", s)
   147  		}
   148  		skips[i] = skip
   149  	}
   150  	_, hash, err := m.cfg.Encoder.EncodeAndHashGeneric(skips)
   151  	if err != nil {
   152  		return nil, fmt.Errorf("Error encoding %+v: %v", skips, err)
   153  	}
   154  
   155  	return hash, nil
   156  }
   157  
   158  // computeFinalSkipPointersHashFromPath recomputes the SkipPointersHash for all
   159  // the seqnos on the SkipPointersPath(initialSeqno, finalSeqno), using the
   160  // information contained in the proof and returns the last one, as well as a
   161  // boolean indicating wether the proof is a full proof or is compressed (i.e. it
   162  // is part of a MerkleInclusionExtensionProof).
   163  func (m *MerkleProofVerifier) computeFinalSkipPointersHashFromPath(ctx logger.ContextInterface, proof *MerkleExtensionProof, initialSeqno Seqno, initialRootHash Hash, finalSeqno Seqno) (h Hash, isPartOfIncExtProof bool, err error) {
   164  	// This function is annotated with an inline example.
   165  
   166  	// We denote with Hi the hash of the RootMetadata Ri with seqno i.
   167  	// Example inputs:
   168  	// initialSeqno = 11
   169  	// initialRootHash = H11
   170  	// finalSeqno = 30
   171  	// the proof contains:
   172  	// - RootHashes = [ H8, H10, H14, H15, H24, H28, H29]
   173  	// - PreviousRootsNoSkips = [R12, R16, (R30)] Note that R30 here is optional.
   174  	//      We set isPartOfIncExtProof = true if it isn't there, but the output h is the same.
   175  
   176  	// Note that:
   177  	// SkipPointersPath(11,30) = [12,16,30]
   178  	// SkipPointersForSeqno(12) = [8,10,11]
   179  	// SkipPointersForSeqno(16) = [8,12,14,15]
   180  	// SkipPointersForSeqno(30) = [16,24,28,29]
   181  
   182  	rootHashMap := make(map[Seqno]Hash)
   183  	rootHashSeqnos, err := ComputeRootHashSeqnosNeededInExtensionProof(initialSeqno, finalSeqno)
   184  	// rootHashSeqnos = [8, 10, 14, 15, 24, 28, 29]
   185  	if err != nil {
   186  		return nil, false, NewProofVerificationFailedError(err)
   187  	}
   188  	if len(rootHashSeqnos) != len(proof.RootHashes) {
   189  		return nil, false, NewProofVerificationFailedError(fmt.Errorf("The proof does not have the expected number of root hashes: exp %v, got %v", len(rootHashSeqnos), len(proof.RootHashes)))
   190  	}
   191  	for i, s := range rootHashSeqnos {
   192  		rootHashMap[s] = proof.RootHashes[i]
   193  	}
   194  	// rootHashMap : { 8 -> H8, 10 -> H10, 14 -> H14, ... }
   195  
   196  	rootMap := make(map[Seqno]RootMetadata)
   197  	rootSeqnos, err := ComputeRootMetadataSeqnosNeededInExtensionProof(initialSeqno, finalSeqno, true)
   198  	// rootSeqnos = [12, 16]
   199  	if err != nil {
   200  		return nil, false, NewProofVerificationFailedError(err)
   201  	}
   202  	if len(rootSeqnos) == len(proof.PreviousRootsNoSkips) {
   203  		// compressed proof
   204  		isPartOfIncExtProof = true
   205  		// if len(proof.PreviousRootsNoSkips) == 2 (R30 is not there), we set isPartOfIncExtProof == true
   206  	} else if len(rootSeqnos) == len(proof.PreviousRootsNoSkips)-1 {
   207  		// full proof
   208  		isPartOfIncExtProof = false
   209  		// if len(proof.PreviousRootsNoSkips) == 3 (R30 is there), we set isPartOfIncExtProof == false
   210  	} else {
   211  		// invalid proof
   212  		return nil, false, NewProofVerificationFailedError(fmt.Errorf("The proof does not have the expected number of roots: exp %v or %v, got %v", len(rootSeqnos), len(rootSeqnos)+1, len(proof.PreviousRootsNoSkips)))
   213  	}
   214  
   215  	for i, s := range rootSeqnos {
   216  		if proof.PreviousRootsNoSkips[i].RootVersion != RootVersionV1 {
   217  			return nil, false, NewProofVerificationFailedError(libkb.NewAppOutdatedError(fmt.Errorf("computeFinalSkipPointersHashFromPath: RootVersion %v at seqno %v is not supported (this client can only handle V1)", proof.PreviousRootsNoSkips[i].RootVersion, s)))
   218  		}
   219  
   220  		rootMap[s] = proof.PreviousRootsNoSkips[i]
   221  	}
   222  	// rootMap : { 12 -> R12, 16 -> R16}
   223  
   224  	prevRootHash := initialRootHash
   225  	// prevRootHash = H11
   226  	prevSeqno := initialSeqno
   227  	// prevSeqno = 11
   228  
   229  	var currentSkipsHash Hash
   230  
   231  	skipPath, err := SkipPointersPath(initialSeqno, finalSeqno)
   232  	// SkipPointersPath(11,30) = [12,16,30]
   233  	if err != nil {
   234  		return nil, false, NewProofVerificationFailedError(err)
   235  	}
   236  	for i, currentSeqno := range skipPath {
   237  		// We annotate this loop for i = 0, currentSeqno = 12
   238  
   239  		rootHashMap[prevSeqno] = prevRootHash
   240  		// rootHashMap : { 11 -> H11, 8 -> H8, 10 -> H10, 14 -> H14, ... }
   241  
   242  		currentSkipsHash, err = m.computeSkipsHashForSeqno(currentSeqno, rootHashMap)
   243  		// currentSkipsHash = SHA( [ H8, H10, H11 ] )
   244  		// It is the expected value of SkipPointersHash for the root at currentSeqno = 12
   245  		if err != nil {
   246  			return nil, false, NewProofVerificationFailedError(err)
   247  		}
   248  
   249  		// the rest of the loop prepares for the next loop iteration, so we can
   250  		// skip it the last time.
   251  		if i == len(skipPath)-1 {
   252  			break
   253  		}
   254  
   255  		currentMeta := rootMap[currentSeqno]
   256  		// currentMeta = R12 (note that R12.SkipPointersHash = nil)
   257  		currentMeta.SkipPointersHash = currentSkipsHash
   258  		// set R12.SkipPointersHash to the value computed above
   259  		_, currRootHash, err := m.cfg.Encoder.EncodeAndHashGeneric(currentMeta)
   260  		// compute H12 (the expected hash of the root at seqno 12)
   261  		if err != nil {
   262  			return nil, false, NewProofVerificationFailedError(err)
   263  		}
   264  
   265  		prevSeqno = currentSeqno
   266  		// prevSeqno = 12
   267  		prevRootHash = currRootHash
   268  		// prevRootHash = H12
   269  	}
   270  	// At the end of the loop, currentSkipsHash contains the expected
   271  	// SkipPointersHash for the root at Seqno 30.
   272  
   273  	return currentSkipsHash, isPartOfIncExtProof, nil
   274  }
   275  
   276  // verifyExtensionProofFinal inserts skipsHash as the SkipsPointersHash in
   277  // rootMetadata, hashes it and checks that such hash matches expRootHash.
   278  func (m *MerkleProofVerifier) verifyExtensionProofFinal(ctx logger.ContextInterface, rootMetadata RootMetadata, skipsHash Hash, expRootHash Hash) error {
   279  	rootMetadata.SkipPointersHash = skipsHash
   280  	if rootMetadata.RootVersion != RootVersionV1 {
   281  		return NewProofVerificationFailedError(libkb.NewAppOutdatedError(fmt.Errorf("verifyExtensionProofFinal: RootVersion %v is not supported (this client can only handle V1)", rootMetadata.RootVersion)))
   282  	}
   283  
   284  	_, hash, err := m.cfg.Encoder.EncodeAndHashGeneric(rootMetadata)
   285  	if err != nil {
   286  		return NewProofVerificationFailedError(err)
   287  	}
   288  	if !hash.Equal(expRootHash) {
   289  		return NewProofVerificationFailedError(fmt.Errorf("verifyExtensionProofFinal: hash mismatch %X != %X", expRootHash, hash))
   290  	}
   291  	return nil
   292  }
   293  
   294  func (m *MerkleProofVerifier) VerifyExtensionProof(ctx logger.ContextInterface, proof *MerkleExtensionProof, initialSeqno Seqno, initialRootHash Hash, finalSeqno Seqno, expRootHash Hash) error {
   295  	// Optimization: if initialSeqno == finalSeqno it is enough to compare
   296  	// hashes, so if the proof is empty we can just do that.
   297  	if initialSeqno == finalSeqno && (proof == nil || (len(proof.PreviousRootsNoSkips) == 0 && len(proof.RootHashes) == 0)) {
   298  		if initialRootHash.Equal(expRootHash) {
   299  			return nil
   300  		}
   301  		return NewProofVerificationFailedError(fmt.Errorf("Hash mismatch: initialSeqno == finalSeqno == %v but %X != %X", initialSeqno, initialRootHash, expRootHash))
   302  	}
   303  
   304  	if proof == nil {
   305  		return NewProofVerificationFailedError(fmt.Errorf("nil proof"))
   306  	}
   307  
   308  	skipsHash, isPartOfIncExtProof, err := m.computeFinalSkipPointersHashFromPath(ctx, proof, initialSeqno, initialRootHash, finalSeqno)
   309  	// For exmaple, if finalSeqno = 30, then skipsHash will be the expected
   310  	// SkipPointersHash of the root at seqno 30 (i.e. SHA512([H16, H24, H28,
   311  	// H29]) where Hi is the hash of the RootMetadata at seqno i)
   312  
   313  	if err != nil {
   314  		return err
   315  	}
   316  	if isPartOfIncExtProof {
   317  		return NewProofVerificationFailedError(fmt.Errorf("The proof does not have the expected number of roots: it appears to be a compressed proof, but is not part of a MerkleInclusionExtensionProof"))
   318  	}
   319  
   320  	return m.verifyExtensionProofFinal(ctx, proof.PreviousRootsNoSkips[len(proof.PreviousRootsNoSkips)-1], skipsHash, expRootHash)
   321  	// For exmaple, if finalSeqno = 30, this function uses the skipsHash above,
   322  	// puts it inside RootMetadata at seqno 30, hashes it and checks the hash
   323  	// matches expRootHash.
   324  }
   325  
   326  func (m *MerkleProofVerifier) VerifyInclusionExtensionProof(ctx logger.ContextInterface, kvp KeyValuePair, proof *MerkleInclusionExtensionProof, initialSeqno Seqno, initialRootHash Hash, finalSeqno Seqno, expRootHash Hash) (err error) {
   327  	if proof == nil {
   328  		return NewProofVerificationFailedError(fmt.Errorf("nil proof"))
   329  	}
   330  
   331  	// Shallow copy so that we can modify some fields without altering the original proof.
   332  	incProof := proof.MerkleInclusionProof
   333  
   334  	switch incProof.RootMetadataNoHash.Seqno {
   335  	case finalSeqno:
   336  		// pass
   337  	case 0:
   338  		// the seqno can be omitted for efficiency.
   339  		incProof.RootMetadataNoHash.Seqno = finalSeqno
   340  	default:
   341  		return NewProofVerificationFailedError(fmt.Errorf("inclusion proof contains wrong Seqno: exp %v got %v", finalSeqno, incProof.RootMetadataNoHash.Seqno))
   342  	}
   343  
   344  	// If initialSeqno == finalSeqno, no extension proof is necessary so if it is not there we skip checking it.
   345  	if initialSeqno != finalSeqno || len(proof.MerkleExtensionProof.PreviousRootsNoSkips) > 0 || len(proof.MerkleExtensionProof.PreviousRootsNoSkips) > 0 {
   346  		skipsHashForNewRoot, isPartOfIncExtProof, err := m.computeFinalSkipPointersHashFromPath(ctx, &proof.MerkleExtensionProof, initialSeqno, initialRootHash, incProof.RootMetadataNoHash.Seqno)
   347  		if err != nil {
   348  			return err
   349  		}
   350  
   351  		if !isPartOfIncExtProof {
   352  			// The server did not compress the inner extension proof, so we check it.
   353  			err := m.verifyExtensionProofFinal(ctx, proof.MerkleExtensionProof.PreviousRootsNoSkips[len(proof.MerkleExtensionProof.PreviousRootsNoSkips)-1], skipsHashForNewRoot, expRootHash)
   354  			if err != nil {
   355  				return err
   356  			}
   357  		}
   358  
   359  		// the server can also send a proof with an empty
   360  		// proof.RootMetadataNoHash.SkipPointersHash. If this value is not sent, and
   361  		// there was some other tampering, the skipsHashForNewRoot will cause the
   362  		// inclusion proof to fail.
   363  		if incProof.RootMetadataNoHash.SkipPointersHash != nil && !incProof.RootMetadataNoHash.SkipPointersHash.Equal(skipsHashForNewRoot) {
   364  			return NewProofVerificationFailedError(
   365  				fmt.Errorf("extension proof failed: expected %X but got %X", skipsHashForNewRoot, incProof.RootMetadataNoHash.SkipPointersHash))
   366  		}
   367  		incProof.RootMetadataNoHash.SkipPointersHash = skipsHashForNewRoot
   368  	}
   369  
   370  	return m.VerifyInclusionProof(ctx, kvp, &incProof, expRootHash)
   371  }