github.com/yinchengtsinghua/golang-Eos-dpos-Ethereum@v0.0.0-20190121132951-92cc4225ed8e/swarm/bmt/bmt_test.go (about) 1 2 //此源码被清华学神尹成大魔王专业翻译分析并修改 3 //尹成QQ77025077 4 //尹成微信18510341407 5 //尹成所在QQ群721929980 6 //尹成邮箱 yinc13@mails.tsinghua.edu.cn 7 //尹成毕业于清华大学,微软区块链领域全球最有价值专家 8 //https://mvp.microsoft.com/zh-cn/PublicProfile/4033620 9 // 10 // 11 // 12 // 13 // 14 // 15 // 16 // 17 // 18 // 19 // 20 // 21 // 22 // 23 // 24 25 package bmt 26 27 import ( 28 "bytes" 29 crand "crypto/rand" 30 "encoding/binary" 31 "fmt" 32 "io" 33 "math/rand" 34 "sync" 35 "sync/atomic" 36 "testing" 37 "time" 38 39 "github.com/ethereum/go-ethereum/crypto/sha3" 40 ) 41 42 // 43 const BufferSize = 4128 44 45 const ( 46 // 47 // 48 // 49 segmentCount = 128 50 ) 51 52 var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128} 53 54 // 55 func sha3hash(data ...[]byte) []byte { 56 h := sha3.NewKeccak256() 57 return doSum(h, nil, data...) 58 } 59 60 // 61 // 62 func TestRefHasher(t *testing.T) { 63 // 64 // 65 type test struct { 66 from int 67 to int 68 expected func([]byte) []byte 69 } 70 71 var tests []*test 72 // 73 // 74 // 75 // 76 tests = append(tests, &test{ 77 from: 1, 78 to: 2, 79 expected: func(d []byte) []byte { 80 data := make([]byte, 64) 81 copy(data, d) 82 return sha3hash(data) 83 }, 84 }) 85 86 // 87 // 88 // 89 // 90 // 91 // 92 // 93 tests = append(tests, &test{ 94 from: 3, 95 to: 4, 96 expected: func(d []byte) []byte { 97 data := make([]byte, 128) 98 copy(data, d) 99 return sha3hash(sha3hash(data[:64]), sha3hash(data[64:])) 100 }, 101 }) 102 103 // 104 // 105 // 106 // 107 // 108 // 109 // 110 // 111 // 112 // 113 // 114 // 115 // 116 tests = append(tests, &test{ 117 from: 5, 118 to: 8, 119 expected: func(d []byte) []byte { 120 data := make([]byte, 256) 121 copy(data, d) 122 return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:]))) 123 }, 124 }) 125 126 // 127 for _, x := range tests { 128 for segmentCount := x.from; segmentCount <= x.to; segmentCount++ { 129 for length := 1; length <= segmentCount*32; length++ { 130 t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) { 131 data := make([]byte, length) 132 if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF { 133 t.Fatal(err) 134 } 135 expected := x.expected(data) 136 actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data) 137 if !bytes.Equal(actual, expected) { 138 t.Fatalf("expected %x, got %x", expected, actual) 139 } 140 }) 141 } 142 } 143 } 144 } 145 146 // 147 func TestHasherEmptyData(t *testing.T) { 148 hasher := sha3.NewKeccak256 149 var data []byte 150 for _, count := range counts { 151 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 152 pool := NewTreePool(hasher, count, PoolSize) 153 defer pool.Drain(0) 154 bmt := New(pool) 155 rbmt := NewRefHasher(hasher, count) 156 refHash := rbmt.Hash(data) 157 expHash := syncHash(bmt, nil, data) 158 if !bytes.Equal(expHash, refHash) { 159 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 160 } 161 }) 162 } 163 } 164 165 // 166 func TestSyncHasherCorrectness(t *testing.T) { 167 data := newData(BufferSize) 168 hasher := sha3.NewKeccak256 169 size := hasher().Size() 170 171 var err error 172 for _, count := range counts { 173 t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) { 174 max := count * size 175 var incr int 176 capacity := 1 177 pool := NewTreePool(hasher, count, capacity) 178 defer pool.Drain(0) 179 for n := 0; n <= max; n += incr { 180 incr = 1 + rand.Intn(5) 181 bmt := New(pool) 182 err = testHasherCorrectness(bmt, hasher, data, n, count) 183 if err != nil { 184 t.Fatal(err) 185 } 186 } 187 }) 188 } 189 } 190 191 // 192 func TestAsyncCorrectness(t *testing.T) { 193 data := newData(BufferSize) 194 hasher := sha3.NewKeccak256 195 size := hasher().Size() 196 whs := []whenHash{first, last, random} 197 198 for _, double := range []bool{false, true} { 199 for _, wh := range whs { 200 for _, count := range counts { 201 t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) { 202 max := count * size 203 var incr int 204 capacity := 1 205 pool := NewTreePool(hasher, count, capacity) 206 defer pool.Drain(0) 207 for n := 1; n <= max; n += incr { 208 incr = 1 + rand.Intn(5) 209 bmt := New(pool) 210 d := data[:n] 211 rbmt := NewRefHasher(hasher, count) 212 exp := rbmt.Hash(d) 213 got := syncHash(bmt, nil, d) 214 if !bytes.Equal(got, exp) { 215 t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got) 216 } 217 sw := bmt.NewAsyncWriter(double) 218 got = asyncHashRandom(sw, nil, d, wh) 219 if !bytes.Equal(got, exp) { 220 t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got) 221 } 222 } 223 }) 224 } 225 } 226 } 227 } 228 229 // 230 func TestHasherReuse(t *testing.T) { 231 t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) { 232 testHasherReuse(1, t) 233 }) 234 t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) { 235 testHasherReuse(PoolSize, t) 236 }) 237 } 238 239 // 240 func testHasherReuse(poolsize int, t *testing.T) { 241 hasher := sha3.NewKeccak256 242 pool := NewTreePool(hasher, segmentCount, poolsize) 243 defer pool.Drain(0) 244 bmt := New(pool) 245 246 for i := 0; i < 100; i++ { 247 data := newData(BufferSize) 248 n := rand.Intn(bmt.Size()) 249 err := testHasherCorrectness(bmt, hasher, data, n, segmentCount) 250 if err != nil { 251 t.Fatal(err) 252 } 253 } 254 } 255 256 // 257 func TestBMTConcurrentUse(t *testing.T) { 258 hasher := sha3.NewKeccak256 259 pool := NewTreePool(hasher, segmentCount, PoolSize) 260 defer pool.Drain(0) 261 cycles := 100 262 errc := make(chan error) 263 264 for i := 0; i < cycles; i++ { 265 go func() { 266 bmt := New(pool) 267 data := newData(BufferSize) 268 n := rand.Intn(bmt.Size()) 269 errc <- testHasherCorrectness(bmt, hasher, data, n, 128) 270 }() 271 } 272 LOOP: 273 for { 274 select { 275 case <-time.NewTimer(5 * time.Second).C: 276 t.Fatal("timed out") 277 case err := <-errc: 278 if err != nil { 279 t.Fatal(err) 280 } 281 cycles-- 282 if cycles == 0 { 283 break LOOP 284 } 285 } 286 } 287 } 288 289 // 290 // 291 func TestBMTWriterBuffers(t *testing.T) { 292 hasher := sha3.NewKeccak256 293 294 for _, count := range counts { 295 t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) { 296 errc := make(chan error) 297 pool := NewTreePool(hasher, count, PoolSize) 298 defer pool.Drain(0) 299 n := count * 32 300 bmt := New(pool) 301 data := newData(n) 302 rbmt := NewRefHasher(hasher, count) 303 refHash := rbmt.Hash(data) 304 expHash := syncHash(bmt, nil, data) 305 if !bytes.Equal(expHash, refHash) { 306 t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash) 307 } 308 attempts := 10 309 f := func() error { 310 bmt := New(pool) 311 bmt.Reset() 312 var buflen int 313 for offset := 0; offset < n; offset += buflen { 314 buflen = rand.Intn(n-offset) + 1 315 read, err := bmt.Write(data[offset : offset+buflen]) 316 if err != nil { 317 return err 318 } 319 if read != buflen { 320 return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read) 321 } 322 } 323 hash := bmt.Sum(nil) 324 if !bytes.Equal(hash, expHash) { 325 return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash) 326 } 327 return nil 328 } 329 330 for j := 0; j < attempts; j++ { 331 go func() { 332 errc <- f() 333 }() 334 } 335 timeout := time.NewTimer(2 * time.Second) 336 for { 337 select { 338 case err := <-errc: 339 if err != nil { 340 t.Fatal(err) 341 } 342 attempts-- 343 if attempts == 0 { 344 return 345 } 346 case <-timeout.C: 347 t.Fatalf("timeout") 348 } 349 } 350 }) 351 } 352 } 353 354 // 355 // 356 func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) { 357 span := make([]byte, 8) 358 if len(d) < n { 359 n = len(d) 360 } 361 binary.BigEndian.PutUint64(span, uint64(n)) 362 data := d[:n] 363 rbmt := NewRefHasher(hasher, count) 364 exp := sha3hash(span, rbmt.Hash(data)) 365 got := syncHash(bmt, span, data) 366 if !bytes.Equal(got, exp) { 367 return fmt.Errorf("wrong hash: expected %x, got %x", exp, got) 368 } 369 return err 370 } 371 372 // 373 func BenchmarkBMT(t *testing.B) { 374 for size := 4096; size >= 128; size /= 2 { 375 t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) { 376 benchmarkSHA3(t, size) 377 }) 378 t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) { 379 benchmarkBMTBaseline(t, size) 380 }) 381 t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) { 382 benchmarkRefHasher(t, size) 383 }) 384 t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) { 385 benchmarkBMT(t, size) 386 }) 387 } 388 } 389 390 type whenHash = int 391 392 const ( 393 first whenHash = iota 394 last 395 random 396 ) 397 398 func BenchmarkBMTAsync(t *testing.B) { 399 whs := []whenHash{first, last, random} 400 for size := 4096; size >= 128; size /= 2 { 401 for _, wh := range whs { 402 for _, double := range []bool{false, true} { 403 t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) { 404 benchmarkBMTAsync(t, size, wh, double) 405 }) 406 } 407 } 408 } 409 } 410 411 func BenchmarkPool(t *testing.B) { 412 caps := []int{1, PoolSize} 413 for size := 4096; size >= 128; size /= 2 { 414 for _, c := range caps { 415 t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) { 416 benchmarkPool(t, c, size) 417 }) 418 } 419 } 420 } 421 422 // 423 func benchmarkSHA3(t *testing.B, n int) { 424 data := newData(n) 425 hasher := sha3.NewKeccak256 426 h := hasher() 427 428 t.ReportAllocs() 429 t.ResetTimer() 430 for i := 0; i < t.N; i++ { 431 doSum(h, nil, data) 432 } 433 } 434 435 // 436 // 437 // 438 // 439 // 440 func benchmarkBMTBaseline(t *testing.B, n int) { 441 hasher := sha3.NewKeccak256 442 hashSize := hasher().Size() 443 data := newData(hashSize) 444 445 t.ReportAllocs() 446 t.ResetTimer() 447 for i := 0; i < t.N; i++ { 448 count := int32((n-1)/hashSize + 1) 449 wg := sync.WaitGroup{} 450 wg.Add(PoolSize) 451 var i int32 452 for j := 0; j < PoolSize; j++ { 453 go func() { 454 defer wg.Done() 455 h := hasher() 456 for atomic.AddInt32(&i, 1) < count { 457 doSum(h, nil, data) 458 } 459 }() 460 } 461 wg.Wait() 462 } 463 } 464 465 // 466 func benchmarkBMT(t *testing.B, n int) { 467 data := newData(n) 468 hasher := sha3.NewKeccak256 469 pool := NewTreePool(hasher, segmentCount, PoolSize) 470 bmt := New(pool) 471 472 t.ReportAllocs() 473 t.ResetTimer() 474 for i := 0; i < t.N; i++ { 475 syncHash(bmt, nil, data) 476 } 477 } 478 479 // 480 func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) { 481 data := newData(n) 482 hasher := sha3.NewKeccak256 483 pool := NewTreePool(hasher, segmentCount, PoolSize) 484 bmt := New(pool).NewAsyncWriter(double) 485 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 486 shuffle(len(idxs), func(i int, j int) { 487 idxs[i], idxs[j] = idxs[j], idxs[i] 488 }) 489 490 t.ReportAllocs() 491 t.ResetTimer() 492 for i := 0; i < t.N; i++ { 493 asyncHash(bmt, nil, n, wh, idxs, segments) 494 } 495 } 496 497 // 498 func benchmarkPool(t *testing.B, poolsize, n int) { 499 data := newData(n) 500 hasher := sha3.NewKeccak256 501 pool := NewTreePool(hasher, segmentCount, poolsize) 502 cycles := 100 503 504 t.ReportAllocs() 505 t.ResetTimer() 506 wg := sync.WaitGroup{} 507 for i := 0; i < t.N; i++ { 508 wg.Add(cycles) 509 for j := 0; j < cycles; j++ { 510 go func() { 511 defer wg.Done() 512 bmt := New(pool) 513 syncHash(bmt, nil, data) 514 }() 515 } 516 wg.Wait() 517 } 518 } 519 520 // 521 func benchmarkRefHasher(t *testing.B, n int) { 522 data := newData(n) 523 hasher := sha3.NewKeccak256 524 rbmt := NewRefHasher(hasher, 128) 525 526 t.ReportAllocs() 527 t.ResetTimer() 528 for i := 0; i < t.N; i++ { 529 rbmt.Hash(data) 530 } 531 } 532 533 func newData(bufferSize int) []byte { 534 data := make([]byte, bufferSize) 535 _, err := io.ReadFull(crand.Reader, data) 536 if err != nil { 537 panic(err.Error()) 538 } 539 return data 540 } 541 542 // 543 func syncHash(h *Hasher, span, data []byte) []byte { 544 h.ResetWithLength(span) 545 h.Write(data) 546 return h.Sum(nil) 547 } 548 549 func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) { 550 l := len(data) 551 n := l / secsize 552 if l%secsize > 0 { 553 n++ 554 } 555 for i := 0; i < n; i++ { 556 idxs = append(idxs, i) 557 end := (i + 1) * secsize 558 if end > l { 559 end = l 560 } 561 section := data[i*secsize : end] 562 segments = append(segments, section) 563 } 564 shuffle(n, func(i int, j int) { 565 idxs[i], idxs[j] = idxs[j], idxs[i] 566 }) 567 return idxs, segments 568 } 569 570 // 571 func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) { 572 idxs, segments := splitAndShuffle(bmt.SectionSize(), data) 573 return asyncHash(bmt, span, len(data), wh, idxs, segments) 574 } 575 576 // 577 // 578 // 579 // 580 func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) { 581 bmt.Reset() 582 if l == 0 { 583 return bmt.Sum(nil, l, span) 584 } 585 c := make(chan []byte, 1) 586 hashf := func() { 587 c <- bmt.Sum(nil, l, span) 588 } 589 maxsize := len(idxs) 590 var r int 591 if wh == random { 592 r = rand.Intn(maxsize) 593 } 594 for i, idx := range idxs { 595 bmt.Write(idx, segments[idx]) 596 if (wh == first || wh == random) && i == r { 597 go hashf() 598 } 599 } 600 if wh == last { 601 return bmt.Sum(nil, l, span) 602 } 603 return <-c 604 } 605 606 // 607 // 608 // 609 // 610 func shuffle(n int, swap func(i, j int)) { 611 if n < 0 { 612 panic("invalid argument to Shuffle") 613 } 614 615 // 616 // 617 // 618 // 619 // 620 // 621 i := n - 1 622 for ; i > 1<<31-1-1; i-- { 623 j := int(rand.Int63n(int64(i + 1))) 624 swap(i, j) 625 } 626 for ; i > 0; i-- { 627 j := int(rand.Int31n(int32(i + 1))) 628 swap(i, j) 629 } 630 }