github.com/yinchengtsinghua/golang-Eos-dpos-Ethereum@v0.0.0-20190121132951-92cc4225ed8e/swarm/bmt/bmt_test.go (about)

     1  
     2  //此源码被清华学神尹成大魔王专业翻译分析并修改
     3  //尹成QQ77025077
     4  //尹成微信18510341407
     5  //尹成所在QQ群721929980
     6  //尹成邮箱 yinc13@mails.tsinghua.edu.cn
     7  //尹成毕业于清华大学,微软区块链领域全球最有价值专家
     8  //https://mvp.microsoft.com/zh-cn/PublicProfile/4033620
     9  //
    10  //
    11  //
    12  //
    13  //
    14  //
    15  //
    16  //
    17  //
    18  //
    19  //
    20  //
    21  //
    22  //
    23  //
    24  
    25  package bmt
    26  
    27  import (
    28  	"bytes"
    29  	crand "crypto/rand"
    30  	"encoding/binary"
    31  	"fmt"
    32  	"io"
    33  	"math/rand"
    34  	"sync"
    35  	"sync/atomic"
    36  	"testing"
    37  	"time"
    38  
    39  	"github.com/ethereum/go-ethereum/crypto/sha3"
    40  )
    41  
    42  //
    43  const BufferSize = 4128
    44  
    45  const (
    46  //
    47  //
    48  //
    49  	segmentCount = 128
    50  )
    51  
    52  var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    53  
    54  //
    55  func sha3hash(data ...[]byte) []byte {
    56  	h := sha3.NewKeccak256()
    57  	return doSum(h, nil, data...)
    58  }
    59  
    60  //
    61  //
    62  func TestRefHasher(t *testing.T) {
    63  //
    64  //
    65  	type test struct {
    66  		from     int
    67  		to       int
    68  		expected func([]byte) []byte
    69  	}
    70  
    71  	var tests []*test
    72  //
    73  //
    74  //
    75  //
    76  	tests = append(tests, &test{
    77  		from: 1,
    78  		to:   2,
    79  		expected: func(d []byte) []byte {
    80  			data := make([]byte, 64)
    81  			copy(data, d)
    82  			return sha3hash(data)
    83  		},
    84  	})
    85  
    86  //
    87  //
    88  //
    89  //
    90  //
    91  //
    92  //
    93  	tests = append(tests, &test{
    94  		from: 3,
    95  		to:   4,
    96  		expected: func(d []byte) []byte {
    97  			data := make([]byte, 128)
    98  			copy(data, d)
    99  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
   100  		},
   101  	})
   102  
   103  //
   104  //
   105  //
   106  //
   107  //
   108  //
   109  //
   110  //
   111  //
   112  //
   113  //
   114  //
   115  //
   116  	tests = append(tests, &test{
   117  		from: 5,
   118  		to:   8,
   119  		expected: func(d []byte) []byte {
   120  			data := make([]byte, 256)
   121  			copy(data, d)
   122  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   123  		},
   124  	})
   125  
   126  //
   127  	for _, x := range tests {
   128  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   129  			for length := 1; length <= segmentCount*32; length++ {
   130  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   131  					data := make([]byte, length)
   132  					if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   133  						t.Fatal(err)
   134  					}
   135  					expected := x.expected(data)
   136  					actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data)
   137  					if !bytes.Equal(actual, expected) {
   138  						t.Fatalf("expected %x, got %x", expected, actual)
   139  					}
   140  				})
   141  			}
   142  		}
   143  	}
   144  }
   145  
   146  //
   147  func TestHasherEmptyData(t *testing.T) {
   148  	hasher := sha3.NewKeccak256
   149  	var data []byte
   150  	for _, count := range counts {
   151  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   152  			pool := NewTreePool(hasher, count, PoolSize)
   153  			defer pool.Drain(0)
   154  			bmt := New(pool)
   155  			rbmt := NewRefHasher(hasher, count)
   156  			refHash := rbmt.Hash(data)
   157  			expHash := syncHash(bmt, nil, data)
   158  			if !bytes.Equal(expHash, refHash) {
   159  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   160  			}
   161  		})
   162  	}
   163  }
   164  
   165  //
   166  func TestSyncHasherCorrectness(t *testing.T) {
   167  	data := newData(BufferSize)
   168  	hasher := sha3.NewKeccak256
   169  	size := hasher().Size()
   170  
   171  	var err error
   172  	for _, count := range counts {
   173  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
   174  			max := count * size
   175  			var incr int
   176  			capacity := 1
   177  			pool := NewTreePool(hasher, count, capacity)
   178  			defer pool.Drain(0)
   179  			for n := 0; n <= max; n += incr {
   180  				incr = 1 + rand.Intn(5)
   181  				bmt := New(pool)
   182  				err = testHasherCorrectness(bmt, hasher, data, n, count)
   183  				if err != nil {
   184  					t.Fatal(err)
   185  				}
   186  			}
   187  		})
   188  	}
   189  }
   190  
   191  //
   192  func TestAsyncCorrectness(t *testing.T) {
   193  	data := newData(BufferSize)
   194  	hasher := sha3.NewKeccak256
   195  	size := hasher().Size()
   196  	whs := []whenHash{first, last, random}
   197  
   198  	for _, double := range []bool{false, true} {
   199  		for _, wh := range whs {
   200  			for _, count := range counts {
   201  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
   202  					max := count * size
   203  					var incr int
   204  					capacity := 1
   205  					pool := NewTreePool(hasher, count, capacity)
   206  					defer pool.Drain(0)
   207  					for n := 1; n <= max; n += incr {
   208  						incr = 1 + rand.Intn(5)
   209  						bmt := New(pool)
   210  						d := data[:n]
   211  						rbmt := NewRefHasher(hasher, count)
   212  						exp := rbmt.Hash(d)
   213  						got := syncHash(bmt, nil, d)
   214  						if !bytes.Equal(got, exp) {
   215  							t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
   216  						}
   217  						sw := bmt.NewAsyncWriter(double)
   218  						got = asyncHashRandom(sw, nil, d, wh)
   219  						if !bytes.Equal(got, exp) {
   220  							t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
   221  						}
   222  					}
   223  				})
   224  			}
   225  		}
   226  	}
   227  }
   228  
   229  //
   230  func TestHasherReuse(t *testing.T) {
   231  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   232  		testHasherReuse(1, t)
   233  	})
   234  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   235  		testHasherReuse(PoolSize, t)
   236  	})
   237  }
   238  
   239  //
   240  func testHasherReuse(poolsize int, t *testing.T) {
   241  	hasher := sha3.NewKeccak256
   242  	pool := NewTreePool(hasher, segmentCount, poolsize)
   243  	defer pool.Drain(0)
   244  	bmt := New(pool)
   245  
   246  	for i := 0; i < 100; i++ {
   247  		data := newData(BufferSize)
   248  		n := rand.Intn(bmt.Size())
   249  		err := testHasherCorrectness(bmt, hasher, data, n, segmentCount)
   250  		if err != nil {
   251  			t.Fatal(err)
   252  		}
   253  	}
   254  }
   255  
   256  //
   257  func TestBMTConcurrentUse(t *testing.T) {
   258  	hasher := sha3.NewKeccak256
   259  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   260  	defer pool.Drain(0)
   261  	cycles := 100
   262  	errc := make(chan error)
   263  
   264  	for i := 0; i < cycles; i++ {
   265  		go func() {
   266  			bmt := New(pool)
   267  			data := newData(BufferSize)
   268  			n := rand.Intn(bmt.Size())
   269  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   270  		}()
   271  	}
   272  LOOP:
   273  	for {
   274  		select {
   275  		case <-time.NewTimer(5 * time.Second).C:
   276  			t.Fatal("timed out")
   277  		case err := <-errc:
   278  			if err != nil {
   279  				t.Fatal(err)
   280  			}
   281  			cycles--
   282  			if cycles == 0 {
   283  				break LOOP
   284  			}
   285  		}
   286  	}
   287  }
   288  
   289  //
   290  //
   291  func TestBMTWriterBuffers(t *testing.T) {
   292  	hasher := sha3.NewKeccak256
   293  
   294  	for _, count := range counts {
   295  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   296  			errc := make(chan error)
   297  			pool := NewTreePool(hasher, count, PoolSize)
   298  			defer pool.Drain(0)
   299  			n := count * 32
   300  			bmt := New(pool)
   301  			data := newData(n)
   302  			rbmt := NewRefHasher(hasher, count)
   303  			refHash := rbmt.Hash(data)
   304  			expHash := syncHash(bmt, nil, data)
   305  			if !bytes.Equal(expHash, refHash) {
   306  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   307  			}
   308  			attempts := 10
   309  			f := func() error {
   310  				bmt := New(pool)
   311  				bmt.Reset()
   312  				var buflen int
   313  				for offset := 0; offset < n; offset += buflen {
   314  					buflen = rand.Intn(n-offset) + 1
   315  					read, err := bmt.Write(data[offset : offset+buflen])
   316  					if err != nil {
   317  						return err
   318  					}
   319  					if read != buflen {
   320  						return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
   321  					}
   322  				}
   323  				hash := bmt.Sum(nil)
   324  				if !bytes.Equal(hash, expHash) {
   325  					return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
   326  				}
   327  				return nil
   328  			}
   329  
   330  			for j := 0; j < attempts; j++ {
   331  				go func() {
   332  					errc <- f()
   333  				}()
   334  			}
   335  			timeout := time.NewTimer(2 * time.Second)
   336  			for {
   337  				select {
   338  				case err := <-errc:
   339  					if err != nil {
   340  						t.Fatal(err)
   341  					}
   342  					attempts--
   343  					if attempts == 0 {
   344  						return
   345  					}
   346  				case <-timeout.C:
   347  					t.Fatalf("timeout")
   348  				}
   349  			}
   350  		})
   351  	}
   352  }
   353  
   354  //
   355  //
   356  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   357  	span := make([]byte, 8)
   358  	if len(d) < n {
   359  		n = len(d)
   360  	}
   361  	binary.BigEndian.PutUint64(span, uint64(n))
   362  	data := d[:n]
   363  	rbmt := NewRefHasher(hasher, count)
   364  	exp := sha3hash(span, rbmt.Hash(data))
   365  	got := syncHash(bmt, span, data)
   366  	if !bytes.Equal(got, exp) {
   367  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   368  	}
   369  	return err
   370  }
   371  
   372  //
   373  func BenchmarkBMT(t *testing.B) {
   374  	for size := 4096; size >= 128; size /= 2 {
   375  		t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
   376  			benchmarkSHA3(t, size)
   377  		})
   378  		t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
   379  			benchmarkBMTBaseline(t, size)
   380  		})
   381  		t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
   382  			benchmarkRefHasher(t, size)
   383  		})
   384  		t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
   385  			benchmarkBMT(t, size)
   386  		})
   387  	}
   388  }
   389  
   390  type whenHash = int
   391  
   392  const (
   393  	first whenHash = iota
   394  	last
   395  	random
   396  )
   397  
   398  func BenchmarkBMTAsync(t *testing.B) {
   399  	whs := []whenHash{first, last, random}
   400  	for size := 4096; size >= 128; size /= 2 {
   401  		for _, wh := range whs {
   402  			for _, double := range []bool{false, true} {
   403  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
   404  					benchmarkBMTAsync(t, size, wh, double)
   405  				})
   406  			}
   407  		}
   408  	}
   409  }
   410  
   411  func BenchmarkPool(t *testing.B) {
   412  	caps := []int{1, PoolSize}
   413  	for size := 4096; size >= 128; size /= 2 {
   414  		for _, c := range caps {
   415  			t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
   416  				benchmarkPool(t, c, size)
   417  			})
   418  		}
   419  	}
   420  }
   421  
   422  //
   423  func benchmarkSHA3(t *testing.B, n int) {
   424  	data := newData(n)
   425  	hasher := sha3.NewKeccak256
   426  	h := hasher()
   427  
   428  	t.ReportAllocs()
   429  	t.ResetTimer()
   430  	for i := 0; i < t.N; i++ {
   431  		doSum(h, nil, data)
   432  	}
   433  }
   434  
   435  //
   436  //
   437  //
   438  //
   439  //
   440  func benchmarkBMTBaseline(t *testing.B, n int) {
   441  	hasher := sha3.NewKeccak256
   442  	hashSize := hasher().Size()
   443  	data := newData(hashSize)
   444  
   445  	t.ReportAllocs()
   446  	t.ResetTimer()
   447  	for i := 0; i < t.N; i++ {
   448  		count := int32((n-1)/hashSize + 1)
   449  		wg := sync.WaitGroup{}
   450  		wg.Add(PoolSize)
   451  		var i int32
   452  		for j := 0; j < PoolSize; j++ {
   453  			go func() {
   454  				defer wg.Done()
   455  				h := hasher()
   456  				for atomic.AddInt32(&i, 1) < count {
   457  					doSum(h, nil, data)
   458  				}
   459  			}()
   460  		}
   461  		wg.Wait()
   462  	}
   463  }
   464  
   465  //
   466  func benchmarkBMT(t *testing.B, n int) {
   467  	data := newData(n)
   468  	hasher := sha3.NewKeccak256
   469  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   470  	bmt := New(pool)
   471  
   472  	t.ReportAllocs()
   473  	t.ResetTimer()
   474  	for i := 0; i < t.N; i++ {
   475  		syncHash(bmt, nil, data)
   476  	}
   477  }
   478  
   479  //
   480  func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
   481  	data := newData(n)
   482  	hasher := sha3.NewKeccak256
   483  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   484  	bmt := New(pool).NewAsyncWriter(double)
   485  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   486  	shuffle(len(idxs), func(i int, j int) {
   487  		idxs[i], idxs[j] = idxs[j], idxs[i]
   488  	})
   489  
   490  	t.ReportAllocs()
   491  	t.ResetTimer()
   492  	for i := 0; i < t.N; i++ {
   493  		asyncHash(bmt, nil, n, wh, idxs, segments)
   494  	}
   495  }
   496  
   497  //
   498  func benchmarkPool(t *testing.B, poolsize, n int) {
   499  	data := newData(n)
   500  	hasher := sha3.NewKeccak256
   501  	pool := NewTreePool(hasher, segmentCount, poolsize)
   502  	cycles := 100
   503  
   504  	t.ReportAllocs()
   505  	t.ResetTimer()
   506  	wg := sync.WaitGroup{}
   507  	for i := 0; i < t.N; i++ {
   508  		wg.Add(cycles)
   509  		for j := 0; j < cycles; j++ {
   510  			go func() {
   511  				defer wg.Done()
   512  				bmt := New(pool)
   513  				syncHash(bmt, nil, data)
   514  			}()
   515  		}
   516  		wg.Wait()
   517  	}
   518  }
   519  
   520  //
   521  func benchmarkRefHasher(t *testing.B, n int) {
   522  	data := newData(n)
   523  	hasher := sha3.NewKeccak256
   524  	rbmt := NewRefHasher(hasher, 128)
   525  
   526  	t.ReportAllocs()
   527  	t.ResetTimer()
   528  	for i := 0; i < t.N; i++ {
   529  		rbmt.Hash(data)
   530  	}
   531  }
   532  
   533  func newData(bufferSize int) []byte {
   534  	data := make([]byte, bufferSize)
   535  	_, err := io.ReadFull(crand.Reader, data)
   536  	if err != nil {
   537  		panic(err.Error())
   538  	}
   539  	return data
   540  }
   541  
   542  //
   543  func syncHash(h *Hasher, span, data []byte) []byte {
   544  	h.ResetWithLength(span)
   545  	h.Write(data)
   546  	return h.Sum(nil)
   547  }
   548  
   549  func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
   550  	l := len(data)
   551  	n := l / secsize
   552  	if l%secsize > 0 {
   553  		n++
   554  	}
   555  	for i := 0; i < n; i++ {
   556  		idxs = append(idxs, i)
   557  		end := (i + 1) * secsize
   558  		if end > l {
   559  			end = l
   560  		}
   561  		section := data[i*secsize : end]
   562  		segments = append(segments, section)
   563  	}
   564  	shuffle(n, func(i int, j int) {
   565  		idxs[i], idxs[j] = idxs[j], idxs[i]
   566  	})
   567  	return idxs, segments
   568  }
   569  
   570  //
   571  func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
   572  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   573  	return asyncHash(bmt, span, len(data), wh, idxs, segments)
   574  }
   575  
   576  //
   577  //
   578  //
   579  //
   580  func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
   581  	bmt.Reset()
   582  	if l == 0 {
   583  		return bmt.Sum(nil, l, span)
   584  	}
   585  	c := make(chan []byte, 1)
   586  	hashf := func() {
   587  		c <- bmt.Sum(nil, l, span)
   588  	}
   589  	maxsize := len(idxs)
   590  	var r int
   591  	if wh == random {
   592  		r = rand.Intn(maxsize)
   593  	}
   594  	for i, idx := range idxs {
   595  		bmt.Write(idx, segments[idx])
   596  		if (wh == first || wh == random) && i == r {
   597  			go hashf()
   598  		}
   599  	}
   600  	if wh == last {
   601  		return bmt.Sum(nil, l, span)
   602  	}
   603  	return <-c
   604  }
   605  
   606  //
   607  //
   608  //
   609  //
   610  func shuffle(n int, swap func(i, j int)) {
   611  	if n < 0 {
   612  		panic("invalid argument to Shuffle")
   613  	}
   614  
   615  //
   616  //
   617  //
   618  //
   619  //
   620  //
   621  	i := n - 1
   622  	for ; i > 1<<31-1-1; i-- {
   623  		j := int(rand.Int63n(int64(i + 1)))
   624  		swap(i, j)
   625  	}
   626  	for ; i > 0; i-- {
   627  		j := int(rand.Int31n(int32(i + 1)))
   628  		swap(i, j)
   629  	}
   630  }