github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/crypto/merkle/simple_tree_test.go (about)

     1  package merkle
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/require"
     7  
     8  	"github.com/gnolang/gno/tm2/pkg/crypto/tmhash"
     9  	"github.com/gnolang/gno/tm2/pkg/random"
    10  	"github.com/gnolang/gno/tm2/pkg/testutils"
    11  )
    12  
    13  type testItem []byte
    14  
    15  func (tI testItem) Hash() []byte {
    16  	return []byte(tI)
    17  }
    18  
    19  func TestSimpleProof(t *testing.T) {
    20  	t.Parallel()
    21  
    22  	total := 100
    23  
    24  	items := make([][]byte, total)
    25  	for i := 0; i < total; i++ {
    26  		items[i] = testItem(random.RandBytes(tmhash.Size))
    27  	}
    28  
    29  	rootHash := SimpleHashFromByteSlices(items)
    30  
    31  	rootHash2, proofs := SimpleProofsFromByteSlices(items)
    32  
    33  	require.Equal(t, rootHash, rootHash2, "Unmatched root hashes: %X vs %X", rootHash, rootHash2)
    34  
    35  	// For each item, check the trail.
    36  	for i, item := range items {
    37  		proof := proofs[i]
    38  
    39  		// Check total/index
    40  		require.Equal(t, proof.Index, i, "Unmatched indices: %d vs %d", proof.Index, i)
    41  
    42  		require.Equal(t, proof.Total, total, "Unmatched totals: %d vs %d", proof.Total, total)
    43  
    44  		// Verify success
    45  		err := proof.Verify(rootHash, item)
    46  		require.NoError(t, err, "Verification failed: %v.", err)
    47  
    48  		// Trail too long should make it fail
    49  		origAunts := proof.Aunts
    50  		proof.Aunts = append(proof.Aunts, random.RandBytes(32))
    51  		err = proof.Verify(rootHash, item)
    52  		require.Error(t, err, "Expected verification to fail for wrong trail length")
    53  
    54  		proof.Aunts = origAunts
    55  
    56  		// Trail too short should make it fail
    57  		proof.Aunts = proof.Aunts[0 : len(proof.Aunts)-1]
    58  		err = proof.Verify(rootHash, item)
    59  		require.Error(t, err, "Expected verification to fail for wrong trail length")
    60  
    61  		proof.Aunts = origAunts
    62  
    63  		// Mutating the itemHash should make it fail.
    64  		err = proof.Verify(rootHash, testutils.MutateByteSlice(item))
    65  		require.Error(t, err, "Expected verification to fail for mutated leaf hash")
    66  
    67  		// Mutating the rootHash should make it fail.
    68  		err = proof.Verify(testutils.MutateByteSlice(rootHash), item)
    69  		require.Error(t, err, "Expected verification to fail for mutated root hash")
    70  	}
    71  }
    72  
    73  func TestSimpleHashAlternatives(t *testing.T) {
    74  	t.Parallel()
    75  
    76  	total := 100
    77  
    78  	items := make([][]byte, total)
    79  	for i := 0; i < total; i++ {
    80  		items[i] = testItem(random.RandBytes(tmhash.Size))
    81  	}
    82  
    83  	rootHash1 := SimpleHashFromByteSlicesIterative(items)
    84  	rootHash2 := SimpleHashFromByteSlices(items)
    85  	require.Equal(t, rootHash1, rootHash2, "Unmatched root hashes: %X vs %X", rootHash1, rootHash2)
    86  }
    87  
    88  func BenchmarkSimpleHashAlternatives(b *testing.B) {
    89  	total := 100
    90  
    91  	items := make([][]byte, total)
    92  	for i := 0; i < total; i++ {
    93  		items[i] = testItem(random.RandBytes(tmhash.Size))
    94  	}
    95  
    96  	b.ResetTimer()
    97  	b.Run("recursive", func(b *testing.B) {
    98  		for i := 0; i < b.N; i++ {
    99  			_ = SimpleHashFromByteSlices(items)
   100  		}
   101  	})
   102  
   103  	b.Run("iterative", func(b *testing.B) {
   104  		for i := 0; i < b.N; i++ {
   105  			_ = SimpleHashFromByteSlicesIterative(items)
   106  		}
   107  	})
   108  }
   109  
   110  func Test_getSplitPoint(t *testing.T) {
   111  	t.Parallel()
   112  
   113  	tests := []struct {
   114  		length int
   115  		want   int
   116  	}{
   117  		{1, 0},
   118  		{2, 1},
   119  		{3, 2},
   120  		{4, 2},
   121  		{5, 4},
   122  		{10, 8},
   123  		{20, 16},
   124  		{100, 64},
   125  		{255, 128},
   126  		{256, 128},
   127  		{257, 256},
   128  	}
   129  	for _, tt := range tests {
   130  		got := getSplitPoint(tt.length)
   131  		require.Equal(t, tt.want, got, "getSplitPoint(%d) = %v, want %v", tt.length, got, tt.want)
   132  	}
   133  }