github.com/oskarth/go-ethereum@v1.6.8-0.20191013093314-dac24a9d3494/swarm/bmt/bmt_test.go (about) 1 // Copyright 2017 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package bmt 18 19 import ( 20 "bytes" 21 crand "crypto/rand" 22 "encoding/binary" 23 "fmt" 24 "io" 25 "math/rand" 26 "sync" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/ethereum/go-ethereum/crypto/sha3" 32 ) 33 34 // the actual data length generated (could be longer than max datalength of the BMT) 35 const BufferSize = 4128 36 37 const ( 38 // segmentCount is the maximum number of segments of the underlying chunk 39 // Should be equal to max-chunk-data-size / hash-size 40 // Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size) 41 segmentCount = 128 42 ) 43 44 var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} 45 46 // calculates the Keccak256 SHA3 hash of the data 47 func sha3hash(data ...[]byte) []byte { 48 h := sha3.NewKeccak256() 49 return doSum(h, nil, data...) 50 } 51 52 // TestRefHasher tests that the RefHasher computes the expected BMT hash for 53 // some small data lengths 54 func TestRefHasher(t *testing.T) { 55 // the test struct is used to specify the expected BMT hash for 56 // segment counts between from and to and lengths from 1 to datalength 57 type test struct { 58 from int 59 to int 60 expected func([]byte) []byte 61 } 62 63 var tests []*test 64 // all lengths in [0,64] should be: 65 // 66 // sha3hash(data) 67 // 68 tests = append(tests, &test{ 69 from: 1, 70 to: 2, 71 expected: func(d []byte) []byte { 72 data := make([]byte, 64) 73 copy(data, d) 74 return sha3hash(data) 75 }, 76 }) 77 78 // all lengths in [3,4] should be: 79 // 80 // sha3hash( 81 // sha3hash(data[:64]) 82 // sha3hash(data[64:]) 83 // ) 84 // 85 tests = append(tests, &test{ 86 from: 3, 87 to: 4, 88 expected: func(d []byte) []byte { 89 data := make([]byte, 128) 90 copy(data, d) 91 return sha3hash(sha3hash(data[:64]), sha3hash(data[64:])) 92 }, 93 }) 94 95 // all segmentCounts in [5,8] should be: 96 // 97 // sha3hash( 98 // sha3hash( 99 // sha3hash(data[:64]) 100 // sha3hash(data[64:128]) 101 // ) 102 // sha3hash( 103 // sha3hash(data[128:192]) 104 // sha3hash(data[192:]) 105 // ) 106 // ) 107 // 108 tests = append(tests, &test{ 109 from: 5, 110 to: 8, 111 expected: func(d []byte) []byte { 112 data := make([]byte, 256) 113 copy(data, d) 114 return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:]))) 115 }, 116 }) 117 118 // run the tests 119 for _, x := range tests { 120 for segmentCount := x.from; segmentCount <= x.to; segmentCount++ { 121 for length := 1; length <= segmentCount*32; length++ { 122 t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) { 123 data := make([]byte, length) 124 if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF { 125 t.Fatal(err) 126 } 127 expected := x.expected(data) 128 actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data) 129 if !bytes.Equal(actual, expected) { 130 t.Fatalf("expected %x, got %x", expected, actual) 131 } 132 }) 133 } 134 } 135 } 136 } 137 138 // tests if hasher responds with correct hash comparing the reference implementation return value 139 func TestHasherEmptyData(t *testing.T) { 140 hasher := sha3.NewKeccak256 141 var data []byte 142 for _, count := range counts { 143 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 144 pool := NewTreePool(hasher, count, PoolSize) 145 defer pool.Drain(0) 146 bmt := New(pool) 147 rbmt := NewRefHasher(hasher, count) 148 refHash := rbmt.Hash(data) 149 expHash := syncHash(bmt, nil, data) 150 if !bytes.Equal(expHash, refHash) { 151 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 152 } 153 }) 154 } 155 } 156 157 // tests sequential write with entire max size written in one go 158 func TestSyncHasherCorrectness(t *testing.T) { 159 data := newData(BufferSize) 160 hasher := sha3.NewKeccak256 161 size := hasher().Size() 162 163 var err error 164 for _, count := range counts { 165 t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { 166 max := count * size 167 var incr int 168 capacity := 1 169 pool := NewTreePool(hasher, count, capacity) 170 defer pool.Drain(0) 171 for n := 0; n <= max; n += incr { 172 incr = 1 + rand.Intn(5) 173 bmt := New(pool) 174 err = testHasherCorrectness(bmt, hasher, data, n, count) 175 if err != nil { 176 t.Fatal(err) 177 } 178 } 179 }) 180 } 181 } 182 183 // tests order-neutral concurrent writes with entire max size written in one go 184 func TestAsyncCorrectness(t *testing.T) { 185 data := newData(BufferSize) 186 hasher := sha3.NewKeccak256 187 size := hasher().Size() 188 whs := []whenHash{first, last, random} 189 190 for _, double := range []bool{false, true} { 191 for _, wh := range whs { 192 for _, count := range counts { 193 t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) { 194 max := count * size 195 var incr int 196 capacity := 1 197 pool := NewTreePool(hasher, count, capacity) 198 defer pool.Drain(0) 199 for n := 1; n <= max; n += incr { 200 incr = 1 + rand.Intn(5) 201 bmt := New(pool) 202 d := data[:n] 203 rbmt := NewRefHasher(hasher, count) 204 exp := rbmt.Hash(d) 205 got := syncHash(bmt, nil, d) 206 if !bytes.Equal(got, exp) { 207 t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got) 208 } 209 sw := bmt.NewAsyncWriter(double) 210 got = asyncHashRandom(sw, nil, d, wh) 211 if !bytes.Equal(got, exp) { 212 t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got) 213 } 214 } 215 }) 216 } 217 } 218 } 219 } 220 221 // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize 222 func TestHasherReuse(t *testing.T) { 223 t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) { 224 testHasherReuse(1, t) 225 }) 226 t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) { 227 testHasherReuse(PoolSize, t) 228 }) 229 } 230 231 // tests if bmt reuse is not corrupting result 232 func testHasherReuse(poolsize int, t *testing.T) { 233 hasher := sha3.NewKeccak256 234 pool := NewTreePool(hasher, segmentCount, poolsize) 235 defer pool.Drain(0) 236 bmt := New(pool) 237 238 for i := 0; i < 100; i++ { 239 data := newData(BufferSize) 240 n := rand.Intn(bmt.Size()) 241 err := testHasherCorrectness(bmt, hasher, data, n, segmentCount) 242 if err != nil { 243 t.Fatal(err) 244 } 245 } 246 } 247 248 // Tests if pool can be cleanly reused even in concurrent use by several hasher 249 func TestBMTConcurrentUse(t *testing.T) { 250 hasher := sha3.NewKeccak256 251 pool := NewTreePool(hasher, segmentCount, PoolSize) 252 defer pool.Drain(0) 253 cycles := 100 254 errc := make(chan error) 255 256 for i := 0; i < cycles; i++ { 257 go func() { 258 bmt := New(pool) 259 data := newData(BufferSize) 260 n := rand.Intn(bmt.Size()) 261 errc <- testHasherCorrectness(bmt, hasher, data, n, 128) 262 }() 263 } 264 LOOP: 265 for { 266 select { 267 case <-time.NewTimer(5 * time.Second).C: 268 t.Fatal("timed out") 269 case err := <-errc: 270 if err != nil { 271 t.Fatal(err) 272 } 273 cycles-- 274 if cycles == 0 { 275 break LOOP 276 } 277 } 278 } 279 } 280 281 // Tests BMT Hasher io.Writer interface is working correctly 282 // even multiple short random write buffers 283 func TestBMTWriterBuffers(t *testing.T) { 284 hasher := sha3.NewKeccak256 285 286 for _, count := range counts { 287 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 288 errc := make(chan error) 289 pool := NewTreePool(hasher, count, PoolSize) 290 defer pool.Drain(0) 291 n := count * 32 292 bmt := New(pool) 293 data := newData(n) 294 rbmt := NewRefHasher(hasher, count) 295 refHash := rbmt.Hash(data) 296 expHash := syncHash(bmt, nil, data) 297 if !bytes.Equal(expHash, refHash) { 298 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 299 } 300 attempts := 10 301 f := func() error { 302 bmt := New(pool) 303 bmt.Reset() 304 var buflen int 305 for offset := 0; offset < n; offset += buflen { 306 buflen = rand.Intn(n-offset) + 1 307 read, err := bmt.Write(data[offset : offset+buflen]) 308 if err != nil { 309 return err 310 } 311 if read != buflen { 312 return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) 313 } 314 } 315 hash := bmt.Sum(nil) 316 if !bytes.Equal(hash, expHash) { 317 return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) 318 } 319 return nil 320 } 321 322 for j := 0; j < attempts; j++ { 323 go func() { 324 errc <- f() 325 }() 326 } 327 timeout := time.NewTimer(2 * time.Second) 328 for { 329 select { 330 case err := <-errc: 331 if err != nil { 332 t.Fatal(err) 333 } 334 attempts-- 335 if attempts == 0 { 336 return 337 } 338 case <-timeout.C: 339 t.Fatalf("timeout") 340 } 341 } 342 }) 343 } 344 } 345 346 // helper function that compares reference and optimised implementations on 347 // correctness 348 func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) { 349 span := make([]byte, 8) 350 if len(d) < n { 351 n = len(d) 352 } 353 binary.BigEndian.PutUint64(span, uint64(n)) 354 data := d[:n] 355 rbmt := NewRefHasher(hasher, count) 356 exp := sha3hash(span, rbmt.Hash(data)) 357 got := syncHash(bmt, span, data) 358 if !bytes.Equal(got, exp) { 359 return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 360 } 361 return err 362 } 363 364 // 365 func BenchmarkBMT(t *testing.B) { 366 for size := 4096; size >= 128; size /= 2 { 367 t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) { 368 benchmarkSHA3(t, size) 369 }) 370 t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) { 371 benchmarkBMTBaseline(t, size) 372 }) 373 t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) { 374 benchmarkRefHasher(t, size) 375 }) 376 t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) { 377 benchmarkBMT(t, size) 378 }) 379 } 380 } 381 382 type whenHash = int 383 384 const ( 385 first whenHash = iota 386 last 387 random 388 ) 389 390 func BenchmarkBMTAsync(t *testing.B) { 391 whs := []whenHash{first, last, random} 392 for size := 4096; size >= 128; size /= 2 { 393 for _, wh := range whs { 394 for _, double := range []bool{false, true} { 395 t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) { 396 benchmarkBMTAsync(t, size, wh, double) 397 }) 398 } 399 } 400 } 401 } 402 403 func BenchmarkPool(t *testing.B) { 404 caps := []int{1, PoolSize} 405 for size := 4096; size >= 128; size /= 2 { 406 for _, c := range caps { 407 t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) { 408 benchmarkPool(t, c, size) 409 }) 410 } 411 } 412 } 413 414 // benchmarks simple sha3 hash on chunks 415 func benchmarkSHA3(t *testing.B, n int) { 416 data := newData(n) 417 hasher := sha3.NewKeccak256 418 h := hasher() 419 420 t.ReportAllocs() 421 t.ResetTimer() 422 for i := 0; i < t.N; i++ { 423 doSum(h, nil, data) 424 } 425 } 426 427 // benchmarks the minimum hashing time for a balanced (for simplicity) BMT 428 // by doing count/segmentsize parallel hashings of 2*segmentsize bytes 429 // doing it on n PoolSize each reusing the base hasher 430 // the premise is that this is the minimum computation needed for a BMT 431 // therefore this serves as a theoretical optimum for concurrent implementations 432 func benchmarkBMTBaseline(t *testing.B, n int) { 433 hasher := sha3.NewKeccak256 434 hashSize := hasher().Size() 435 data := newData(hashSize) 436 437 t.ReportAllocs() 438 t.ResetTimer() 439 for i := 0; i < t.N; i++ { 440 count := int32((n-1)/hashSize + 1) 441 wg := sync.WaitGroup{} 442 wg.Add(PoolSize) 443 var i int32 444 for j := 0; j < PoolSize; j++ { 445 go func() { 446 defer wg.Done() 447 h := hasher() 448 for atomic.AddInt32(&i, 1) < count { 449 doSum(h, nil, data) 450 } 451 }() 452 } 453 wg.Wait() 454 } 455 } 456 457 // benchmarks BMT Hasher 458 func benchmarkBMT(t *testing.B, n int) { 459 data := newData(n) 460 hasher := sha3.NewKeccak256 461 pool := NewTreePool(hasher, segmentCount, PoolSize) 462 bmt := New(pool) 463 464 t.ReportAllocs() 465 t.ResetTimer() 466 for i := 0; i < t.N; i++ { 467 syncHash(bmt, nil, data) 468 } 469 } 470 471 // benchmarks BMT hasher with asynchronous concurrent segment/section writes 472 func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) { 473 data := newData(n) 474 hasher := sha3.NewKeccak256 475 pool := NewTreePool(hasher, segmentCount, PoolSize) 476 bmt := New(pool).NewAsyncWriter(double) 477 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 478 shuffle(len(idxs), func(i int, j int) { 479 idxs[i], idxs[j] = idxs[j], idxs[i] 480 }) 481 482 t.ReportAllocs() 483 t.ResetTimer() 484 for i := 0; i < t.N; i++ { 485 asyncHash(bmt, nil, n, wh, idxs, segments) 486 } 487 } 488 489 // benchmarks 100 concurrent bmt hashes with pool capacity 490 func benchmarkPool(t *testing.B, poolsize, n int) { 491 data := newData(n) 492 hasher := sha3.NewKeccak256 493 pool := NewTreePool(hasher, segmentCount, poolsize) 494 cycles := 100 495 496 t.ReportAllocs() 497 t.ResetTimer() 498 wg := sync.WaitGroup{} 499 for i := 0; i < t.N; i++ { 500 wg.Add(cycles) 501 for j := 0; j < cycles; j++ { 502 go func() { 503 defer wg.Done() 504 bmt := New(pool) 505 syncHash(bmt, nil, data) 506 }() 507 } 508 wg.Wait() 509 } 510 } 511 512 // benchmarks the reference hasher 513 func benchmarkRefHasher(t *testing.B, n int) { 514 data := newData(n) 515 hasher := sha3.NewKeccak256 516 rbmt := NewRefHasher(hasher, 128) 517 518 t.ReportAllocs() 519 t.ResetTimer() 520 for i := 0; i < t.N; i++ { 521 rbmt.Hash(data) 522 } 523 } 524 525 func newData(bufferSize int) []byte { 526 data := make([]byte, bufferSize) 527 _, err := io.ReadFull(crand.Reader, data) 528 if err != nil { 529 panic(err.Error()) 530 } 531 return data 532 } 533 534 // Hash hashes the data and the span using the bmt hasher 535 func syncHash(h *Hasher, span, data []byte) []byte { 536 h.ResetWithLength(span) 537 h.Write(data) 538 return h.Sum(nil) 539 } 540 541 func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) { 542 l := len(data) 543 n := l / secsize 544 if l%secsize > 0 { 545 n++ 546 } 547 for i := 0; i < n; i++ { 548 idxs = append(idxs, i) 549 end := (i + 1) * secsize 550 if end > l { 551 end = l 552 } 553 section := data[i*secsize : end] 554 segments = append(segments, section) 555 } 556 shuffle(n, func(i int, j int) { 557 idxs[i], idxs[j] = idxs[j], idxs[i] 558 }) 559 return idxs, segments 560 } 561 562 // splits the input data performs a random shuffle to mock async section writes 563 func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) { 564 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 565 return asyncHash(bmt, span, len(data), wh, idxs, segments) 566 } 567 568 // mock for async section writes for BMT SectionWriter 569 // requires a permutation (a random shuffle) of list of all indexes of segments 570 // and writes them in order to the appropriate section 571 // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes]) 572 func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { 573 bmt.Reset() 574 if l == 0 { 575 return bmt.Sum(nil, l, span) 576 } 577 c := make(chan []byte, 1) 578 hashf := func() { 579 c <- bmt.Sum(nil, l, span) 580 } 581 maxsize := len(idxs) 582 var r int 583 if wh == random { 584 r = rand.Intn(maxsize) 585 } 586 for i, idx := range idxs { 587 bmt.Write(idx, segments[idx]) 588 if (wh == first || wh == random) && i == r { 589 go hashf() 590 } 591 } 592 if wh == last { 593 return bmt.Sum(nil, l, span) 594 } 595 return <-c 596 } 597 598 // this is also in swarm/network_test.go 599 // shuffle pseudo-randomizes the order of elements. 600 // n is the number of elements. Shuffle panics if n < 0. 601 // swap swaps the elements with indexes i and j. 602 func shuffle(n int, swap func(i, j int)) { 603 if n < 0 { 604 panic("invalid argument to Shuffle") 605 } 606 607 // Fisher-Yates shuffle: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle 608 // Shuffle really ought not be called with n that doesn't fit in 32 bits. 609 // Not only will it take a very long time, but with 2³¹! possible permutations, 610 // there's no way that any PRNG can have a big enough internal state to 611 // generate even a minuscule percentage of the possible permutations. 612 // Nevertheless, the right API signature accepts an int n, so handle it as best we can. 613 i := n - 1 614 for ; i > 1<<31-1-1; i-- { 615 j := int(rand.Int63n(int64(i + 1))) 616 swap(i, j) 617 } 618 for ; i > 0; i-- { 619 j := int(rand.Int31n(int32(i + 1))) 620 swap(i, j) 621 } 622 }