github.com/ethersphere/bee/v2@v2.2.0/pkg/bmt/bmt_test.go (about) 1 // Copyright 2021 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bmt_test 6 7 import ( 8 "bytes" 9 "context" 10 "fmt" 11 "math/rand" 12 "sort" 13 "testing" 14 "time" 15 16 "github.com/ethersphere/bee/v2/pkg/bmt" 17 "github.com/ethersphere/bee/v2/pkg/bmt/reference" 18 "github.com/ethersphere/bee/v2/pkg/swarm" 19 "github.com/ethersphere/bee/v2/pkg/util/testutil" 20 "golang.org/x/sync/errgroup" 21 ) 22 23 const ( 24 // testPoolSize is the number of bmt trees the pool keeps when 25 testPoolSize = 16 26 // segmentCount is the maximum number of segments of the underlying chunk 27 // Should be equal to max-chunk-data-size / hash-size 28 // Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size) 29 testSegmentCount = 128 30 ) 31 32 var ( 33 testSegmentCounts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} 34 hashSize = swarm.NewHasher().Size() 35 seed = time.Now().Unix() 36 ) 37 38 func refHash(count int, data []byte) ([]byte, error) { 39 rbmt := reference.NewRefHasher(swarm.NewHasher(), count) 40 refNoMetaHash, err := rbmt.Hash(data) 41 if err != nil { 42 return nil, err 43 } 44 return bmt.Sha3hash(bmt.LengthToSpan(int64(len(data))), refNoMetaHash) 45 } 46 47 // syncHash hashes the data and the span using the bmt hasher 48 func syncHash(h *bmt.Hasher, data []byte) ([]byte, error) { 49 h.Reset() 50 h.SetHeaderInt64(int64(len(data))) 51 _, err := h.Write(data) 52 if err != nil { 53 return nil, err 54 } 55 return h.Hash(nil) 56 } 57 58 // tests if hasher responds with correct hash comparing the reference implementation return value 59 func TestHasherEmptyData(t *testing.T) { 60 t.Parallel() 61 62 for _, count := range testSegmentCounts { 63 count := count 64 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 65 t.Parallel() 66 67 expHash, err := refHash(count, nil) 68 if err != nil { 69 t.Fatal(err) 70 } 71 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, 1)) 72 h := pool.Get() 73 resHash, err := syncHash(h, nil) 74 if err != nil { 75 t.Fatal(err) 76 } 77 pool.Put(h) 78 if !bytes.Equal(expHash, resHash) { 79 t.Fatalf("hash mismatch with reference. expected %x, got %x", expHash, resHash) 80 } 81 }) 82 } 83 } 84 85 // tests sequential write with entire max size written in one go 86 func TestSyncHasherCorrectness(t *testing.T) { 87 t.Parallel() 88 testData := testutil.RandBytesWithSeed(t, 4096, seed) 89 90 for _, count := range testSegmentCounts { 91 count := count 92 t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { 93 t.Parallel() 94 max := count * hashSize 95 var incr int 96 capacity := 1 97 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, capacity)) 98 for n := 0; n <= max; n += incr { 99 h := pool.Get() 100 incr = 1 + rand.Intn(5) 101 err := testHasherCorrectness(h, testData, n, count) 102 if err != nil { 103 t.Fatalf("seed %d: %v", seed, err) 104 } 105 pool.Put(h) 106 } 107 }) 108 } 109 } 110 111 // tests that the BMT hasher can be synchronously reused with poolsizes 1 and testPoolSize 112 func TestHasherReuse(t *testing.T) { 113 t.Parallel() 114 115 t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) { 116 t.Parallel() 117 testHasherReuse(t, 1) 118 }) 119 120 t.Run(fmt.Sprintf("poolsize_%d", testPoolSize), func(t *testing.T) { 121 t.Parallel() 122 testHasherReuse(t, testPoolSize) 123 }) 124 } 125 126 // tests if bmt reuse is not corrupting result 127 func testHasherReuse(t *testing.T, poolsize int) { 128 t.Helper() 129 130 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, poolsize)) 131 h := pool.Get() 132 defer pool.Put(h) 133 134 for i := 0; i < 100; i++ { 135 seed := int64(i) 136 testData := testutil.RandBytesWithSeed(t, 4096, seed) 137 n := rand.Intn(h.Capacity()) 138 err := testHasherCorrectness(h, testData, n, testSegmentCount) 139 if err != nil { 140 t.Fatalf("seed %d: %v", seed, err) 141 } 142 } 143 } 144 145 // tests if pool can be cleanly reused even in concurrent use by several hashers 146 func TestBMTConcurrentUse(t *testing.T) { 147 t.Parallel() 148 149 testData := testutil.RandBytesWithSeed(t, 4096, seed) 150 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize)) 151 cycles := 100 152 153 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 154 defer cancel() 155 eg, ectx := errgroup.WithContext(ctx) 156 for i := 0; i < cycles; i++ { 157 eg.Go(func() error { 158 select { 159 case <-ectx.Done(): 160 return ectx.Err() 161 default: 162 } 163 h := pool.Get() 164 defer pool.Put(h) 165 166 n := rand.Intn(h.Capacity()) 167 return testHasherCorrectness(h, testData, n, testSegmentCount) 168 }) 169 } 170 if err := eg.Wait(); err != nil { 171 t.Fatalf("seed %d: %v", seed, err) 172 } 173 } 174 175 // tests BMT Hasher io.Writer interface is working correctly even with random short writes 176 func TestBMTWriterBuffers(t *testing.T) { 177 t.Parallel() 178 179 for i, count := range testSegmentCounts { 180 i, count := i, count 181 182 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 183 t.Parallel() 184 185 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, testPoolSize)) 186 h := pool.Get() 187 defer pool.Put(h) 188 189 size := h.Capacity() 190 seed := int64(i) 191 testData := testutil.RandBytesWithSeed(t, 4096, seed) 192 193 resHash, err := syncHash(h, testData[:size]) 194 if err != nil { 195 t.Fatal(err) 196 } 197 expHash, err := refHash(count, testData[:size]) 198 if err != nil { 199 t.Fatal(err) 200 } 201 if !bytes.Equal(resHash, expHash) { 202 t.Fatalf("single write :hash mismatch with reference. expected %x, got %x", expHash, resHash) 203 } 204 attempts := 10 205 f := func() error { 206 h := pool.Get() 207 defer pool.Put(h) 208 209 reads := rand.Intn(count*2-1) + 1 210 offsets := make([]int, reads+1) 211 for i := 0; i < reads; i++ { 212 offsets[i] = rand.Intn(size) + 1 213 } 214 offsets[reads] = size 215 from := 0 216 sort.Ints(offsets) 217 for _, to := range offsets { 218 if from < to { 219 read, err := h.Write(testData[from:to]) 220 if err != nil { 221 return err 222 } 223 if read != to-from { 224 return fmt.Errorf("incorrect read. expected %v bytes, got %v", to-from, read) 225 } 226 from = to 227 } 228 } 229 h.SetHeaderInt64(int64(size)) 230 resHash, err := h.Hash(nil) 231 if err != nil { 232 return err 233 } 234 if !bytes.Equal(resHash, expHash) { 235 return fmt.Errorf("hash mismatch on %v. expected %x, got %x", offsets, expHash, resHash) 236 } 237 return nil 238 } 239 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 240 defer cancel() 241 eg, ectx := errgroup.WithContext(ctx) 242 for i := 0; i < attempts; i++ { 243 eg.Go(func() error { 244 select { 245 case <-ectx.Done(): 246 return ectx.Err() 247 default: 248 } 249 return f() 250 }) 251 } 252 if err := eg.Wait(); err != nil { 253 t.Fatalf("seed %d: %v", seed, err) 254 } 255 }) 256 } 257 } 258 259 // helper function that compares reference and optimised implementations for correctness 260 func testHasherCorrectness(h *bmt.Hasher, data []byte, n, count int) (err error) { 261 if len(data) < n { 262 n = len(data) 263 } 264 exp, err := refHash(count, data[:n]) 265 if err != nil { 266 return err 267 } 268 got, err := syncHash(h, data[:n]) 269 if err != nil { 270 return err 271 } 272 if !bytes.Equal(got, exp) { 273 return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 274 } 275 return nil 276 } 277 278 // verifies that the bmt.Hasher can be used with the hash.Hash interface 279 func TestUseSyncAsOrdinaryHasher(t *testing.T) { 280 t.Parallel() 281 282 pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize)) 283 h := pool.Get() 284 defer pool.Put(h) 285 data := []byte("moodbytesmoodbytesmoodbytesmoodbytes") 286 expHash, err := refHash(128, data) 287 if err != nil { 288 t.Fatal(err) 289 } 290 h.SetHeaderInt64(int64(len(data))) 291 _, err = h.Write(data) 292 if err != nil { 293 t.Fatal(err) 294 } 295 resHash := h.Sum(nil) 296 if !bytes.Equal(expHash, resHash) { 297 t.Fatalf("normalhash; expected %x, got %x", expHash, resHash) 298 } 299 }