github.com/SmartMeshFoundation/Spectrum@v0.0.0-20220621030607-452a266fee1e/bmt/bmt_test.go (about) 1 // Copyright 2017 The Spectrum Authors 2 // This file is part of the Spectrum library. 3 // 4 // The Spectrum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The Spectrum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the Spectrum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package bmt 18 19 import ( 20 "bytes" 21 crand "crypto/rand" 22 "fmt" 23 "hash" 24 "io" 25 "math/rand" 26 "sync" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/SmartMeshFoundation/Spectrum/crypto/sha3" 32 ) 33 34 const ( 35 maxproccnt = 8 36 ) 37 38 // TestRefHasher tests that the RefHasher computes the expected BMT hash for 39 // all data lengths between 0 and 256 bytes 40 func TestRefHasher(t *testing.T) { 41 hashFunc := sha3.NewKeccak256 42 43 sha3 := func(data ...[]byte) []byte { 44 h := hashFunc() 45 for _, v := range data { 46 h.Write(v) 47 } 48 return h.Sum(nil) 49 } 50 51 // the test struct is used to specify the expected BMT hash for data 52 // lengths between "from" and "to" 53 type test struct { 54 from int64 55 to int64 56 expected func([]byte) []byte 57 } 58 59 var tests []*test 60 61 // all lengths in [0,64] should be: 62 // 63 // sha3(data) 64 // 65 tests = append(tests, &test{ 66 from: 0, 67 to: 64, 68 expected: func(data []byte) []byte { 69 return sha3(data) 70 }, 71 }) 72 73 // all lengths in [65,96] should be: 74 // 75 // sha3( 76 // sha3(data[:64]) 77 // data[64:] 78 // ) 79 // 80 tests = append(tests, &test{ 81 from: 65, 82 to: 96, 83 expected: func(data []byte) []byte { 84 return sha3(sha3(data[:64]), data[64:]) 85 }, 86 }) 87 88 // all lengths in [97,128] should be: 89 // 90 // sha3( 91 // sha3(data[:64]) 92 // sha3(data[64:]) 93 // ) 94 // 95 tests = append(tests, &test{ 96 from: 97, 97 to: 128, 98 expected: func(data []byte) []byte { 99 return sha3(sha3(data[:64]), sha3(data[64:])) 100 }, 101 }) 102 103 // all lengths in [129,160] should be: 104 // 105 // sha3( 106 // sha3( 107 // sha3(data[:64]) 108 // sha3(data[64:128]) 109 // ) 110 // data[128:] 111 // ) 112 // 113 tests = append(tests, &test{ 114 from: 129, 115 to: 160, 116 expected: func(data []byte) []byte { 117 return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), data[128:]) 118 }, 119 }) 120 121 // all lengths in [161,192] should be: 122 // 123 // sha3( 124 // sha3( 125 // sha3(data[:64]) 126 // sha3(data[64:128]) 127 // ) 128 // sha3(data[128:]) 129 // ) 130 // 131 tests = append(tests, &test{ 132 from: 161, 133 to: 192, 134 expected: func(data []byte) []byte { 135 return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(data[128:])) 136 }, 137 }) 138 139 // all lengths in [193,224] should be: 140 // 141 // sha3( 142 // sha3( 143 // sha3(data[:64]) 144 // sha3(data[64:128]) 145 // ) 146 // sha3( 147 // sha3(data[128:192]) 148 // data[192:] 149 // ) 150 // ) 151 // 152 tests = append(tests, &test{ 153 from: 193, 154 to: 224, 155 expected: func(data []byte) []byte { 156 return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), data[192:])) 157 }, 158 }) 159 160 // all lengths in [225,256] should be: 161 // 162 // sha3( 163 // sha3( 164 // sha3(data[:64]) 165 // sha3(data[64:128]) 166 // ) 167 // sha3( 168 // sha3(data[128:192]) 169 // sha3(data[192:]) 170 // ) 171 // ) 172 // 173 tests = append(tests, &test{ 174 from: 225, 175 to: 256, 176 expected: func(data []byte) []byte { 177 return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), sha3(data[192:]))) 178 }, 179 }) 180 181 // run the tests 182 for _, x := range tests { 183 for length := x.from; length <= x.to; length++ { 184 t.Run(fmt.Sprintf("%d_bytes", length), func(t *testing.T) { 185 data := make([]byte, length) 186 if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF { 187 t.Fatal(err) 188 } 189 expected := x.expected(data) 190 actual := NewRefHasher(hashFunc, 128).Hash(data) 191 if !bytes.Equal(actual, expected) { 192 t.Fatalf("expected %x, got %x", expected, actual) 193 } 194 }) 195 } 196 } 197 } 198 199 func testDataReader(l int) (r io.Reader) { 200 return io.LimitReader(crand.Reader, int64(l)) 201 } 202 203 func TestHasherCorrectness(t *testing.T) { 204 err := testHasher(testBaseHasher) 205 if err != nil { 206 t.Fatal(err) 207 } 208 } 209 210 func testHasher(f func(BaseHasher, []byte, int, int) error) error { 211 tdata := testDataReader(4128) 212 data := make([]byte, 4128) 213 tdata.Read(data) 214 hasher := sha3.NewKeccak256 215 size := hasher().Size() 216 counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128} 217 218 var err error 219 for _, count := range counts { 220 max := count * size 221 incr := 1 222 for n := 0; n <= max+incr; n += incr { 223 err = f(hasher, data, n, count) 224 if err != nil { 225 return err 226 } 227 } 228 } 229 return nil 230 } 231 232 func TestHasherReuseWithoutRelease(t *testing.T) { 233 testHasherReuse(1, t) 234 } 235 236 func TestHasherReuseWithRelease(t *testing.T) { 237 testHasherReuse(maxproccnt, t) 238 } 239 240 func testHasherReuse(i int, t *testing.T) { 241 hasher := sha3.NewKeccak256 242 pool := NewTreePool(hasher, 128, i) 243 defer pool.Drain(0) 244 bmt := New(pool) 245 246 for i := 0; i < 500; i++ { 247 n := rand.Intn(4096) 248 tdata := testDataReader(n) 249 data := make([]byte, n) 250 tdata.Read(data) 251 252 err := testHasherCorrectness(bmt, hasher, data, n, 128) 253 if err != nil { 254 t.Fatal(err) 255 } 256 } 257 } 258 259 func TestHasherConcurrency(t *testing.T) { 260 hasher := sha3.NewKeccak256 261 pool := NewTreePool(hasher, 128, maxproccnt) 262 defer pool.Drain(0) 263 wg := sync.WaitGroup{} 264 cycles := 100 265 wg.Add(maxproccnt * cycles) 266 errc := make(chan error) 267 268 for p := 0; p < maxproccnt; p++ { 269 for i := 0; i < cycles; i++ { 270 go func() { 271 bmt := New(pool) 272 n := rand.Intn(4096) 273 tdata := testDataReader(n) 274 data := make([]byte, n) 275 tdata.Read(data) 276 err := testHasherCorrectness(bmt, hasher, data, n, 128) 277 wg.Done() 278 if err != nil { 279 errc <- err 280 } 281 }() 282 } 283 } 284 go func() { 285 wg.Wait() 286 close(errc) 287 }() 288 var err error 289 select { 290 case <-time.NewTimer(5 * time.Second).C: 291 err = fmt.Errorf("timed out") 292 case err = <-errc: 293 } 294 if err != nil { 295 t.Fatal(err) 296 } 297 } 298 299 func testBaseHasher(hasher BaseHasher, d []byte, n, count int) error { 300 pool := NewTreePool(hasher, count, 1) 301 defer pool.Drain(0) 302 bmt := New(pool) 303 return testHasherCorrectness(bmt, hasher, d, n, count) 304 } 305 306 func testHasherCorrectness(bmt hash.Hash, hasher BaseHasher, d []byte, n, count int) (err error) { 307 data := d[:n] 308 rbmt := NewRefHasher(hasher, count) 309 exp := rbmt.Hash(data) 310 timeout := time.NewTimer(time.Second) 311 c := make(chan error) 312 313 go func() { 314 bmt.Reset() 315 bmt.Write(data) 316 got := bmt.Sum(nil) 317 if !bytes.Equal(got, exp) { 318 c <- fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 319 } 320 close(c) 321 }() 322 select { 323 case <-timeout.C: 324 err = fmt.Errorf("BMT hash calculation timed out") 325 case err = <-c: 326 } 327 return err 328 } 329 330 func BenchmarkSHA3_4k(t *testing.B) { benchmarkSHA3(4096, t) } 331 func BenchmarkSHA3_2k(t *testing.B) { benchmarkSHA3(4096/2, t) } 332 func BenchmarkSHA3_1k(t *testing.B) { benchmarkSHA3(4096/4, t) } 333 func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) } 334 func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) } 335 func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) } 336 337 func BenchmarkBMTBaseline_4k(t *testing.B) { benchmarkBMTBaseline(4096, t) } 338 func BenchmarkBMTBaseline_2k(t *testing.B) { benchmarkBMTBaseline(4096/2, t) } 339 func BenchmarkBMTBaseline_1k(t *testing.B) { benchmarkBMTBaseline(4096/4, t) } 340 func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) } 341 func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) } 342 func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) } 343 344 func BenchmarkRefHasher_4k(t *testing.B) { benchmarkRefHasher(4096, t) } 345 func BenchmarkRefHasher_2k(t *testing.B) { benchmarkRefHasher(4096/2, t) } 346 func BenchmarkRefHasher_1k(t *testing.B) { benchmarkRefHasher(4096/4, t) } 347 func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) } 348 func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) } 349 func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) } 350 351 func BenchmarkHasher_4k(t *testing.B) { benchmarkHasher(4096, t) } 352 func BenchmarkHasher_2k(t *testing.B) { benchmarkHasher(4096/2, t) } 353 func BenchmarkHasher_1k(t *testing.B) { benchmarkHasher(4096/4, t) } 354 func BenchmarkHasher_512b(t *testing.B) { benchmarkHasher(4096/8, t) } 355 func BenchmarkHasher_256b(t *testing.B) { benchmarkHasher(4096/16, t) } 356 func BenchmarkHasher_128b(t *testing.B) { benchmarkHasher(4096/32, t) } 357 358 func BenchmarkHasherNoReuse_4k(t *testing.B) { benchmarkHasherReuse(1, 4096, t) } 359 func BenchmarkHasherNoReuse_2k(t *testing.B) { benchmarkHasherReuse(1, 4096/2, t) } 360 func BenchmarkHasherNoReuse_1k(t *testing.B) { benchmarkHasherReuse(1, 4096/4, t) } 361 func BenchmarkHasherNoReuse_512b(t *testing.B) { benchmarkHasherReuse(1, 4096/8, t) } 362 func BenchmarkHasherNoReuse_256b(t *testing.B) { benchmarkHasherReuse(1, 4096/16, t) } 363 func BenchmarkHasherNoReuse_128b(t *testing.B) { benchmarkHasherReuse(1, 4096/32, t) } 364 365 func BenchmarkHasherReuse_4k(t *testing.B) { benchmarkHasherReuse(16, 4096, t) } 366 func BenchmarkHasherReuse_2k(t *testing.B) { benchmarkHasherReuse(16, 4096/2, t) } 367 func BenchmarkHasherReuse_1k(t *testing.B) { benchmarkHasherReuse(16, 4096/4, t) } 368 func BenchmarkHasherReuse_512b(t *testing.B) { benchmarkHasherReuse(16, 4096/8, t) } 369 func BenchmarkHasherReuse_256b(t *testing.B) { benchmarkHasherReuse(16, 4096/16, t) } 370 func BenchmarkHasherReuse_128b(t *testing.B) { benchmarkHasherReuse(16, 4096/32, t) } 371 372 // benchmarks the minimum hashing time for a balanced (for simplicity) BMT 373 // by doing count/segmentsize parallel hashings of 2*segmentsize bytes 374 // doing it on n maxproccnt each reusing the base hasher 375 // the premise is that this is the minimum computation needed for a BMT 376 // therefore this serves as a theoretical optimum for concurrent implementations 377 func benchmarkBMTBaseline(n int, t *testing.B) { 378 tdata := testDataReader(64) 379 data := make([]byte, 64) 380 tdata.Read(data) 381 hasher := sha3.NewKeccak256 382 383 t.ReportAllocs() 384 t.ResetTimer() 385 for i := 0; i < t.N; i++ { 386 count := int32((n-1)/hasher().Size() + 1) 387 wg := sync.WaitGroup{} 388 wg.Add(maxproccnt) 389 var i int32 390 for j := 0; j < maxproccnt; j++ { 391 go func() { 392 defer wg.Done() 393 h := hasher() 394 for atomic.AddInt32(&i, 1) < count { 395 h.Reset() 396 h.Write(data) 397 h.Sum(nil) 398 } 399 }() 400 } 401 wg.Wait() 402 } 403 } 404 405 func benchmarkHasher(n int, t *testing.B) { 406 tdata := testDataReader(n) 407 data := make([]byte, n) 408 tdata.Read(data) 409 410 size := 1 411 hasher := sha3.NewKeccak256 412 segmentCount := 128 413 pool := NewTreePool(hasher, segmentCount, size) 414 bmt := New(pool) 415 416 t.ReportAllocs() 417 t.ResetTimer() 418 for i := 0; i < t.N; i++ { 419 bmt.Reset() 420 bmt.Write(data) 421 bmt.Sum(nil) 422 } 423 } 424 425 func benchmarkHasherReuse(poolsize, n int, t *testing.B) { 426 tdata := testDataReader(n) 427 data := make([]byte, n) 428 tdata.Read(data) 429 430 hasher := sha3.NewKeccak256 431 segmentCount := 128 432 pool := NewTreePool(hasher, segmentCount, poolsize) 433 cycles := 200 434 435 t.ReportAllocs() 436 t.ResetTimer() 437 for i := 0; i < t.N; i++ { 438 wg := sync.WaitGroup{} 439 wg.Add(cycles) 440 for j := 0; j < cycles; j++ { 441 bmt := New(pool) 442 go func() { 443 defer wg.Done() 444 bmt.Reset() 445 bmt.Write(data) 446 bmt.Sum(nil) 447 }() 448 } 449 wg.Wait() 450 } 451 } 452 453 func benchmarkSHA3(n int, t *testing.B) { 454 data := make([]byte, n) 455 tdata := testDataReader(n) 456 tdata.Read(data) 457 hasher := sha3.NewKeccak256 458 h := hasher() 459 460 t.ReportAllocs() 461 t.ResetTimer() 462 for i := 0; i < t.N; i++ { 463 h.Reset() 464 h.Write(data) 465 h.Sum(nil) 466 } 467 } 468 469 func benchmarkRefHasher(n int, t *testing.B) { 470 data := make([]byte, n) 471 tdata := testDataReader(n) 472 tdata.Read(data) 473 hasher := sha3.NewKeccak256 474 rbmt := NewRefHasher(hasher, 128) 475 476 t.ReportAllocs() 477 t.ResetTimer() 478 for i := 0; i < t.N; i++ { 479 rbmt.Hash(data) 480 } 481 }