github.com/ethersphere/bee/v2@v2.2.0/pkg/bmt/reference/reference_test.go (about) 1 // Copyright 2021 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package reference_test 6 7 import ( 8 "bytes" 9 crand "crypto/rand" 10 "fmt" 11 "hash" 12 "io" 13 "testing" 14 15 "github.com/ethersphere/bee/v2/pkg/bmt/reference" 16 17 "golang.org/x/crypto/sha3" 18 ) 19 20 // calculates the hash of the data using hash.Hash 21 func doSum(h hash.Hash, b []byte, data ...[]byte) ([]byte, error) { 22 h.Reset() 23 for _, v := range data { 24 var err error 25 _, err = h.Write(v) 26 if err != nil { 27 return nil, err 28 } 29 } 30 return h.Sum(b), nil 31 } 32 33 // calculates the Keccak256 SHA3 hash of the data 34 func sha3hash(t *testing.T, data ...[]byte) []byte { 35 t.Helper() 36 h := sha3.NewLegacyKeccak256() 37 r, err := doSum(h, nil, data...) 38 if err != nil { 39 t.Fatal(err) 40 } 41 return r 42 } 43 44 // TestRefHasher tests that the RefHasher computes the expected BMT hash for some small data lengths. 45 func TestRefHasher(t *testing.T) { 46 t.Parallel() 47 48 // the test struct is used to specify the expected BMT hash for 49 // segment counts between from and to and lengths from 1 to datalength 50 for _, x := range []struct { 51 from int 52 to int 53 expected func([]byte) []byte 54 }{ 55 // all lengths in [0,64] should be: 56 // 57 // sha3hash(data) 58 // 59 { 60 from: 1, 61 to: 2, 62 expected: func(d []byte) []byte { 63 data := make([]byte, 64) 64 copy(data, d) 65 return sha3hash(t, data) 66 }, 67 }, 68 // all lengths in [3,4] should be: 69 // 70 // sha3hash( 71 // sha3hash(data[:64]) 72 // sha3hash(data[64:]) 73 // ) 74 // 75 { 76 from: 3, 77 to: 4, 78 expected: func(d []byte) []byte { 79 data := make([]byte, 128) 80 copy(data, d) 81 return sha3hash(t, sha3hash(t, data[:64]), sha3hash(t, data[64:])) 82 }, 83 }, 84 // all bmttestutil.SegmentCounts in [5,8] should be: 85 // 86 // sha3hash( 87 // sha3hash( 88 // sha3hash(data[:64]) 89 // sha3hash(data[64:128]) 90 // ) 91 // sha3hash( 92 // sha3hash(data[128:192]) 93 // sha3hash(data[192:]) 94 // ) 95 // ) 96 // 97 { 98 from: 5, 99 to: 8, 100 expected: func(d []byte) []byte { 101 data := make([]byte, 256) 102 copy(data, d) 103 return sha3hash(t, sha3hash(t, sha3hash(t, data[:64]), sha3hash(t, data[64:128])), sha3hash(t, sha3hash(t, data[128:192]), sha3hash(t, data[192:]))) 104 }, 105 }, 106 } { 107 for segCount := x.from; segCount <= x.to; segCount++ { 108 for length := 1; length <= segCount*32; length++ { 109 length, segCount, x := length, segCount, x 110 111 t.Run(fmt.Sprintf("%d_segments_%d_bytes", segCount, length), func(t *testing.T) { 112 t.Parallel() 113 114 data := make([]byte, length) 115 _, err := io.ReadFull(crand.Reader, data) 116 if err != nil { 117 t.Fatal(err) 118 } 119 expected := x.expected(data) 120 actual, err := reference.NewRefHasher(sha3.NewLegacyKeccak256(), segCount).Hash(data) 121 if err != nil { 122 t.Fatal(err) 123 } 124 if !bytes.Equal(actual, expected) { 125 t.Fatalf("expected %x, got %x", expected, actual) 126 } 127 }) 128 } 129 } 130 } 131 }