github.com/codingfuture/orig-energi3@v0.8.4/swarm/bmt/bmt_test.go (about) 1 // Copyright 2018 The Energi Core Authors 2 // Copyright 2018 The go-ethereum Authors 3 // This file is part of the Energi Core library. 4 // 5 // The Energi Core library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The Energi Core library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the Energi Core library. If not, see <http://www.gnu.org/licenses/>. 17 18 package bmt 19 20 import ( 21 "bytes" 22 "encoding/binary" 23 "fmt" 24 "math/rand" 25 "sync" 26 "sync/atomic" 27 "testing" 28 "time" 29 30 "github.com/ethereum/go-ethereum/swarm/testutil" 31 "golang.org/x/crypto/sha3" 32 ) 33 34 // the actual data length generated (could be longer than max datalength of the BMT) 35 const BufferSize = 4128 36 37 const ( 38 // segmentCount is the maximum number of segments of the underlying chunk 39 // Should be equal to max-chunk-data-size / hash-size 40 // Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size) 41 segmentCount = 128 42 ) 43 44 var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} 45 46 // calculates the Keccak256 SHA3 hash of the data 47 func sha3hash(data ...[]byte) []byte { 48 h := sha3.NewLegacyKeccak256() 49 return doSum(h, nil, data...) 50 } 51 52 // TestRefHasher tests that the RefHasher computes the expected BMT hash for 53 // some small data lengths 54 func TestRefHasher(t *testing.T) { 55 // the test struct is used to specify the expected BMT hash for 56 // segment counts between from and to and lengths from 1 to datalength 57 type test struct { 58 from int 59 to int 60 expected func([]byte) []byte 61 } 62 63 var tests []*test 64 // all lengths in [0,64] should be: 65 // 66 // sha3hash(data) 67 // 68 tests = append(tests, &test{ 69 from: 1, 70 to: 2, 71 expected: func(d []byte) []byte { 72 data := make([]byte, 64) 73 copy(data, d) 74 return sha3hash(data) 75 }, 76 }) 77 78 // all lengths in [3,4] should be: 79 // 80 // sha3hash( 81 // sha3hash(data[:64]) 82 // sha3hash(data[64:]) 83 // ) 84 // 85 tests = append(tests, &test{ 86 from: 3, 87 to: 4, 88 expected: func(d []byte) []byte { 89 data := make([]byte, 128) 90 copy(data, d) 91 return sha3hash(sha3hash(data[:64]), sha3hash(data[64:])) 92 }, 93 }) 94 95 // all segmentCounts in [5,8] should be: 96 // 97 // sha3hash( 98 // sha3hash( 99 // sha3hash(data[:64]) 100 // sha3hash(data[64:128]) 101 // ) 102 // sha3hash( 103 // sha3hash(data[128:192]) 104 // sha3hash(data[192:]) 105 // ) 106 // ) 107 // 108 tests = append(tests, &test{ 109 from: 5, 110 to: 8, 111 expected: func(d []byte) []byte { 112 data := make([]byte, 256) 113 copy(data, d) 114 return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:]))) 115 }, 116 }) 117 118 // run the tests 119 for i, x := range tests { 120 for segmentCount := x.from; segmentCount <= x.to; segmentCount++ { 121 for length := 1; length <= segmentCount*32; length++ { 122 t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) { 123 data := testutil.RandomBytes(i, length) 124 expected := x.expected(data) 125 actual := NewRefHasher(sha3.NewLegacyKeccak256, segmentCount).Hash(data) 126 if !bytes.Equal(actual, expected) { 127 t.Fatalf("expected %x, got %x", expected, actual) 128 } 129 }) 130 } 131 } 132 } 133 } 134 135 // tests if hasher responds with correct hash comparing the reference implementation return value 136 func TestHasherEmptyData(t *testing.T) { 137 hasher := sha3.NewLegacyKeccak256 138 var data []byte 139 for _, count := range counts { 140 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 141 pool := NewTreePool(hasher, count, PoolSize) 142 defer pool.Drain(0) 143 bmt := New(pool) 144 rbmt := NewRefHasher(hasher, count) 145 refHash := rbmt.Hash(data) 146 expHash := syncHash(bmt, nil, data) 147 if !bytes.Equal(expHash, refHash) { 148 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 149 } 150 }) 151 } 152 } 153 154 // tests sequential write with entire max size written in one go 155 func TestSyncHasherCorrectness(t *testing.T) { 156 data := testutil.RandomBytes(1, BufferSize) 157 hasher := sha3.NewLegacyKeccak256 158 size := hasher().Size() 159 160 var err error 161 for _, count := range counts { 162 t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { 163 max := count * size 164 var incr int 165 capacity := 1 166 pool := NewTreePool(hasher, count, capacity) 167 defer pool.Drain(0) 168 for n := 0; n <= max; n += incr { 169 incr = 1 + rand.Intn(5) 170 bmt := New(pool) 171 err = testHasherCorrectness(bmt, hasher, data, n, count) 172 if err != nil { 173 t.Fatal(err) 174 } 175 } 176 }) 177 } 178 } 179 180 // tests order-neutral concurrent writes with entire max size written in one go 181 func TestAsyncCorrectness(t *testing.T) { 182 data := testutil.RandomBytes(1, BufferSize) 183 hasher := sha3.NewLegacyKeccak256 184 size := hasher().Size() 185 whs := []whenHash{first, last, random} 186 187 for _, double := range []bool{false, true} { 188 for _, wh := range whs { 189 for _, count := range counts { 190 t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) { 191 max := count * size 192 var incr int 193 capacity := 1 194 pool := NewTreePool(hasher, count, capacity) 195 defer pool.Drain(0) 196 for n := 1; n <= max; n += incr { 197 incr = 1 + rand.Intn(5) 198 bmt := New(pool) 199 d := data[:n] 200 rbmt := NewRefHasher(hasher, count) 201 exp := rbmt.Hash(d) 202 got := syncHash(bmt, nil, d) 203 if !bytes.Equal(got, exp) { 204 t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got) 205 } 206 sw := bmt.NewAsyncWriter(double) 207 got = asyncHashRandom(sw, nil, d, wh) 208 if !bytes.Equal(got, exp) { 209 t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got) 210 } 211 } 212 }) 213 } 214 } 215 } 216 } 217 218 // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize 219 func TestHasherReuse(t *testing.T) { 220 t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) { 221 testHasherReuse(1, t) 222 }) 223 t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) { 224 testHasherReuse(PoolSize, t) 225 }) 226 } 227 228 // tests if bmt reuse is not corrupting result 229 func testHasherReuse(poolsize int, t *testing.T) { 230 hasher := sha3.NewLegacyKeccak256 231 pool := NewTreePool(hasher, segmentCount, poolsize) 232 defer pool.Drain(0) 233 bmt := New(pool) 234 235 for i := 0; i < 100; i++ { 236 data := testutil.RandomBytes(1, BufferSize) 237 n := rand.Intn(bmt.Size()) 238 err := testHasherCorrectness(bmt, hasher, data, n, segmentCount) 239 if err != nil { 240 t.Fatal(err) 241 } 242 } 243 } 244 245 // Tests if pool can be cleanly reused even in concurrent use by several hasher 246 func TestBMTConcurrentUse(t *testing.T) { 247 hasher := sha3.NewLegacyKeccak256 248 pool := NewTreePool(hasher, segmentCount, PoolSize) 249 defer pool.Drain(0) 250 cycles := 100 251 errc := make(chan error) 252 253 for i := 0; i < cycles; i++ { 254 go func() { 255 bmt := New(pool) 256 data := testutil.RandomBytes(1, BufferSize) 257 n := rand.Intn(bmt.Size()) 258 errc <- testHasherCorrectness(bmt, hasher, data, n, 128) 259 }() 260 } 261 LOOP: 262 for { 263 select { 264 case <-time.NewTimer(5 * time.Second).C: 265 t.Fatal("timed out") 266 case err := <-errc: 267 if err != nil { 268 t.Fatal(err) 269 } 270 cycles-- 271 if cycles == 0 { 272 break LOOP 273 } 274 } 275 } 276 } 277 278 // Tests BMT Hasher io.Writer interface is working correctly 279 // even multiple short random write buffers 280 func TestBMTWriterBuffers(t *testing.T) { 281 hasher := sha3.NewLegacyKeccak256 282 283 for _, count := range counts { 284 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 285 errc := make(chan error) 286 pool := NewTreePool(hasher, count, PoolSize) 287 defer pool.Drain(0) 288 n := count * 32 289 bmt := New(pool) 290 data := testutil.RandomBytes(1, n) 291 rbmt := NewRefHasher(hasher, count) 292 refHash := rbmt.Hash(data) 293 expHash := syncHash(bmt, nil, data) 294 if !bytes.Equal(expHash, refHash) { 295 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 296 } 297 attempts := 10 298 f := func() error { 299 bmt := New(pool) 300 bmt.Reset() 301 var buflen int 302 for offset := 0; offset < n; offset += buflen { 303 buflen = rand.Intn(n-offset) + 1 304 read, err := bmt.Write(data[offset : offset+buflen]) 305 if err != nil { 306 return err 307 } 308 if read != buflen { 309 return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) 310 } 311 } 312 hash := bmt.Sum(nil) 313 if !bytes.Equal(hash, expHash) { 314 return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) 315 } 316 return nil 317 } 318 319 for j := 0; j < attempts; j++ { 320 go func() { 321 errc <- f() 322 }() 323 } 324 timeout := time.NewTimer(2 * time.Second) 325 for { 326 select { 327 case err := <-errc: 328 if err != nil { 329 t.Fatal(err) 330 } 331 attempts-- 332 if attempts == 0 { 333 return 334 } 335 case <-timeout.C: 336 t.Fatalf("timeout") 337 } 338 } 339 }) 340 } 341 } 342 343 // helper function that compares reference and optimised implementations on 344 // correctness 345 func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) { 346 span := make([]byte, 8) 347 if len(d) < n { 348 n = len(d) 349 } 350 binary.BigEndian.PutUint64(span, uint64(n)) 351 data := d[:n] 352 rbmt := NewRefHasher(hasher, count) 353 exp := sha3hash(span, rbmt.Hash(data)) 354 got := syncHash(bmt, span, data) 355 if !bytes.Equal(got, exp) { 356 return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 357 } 358 return err 359 } 360 361 // 362 func BenchmarkBMT(t *testing.B) { 363 for size := 4096; size >= 128; size /= 2 { 364 t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) { 365 benchmarkSHA3(t, size) 366 }) 367 t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) { 368 benchmarkBMTBaseline(t, size) 369 }) 370 t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) { 371 benchmarkRefHasher(t, size) 372 }) 373 t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) { 374 benchmarkBMT(t, size) 375 }) 376 } 377 } 378 379 type whenHash = int 380 381 const ( 382 first whenHash = iota 383 last 384 random 385 ) 386 387 func BenchmarkBMTAsync(t *testing.B) { 388 whs := []whenHash{first, last, random} 389 for size := 4096; size >= 128; size /= 2 { 390 for _, wh := range whs { 391 for _, double := range []bool{false, true} { 392 t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) { 393 benchmarkBMTAsync(t, size, wh, double) 394 }) 395 } 396 } 397 } 398 } 399 400 func BenchmarkPool(t *testing.B) { 401 caps := []int{1, PoolSize} 402 for size := 4096; size >= 128; size /= 2 { 403 for _, c := range caps { 404 t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) { 405 benchmarkPool(t, c, size) 406 }) 407 } 408 } 409 } 410 411 // benchmarks simple sha3 hash on chunks 412 func benchmarkSHA3(t *testing.B, n int) { 413 data := testutil.RandomBytes(1, n) 414 hasher := sha3.NewLegacyKeccak256 415 h := hasher() 416 417 t.ReportAllocs() 418 t.ResetTimer() 419 for i := 0; i < t.N; i++ { 420 doSum(h, nil, data) 421 } 422 } 423 424 // benchmarks the minimum hashing time for a balanced (for simplicity) BMT 425 // by doing count/segmentsize parallel hashings of 2*segmentsize bytes 426 // doing it on n PoolSize each reusing the base hasher 427 // the premise is that this is the minimum computation needed for a BMT 428 // therefore this serves as a theoretical optimum for concurrent implementations 429 func benchmarkBMTBaseline(t *testing.B, n int) { 430 hasher := sha3.NewLegacyKeccak256 431 hashSize := hasher().Size() 432 data := testutil.RandomBytes(1, hashSize) 433 434 t.ReportAllocs() 435 t.ResetTimer() 436 for i := 0; i < t.N; i++ { 437 count := int32((n-1)/hashSize + 1) 438 wg := sync.WaitGroup{} 439 wg.Add(PoolSize) 440 var i int32 441 for j := 0; j < PoolSize; j++ { 442 go func() { 443 defer wg.Done() 444 h := hasher() 445 for atomic.AddInt32(&i, 1) < count { 446 doSum(h, nil, data) 447 } 448 }() 449 } 450 wg.Wait() 451 } 452 } 453 454 // benchmarks BMT Hasher 455 func benchmarkBMT(t *testing.B, n int) { 456 data := testutil.RandomBytes(1, n) 457 hasher := sha3.NewLegacyKeccak256 458 pool := NewTreePool(hasher, segmentCount, PoolSize) 459 bmt := New(pool) 460 461 t.ReportAllocs() 462 t.ResetTimer() 463 for i := 0; i < t.N; i++ { 464 syncHash(bmt, nil, data) 465 } 466 } 467 468 // benchmarks BMT hasher with asynchronous concurrent segment/section writes 469 func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) { 470 data := testutil.RandomBytes(1, n) 471 hasher := sha3.NewLegacyKeccak256 472 pool := NewTreePool(hasher, segmentCount, PoolSize) 473 bmt := New(pool).NewAsyncWriter(double) 474 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 475 rand.Shuffle(len(idxs), func(i int, j int) { 476 idxs[i], idxs[j] = idxs[j], idxs[i] 477 }) 478 479 t.ReportAllocs() 480 t.ResetTimer() 481 for i := 0; i < t.N; i++ { 482 asyncHash(bmt, nil, n, wh, idxs, segments) 483 } 484 } 485 486 // benchmarks 100 concurrent bmt hashes with pool capacity 487 func benchmarkPool(t *testing.B, poolsize, n int) { 488 data := testutil.RandomBytes(1, n) 489 hasher := sha3.NewLegacyKeccak256 490 pool := NewTreePool(hasher, segmentCount, poolsize) 491 cycles := 100 492 493 t.ReportAllocs() 494 t.ResetTimer() 495 wg := sync.WaitGroup{} 496 for i := 0; i < t.N; i++ { 497 wg.Add(cycles) 498 for j := 0; j < cycles; j++ { 499 go func() { 500 defer wg.Done() 501 bmt := New(pool) 502 syncHash(bmt, nil, data) 503 }() 504 } 505 wg.Wait() 506 } 507 } 508 509 // benchmarks the reference hasher 510 func benchmarkRefHasher(t *testing.B, n int) { 511 data := testutil.RandomBytes(1, n) 512 hasher := sha3.NewLegacyKeccak256 513 rbmt := NewRefHasher(hasher, 128) 514 515 t.ReportAllocs() 516 t.ResetTimer() 517 for i := 0; i < t.N; i++ { 518 rbmt.Hash(data) 519 } 520 } 521 522 // Hash hashes the data and the span using the bmt hasher 523 func syncHash(h *Hasher, span, data []byte) []byte { 524 h.ResetWithLength(span) 525 h.Write(data) 526 return h.Sum(nil) 527 } 528 529 func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) { 530 l := len(data) 531 n := l / secsize 532 if l%secsize > 0 { 533 n++ 534 } 535 for i := 0; i < n; i++ { 536 idxs = append(idxs, i) 537 end := (i + 1) * secsize 538 if end > l { 539 end = l 540 } 541 section := data[i*secsize : end] 542 segments = append(segments, section) 543 } 544 rand.Shuffle(n, func(i int, j int) { 545 idxs[i], idxs[j] = idxs[j], idxs[i] 546 }) 547 return idxs, segments 548 } 549 550 // splits the input data performs a random shuffle to mock async section writes 551 func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) { 552 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 553 return asyncHash(bmt, span, len(data), wh, idxs, segments) 554 } 555 556 // mock for async section writes for BMT SectionWriter 557 // requires a permutation (a random shuffle) of list of all indexes of segments 558 // and writes them in order to the appropriate section 559 // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes]) 560 func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { 561 bmt.Reset() 562 if l == 0 { 563 return bmt.Sum(nil, l, span) 564 } 565 c := make(chan []byte, 1) 566 hashf := func() { 567 c <- bmt.Sum(nil, l, span) 568 } 569 maxsize := len(idxs) 570 var r int 571 if wh == random { 572 r = rand.Intn(maxsize) 573 } 574 for i, idx := range idxs { 575 bmt.Write(idx, segments[idx]) 576 if (wh == first || wh == random) && i == r { 577 go hashf() 578 } 579 } 580 if wh == last { 581 return bmt.Sum(nil, l, span) 582 } 583 return <-c 584 }