github.com/NlaakStudiosLLC/gwfchain/v3@v3.0.0-20210902130704-413c69345317/bmt/bmt_test.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package bmt
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"fmt"
    23  	"hash"
    24  	"io"
    25  	"math/rand"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"golang.org/x/crypto/sha3"
    32  )
    33  
    34  const (
    35  	maxproccnt = 8
    36  )
    37  
    38  // TestRefHasher tests that the RefHasher computes the expected BMT hash for
    39  // all data lengths between 0 and 256 bytes
    40  func TestRefHasher(t *testing.T) {
    41  	hashFunc := sha3.NewLegacyKeccak256
    42  
    43  	sha3 := func(data ...[]byte) []byte {
    44  		h := hashFunc()
    45  		for _, v := range data {
    46  			h.Write(v)
    47  		}
    48  		return h.Sum(nil)
    49  	}
    50  
    51  	// the test struct is used to specify the expected BMT hash for data
    52  	// lengths between "from" and "to"
    53  	type test struct {
    54  		from     int64
    55  		to       int64
    56  		expected func([]byte) []byte
    57  	}
    58  
    59  	var tests []*test
    60  
    61  	// all lengths in [0,64] should be:
    62  	//
    63  	//   sha3(data)
    64  	//
    65  	tests = append(tests, &test{
    66  		from: 0,
    67  		to:   64,
    68  		expected: func(data []byte) []byte {
    69  			return sha3(data)
    70  		},
    71  	})
    72  
    73  	// all lengths in [65,96] should be:
    74  	//
    75  	//   sha3(
    76  	//     sha3(data[:64])
    77  	//     data[64:]
    78  	//   )
    79  	//
    80  	tests = append(tests, &test{
    81  		from: 65,
    82  		to:   96,
    83  		expected: func(data []byte) []byte {
    84  			return sha3(sha3(data[:64]), data[64:])
    85  		},
    86  	})
    87  
    88  	// all lengths in [97,128] should be:
    89  	//
    90  	//   sha3(
    91  	//     sha3(data[:64])
    92  	//     sha3(data[64:])
    93  	//   )
    94  	//
    95  	tests = append(tests, &test{
    96  		from: 97,
    97  		to:   128,
    98  		expected: func(data []byte) []byte {
    99  			return sha3(sha3(data[:64]), sha3(data[64:]))
   100  		},
   101  	})
   102  
   103  	// all lengths in [129,160] should be:
   104  	//
   105  	//   sha3(
   106  	//     sha3(
   107  	//       sha3(data[:64])
   108  	//       sha3(data[64:128])
   109  	//     )
   110  	//     data[128:]
   111  	//   )
   112  	//
   113  	tests = append(tests, &test{
   114  		from: 129,
   115  		to:   160,
   116  		expected: func(data []byte) []byte {
   117  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), data[128:])
   118  		},
   119  	})
   120  
   121  	// all lengths in [161,192] should be:
   122  	//
   123  	//   sha3(
   124  	//     sha3(
   125  	//       sha3(data[:64])
   126  	//       sha3(data[64:128])
   127  	//     )
   128  	//     sha3(data[128:])
   129  	//   )
   130  	//
   131  	tests = append(tests, &test{
   132  		from: 161,
   133  		to:   192,
   134  		expected: func(data []byte) []byte {
   135  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(data[128:]))
   136  		},
   137  	})
   138  
   139  	// all lengths in [193,224] should be:
   140  	//
   141  	//   sha3(
   142  	//     sha3(
   143  	//       sha3(data[:64])
   144  	//       sha3(data[64:128])
   145  	//     )
   146  	//     sha3(
   147  	//       sha3(data[128:192])
   148  	//       data[192:]
   149  	//     )
   150  	//   )
   151  	//
   152  	tests = append(tests, &test{
   153  		from: 193,
   154  		to:   224,
   155  		expected: func(data []byte) []byte {
   156  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), data[192:]))
   157  		},
   158  	})
   159  
   160  	// all lengths in [225,256] should be:
   161  	//
   162  	//   sha3(
   163  	//     sha3(
   164  	//       sha3(data[:64])
   165  	//       sha3(data[64:128])
   166  	//     )
   167  	//     sha3(
   168  	//       sha3(data[128:192])
   169  	//       sha3(data[192:])
   170  	//     )
   171  	//   )
   172  	//
   173  	tests = append(tests, &test{
   174  		from: 225,
   175  		to:   256,
   176  		expected: func(data []byte) []byte {
   177  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), sha3(data[192:])))
   178  		},
   179  	})
   180  
   181  	// run the tests
   182  	for _, x := range tests {
   183  		for length := x.from; length <= x.to; length++ {
   184  			t.Run(fmt.Sprintf("%d_bytes", length), func(t *testing.T) {
   185  				data := make([]byte, length)
   186  				if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   187  					t.Fatal(err)
   188  				}
   189  				expected := x.expected(data)
   190  				actual := NewRefHasher(hashFunc, 128).Hash(data)
   191  				if !bytes.Equal(actual, expected) {
   192  					t.Fatalf("expected %x, got %x", expected, actual)
   193  				}
   194  			})
   195  		}
   196  	}
   197  }
   198  
   199  func testDataReader(l int) (r io.Reader) {
   200  	return io.LimitReader(crand.Reader, int64(l))
   201  }
   202  
   203  func TestHasherCorrectness(t *testing.T) {
   204  	err := testHasher(testBaseHasher)
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  }
   209  
   210  func testHasher(f func(BaseHasher, []byte, int, int) error) error {
   211  	tdata := testDataReader(4128)
   212  	data := make([]byte, 4128)
   213  	tdata.Read(data)
   214  	hasher := sha3.NewLegacyKeccak256
   215  	size := hasher().Size()
   216  	counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128}
   217  
   218  	var err error
   219  	for _, count := range counts {
   220  		max := count * size
   221  		incr := 1
   222  		for n := 0; n <= max+incr; n += incr {
   223  			err = f(hasher, data, n, count)
   224  			if err != nil {
   225  				return err
   226  			}
   227  		}
   228  	}
   229  	return nil
   230  }
   231  
   232  func TestHasherReuseWithoutRelease(t *testing.T) {
   233  	testHasherReuse(1, t)
   234  }
   235  
   236  func TestHasherReuseWithRelease(t *testing.T) {
   237  	testHasherReuse(maxproccnt, t)
   238  }
   239  
   240  func testHasherReuse(i int, t *testing.T) {
   241  	hasher := sha3.NewLegacyKeccak256
   242  	pool := NewTreePool(hasher, 128, i)
   243  	defer pool.Drain(0)
   244  	bmt := New(pool)
   245  
   246  	for i := 0; i < 500; i++ {
   247  		n := rand.Intn(4096)
   248  		tdata := testDataReader(n)
   249  		data := make([]byte, n)
   250  		tdata.Read(data)
   251  
   252  		err := testHasherCorrectness(bmt, hasher, data, n, 128)
   253  		if err != nil {
   254  			t.Fatal(err)
   255  		}
   256  	}
   257  }
   258  
   259  func TestHasherConcurrency(t *testing.T) {
   260  	hasher := sha3.NewLegacyKeccak256
   261  	pool := NewTreePool(hasher, 128, maxproccnt)
   262  	defer pool.Drain(0)
   263  	wg := sync.WaitGroup{}
   264  	cycles := 100
   265  	n := maxproccnt * cycles
   266  	wg.Add(n)
   267  	errc := make(chan error, n)
   268  
   269  	for p := 0; p < maxproccnt; p++ {
   270  		for i := 0; i < cycles; i++ {
   271  			go func() {
   272  				defer wg.Done()
   273  				bmt := New(pool)
   274  				n := rand.Intn(4096)
   275  				tdata := testDataReader(n)
   276  				data := make([]byte, n)
   277  				tdata.Read(data)
   278  				err := testHasherCorrectness(bmt, hasher, data, n, 128)
   279  				if err != nil {
   280  					errc <- err
   281  				}
   282  			}()
   283  		}
   284  	}
   285  	go func() {
   286  		wg.Wait()
   287  		close(errc)
   288  	}()
   289  	for err := range errc {
   290  		t.Error(err)
   291  	}
   292  }
   293  
   294  func testBaseHasher(hasher BaseHasher, d []byte, n, count int) error {
   295  	pool := NewTreePool(hasher, count, 1)
   296  	defer pool.Drain(0)
   297  	bmt := New(pool)
   298  	return testHasherCorrectness(bmt, hasher, d, n, count)
   299  }
   300  
   301  func testHasherCorrectness(bmt hash.Hash, hasher BaseHasher, d []byte, n, count int) (err error) {
   302  	data := d[:n]
   303  	rbmt := NewRefHasher(hasher, count)
   304  	exp := rbmt.Hash(data)
   305  	c := make(chan error)
   306  
   307  	go func() {
   308  		defer close(c)
   309  		bmt.Reset()
   310  		bmt.Write(data)
   311  		got := bmt.Sum(nil)
   312  		if !bytes.Equal(got, exp) {
   313  			c <- fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   314  		}
   315  	}()
   316  	select {
   317  	case <-time.After(20 * time.Second):
   318  		err = fmt.Errorf("BMT hash calculation timed out")
   319  	case err = <-c:
   320  	}
   321  	return err
   322  }
   323  
   324  func BenchmarkSHA3_4k(t *testing.B)   { benchmarkSHA3(4096, t) }
   325  func BenchmarkSHA3_2k(t *testing.B)   { benchmarkSHA3(4096/2, t) }
   326  func BenchmarkSHA3_1k(t *testing.B)   { benchmarkSHA3(4096/4, t) }
   327  func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) }
   328  func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) }
   329  func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) }
   330  
   331  func BenchmarkBMTBaseline_4k(t *testing.B)   { benchmarkBMTBaseline(4096, t) }
   332  func BenchmarkBMTBaseline_2k(t *testing.B)   { benchmarkBMTBaseline(4096/2, t) }
   333  func BenchmarkBMTBaseline_1k(t *testing.B)   { benchmarkBMTBaseline(4096/4, t) }
   334  func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) }
   335  func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) }
   336  func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) }
   337  
   338  func BenchmarkRefHasher_4k(t *testing.B)   { benchmarkRefHasher(4096, t) }
   339  func BenchmarkRefHasher_2k(t *testing.B)   { benchmarkRefHasher(4096/2, t) }
   340  func BenchmarkRefHasher_1k(t *testing.B)   { benchmarkRefHasher(4096/4, t) }
   341  func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) }
   342  func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) }
   343  func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) }
   344  
   345  func BenchmarkHasher_4k(t *testing.B)   { benchmarkHasher(4096, t) }
   346  func BenchmarkHasher_2k(t *testing.B)   { benchmarkHasher(4096/2, t) }
   347  func BenchmarkHasher_1k(t *testing.B)   { benchmarkHasher(4096/4, t) }
   348  func BenchmarkHasher_512b(t *testing.B) { benchmarkHasher(4096/8, t) }
   349  func BenchmarkHasher_256b(t *testing.B) { benchmarkHasher(4096/16, t) }
   350  func BenchmarkHasher_128b(t *testing.B) { benchmarkHasher(4096/32, t) }
   351  
   352  func BenchmarkHasherNoReuse_4k(t *testing.B)   { benchmarkHasherReuse(1, 4096, t) }
   353  func BenchmarkHasherNoReuse_2k(t *testing.B)   { benchmarkHasherReuse(1, 4096/2, t) }
   354  func BenchmarkHasherNoReuse_1k(t *testing.B)   { benchmarkHasherReuse(1, 4096/4, t) }
   355  func BenchmarkHasherNoReuse_512b(t *testing.B) { benchmarkHasherReuse(1, 4096/8, t) }
   356  func BenchmarkHasherNoReuse_256b(t *testing.B) { benchmarkHasherReuse(1, 4096/16, t) }
   357  func BenchmarkHasherNoReuse_128b(t *testing.B) { benchmarkHasherReuse(1, 4096/32, t) }
   358  
   359  func BenchmarkHasherReuse_4k(t *testing.B)   { benchmarkHasherReuse(16, 4096, t) }
   360  func BenchmarkHasherReuse_2k(t *testing.B)   { benchmarkHasherReuse(16, 4096/2, t) }
   361  func BenchmarkHasherReuse_1k(t *testing.B)   { benchmarkHasherReuse(16, 4096/4, t) }
   362  func BenchmarkHasherReuse_512b(t *testing.B) { benchmarkHasherReuse(16, 4096/8, t) }
   363  func BenchmarkHasherReuse_256b(t *testing.B) { benchmarkHasherReuse(16, 4096/16, t) }
   364  func BenchmarkHasherReuse_128b(t *testing.B) { benchmarkHasherReuse(16, 4096/32, t) }
   365  
   366  // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
   367  // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
   368  // doing it on n maxproccnt each reusing the base hasher
   369  // the premise is that this is the minimum computation needed for a BMT
   370  // therefore this serves as a theoretical optimum for concurrent implementations
   371  func benchmarkBMTBaseline(n int, t *testing.B) {
   372  	tdata := testDataReader(64)
   373  	data := make([]byte, 64)
   374  	tdata.Read(data)
   375  	hasher := sha3.NewLegacyKeccak256
   376  
   377  	t.ReportAllocs()
   378  	t.ResetTimer()
   379  	for i := 0; i < t.N; i++ {
   380  		count := int32((n-1)/hasher().Size() + 1)
   381  		wg := sync.WaitGroup{}
   382  		wg.Add(maxproccnt)
   383  		var i int32
   384  		for j := 0; j < maxproccnt; j++ {
   385  			go func() {
   386  				defer wg.Done()
   387  				h := hasher()
   388  				for atomic.AddInt32(&i, 1) < count {
   389  					h.Reset()
   390  					h.Write(data)
   391  					h.Sum(nil)
   392  				}
   393  			}()
   394  		}
   395  		wg.Wait()
   396  	}
   397  }
   398  
   399  func benchmarkHasher(n int, t *testing.B) {
   400  	tdata := testDataReader(n)
   401  	data := make([]byte, n)
   402  	tdata.Read(data)
   403  
   404  	size := 1
   405  	hasher := sha3.NewLegacyKeccak256
   406  	segmentCount := 128
   407  	pool := NewTreePool(hasher, segmentCount, size)
   408  	bmt := New(pool)
   409  
   410  	t.ReportAllocs()
   411  	t.ResetTimer()
   412  	for i := 0; i < t.N; i++ {
   413  		bmt.Reset()
   414  		bmt.Write(data)
   415  		bmt.Sum(nil)
   416  	}
   417  }
   418  
   419  func benchmarkHasherReuse(poolsize, n int, t *testing.B) {
   420  	tdata := testDataReader(n)
   421  	data := make([]byte, n)
   422  	tdata.Read(data)
   423  
   424  	hasher := sha3.NewLegacyKeccak256
   425  	segmentCount := 128
   426  	pool := NewTreePool(hasher, segmentCount, poolsize)
   427  	cycles := 200
   428  
   429  	t.ReportAllocs()
   430  	t.ResetTimer()
   431  	for i := 0; i < t.N; i++ {
   432  		wg := sync.WaitGroup{}
   433  		wg.Add(cycles)
   434  		for j := 0; j < cycles; j++ {
   435  			bmt := New(pool)
   436  			go func() {
   437  				defer wg.Done()
   438  				bmt.Reset()
   439  				bmt.Write(data)
   440  				bmt.Sum(nil)
   441  			}()
   442  		}
   443  		wg.Wait()
   444  	}
   445  }
   446  
   447  func benchmarkSHA3(n int, t *testing.B) {
   448  	data := make([]byte, n)
   449  	tdata := testDataReader(n)
   450  	tdata.Read(data)
   451  	hasher := sha3.NewLegacyKeccak256
   452  	h := hasher()
   453  
   454  	t.ReportAllocs()
   455  	t.ResetTimer()
   456  	for i := 0; i < t.N; i++ {
   457  		h.Reset()
   458  		h.Write(data)
   459  		h.Sum(nil)
   460  	}
   461  }
   462  
   463  func benchmarkRefHasher(n int, t *testing.B) {
   464  	data := make([]byte, n)
   465  	tdata := testDataReader(n)
   466  	tdata.Read(data)
   467  	hasher := sha3.NewLegacyKeccak256
   468  	rbmt := NewRefHasher(hasher, 128)
   469  
   470  	t.ReportAllocs()
   471  	t.ResetTimer()
   472  	for i := 0; i < t.N; i++ {
   473  		rbmt.Hash(data)
   474  	}
   475  }