github.com/linapex/ethereum-dpos-chinese@v0.0.0-20190316121959-b78b3a4a1ece/swarm/bmt/bmt_test.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 12:09:47</date>
    10  //</624342670646448128>
    11  
    12  //
    13  //
    14  //
    15  //
    16  //
    17  //
    18  //
    19  //
    20  //
    21  //
    22  //
    23  //
    24  //
    25  //
    26  //
    27  
    28  package bmt
    29  
    30  import (
    31  	"bytes"
    32  	crand "crypto/rand"
    33  	"encoding/binary"
    34  	"fmt"
    35  	"io"
    36  	"math/rand"
    37  	"sync"
    38  	"sync/atomic"
    39  	"testing"
    40  	"time"
    41  
    42  	"github.com/ethereum/go-ethereum/crypto/sha3"
    43  )
    44  
    45  //
    46  const BufferSize = 4128
    47  
    48  const (
    49  //
    50  //
    51  //
    52  	segmentCount = 128
    53  )
    54  
    55  var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    56  
    57  //
    58  func sha3hash(data ...[]byte) []byte {
    59  	h := sha3.NewKeccak256()
    60  	return doSum(h, nil, data...)
    61  }
    62  
    63  //
    64  //
    65  func TestRefHasher(t *testing.T) {
    66  //
    67  //
    68  	type test struct {
    69  		from     int
    70  		to       int
    71  		expected func([]byte) []byte
    72  	}
    73  
    74  	var tests []*test
    75  //
    76  //
    77  //
    78  //
    79  	tests = append(tests, &test{
    80  		from: 1,
    81  		to:   2,
    82  		expected: func(d []byte) []byte {
    83  			data := make([]byte, 64)
    84  			copy(data, d)
    85  			return sha3hash(data)
    86  		},
    87  	})
    88  
    89  //
    90  //
    91  //
    92  //
    93  //
    94  //
    95  //
    96  	tests = append(tests, &test{
    97  		from: 3,
    98  		to:   4,
    99  		expected: func(d []byte) []byte {
   100  			data := make([]byte, 128)
   101  			copy(data, d)
   102  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
   103  		},
   104  	})
   105  
   106  //
   107  //
   108  //
   109  //
   110  //
   111  //
   112  //
   113  //
   114  //
   115  //
   116  //
   117  //
   118  //
   119  	tests = append(tests, &test{
   120  		from: 5,
   121  		to:   8,
   122  		expected: func(d []byte) []byte {
   123  			data := make([]byte, 256)
   124  			copy(data, d)
   125  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   126  		},
   127  	})
   128  
   129  //
   130  	for _, x := range tests {
   131  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   132  			for length := 1; length <= segmentCount*32; length++ {
   133  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   134  					data := make([]byte, length)
   135  					if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   136  						t.Fatal(err)
   137  					}
   138  					expected := x.expected(data)
   139  					actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data)
   140  					if !bytes.Equal(actual, expected) {
   141  						t.Fatalf("expected %x, got %x", expected, actual)
   142  					}
   143  				})
   144  			}
   145  		}
   146  	}
   147  }
   148  
   149  //
   150  func TestHasherEmptyData(t *testing.T) {
   151  	hasher := sha3.NewKeccak256
   152  	var data []byte
   153  	for _, count := range counts {
   154  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   155  			pool := NewTreePool(hasher, count, PoolSize)
   156  			defer pool.Drain(0)
   157  			bmt := New(pool)
   158  			rbmt := NewRefHasher(hasher, count)
   159  			refHash := rbmt.Hash(data)
   160  			expHash := syncHash(bmt, nil, data)
   161  			if !bytes.Equal(expHash, refHash) {
   162  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   163  			}
   164  		})
   165  	}
   166  }
   167  
   168  //
   169  func TestSyncHasherCorrectness(t *testing.T) {
   170  	data := newData(BufferSize)
   171  	hasher := sha3.NewKeccak256
   172  	size := hasher().Size()
   173  
   174  	var err error
   175  	for _, count := range counts {
   176  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
   177  			max := count * size
   178  			var incr int
   179  			capacity := 1
   180  			pool := NewTreePool(hasher, count, capacity)
   181  			defer pool.Drain(0)
   182  			for n := 0; n <= max; n += incr {
   183  				incr = 1 + rand.Intn(5)
   184  				bmt := New(pool)
   185  				err = testHasherCorrectness(bmt, hasher, data, n, count)
   186  				if err != nil {
   187  					t.Fatal(err)
   188  				}
   189  			}
   190  		})
   191  	}
   192  }
   193  
   194  //
   195  func TestAsyncCorrectness(t *testing.T) {
   196  	data := newData(BufferSize)
   197  	hasher := sha3.NewKeccak256
   198  	size := hasher().Size()
   199  	whs := []whenHash{first, last, random}
   200  
   201  	for _, double := range []bool{false, true} {
   202  		for _, wh := range whs {
   203  			for _, count := range counts {
   204  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
   205  					max := count * size
   206  					var incr int
   207  					capacity := 1
   208  					pool := NewTreePool(hasher, count, capacity)
   209  					defer pool.Drain(0)
   210  					for n := 1; n <= max; n += incr {
   211  						incr = 1 + rand.Intn(5)
   212  						bmt := New(pool)
   213  						d := data[:n]
   214  						rbmt := NewRefHasher(hasher, count)
   215  						exp := rbmt.Hash(d)
   216  						got := syncHash(bmt, nil, d)
   217  						if !bytes.Equal(got, exp) {
   218  							t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
   219  						}
   220  						sw := bmt.NewAsyncWriter(double)
   221  						got = asyncHashRandom(sw, nil, d, wh)
   222  						if !bytes.Equal(got, exp) {
   223  							t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
   224  						}
   225  					}
   226  				})
   227  			}
   228  		}
   229  	}
   230  }
   231  
   232  //
   233  func TestHasherReuse(t *testing.T) {
   234  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   235  		testHasherReuse(1, t)
   236  	})
   237  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   238  		testHasherReuse(PoolSize, t)
   239  	})
   240  }
   241  
   242  //
   243  func testHasherReuse(poolsize int, t *testing.T) {
   244  	hasher := sha3.NewKeccak256
   245  	pool := NewTreePool(hasher, segmentCount, poolsize)
   246  	defer pool.Drain(0)
   247  	bmt := New(pool)
   248  
   249  	for i := 0; i < 100; i++ {
   250  		data := newData(BufferSize)
   251  		n := rand.Intn(bmt.Size())
   252  		err := testHasherCorrectness(bmt, hasher, data, n, segmentCount)
   253  		if err != nil {
   254  			t.Fatal(err)
   255  		}
   256  	}
   257  }
   258  
   259  //
   260  func TestBMTConcurrentUse(t *testing.T) {
   261  	hasher := sha3.NewKeccak256
   262  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   263  	defer pool.Drain(0)
   264  	cycles := 100
   265  	errc := make(chan error)
   266  
   267  	for i := 0; i < cycles; i++ {
   268  		go func() {
   269  			bmt := New(pool)
   270  			data := newData(BufferSize)
   271  			n := rand.Intn(bmt.Size())
   272  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   273  		}()
   274  	}
   275  LOOP:
   276  	for {
   277  		select {
   278  		case <-time.NewTimer(5 * time.Second).C:
   279  			t.Fatal("timed out")
   280  		case err := <-errc:
   281  			if err != nil {
   282  				t.Fatal(err)
   283  			}
   284  			cycles--
   285  			if cycles == 0 {
   286  				break LOOP
   287  			}
   288  		}
   289  	}
   290  }
   291  
   292  //
   293  //
   294  func TestBMTWriterBuffers(t *testing.T) {
   295  	hasher := sha3.NewKeccak256
   296  
   297  	for _, count := range counts {
   298  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   299  			errc := make(chan error)
   300  			pool := NewTreePool(hasher, count, PoolSize)
   301  			defer pool.Drain(0)
   302  			n := count * 32
   303  			bmt := New(pool)
   304  			data := newData(n)
   305  			rbmt := NewRefHasher(hasher, count)
   306  			refHash := rbmt.Hash(data)
   307  			expHash := syncHash(bmt, nil, data)
   308  			if !bytes.Equal(expHash, refHash) {
   309  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   310  			}
   311  			attempts := 10
   312  			f := func() error {
   313  				bmt := New(pool)
   314  				bmt.Reset()
   315  				var buflen int
   316  				for offset := 0; offset < n; offset += buflen {
   317  					buflen = rand.Intn(n-offset) + 1
   318  					read, err := bmt.Write(data[offset : offset+buflen])
   319  					if err != nil {
   320  						return err
   321  					}
   322  					if read != buflen {
   323  						return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
   324  					}
   325  				}
   326  				hash := bmt.Sum(nil)
   327  				if !bytes.Equal(hash, expHash) {
   328  					return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
   329  				}
   330  				return nil
   331  			}
   332  
   333  			for j := 0; j < attempts; j++ {
   334  				go func() {
   335  					errc <- f()
   336  				}()
   337  			}
   338  			timeout := time.NewTimer(2 * time.Second)
   339  			for {
   340  				select {
   341  				case err := <-errc:
   342  					if err != nil {
   343  						t.Fatal(err)
   344  					}
   345  					attempts--
   346  					if attempts == 0 {
   347  						return
   348  					}
   349  				case <-timeout.C:
   350  					t.Fatalf("timeout")
   351  				}
   352  			}
   353  		})
   354  	}
   355  }
   356  
   357  //
   358  //
   359  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   360  	span := make([]byte, 8)
   361  	if len(d) < n {
   362  		n = len(d)
   363  	}
   364  	binary.BigEndian.PutUint64(span, uint64(n))
   365  	data := d[:n]
   366  	rbmt := NewRefHasher(hasher, count)
   367  	exp := sha3hash(span, rbmt.Hash(data))
   368  	got := syncHash(bmt, span, data)
   369  	if !bytes.Equal(got, exp) {
   370  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   371  	}
   372  	return err
   373  }
   374  
   375  //
   376  func BenchmarkBMT(t *testing.B) {
   377  	for size := 4096; size >= 128; size /= 2 {
   378  		t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
   379  			benchmarkSHA3(t, size)
   380  		})
   381  		t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
   382  			benchmarkBMTBaseline(t, size)
   383  		})
   384  		t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
   385  			benchmarkRefHasher(t, size)
   386  		})
   387  		t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
   388  			benchmarkBMT(t, size)
   389  		})
   390  	}
   391  }
   392  
   393  type whenHash = int
   394  
   395  const (
   396  	first whenHash = iota
   397  	last
   398  	random
   399  )
   400  
   401  func BenchmarkBMTAsync(t *testing.B) {
   402  	whs := []whenHash{first, last, random}
   403  	for size := 4096; size >= 128; size /= 2 {
   404  		for _, wh := range whs {
   405  			for _, double := range []bool{false, true} {
   406  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
   407  					benchmarkBMTAsync(t, size, wh, double)
   408  				})
   409  			}
   410  		}
   411  	}
   412  }
   413  
   414  func BenchmarkPool(t *testing.B) {
   415  	caps := []int{1, PoolSize}
   416  	for size := 4096; size >= 128; size /= 2 {
   417  		for _, c := range caps {
   418  			t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
   419  				benchmarkPool(t, c, size)
   420  			})
   421  		}
   422  	}
   423  }
   424  
   425  //
   426  func benchmarkSHA3(t *testing.B, n int) {
   427  	data := newData(n)
   428  	hasher := sha3.NewKeccak256
   429  	h := hasher()
   430  
   431  	t.ReportAllocs()
   432  	t.ResetTimer()
   433  	for i := 0; i < t.N; i++ {
   434  		doSum(h, nil, data)
   435  	}
   436  }
   437  
   438  //
   439  //
   440  //
   441  //
   442  //
   443  func benchmarkBMTBaseline(t *testing.B, n int) {
   444  	hasher := sha3.NewKeccak256
   445  	hashSize := hasher().Size()
   446  	data := newData(hashSize)
   447  
   448  	t.ReportAllocs()
   449  	t.ResetTimer()
   450  	for i := 0; i < t.N; i++ {
   451  		count := int32((n-1)/hashSize + 1)
   452  		wg := sync.WaitGroup{}
   453  		wg.Add(PoolSize)
   454  		var i int32
   455  		for j := 0; j < PoolSize; j++ {
   456  			go func() {
   457  				defer wg.Done()
   458  				h := hasher()
   459  				for atomic.AddInt32(&i, 1) < count {
   460  					doSum(h, nil, data)
   461  				}
   462  			}()
   463  		}
   464  		wg.Wait()
   465  	}
   466  }
   467  
   468  //
   469  func benchmarkBMT(t *testing.B, n int) {
   470  	data := newData(n)
   471  	hasher := sha3.NewKeccak256
   472  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   473  	bmt := New(pool)
   474  
   475  	t.ReportAllocs()
   476  	t.ResetTimer()
   477  	for i := 0; i < t.N; i++ {
   478  		syncHash(bmt, nil, data)
   479  	}
   480  }
   481  
   482  //
   483  func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
   484  	data := newData(n)
   485  	hasher := sha3.NewKeccak256
   486  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   487  	bmt := New(pool).NewAsyncWriter(double)
   488  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   489  	shuffle(len(idxs), func(i int, j int) {
   490  		idxs[i], idxs[j] = idxs[j], idxs[i]
   491  	})
   492  
   493  	t.ReportAllocs()
   494  	t.ResetTimer()
   495  	for i := 0; i < t.N; i++ {
   496  		asyncHash(bmt, nil, n, wh, idxs, segments)
   497  	}
   498  }
   499  
   500  //
   501  func benchmarkPool(t *testing.B, poolsize, n int) {
   502  	data := newData(n)
   503  	hasher := sha3.NewKeccak256
   504  	pool := NewTreePool(hasher, segmentCount, poolsize)
   505  	cycles := 100
   506  
   507  	t.ReportAllocs()
   508  	t.ResetTimer()
   509  	wg := sync.WaitGroup{}
   510  	for i := 0; i < t.N; i++ {
   511  		wg.Add(cycles)
   512  		for j := 0; j < cycles; j++ {
   513  			go func() {
   514  				defer wg.Done()
   515  				bmt := New(pool)
   516  				syncHash(bmt, nil, data)
   517  			}()
   518  		}
   519  		wg.Wait()
   520  	}
   521  }
   522  
   523  //
   524  func benchmarkRefHasher(t *testing.B, n int) {
   525  	data := newData(n)
   526  	hasher := sha3.NewKeccak256
   527  	rbmt := NewRefHasher(hasher, 128)
   528  
   529  	t.ReportAllocs()
   530  	t.ResetTimer()
   531  	for i := 0; i < t.N; i++ {
   532  		rbmt.Hash(data)
   533  	}
   534  }
   535  
   536  func newData(bufferSize int) []byte {
   537  	data := make([]byte, bufferSize)
   538  	_, err := io.ReadFull(crand.Reader, data)
   539  	if err != nil {
   540  		panic(err.Error())
   541  	}
   542  	return data
   543  }
   544  
   545  //
   546  func syncHash(h *Hasher, span, data []byte) []byte {
   547  	h.ResetWithLength(span)
   548  	h.Write(data)
   549  	return h.Sum(nil)
   550  }
   551  
   552  func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
   553  	l := len(data)
   554  	n := l / secsize
   555  	if l%secsize > 0 {
   556  		n++
   557  	}
   558  	for i := 0; i < n; i++ {
   559  		idxs = append(idxs, i)
   560  		end := (i + 1) * secsize
   561  		if end > l {
   562  			end = l
   563  		}
   564  		section := data[i*secsize : end]
   565  		segments = append(segments, section)
   566  	}
   567  	shuffle(n, func(i int, j int) {
   568  		idxs[i], idxs[j] = idxs[j], idxs[i]
   569  	})
   570  	return idxs, segments
   571  }
   572  
   573  //
   574  func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
   575  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   576  	return asyncHash(bmt, span, len(data), wh, idxs, segments)
   577  }
   578  
   579  //
   580  //
   581  //
   582  //
   583  func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
   584  	bmt.Reset()
   585  	if l == 0 {
   586  		return bmt.Sum(nil, l, span)
   587  	}
   588  	c := make(chan []byte, 1)
   589  	hashf := func() {
   590  		c <- bmt.Sum(nil, l, span)
   591  	}
   592  	maxsize := len(idxs)
   593  	var r int
   594  	if wh == random {
   595  		r = rand.Intn(maxsize)
   596  	}
   597  	for i, idx := range idxs {
   598  		bmt.Write(idx, segments[idx])
   599  		if (wh == first || wh == random) && i == r {
   600  			go hashf()
   601  		}
   602  	}
   603  	if wh == last {
   604  		return bmt.Sum(nil, l, span)
   605  	}
   606  	return <-c
   607  }
   608  
   609  //
   610  //
   611  //
   612  //
   613  func shuffle(n int, swap func(i, j int)) {
   614  	if n < 0 {
   615  		panic("invalid argument to Shuffle")
   616  	}
   617  
   618  //
   619  //
   620  //
   621  //
   622  //
   623  //
   624  	i := n - 1
   625  	for ; i > 1<<31-1-1; i-- {
   626  		j := int(rand.Int63n(int64(i + 1)))
   627  		swap(i, j)
   628  	}
   629  	for ; i > 0; i-- {
   630  		j := int(rand.Int31n(int32(i + 1)))
   631  		swap(i, j)
   632  	}
   633  }
   634