github.com/codingfuture/orig-energi3@v0.8.4/swarm/bmt/bmt_test.go (about)

     1  // Copyright 2018 The Energi Core Authors
     2  // Copyright 2018 The go-ethereum Authors
     3  // This file is part of the Energi Core library.
     4  //
     5  // The Energi Core library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The Energi Core library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the Energi Core library. If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package bmt
    19  
    20  import (
    21  	"bytes"
    22  	"encoding/binary"
    23  	"fmt"
    24  	"math/rand"
    25  	"sync"
    26  	"sync/atomic"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/ethereum/go-ethereum/swarm/testutil"
    31  	"golang.org/x/crypto/sha3"
    32  )
    33  
    34  // the actual data length generated (could be longer than max datalength of the BMT)
    35  const BufferSize = 4128
    36  
    37  const (
    38  	// segmentCount is the maximum number of segments of the underlying chunk
    39  	// Should be equal to max-chunk-data-size / hash-size
    40  	// Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size)
    41  	segmentCount = 128
    42  )
    43  
    44  var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    45  
    46  // calculates the Keccak256 SHA3 hash of the data
    47  func sha3hash(data ...[]byte) []byte {
    48  	h := sha3.NewLegacyKeccak256()
    49  	return doSum(h, nil, data...)
    50  }
    51  
    52  // TestRefHasher tests that the RefHasher computes the expected BMT hash for
    53  // some small data lengths
    54  func TestRefHasher(t *testing.T) {
    55  	// the test struct is used to specify the expected BMT hash for
    56  	// segment counts between from and to and lengths from 1 to datalength
    57  	type test struct {
    58  		from     int
    59  		to       int
    60  		expected func([]byte) []byte
    61  	}
    62  
    63  	var tests []*test
    64  	// all lengths in [0,64] should be:
    65  	//
    66  	//   sha3hash(data)
    67  	//
    68  	tests = append(tests, &test{
    69  		from: 1,
    70  		to:   2,
    71  		expected: func(d []byte) []byte {
    72  			data := make([]byte, 64)
    73  			copy(data, d)
    74  			return sha3hash(data)
    75  		},
    76  	})
    77  
    78  	// all lengths in [3,4] should be:
    79  	//
    80  	//   sha3hash(
    81  	//     sha3hash(data[:64])
    82  	//     sha3hash(data[64:])
    83  	//   )
    84  	//
    85  	tests = append(tests, &test{
    86  		from: 3,
    87  		to:   4,
    88  		expected: func(d []byte) []byte {
    89  			data := make([]byte, 128)
    90  			copy(data, d)
    91  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
    92  		},
    93  	})
    94  
    95  	// all segmentCounts in [5,8] should be:
    96  	//
    97  	//   sha3hash(
    98  	//     sha3hash(
    99  	//       sha3hash(data[:64])
   100  	//       sha3hash(data[64:128])
   101  	//     )
   102  	//     sha3hash(
   103  	//       sha3hash(data[128:192])
   104  	//       sha3hash(data[192:])
   105  	//     )
   106  	//   )
   107  	//
   108  	tests = append(tests, &test{
   109  		from: 5,
   110  		to:   8,
   111  		expected: func(d []byte) []byte {
   112  			data := make([]byte, 256)
   113  			copy(data, d)
   114  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   115  		},
   116  	})
   117  
   118  	// run the tests
   119  	for i, x := range tests {
   120  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   121  			for length := 1; length <= segmentCount*32; length++ {
   122  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   123  					data := testutil.RandomBytes(i, length)
   124  					expected := x.expected(data)
   125  					actual := NewRefHasher(sha3.NewLegacyKeccak256, segmentCount).Hash(data)
   126  					if !bytes.Equal(actual, expected) {
   127  						t.Fatalf("expected %x, got %x", expected, actual)
   128  					}
   129  				})
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  // tests if hasher responds with correct hash comparing the reference implementation return value
   136  func TestHasherEmptyData(t *testing.T) {
   137  	hasher := sha3.NewLegacyKeccak256
   138  	var data []byte
   139  	for _, count := range counts {
   140  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   141  			pool := NewTreePool(hasher, count, PoolSize)
   142  			defer pool.Drain(0)
   143  			bmt := New(pool)
   144  			rbmt := NewRefHasher(hasher, count)
   145  			refHash := rbmt.Hash(data)
   146  			expHash := syncHash(bmt, nil, data)
   147  			if !bytes.Equal(expHash, refHash) {
   148  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   149  			}
   150  		})
   151  	}
   152  }
   153  
   154  // tests sequential write with entire max size written in one go
   155  func TestSyncHasherCorrectness(t *testing.T) {
   156  	data := testutil.RandomBytes(1, BufferSize)
   157  	hasher := sha3.NewLegacyKeccak256
   158  	size := hasher().Size()
   159  
   160  	var err error
   161  	for _, count := range counts {
   162  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
   163  			max := count * size
   164  			var incr int
   165  			capacity := 1
   166  			pool := NewTreePool(hasher, count, capacity)
   167  			defer pool.Drain(0)
   168  			for n := 0; n <= max; n += incr {
   169  				incr = 1 + rand.Intn(5)
   170  				bmt := New(pool)
   171  				err = testHasherCorrectness(bmt, hasher, data, n, count)
   172  				if err != nil {
   173  					t.Fatal(err)
   174  				}
   175  			}
   176  		})
   177  	}
   178  }
   179  
   180  // tests order-neutral concurrent writes with entire max size written in one go
   181  func TestAsyncCorrectness(t *testing.T) {
   182  	data := testutil.RandomBytes(1, BufferSize)
   183  	hasher := sha3.NewLegacyKeccak256
   184  	size := hasher().Size()
   185  	whs := []whenHash{first, last, random}
   186  
   187  	for _, double := range []bool{false, true} {
   188  		for _, wh := range whs {
   189  			for _, count := range counts {
   190  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
   191  					max := count * size
   192  					var incr int
   193  					capacity := 1
   194  					pool := NewTreePool(hasher, count, capacity)
   195  					defer pool.Drain(0)
   196  					for n := 1; n <= max; n += incr {
   197  						incr = 1 + rand.Intn(5)
   198  						bmt := New(pool)
   199  						d := data[:n]
   200  						rbmt := NewRefHasher(hasher, count)
   201  						exp := rbmt.Hash(d)
   202  						got := syncHash(bmt, nil, d)
   203  						if !bytes.Equal(got, exp) {
   204  							t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
   205  						}
   206  						sw := bmt.NewAsyncWriter(double)
   207  						got = asyncHashRandom(sw, nil, d, wh)
   208  						if !bytes.Equal(got, exp) {
   209  							t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
   210  						}
   211  					}
   212  				})
   213  			}
   214  		}
   215  	}
   216  }
   217  
   218  // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
   219  func TestHasherReuse(t *testing.T) {
   220  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   221  		testHasherReuse(1, t)
   222  	})
   223  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   224  		testHasherReuse(PoolSize, t)
   225  	})
   226  }
   227  
   228  // tests if bmt reuse is not corrupting result
   229  func testHasherReuse(poolsize int, t *testing.T) {
   230  	hasher := sha3.NewLegacyKeccak256
   231  	pool := NewTreePool(hasher, segmentCount, poolsize)
   232  	defer pool.Drain(0)
   233  	bmt := New(pool)
   234  
   235  	for i := 0; i < 100; i++ {
   236  		data := testutil.RandomBytes(1, BufferSize)
   237  		n := rand.Intn(bmt.Size())
   238  		err := testHasherCorrectness(bmt, hasher, data, n, segmentCount)
   239  		if err != nil {
   240  			t.Fatal(err)
   241  		}
   242  	}
   243  }
   244  
   245  // Tests if pool can be cleanly reused even in concurrent use by several hasher
   246  func TestBMTConcurrentUse(t *testing.T) {
   247  	hasher := sha3.NewLegacyKeccak256
   248  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   249  	defer pool.Drain(0)
   250  	cycles := 100
   251  	errc := make(chan error)
   252  
   253  	for i := 0; i < cycles; i++ {
   254  		go func() {
   255  			bmt := New(pool)
   256  			data := testutil.RandomBytes(1, BufferSize)
   257  			n := rand.Intn(bmt.Size())
   258  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   259  		}()
   260  	}
   261  LOOP:
   262  	for {
   263  		select {
   264  		case <-time.NewTimer(5 * time.Second).C:
   265  			t.Fatal("timed out")
   266  		case err := <-errc:
   267  			if err != nil {
   268  				t.Fatal(err)
   269  			}
   270  			cycles--
   271  			if cycles == 0 {
   272  				break LOOP
   273  			}
   274  		}
   275  	}
   276  }
   277  
   278  // Tests BMT Hasher io.Writer interface is working correctly
   279  // even multiple short random write buffers
   280  func TestBMTWriterBuffers(t *testing.T) {
   281  	hasher := sha3.NewLegacyKeccak256
   282  
   283  	for _, count := range counts {
   284  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   285  			errc := make(chan error)
   286  			pool := NewTreePool(hasher, count, PoolSize)
   287  			defer pool.Drain(0)
   288  			n := count * 32
   289  			bmt := New(pool)
   290  			data := testutil.RandomBytes(1, n)
   291  			rbmt := NewRefHasher(hasher, count)
   292  			refHash := rbmt.Hash(data)
   293  			expHash := syncHash(bmt, nil, data)
   294  			if !bytes.Equal(expHash, refHash) {
   295  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   296  			}
   297  			attempts := 10
   298  			f := func() error {
   299  				bmt := New(pool)
   300  				bmt.Reset()
   301  				var buflen int
   302  				for offset := 0; offset < n; offset += buflen {
   303  					buflen = rand.Intn(n-offset) + 1
   304  					read, err := bmt.Write(data[offset : offset+buflen])
   305  					if err != nil {
   306  						return err
   307  					}
   308  					if read != buflen {
   309  						return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
   310  					}
   311  				}
   312  				hash := bmt.Sum(nil)
   313  				if !bytes.Equal(hash, expHash) {
   314  					return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
   315  				}
   316  				return nil
   317  			}
   318  
   319  			for j := 0; j < attempts; j++ {
   320  				go func() {
   321  					errc <- f()
   322  				}()
   323  			}
   324  			timeout := time.NewTimer(2 * time.Second)
   325  			for {
   326  				select {
   327  				case err := <-errc:
   328  					if err != nil {
   329  						t.Fatal(err)
   330  					}
   331  					attempts--
   332  					if attempts == 0 {
   333  						return
   334  					}
   335  				case <-timeout.C:
   336  					t.Fatalf("timeout")
   337  				}
   338  			}
   339  		})
   340  	}
   341  }
   342  
   343  // helper function that compares reference and optimised implementations on
   344  // correctness
   345  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   346  	span := make([]byte, 8)
   347  	if len(d) < n {
   348  		n = len(d)
   349  	}
   350  	binary.BigEndian.PutUint64(span, uint64(n))
   351  	data := d[:n]
   352  	rbmt := NewRefHasher(hasher, count)
   353  	exp := sha3hash(span, rbmt.Hash(data))
   354  	got := syncHash(bmt, span, data)
   355  	if !bytes.Equal(got, exp) {
   356  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   357  	}
   358  	return err
   359  }
   360  
   361  //
   362  func BenchmarkBMT(t *testing.B) {
   363  	for size := 4096; size >= 128; size /= 2 {
   364  		t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
   365  			benchmarkSHA3(t, size)
   366  		})
   367  		t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
   368  			benchmarkBMTBaseline(t, size)
   369  		})
   370  		t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
   371  			benchmarkRefHasher(t, size)
   372  		})
   373  		t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
   374  			benchmarkBMT(t, size)
   375  		})
   376  	}
   377  }
   378  
   379  type whenHash = int
   380  
   381  const (
   382  	first whenHash = iota
   383  	last
   384  	random
   385  )
   386  
   387  func BenchmarkBMTAsync(t *testing.B) {
   388  	whs := []whenHash{first, last, random}
   389  	for size := 4096; size >= 128; size /= 2 {
   390  		for _, wh := range whs {
   391  			for _, double := range []bool{false, true} {
   392  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
   393  					benchmarkBMTAsync(t, size, wh, double)
   394  				})
   395  			}
   396  		}
   397  	}
   398  }
   399  
   400  func BenchmarkPool(t *testing.B) {
   401  	caps := []int{1, PoolSize}
   402  	for size := 4096; size >= 128; size /= 2 {
   403  		for _, c := range caps {
   404  			t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
   405  				benchmarkPool(t, c, size)
   406  			})
   407  		}
   408  	}
   409  }
   410  
   411  // benchmarks simple sha3 hash on chunks
   412  func benchmarkSHA3(t *testing.B, n int) {
   413  	data := testutil.RandomBytes(1, n)
   414  	hasher := sha3.NewLegacyKeccak256
   415  	h := hasher()
   416  
   417  	t.ReportAllocs()
   418  	t.ResetTimer()
   419  	for i := 0; i < t.N; i++ {
   420  		doSum(h, nil, data)
   421  	}
   422  }
   423  
   424  // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
   425  // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
   426  // doing it on n PoolSize each reusing the base hasher
   427  // the premise is that this is the minimum computation needed for a BMT
   428  // therefore this serves as a theoretical optimum for concurrent implementations
   429  func benchmarkBMTBaseline(t *testing.B, n int) {
   430  	hasher := sha3.NewLegacyKeccak256
   431  	hashSize := hasher().Size()
   432  	data := testutil.RandomBytes(1, hashSize)
   433  
   434  	t.ReportAllocs()
   435  	t.ResetTimer()
   436  	for i := 0; i < t.N; i++ {
   437  		count := int32((n-1)/hashSize + 1)
   438  		wg := sync.WaitGroup{}
   439  		wg.Add(PoolSize)
   440  		var i int32
   441  		for j := 0; j < PoolSize; j++ {
   442  			go func() {
   443  				defer wg.Done()
   444  				h := hasher()
   445  				for atomic.AddInt32(&i, 1) < count {
   446  					doSum(h, nil, data)
   447  				}
   448  			}()
   449  		}
   450  		wg.Wait()
   451  	}
   452  }
   453  
   454  // benchmarks BMT Hasher
   455  func benchmarkBMT(t *testing.B, n int) {
   456  	data := testutil.RandomBytes(1, n)
   457  	hasher := sha3.NewLegacyKeccak256
   458  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   459  	bmt := New(pool)
   460  
   461  	t.ReportAllocs()
   462  	t.ResetTimer()
   463  	for i := 0; i < t.N; i++ {
   464  		syncHash(bmt, nil, data)
   465  	}
   466  }
   467  
   468  // benchmarks BMT hasher with asynchronous concurrent segment/section writes
   469  func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
   470  	data := testutil.RandomBytes(1, n)
   471  	hasher := sha3.NewLegacyKeccak256
   472  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   473  	bmt := New(pool).NewAsyncWriter(double)
   474  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   475  	rand.Shuffle(len(idxs), func(i int, j int) {
   476  		idxs[i], idxs[j] = idxs[j], idxs[i]
   477  	})
   478  
   479  	t.ReportAllocs()
   480  	t.ResetTimer()
   481  	for i := 0; i < t.N; i++ {
   482  		asyncHash(bmt, nil, n, wh, idxs, segments)
   483  	}
   484  }
   485  
   486  // benchmarks 100 concurrent bmt hashes with pool capacity
   487  func benchmarkPool(t *testing.B, poolsize, n int) {
   488  	data := testutil.RandomBytes(1, n)
   489  	hasher := sha3.NewLegacyKeccak256
   490  	pool := NewTreePool(hasher, segmentCount, poolsize)
   491  	cycles := 100
   492  
   493  	t.ReportAllocs()
   494  	t.ResetTimer()
   495  	wg := sync.WaitGroup{}
   496  	for i := 0; i < t.N; i++ {
   497  		wg.Add(cycles)
   498  		for j := 0; j < cycles; j++ {
   499  			go func() {
   500  				defer wg.Done()
   501  				bmt := New(pool)
   502  				syncHash(bmt, nil, data)
   503  			}()
   504  		}
   505  		wg.Wait()
   506  	}
   507  }
   508  
   509  // benchmarks the reference hasher
   510  func benchmarkRefHasher(t *testing.B, n int) {
   511  	data := testutil.RandomBytes(1, n)
   512  	hasher := sha3.NewLegacyKeccak256
   513  	rbmt := NewRefHasher(hasher, 128)
   514  
   515  	t.ReportAllocs()
   516  	t.ResetTimer()
   517  	for i := 0; i < t.N; i++ {
   518  		rbmt.Hash(data)
   519  	}
   520  }
   521  
   522  // Hash hashes the data and the span using the bmt hasher
   523  func syncHash(h *Hasher, span, data []byte) []byte {
   524  	h.ResetWithLength(span)
   525  	h.Write(data)
   526  	return h.Sum(nil)
   527  }
   528  
   529  func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
   530  	l := len(data)
   531  	n := l / secsize
   532  	if l%secsize > 0 {
   533  		n++
   534  	}
   535  	for i := 0; i < n; i++ {
   536  		idxs = append(idxs, i)
   537  		end := (i + 1) * secsize
   538  		if end > l {
   539  			end = l
   540  		}
   541  		section := data[i*secsize : end]
   542  		segments = append(segments, section)
   543  	}
   544  	rand.Shuffle(n, func(i int, j int) {
   545  		idxs[i], idxs[j] = idxs[j], idxs[i]
   546  	})
   547  	return idxs, segments
   548  }
   549  
   550  // splits the input data performs a random shuffle to mock async section writes
   551  func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
   552  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   553  	return asyncHash(bmt, span, len(data), wh, idxs, segments)
   554  }
   555  
   556  // mock for async section writes for BMT SectionWriter
   557  // requires a permutation (a random shuffle) of list of all indexes of segments
   558  // and writes them in order to the appropriate section
   559  // the Sum function is called according to the wh parameter (first, last, random [relative to segment writes])
   560  func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
   561  	bmt.Reset()
   562  	if l == 0 {
   563  		return bmt.Sum(nil, l, span)
   564  	}
   565  	c := make(chan []byte, 1)
   566  	hashf := func() {
   567  		c <- bmt.Sum(nil, l, span)
   568  	}
   569  	maxsize := len(idxs)
   570  	var r int
   571  	if wh == random {
   572  		r = rand.Intn(maxsize)
   573  	}
   574  	for i, idx := range idxs {
   575  		bmt.Write(idx, segments[idx])
   576  		if (wh == first || wh == random) && i == r {
   577  			go hashf()
   578  		}
   579  	}
   580  	if wh == last {
   581  		return bmt.Sum(nil, l, span)
   582  	}
   583  	return <-c
   584  }