github.com/muhammedhassanm/blockchain@v0.0.0-20200120143007-697261defd4d/go-ethereum-master/swarm/bmt/bmt_test.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package bmt
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"encoding/binary"
    23  	"fmt"
    24  	"io"
    25  	"math/rand"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/ethereum/go-ethereum/crypto/sha3"
    32  )
    33  
    34  // the actual data length generated (could be longer than max datalength of the BMT)
    35  const BufferSize = 4128
    36  
    37  var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    38  
    39  // calculates the Keccak256 SHA3 hash of the data
    40  func sha3hash(data ...[]byte) []byte {
    41  	h := sha3.NewKeccak256()
    42  	return doHash(h, nil, data...)
    43  }
    44  
    45  // TestRefHasher tests that the RefHasher computes the expected BMT hash for
    46  // all data lengths between 0 and 256 bytes
    47  func TestRefHasher(t *testing.T) {
    48  
    49  	// the test struct is used to specify the expected BMT hash for
    50  	// segment counts between from and to and lengths from 1 to datalength
    51  	type test struct {
    52  		from     int
    53  		to       int
    54  		expected func([]byte) []byte
    55  	}
    56  
    57  	var tests []*test
    58  	// all lengths in [0,64] should be:
    59  	//
    60  	//   sha3hash(data)
    61  	//
    62  	tests = append(tests, &test{
    63  		from: 1,
    64  		to:   2,
    65  		expected: func(d []byte) []byte {
    66  			data := make([]byte, 64)
    67  			copy(data, d)
    68  			return sha3hash(data)
    69  		},
    70  	})
    71  
    72  	// all lengths in [3,4] should be:
    73  	//
    74  	//   sha3hash(
    75  	//     sha3hash(data[:64])
    76  	//     sha3hash(data[64:])
    77  	//   )
    78  	//
    79  	tests = append(tests, &test{
    80  		from: 3,
    81  		to:   4,
    82  		expected: func(d []byte) []byte {
    83  			data := make([]byte, 128)
    84  			copy(data, d)
    85  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
    86  		},
    87  	})
    88  
    89  	// all segmentCounts in [5,8] should be:
    90  	//
    91  	//   sha3hash(
    92  	//     sha3hash(
    93  	//       sha3hash(data[:64])
    94  	//       sha3hash(data[64:128])
    95  	//     )
    96  	//     sha3hash(
    97  	//       sha3hash(data[128:192])
    98  	//       sha3hash(data[192:])
    99  	//     )
   100  	//   )
   101  	//
   102  	tests = append(tests, &test{
   103  		from: 5,
   104  		to:   8,
   105  		expected: func(d []byte) []byte {
   106  			data := make([]byte, 256)
   107  			copy(data, d)
   108  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   109  		},
   110  	})
   111  
   112  	// run the tests
   113  	for _, x := range tests {
   114  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   115  			for length := 1; length <= segmentCount*32; length++ {
   116  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   117  					data := make([]byte, length)
   118  					if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   119  						t.Fatal(err)
   120  					}
   121  					expected := x.expected(data)
   122  					actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data)
   123  					if !bytes.Equal(actual, expected) {
   124  						t.Fatalf("expected %x, got %x", expected, actual)
   125  					}
   126  				})
   127  			}
   128  		}
   129  	}
   130  }
   131  
   132  // tests if hasher responds with correct hash
   133  func TestHasherEmptyData(t *testing.T) {
   134  	hasher := sha3.NewKeccak256
   135  	var data []byte
   136  	for _, count := range counts {
   137  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   138  			pool := NewTreePool(hasher, count, PoolSize)
   139  			defer pool.Drain(0)
   140  			bmt := New(pool)
   141  			rbmt := NewRefHasher(hasher, count)
   142  			refHash := rbmt.Hash(data)
   143  			expHash := Hash(bmt, nil, data)
   144  			if !bytes.Equal(expHash, refHash) {
   145  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   146  			}
   147  		})
   148  	}
   149  }
   150  
   151  func TestHasherCorrectness(t *testing.T) {
   152  	data := newData(BufferSize)
   153  	hasher := sha3.NewKeccak256
   154  	size := hasher().Size()
   155  
   156  	var err error
   157  	for _, count := range counts {
   158  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
   159  			max := count * size
   160  			incr := 1
   161  			capacity := 1
   162  			pool := NewTreePool(hasher, count, capacity)
   163  			defer pool.Drain(0)
   164  			for n := 0; n <= max; n += incr {
   165  				incr = 1 + rand.Intn(5)
   166  				bmt := New(pool)
   167  				err = testHasherCorrectness(bmt, hasher, data, n, count)
   168  				if err != nil {
   169  					t.Fatal(err)
   170  				}
   171  			}
   172  		})
   173  	}
   174  }
   175  
   176  // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
   177  func TestHasherReuse(t *testing.T) {
   178  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   179  		testHasherReuse(1, t)
   180  	})
   181  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   182  		testHasherReuse(PoolSize, t)
   183  	})
   184  }
   185  
   186  func testHasherReuse(poolsize int, t *testing.T) {
   187  	hasher := sha3.NewKeccak256
   188  	pool := NewTreePool(hasher, SegmentCount, poolsize)
   189  	defer pool.Drain(0)
   190  	bmt := New(pool)
   191  
   192  	for i := 0; i < 100; i++ {
   193  		data := newData(BufferSize)
   194  		n := rand.Intn(bmt.DataLength())
   195  		err := testHasherCorrectness(bmt, hasher, data, n, SegmentCount)
   196  		if err != nil {
   197  			t.Fatal(err)
   198  		}
   199  	}
   200  }
   201  
   202  // Tests if pool can be cleanly reused even in concurrent use
   203  func TestBMTHasherConcurrentUse(t *testing.T) {
   204  	hasher := sha3.NewKeccak256
   205  	pool := NewTreePool(hasher, SegmentCount, PoolSize)
   206  	defer pool.Drain(0)
   207  	cycles := 100
   208  	errc := make(chan error)
   209  
   210  	for i := 0; i < cycles; i++ {
   211  		go func() {
   212  			bmt := New(pool)
   213  			data := newData(BufferSize)
   214  			n := rand.Intn(bmt.DataLength())
   215  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   216  		}()
   217  	}
   218  LOOP:
   219  	for {
   220  		select {
   221  		case <-time.NewTimer(5 * time.Second).C:
   222  			t.Fatal("timed out")
   223  		case err := <-errc:
   224  			if err != nil {
   225  				t.Fatal(err)
   226  			}
   227  			cycles--
   228  			if cycles == 0 {
   229  				break LOOP
   230  			}
   231  		}
   232  	}
   233  }
   234  
   235  // Tests BMT Hasher io.Writer interface is working correctly
   236  // even multiple short random write buffers
   237  func TestBMTHasherWriterBuffers(t *testing.T) {
   238  	hasher := sha3.NewKeccak256
   239  
   240  	for _, count := range counts {
   241  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   242  			errc := make(chan error)
   243  			pool := NewTreePool(hasher, count, PoolSize)
   244  			defer pool.Drain(0)
   245  			n := count * 32
   246  			bmt := New(pool)
   247  			data := newData(n)
   248  			rbmt := NewRefHasher(hasher, count)
   249  			refHash := rbmt.Hash(data)
   250  			expHash := Hash(bmt, nil, data)
   251  			if !bytes.Equal(expHash, refHash) {
   252  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   253  			}
   254  			attempts := 10
   255  			f := func() error {
   256  				bmt := New(pool)
   257  				bmt.Reset()
   258  				var buflen int
   259  				for offset := 0; offset < n; offset += buflen {
   260  					buflen = rand.Intn(n-offset) + 1
   261  					read, err := bmt.Write(data[offset : offset+buflen])
   262  					if err != nil {
   263  						return err
   264  					}
   265  					if read != buflen {
   266  						return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
   267  					}
   268  				}
   269  				hash := bmt.Sum(nil)
   270  				if !bytes.Equal(hash, expHash) {
   271  					return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
   272  				}
   273  				return nil
   274  			}
   275  
   276  			for j := 0; j < attempts; j++ {
   277  				go func() {
   278  					errc <- f()
   279  				}()
   280  			}
   281  			timeout := time.NewTimer(2 * time.Second)
   282  			for {
   283  				select {
   284  				case err := <-errc:
   285  					if err != nil {
   286  						t.Fatal(err)
   287  					}
   288  					attempts--
   289  					if attempts == 0 {
   290  						return
   291  					}
   292  				case <-timeout.C:
   293  					t.Fatalf("timeout")
   294  				}
   295  			}
   296  		})
   297  	}
   298  }
   299  
   300  // helper function that compares reference and optimised implementations on
   301  // correctness
   302  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   303  	span := make([]byte, 8)
   304  	if len(d) < n {
   305  		n = len(d)
   306  	}
   307  	binary.BigEndian.PutUint64(span, uint64(n))
   308  	data := d[:n]
   309  	rbmt := NewRefHasher(hasher, count)
   310  	exp := sha3hash(span, rbmt.Hash(data))
   311  	got := Hash(bmt, span, data)
   312  	if !bytes.Equal(got, exp) {
   313  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   314  	}
   315  	return err
   316  }
   317  
   318  func BenchmarkSHA3_4k(t *testing.B)   { benchmarkSHA3(4096, t) }
   319  func BenchmarkSHA3_2k(t *testing.B)   { benchmarkSHA3(4096/2, t) }
   320  func BenchmarkSHA3_1k(t *testing.B)   { benchmarkSHA3(4096/4, t) }
   321  func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) }
   322  func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) }
   323  func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) }
   324  
   325  func BenchmarkBMTBaseline_4k(t *testing.B)   { benchmarkBMTBaseline(4096, t) }
   326  func BenchmarkBMTBaseline_2k(t *testing.B)   { benchmarkBMTBaseline(4096/2, t) }
   327  func BenchmarkBMTBaseline_1k(t *testing.B)   { benchmarkBMTBaseline(4096/4, t) }
   328  func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) }
   329  func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) }
   330  func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) }
   331  
   332  func BenchmarkRefHasher_4k(t *testing.B)   { benchmarkRefHasher(4096, t) }
   333  func BenchmarkRefHasher_2k(t *testing.B)   { benchmarkRefHasher(4096/2, t) }
   334  func BenchmarkRefHasher_1k(t *testing.B)   { benchmarkRefHasher(4096/4, t) }
   335  func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) }
   336  func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) }
   337  func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) }
   338  
   339  func BenchmarkBMTHasher_4k(t *testing.B)   { benchmarkBMTHasher(4096, t) }
   340  func BenchmarkBMTHasher_2k(t *testing.B)   { benchmarkBMTHasher(4096/2, t) }
   341  func BenchmarkBMTHasher_1k(t *testing.B)   { benchmarkBMTHasher(4096/4, t) }
   342  func BenchmarkBMTHasher_512b(t *testing.B) { benchmarkBMTHasher(4096/8, t) }
   343  func BenchmarkBMTHasher_256b(t *testing.B) { benchmarkBMTHasher(4096/16, t) }
   344  func BenchmarkBMTHasher_128b(t *testing.B) { benchmarkBMTHasher(4096/32, t) }
   345  
   346  func BenchmarkBMTHasherNoPool_4k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096, t) }
   347  func BenchmarkBMTHasherNoPool_2k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096/2, t) }
   348  func BenchmarkBMTHasherNoPool_1k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096/4, t) }
   349  func BenchmarkBMTHasherNoPool_512b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/8, t) }
   350  func BenchmarkBMTHasherNoPool_256b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/16, t) }
   351  func BenchmarkBMTHasherNoPool_128b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/32, t) }
   352  
   353  func BenchmarkBMTHasherPool_4k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096, t) }
   354  func BenchmarkBMTHasherPool_2k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096/2, t) }
   355  func BenchmarkBMTHasherPool_1k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096/4, t) }
   356  func BenchmarkBMTHasherPool_512b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/8, t) }
   357  func BenchmarkBMTHasherPool_256b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/16, t) }
   358  func BenchmarkBMTHasherPool_128b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/32, t) }
   359  
   360  // benchmarks simple sha3 hash on chunks
   361  func benchmarkSHA3(n int, t *testing.B) {
   362  	data := newData(n)
   363  	hasher := sha3.NewKeccak256
   364  	h := hasher()
   365  
   366  	t.ReportAllocs()
   367  	t.ResetTimer()
   368  	for i := 0; i < t.N; i++ {
   369  		h.Reset()
   370  		h.Write(data)
   371  		h.Sum(nil)
   372  	}
   373  }
   374  
   375  // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
   376  // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
   377  // doing it on n PoolSize each reusing the base hasher
   378  // the premise is that this is the minimum computation needed for a BMT
   379  // therefore this serves as a theoretical optimum for concurrent implementations
   380  func benchmarkBMTBaseline(n int, t *testing.B) {
   381  	hasher := sha3.NewKeccak256
   382  	hashSize := hasher().Size()
   383  	data := newData(hashSize)
   384  
   385  	t.ReportAllocs()
   386  	t.ResetTimer()
   387  	for i := 0; i < t.N; i++ {
   388  		count := int32((n-1)/hashSize + 1)
   389  		wg := sync.WaitGroup{}
   390  		wg.Add(PoolSize)
   391  		var i int32
   392  		for j := 0; j < PoolSize; j++ {
   393  			go func() {
   394  				defer wg.Done()
   395  				h := hasher()
   396  				for atomic.AddInt32(&i, 1) < count {
   397  					h.Reset()
   398  					h.Write(data)
   399  					h.Sum(nil)
   400  				}
   401  			}()
   402  		}
   403  		wg.Wait()
   404  	}
   405  }
   406  
   407  // benchmarks BMT Hasher
   408  func benchmarkBMTHasher(n int, t *testing.B) {
   409  	data := newData(n)
   410  	hasher := sha3.NewKeccak256
   411  	pool := NewTreePool(hasher, SegmentCount, PoolSize)
   412  
   413  	t.ReportAllocs()
   414  	t.ResetTimer()
   415  	for i := 0; i < t.N; i++ {
   416  		bmt := New(pool)
   417  		Hash(bmt, nil, data)
   418  	}
   419  }
   420  
   421  // benchmarks 100 concurrent bmt hashes with pool capacity
   422  func benchmarkBMTHasherPool(poolsize, n int, t *testing.B) {
   423  	data := newData(n)
   424  	hasher := sha3.NewKeccak256
   425  	pool := NewTreePool(hasher, SegmentCount, poolsize)
   426  	cycles := 100
   427  
   428  	t.ReportAllocs()
   429  	t.ResetTimer()
   430  	wg := sync.WaitGroup{}
   431  	for i := 0; i < t.N; i++ {
   432  		wg.Add(cycles)
   433  		for j := 0; j < cycles; j++ {
   434  			go func() {
   435  				defer wg.Done()
   436  				bmt := New(pool)
   437  				Hash(bmt, nil, data)
   438  			}()
   439  		}
   440  		wg.Wait()
   441  	}
   442  }
   443  
   444  // benchmarks the reference hasher
   445  func benchmarkRefHasher(n int, t *testing.B) {
   446  	data := newData(n)
   447  	hasher := sha3.NewKeccak256
   448  	rbmt := NewRefHasher(hasher, 128)
   449  
   450  	t.ReportAllocs()
   451  	t.ResetTimer()
   452  	for i := 0; i < t.N; i++ {
   453  		rbmt.Hash(data)
   454  	}
   455  }
   456  
   457  func newData(bufferSize int) []byte {
   458  	data := make([]byte, bufferSize)
   459  	_, err := io.ReadFull(crand.Reader, data)
   460  	if err != nil {
   461  		panic(err.Error())
   462  	}
   463  	return data
   464  }