github.com/jincm/wesharechain@v0.0.0-20210122032815-1537409ce26a/chain/swarm/bmt/bmt_test.go (about) 1 // Copyright 2017 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package bmt 18 19 import ( 20 "bytes" 21 "encoding/binary" 22 "fmt" 23 "math/rand" 24 "sync" 25 "sync/atomic" 26 "testing" 27 "time" 28 29 "github.com/ethereum/go-ethereum/swarm/testutil" 30 "golang.org/x/crypto/sha3" 31 ) 32 33 // the actual data length generated (could be longer than max datalength of the BMT) 34 const BufferSize = 4128 35 36 const ( 37 // segmentCount is the maximum number of segments of the underlying chunk 38 // Should be equal to max-chunk-data-size / hash-size 39 // Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size) 40 segmentCount = 128 41 ) 42 43 var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} 44 45 // calculates the Keccak256 SHA3 hash of the data 46 func sha3hash(data ...[]byte) []byte { 47 h := sha3.NewLegacyKeccak256() 48 return doSum(h, nil, data...) 49 } 50 51 // TestRefHasher tests that the RefHasher computes the expected BMT hash for 52 // some small data lengths 53 func TestRefHasher(t *testing.T) { 54 // the test struct is used to specify the expected BMT hash for 55 // segment counts between from and to and lengths from 1 to datalength 56 type test struct { 57 from int 58 to int 59 expected func([]byte) []byte 60 } 61 62 var tests []*test 63 // all lengths in [0,64] should be: 64 // 65 // sha3hash(data) 66 // 67 tests = append(tests, &test{ 68 from: 1, 69 to: 2, 70 expected: func(d []byte) []byte { 71 data := make([]byte, 64) 72 copy(data, d) 73 return sha3hash(data) 74 }, 75 }) 76 77 // all lengths in [3,4] should be: 78 // 79 // sha3hash( 80 // sha3hash(data[:64]) 81 // sha3hash(data[64:]) 82 // ) 83 // 84 tests = append(tests, &test{ 85 from: 3, 86 to: 4, 87 expected: func(d []byte) []byte { 88 data := make([]byte, 128) 89 copy(data, d) 90 return sha3hash(sha3hash(data[:64]), sha3hash(data[64:])) 91 }, 92 }) 93 94 // all segmentCounts in [5,8] should be: 95 // 96 // sha3hash( 97 // sha3hash( 98 // sha3hash(data[:64]) 99 // sha3hash(data[64:128]) 100 // ) 101 // sha3hash( 102 // sha3hash(data[128:192]) 103 // sha3hash(data[192:]) 104 // ) 105 // ) 106 // 107 tests = append(tests, &test{ 108 from: 5, 109 to: 8, 110 expected: func(d []byte) []byte { 111 data := make([]byte, 256) 112 copy(data, d) 113 return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:]))) 114 }, 115 }) 116 117 // run the tests 118 for i, x := range tests { 119 for segmentCount := x.from; segmentCount <= x.to; segmentCount++ { 120 for length := 1; length <= segmentCount*32; length++ { 121 t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) { 122 data := testutil.RandomBytes(i, length) 123 expected := x.expected(data) 124 actual := NewRefHasher(sha3.NewLegacyKeccak256, segmentCount).Hash(data) 125 if !bytes.Equal(actual, expected) { 126 t.Fatalf("expected %x, got %x", expected, actual) 127 } 128 }) 129 } 130 } 131 } 132 } 133 134 // tests if hasher responds with correct hash comparing the reference implementation return value 135 func TestHasherEmptyData(t *testing.T) { 136 hasher := sha3.NewLegacyKeccak256 137 var data []byte 138 for _, count := range counts { 139 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 140 pool := NewTreePool(hasher, count, PoolSize) 141 defer pool.Drain(0) 142 bmt := New(pool) 143 rbmt := NewRefHasher(hasher, count) 144 refHash := rbmt.Hash(data) 145 expHash := syncHash(bmt, nil, data) 146 if !bytes.Equal(expHash, refHash) { 147 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 148 } 149 }) 150 } 151 } 152 153 // tests sequential write with entire max size written in one go 154 func TestSyncHasherCorrectness(t *testing.T) { 155 data := testutil.RandomBytes(1, BufferSize) 156 hasher := sha3.NewLegacyKeccak256 157 size := hasher().Size() 158 159 var err error 160 for _, count := range counts { 161 t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { 162 max := count * size 163 var incr int 164 capacity := 1 165 pool := NewTreePool(hasher, count, capacity) 166 defer pool.Drain(0) 167 for n := 0; n <= max; n += incr { 168 incr = 1 + rand.Intn(5) 169 bmt := New(pool) 170 err = testHasherCorrectness(bmt, hasher, data, n, count) 171 if err != nil { 172 t.Fatal(err) 173 } 174 } 175 }) 176 } 177 } 178 179 // tests order-neutral concurrent writes with entire max size written in one go 180 func TestAsyncCorrectness(t *testing.T) { 181 data := testutil.RandomBytes(1, BufferSize) 182 hasher := sha3.NewLegacyKeccak256 183 size := hasher().Size() 184 whs := []whenHash{first, last, random} 185 186 for _, double := range []bool{false, true} { 187 for _, wh := range whs { 188 for _, count := range counts { 189 t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) { 190 max := count * size 191 var incr int 192 capacity := 1 193 pool := NewTreePool(hasher, count, capacity) 194 defer pool.Drain(0) 195 for n := 1; n <= max; n += incr { 196 incr = 1 + rand.Intn(5) 197 bmt := New(pool) 198 d := data[:n] 199 rbmt := NewRefHasher(hasher, count) 200 exp := rbmt.Hash(d) 201 got := syncHash(bmt, nil, d) 202 if !bytes.Equal(got, exp) { 203 t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got) 204 } 205 sw := bmt.NewAsyncWriter(double) 206 got = asyncHashRandom(sw, nil, d, wh) 207 if !bytes.Equal(got, exp) { 208 t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got) 209 } 210 } 211 }) 212 } 213 } 214 } 215 } 216 217 // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize 218 func TestHasherReuse(t *testing.T) { 219 t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) { 220 testHasherReuse(1, t) 221 }) 222 t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) { 223 testHasherReuse(PoolSize, t) 224 }) 225 } 226 227 // tests if bmt reuse is not corrupting result 228 func testHasherReuse(poolsize int, t *testing.T) { 229 hasher := sha3.NewLegacyKeccak256 230 pool := NewTreePool(hasher, segmentCount, poolsize) 231 defer pool.Drain(0) 232 bmt := New(pool) 233 234 for i := 0; i < 100; i++ { 235 data := testutil.RandomBytes(1, BufferSize) 236 n := rand.Intn(bmt.Size()) 237 err := testHasherCorrectness(bmt, hasher, data, n, segmentCount) 238 if err != nil { 239 t.Fatal(err) 240 } 241 } 242 } 243 244 // Tests if pool can be cleanly reused even in concurrent use by several hasher 245 func TestBMTConcurrentUse(t *testing.T) { 246 hasher := sha3.NewLegacyKeccak256 247 pool := NewTreePool(hasher, segmentCount, PoolSize) 248 defer pool.Drain(0) 249 cycles := 100 250 errc := make(chan error) 251 252 for i := 0; i < cycles; i++ { 253 go func() { 254 bmt := New(pool) 255 data := testutil.RandomBytes(1, BufferSize) 256 n := rand.Intn(bmt.Size()) 257 errc <- testHasherCorrectness(bmt, hasher, data, n, 128) 258 }() 259 } 260 LOOP: 261 for { 262 select { 263 case <-time.NewTimer(5 * time.Second).C: 264 t.Fatal("timed out") 265 case err := <-errc: 266 if err != nil { 267 t.Fatal(err) 268 } 269 cycles-- 270 if cycles == 0 { 271 break LOOP 272 } 273 } 274 } 275 } 276 277 // Tests BMT Hasher io.Writer interface is working correctly 278 // even multiple short random write buffers 279 func TestBMTWriterBuffers(t *testing.T) { 280 hasher := sha3.NewLegacyKeccak256 281 282 for _, count := range counts { 283 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 284 errc := make(chan error) 285 pool := NewTreePool(hasher, count, PoolSize) 286 defer pool.Drain(0) 287 n := count * 32 288 bmt := New(pool) 289 data := testutil.RandomBytes(1, n) 290 rbmt := NewRefHasher(hasher, count) 291 refHash := rbmt.Hash(data) 292 expHash := syncHash(bmt, nil, data) 293 if !bytes.Equal(expHash, refHash) { 294 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 295 } 296 attempts := 10 297 f := func() error { 298 bmt := New(pool) 299 bmt.Reset() 300 var buflen int 301 for offset := 0; offset < n; offset += buflen { 302 buflen = rand.Intn(n-offset) + 1 303 read, err := bmt.Write(data[offset : offset+buflen]) 304 if err != nil { 305 return err 306 } 307 if read != buflen { 308 return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) 309 } 310 } 311 hash := bmt.Sum(nil) 312 if !bytes.Equal(hash, expHash) { 313 return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) 314 } 315 return nil 316 } 317 318 for j := 0; j < attempts; j++ { 319 go func() { 320 errc <- f() 321 }() 322 } 323 timeout := time.NewTimer(2 * time.Second) 324 for { 325 select { 326 case err := <-errc: 327 if err != nil { 328 t.Fatal(err) 329 } 330 attempts-- 331 if attempts == 0 { 332 return 333 } 334 case <-timeout.C: 335 t.Fatalf("timeout") 336 } 337 } 338 }) 339 } 340 } 341 342 // helper function that compares reference and optimised implementations on 343 // correctness 344 func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) { 345 span := make([]byte, 8) 346 if len(d) < n { 347 n = len(d) 348 } 349 binary.BigEndian.PutUint64(span, uint64(n)) 350 data := d[:n] 351 rbmt := NewRefHasher(hasher, count) 352 exp := sha3hash(span, rbmt.Hash(data)) 353 got := syncHash(bmt, span, data) 354 if !bytes.Equal(got, exp) { 355 return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 356 } 357 return err 358 } 359 360 // 361 func BenchmarkBMT(t *testing.B) { 362 for size := 4096; size >= 128; size /= 2 { 363 t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) { 364 benchmarkSHA3(t, size) 365 }) 366 t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) { 367 benchmarkBMTBaseline(t, size) 368 }) 369 t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) { 370 benchmarkRefHasher(t, size) 371 }) 372 t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) { 373 benchmarkBMT(t, size) 374 }) 375 } 376 } 377 378 type whenHash = int 379 380 const ( 381 first whenHash = iota 382 last 383 random 384 ) 385 386 func BenchmarkBMTAsync(t *testing.B) { 387 whs := []whenHash{first, last, random} 388 for size := 4096; size >= 128; size /= 2 { 389 for _, wh := range whs { 390 for _, double := range []bool{false, true} { 391 t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) { 392 benchmarkBMTAsync(t, size, wh, double) 393 }) 394 } 395 } 396 } 397 } 398 399 func BenchmarkPool(t *testing.B) { 400 caps := []int{1, PoolSize} 401 for size := 4096; size >= 128; size /= 2 { 402 for _, c := range caps { 403 t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) { 404 benchmarkPool(t, c, size) 405 }) 406 } 407 } 408 } 409 410 // benchmarks simple sha3 hash on chunks 411 func benchmarkSHA3(t *testing.B, n int) { 412 data := testutil.RandomBytes(1, n) 413 hasher := sha3.NewLegacyKeccak256 414 h := hasher() 415 416 t.ReportAllocs() 417 t.ResetTimer() 418 for i := 0; i < t.N; i++ { 419 doSum(h, nil, data) 420 } 421 } 422 423 // benchmarks the minimum hashing time for a balanced (for simplicity) BMT 424 // by doing count/segmentsize parallel hashings of 2*segmentsize bytes 425 // doing it on n PoolSize each reusing the base hasher 426 // the premise is that this is the minimum computation needed for a BMT 427 // therefore this serves as a theoretical optimum for concurrent implementations 428 func benchmarkBMTBaseline(t *testing.B, n int) { 429 hasher := sha3.NewLegacyKeccak256 430 hashSize := hasher().Size() 431 data := testutil.RandomBytes(1, hashSize) 432 433 t.ReportAllocs() 434 t.ResetTimer() 435 for i := 0; i < t.N; i++ { 436 count := int32((n-1)/hashSize + 1) 437 wg := sync.WaitGroup{} 438 wg.Add(PoolSize) 439 var i int32 440 for j := 0; j < PoolSize; j++ { 441 go func() { 442 defer wg.Done() 443 h := hasher() 444 for atomic.AddInt32(&i, 1) < count { 445 doSum(h, nil, data) 446 } 447 }() 448 } 449 wg.Wait() 450 } 451 } 452 453 // benchmarks BMT Hasher 454 func benchmarkBMT(t *testing.B, n int) { 455 data := testutil.RandomBytes(1, n) 456 hasher := sha3.NewLegacyKeccak256 457 pool := NewTreePool(hasher, segmentCount, PoolSize) 458 bmt := New(pool) 459 460 t.ReportAllocs() 461 t.ResetTimer() 462 for i := 0; i < t.N; i++ { 463 syncHash(bmt, nil, data) 464 } 465 } 466 467 // benchmarks BMT hasher with asynchronous concurrent segment/section writes 468 func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) { 469 data := testutil.RandomBytes(1, n) 470 hasher := sha3.NewLegacyKeccak256 471 pool := NewTreePool(hasher, segmentCount, PoolSize) 472 bmt := New(pool).NewAsyncWriter(double) 473 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 474 rand.Shuffle(len(idxs), func(i int, j int) { 475 idxs[i], idxs[j] = idxs[j], idxs[i] 476 }) 477 478 t.ReportAllocs() 479 t.ResetTimer() 480 for i := 0; i < t.N; i++ { 481 asyncHash(bmt, nil, n, wh, idxs, segments) 482 } 483 } 484 485 // benchmarks 100 concurrent bmt hashes with pool capacity 486 func benchmarkPool(t *testing.B, poolsize, n int) { 487 data := testutil.RandomBytes(1, n) 488 hasher := sha3.NewLegacyKeccak256 489 pool := NewTreePool(hasher, segmentCount, poolsize) 490 cycles := 100 491 492 t.ReportAllocs() 493 t.ResetTimer() 494 wg := sync.WaitGroup{} 495 for i := 0; i < t.N; i++ { 496 wg.Add(cycles) 497 for j := 0; j < cycles; j++ { 498 go func() { 499 defer wg.Done() 500 bmt := New(pool) 501 syncHash(bmt, nil, data) 502 }() 503 } 504 wg.Wait() 505 } 506 } 507 508 // benchmarks the reference hasher 509 func benchmarkRefHasher(t *testing.B, n int) { 510 data := testutil.RandomBytes(1, n) 511 hasher := sha3.NewLegacyKeccak256 512 rbmt := NewRefHasher(hasher, 128) 513 514 t.ReportAllocs() 515 t.ResetTimer() 516 for i := 0; i < t.N; i++ { 517 rbmt.Hash(data) 518 } 519 } 520 521 // Hash hashes the data and the span using the bmt hasher 522 func syncHash(h *Hasher, span, data []byte) []byte { 523 h.ResetWithLength(span) 524 h.Write(data) 525 return h.Sum(nil) 526 } 527 528 func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) { 529 l := len(data) 530 n := l / secsize 531 if l%secsize > 0 { 532 n++ 533 } 534 for i := 0; i < n; i++ { 535 idxs = append(idxs, i) 536 end := (i + 1) * secsize 537 if end > l { 538 end = l 539 } 540 section := data[i*secsize : end] 541 segments = append(segments, section) 542 } 543 rand.Shuffle(n, func(i int, j int) { 544 idxs[i], idxs[j] = idxs[j], idxs[i] 545 }) 546 return idxs, segments 547 } 548 549 // splits the input data performs a random shuffle to mock async section writes 550 func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) { 551 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 552 return asyncHash(bmt, span, len(data), wh, idxs, segments) 553 } 554 555 // mock for async section writes for BMT SectionWriter 556 // requires a permutation (a random shuffle) of list of all indexes of segments 557 // and writes them in order to the appropriate section 558 // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes]) 559 func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { 560 bmt.Reset() 561 if l == 0 { 562 return bmt.Sum(nil, l, span) 563 } 564 c := make(chan []byte, 1) 565 hashf := func() { 566 c <- bmt.Sum(nil, l, span) 567 } 568 maxsize := len(idxs) 569 var r int 570 if wh == random { 571 r = rand.Intn(maxsize) 572 } 573 for i, idx := range idxs { 574 bmt.Write(idx, segments[idx]) 575 if (wh == first || wh == random) && i == r { 576 go hashf() 577 } 578 } 579 if wh == last { 580 return bmt.Sum(nil, l, span) 581 } 582 return <-c 583 }