github.com/devwanda/aphelion-staking@v0.33.9/crypto/merkle/simple_tree_test.go (about)

     1  package merkle
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/require"
     7  
     8  	tmrand "github.com/devwanda/aphelion-staking/libs/rand"
     9  	. "github.com/devwanda/aphelion-staking/libs/test"
    10  
    11  	"github.com/devwanda/aphelion-staking/crypto/tmhash"
    12  )
    13  
    14  type testItem []byte
    15  
    16  func (tI testItem) Hash() []byte {
    17  	return []byte(tI)
    18  }
    19  
    20  func TestSimpleProof(t *testing.T) {
    21  
    22  	total := 100
    23  
    24  	items := make([][]byte, total)
    25  	for i := 0; i < total; i++ {
    26  		items[i] = testItem(tmrand.Bytes(tmhash.Size))
    27  	}
    28  
    29  	rootHash := SimpleHashFromByteSlices(items)
    30  
    31  	rootHash2, proofs := SimpleProofsFromByteSlices(items)
    32  
    33  	require.Equal(t, rootHash, rootHash2, "Unmatched root hashes: %X vs %X", rootHash, rootHash2)
    34  
    35  	// For each item, check the trail.
    36  	for i, item := range items {
    37  		proof := proofs[i]
    38  
    39  		// Check total/index
    40  		require.Equal(t, proof.Index, i, "Unmatched indicies: %d vs %d", proof.Index, i)
    41  
    42  		require.Equal(t, proof.Total, total, "Unmatched totals: %d vs %d", proof.Total, total)
    43  
    44  		// Verify success
    45  		err := proof.Verify(rootHash, item)
    46  		require.NoError(t, err, "Verification failed: %v.", err)
    47  
    48  		// Trail too long should make it fail
    49  		origAunts := proof.Aunts
    50  		proof.Aunts = append(proof.Aunts, tmrand.Bytes(32))
    51  		err = proof.Verify(rootHash, item)
    52  		require.Error(t, err, "Expected verification to fail for wrong trail length")
    53  
    54  		proof.Aunts = origAunts
    55  
    56  		// Trail too short should make it fail
    57  		proof.Aunts = proof.Aunts[0 : len(proof.Aunts)-1]
    58  		err = proof.Verify(rootHash, item)
    59  		require.Error(t, err, "Expected verification to fail for wrong trail length")
    60  
    61  		proof.Aunts = origAunts
    62  
    63  		// Mutating the itemHash should make it fail.
    64  		err = proof.Verify(rootHash, MutateByteSlice(item))
    65  		require.Error(t, err, "Expected verification to fail for mutated leaf hash")
    66  
    67  		// Mutating the rootHash should make it fail.
    68  		err = proof.Verify(MutateByteSlice(rootHash), item)
    69  		require.Error(t, err, "Expected verification to fail for mutated root hash")
    70  	}
    71  }
    72  
    73  func TestSimpleHashAlternatives(t *testing.T) {
    74  
    75  	total := 100
    76  
    77  	items := make([][]byte, total)
    78  	for i := 0; i < total; i++ {
    79  		items[i] = testItem(tmrand.Bytes(tmhash.Size))
    80  	}
    81  
    82  	rootHash1 := SimpleHashFromByteSlicesIterative(items)
    83  	rootHash2 := SimpleHashFromByteSlices(items)
    84  	require.Equal(t, rootHash1, rootHash2, "Unmatched root hashes: %X vs %X", rootHash1, rootHash2)
    85  }
    86  
    87  func BenchmarkSimpleHashAlternatives(b *testing.B) {
    88  	total := 100
    89  
    90  	items := make([][]byte, total)
    91  	for i := 0; i < total; i++ {
    92  		items[i] = testItem(tmrand.Bytes(tmhash.Size))
    93  	}
    94  
    95  	b.ResetTimer()
    96  	b.Run("recursive", func(b *testing.B) {
    97  		for i := 0; i < b.N; i++ {
    98  			_ = SimpleHashFromByteSlices(items)
    99  		}
   100  	})
   101  
   102  	b.Run("iterative", func(b *testing.B) {
   103  		for i := 0; i < b.N; i++ {
   104  			_ = SimpleHashFromByteSlicesIterative(items)
   105  		}
   106  	})
   107  }
   108  
   109  func Test_getSplitPoint(t *testing.T) {
   110  	tests := []struct {
   111  		length int
   112  		want   int
   113  	}{
   114  		{1, 0},
   115  		{2, 1},
   116  		{3, 2},
   117  		{4, 2},
   118  		{5, 4},
   119  		{10, 8},
   120  		{20, 16},
   121  		{100, 64},
   122  		{255, 128},
   123  		{256, 128},
   124  		{257, 256},
   125  	}
   126  	for _, tt := range tests {
   127  		got := getSplitPoint(tt.length)
   128  		require.Equal(t, tt.want, got, "getSplitPoint(%d) = %v, want %v", tt.length, got, tt.want)
   129  	}
   130  }