github.com/aquanetwork/aquachain@v1.7.8/opt/bmt/bmt_test.go (about)

     1  // Copyright 2017 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package bmt
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"fmt"
    23  	"hash"
    24  	"io"
    25  	"math/rand"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"gitlab.com/aquachain/aquachain/crypto/sha3"
    32  )
    33  
    34  const (
    35  	maxproccnt = 8
    36  )
    37  
    38  // TestRefHasher tests that the RefHasher computes the expected BMT hash for
    39  // all data lengths between 0 and 256 bytes
    40  func TestRefHasher(t *testing.T) {
    41  	hashFunc := sha3.NewKeccak256
    42  
    43  	sha3 := func(data ...[]byte) []byte {
    44  		h := hashFunc()
    45  		for _, v := range data {
    46  			h.Write(v)
    47  		}
    48  		return h.Sum(nil)
    49  	}
    50  
    51  	// the test struct is used to specify the expected BMT hash for data
    52  	// lengths between "from" and "to"
    53  	type test struct {
    54  		from     int64
    55  		to       int64
    56  		expected func([]byte) []byte
    57  	}
    58  
    59  	var tests []*test
    60  
    61  	// all lengths in [0,64] should be:
    62  	//
    63  	//   sha3(data)
    64  	//
    65  	tests = append(tests, &test{
    66  		from: 0,
    67  		to:   64,
    68  		expected: func(data []byte) []byte {
    69  			return sha3(data)
    70  		},
    71  	})
    72  
    73  	// all lengths in [65,96] should be:
    74  	//
    75  	//   sha3(
    76  	//     sha3(data[:64])
    77  	//     data[64:]
    78  	//   )
    79  	//
    80  	tests = append(tests, &test{
    81  		from: 65,
    82  		to:   96,
    83  		expected: func(data []byte) []byte {
    84  			return sha3(sha3(data[:64]), data[64:])
    85  		},
    86  	})
    87  
    88  	// all lengths in [97,128] should be:
    89  	//
    90  	//   sha3(
    91  	//     sha3(data[:64])
    92  	//     sha3(data[64:])
    93  	//   )
    94  	//
    95  	tests = append(tests, &test{
    96  		from: 97,
    97  		to:   128,
    98  		expected: func(data []byte) []byte {
    99  			return sha3(sha3(data[:64]), sha3(data[64:]))
   100  		},
   101  	})
   102  
   103  	// all lengths in [129,160] should be:
   104  	//
   105  	//   sha3(
   106  	//     sha3(
   107  	//       sha3(data[:64])
   108  	//       sha3(data[64:128])
   109  	//     )
   110  	//     data[128:]
   111  	//   )
   112  	//
   113  	tests = append(tests, &test{
   114  		from: 129,
   115  		to:   160,
   116  		expected: func(data []byte) []byte {
   117  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), data[128:])
   118  		},
   119  	})
   120  
   121  	// all lengths in [161,192] should be:
   122  	//
   123  	//   sha3(
   124  	//     sha3(
   125  	//       sha3(data[:64])
   126  	//       sha3(data[64:128])
   127  	//     )
   128  	//     sha3(data[128:])
   129  	//   )
   130  	//
   131  	tests = append(tests, &test{
   132  		from: 161,
   133  		to:   192,
   134  		expected: func(data []byte) []byte {
   135  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(data[128:]))
   136  		},
   137  	})
   138  
   139  	// all lengths in [193,224] should be:
   140  	//
   141  	//   sha3(
   142  	//     sha3(
   143  	//       sha3(data[:64])
   144  	//       sha3(data[64:128])
   145  	//     )
   146  	//     sha3(
   147  	//       sha3(data[128:192])
   148  	//       data[192:]
   149  	//     )
   150  	//   )
   151  	//
   152  	tests = append(tests, &test{
   153  		from: 193,
   154  		to:   224,
   155  		expected: func(data []byte) []byte {
   156  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), data[192:]))
   157  		},
   158  	})
   159  
   160  	// all lengths in [225,256] should be:
   161  	//
   162  	//   sha3(
   163  	//     sha3(
   164  	//       sha3(data[:64])
   165  	//       sha3(data[64:128])
   166  	//     )
   167  	//     sha3(
   168  	//       sha3(data[128:192])
   169  	//       sha3(data[192:])
   170  	//     )
   171  	//   )
   172  	//
   173  	tests = append(tests, &test{
   174  		from: 225,
   175  		to:   256,
   176  		expected: func(data []byte) []byte {
   177  			return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), sha3(data[192:])))
   178  		},
   179  	})
   180  
   181  	// run the tests
   182  	for _, x := range tests {
   183  		for length := x.from; length <= x.to; length++ {
   184  			t.Run(fmt.Sprintf("%d_bytes", length), func(t *testing.T) {
   185  				data := make([]byte, length)
   186  				if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   187  					t.Fatal(err)
   188  				}
   189  				expected := x.expected(data)
   190  				actual := NewRefHasher(hashFunc, 128).Hash(data)
   191  				if !bytes.Equal(actual, expected) {
   192  					t.Fatalf("expected %x, got %x", expected, actual)
   193  				}
   194  			})
   195  		}
   196  	}
   197  }
   198  
   199  func testDataReader(l int) (r io.Reader) {
   200  	return io.LimitReader(crand.Reader, int64(l))
   201  }
   202  
   203  func TestHasherCorrectness(t *testing.T) {
   204  	err := testHasher(testBaseHasher)
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  }
   209  
   210  func testHasher(f func(BaseHasher, []byte, int, int) error) error {
   211  	tdata := testDataReader(4128)
   212  	data := make([]byte, 4128)
   213  	tdata.Read(data)
   214  	hasher := sha3.NewKeccak256
   215  	size := hasher().Size()
   216  	counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128}
   217  
   218  	var err error
   219  	for _, count := range counts {
   220  		max := count * size
   221  		incr := 1
   222  		for n := 0; n <= max+incr; n += incr {
   223  			err = f(hasher, data, n, count)
   224  			if err != nil {
   225  				return err
   226  			}
   227  		}
   228  	}
   229  	return nil
   230  }
   231  
   232  func TestHasherReuseWithoutRelease(t *testing.T) {
   233  	testHasherReuse(1, t)
   234  }
   235  
   236  func TestHasherReuseWithRelease(t *testing.T) {
   237  	testHasherReuse(maxproccnt, t)
   238  }
   239  
   240  func testHasherReuse(i int, t *testing.T) {
   241  	hasher := sha3.NewKeccak256
   242  	pool := NewTreePool(hasher, 128, i)
   243  	defer pool.Drain(0)
   244  	bmt := New(pool)
   245  
   246  	for i := 0; i < 500; i++ {
   247  		n := rand.Intn(4096)
   248  		tdata := testDataReader(n)
   249  		data := make([]byte, n)
   250  		tdata.Read(data)
   251  
   252  		err := testHasherCorrectness(bmt, hasher, data, n, 128)
   253  		if err != nil {
   254  			t.Fatal(err)
   255  		}
   256  	}
   257  }
   258  
   259  func TestHasherConcurrency(t *testing.T) {
   260  	hasher := sha3.NewKeccak256
   261  	pool := NewTreePool(hasher, 128, maxproccnt)
   262  	defer pool.Drain(0)
   263  	wg := sync.WaitGroup{}
   264  	cycles := 100
   265  	wg.Add(maxproccnt * cycles)
   266  	errc := make(chan error)
   267  
   268  	for p := 0; p < maxproccnt; p++ {
   269  		for i := 0; i < cycles; i++ {
   270  			go func() {
   271  				bmt := New(pool)
   272  				n := rand.Intn(4096)
   273  				tdata := testDataReader(n)
   274  				data := make([]byte, n)
   275  				tdata.Read(data)
   276  				err := testHasherCorrectness(bmt, hasher, data, n, 128)
   277  				wg.Done()
   278  				if err != nil {
   279  					errc <- err
   280  				}
   281  			}()
   282  		}
   283  	}
   284  	go func() {
   285  		wg.Wait()
   286  		close(errc)
   287  	}()
   288  	var err error
   289  	select {
   290  	case <-time.NewTimer(5 * time.Second).C:
   291  		err = fmt.Errorf("timed out")
   292  	case err = <-errc:
   293  	}
   294  	if err != nil {
   295  		t.Fatal(err)
   296  	}
   297  }
   298  
   299  func testBaseHasher(hasher BaseHasher, d []byte, n, count int) error {
   300  	pool := NewTreePool(hasher, count, 1)
   301  	defer pool.Drain(0)
   302  	bmt := New(pool)
   303  	return testHasherCorrectness(bmt, hasher, d, n, count)
   304  }
   305  
   306  func testHasherCorrectness(bmt hash.Hash, hasher BaseHasher, d []byte, n, count int) (err error) {
   307  	data := d[:n]
   308  	rbmt := NewRefHasher(hasher, count)
   309  	exp := rbmt.Hash(data)
   310  	timeout := time.NewTimer(time.Second)
   311  	c := make(chan error)
   312  
   313  	go func() {
   314  		bmt.Reset()
   315  		bmt.Write(data)
   316  		got := bmt.Sum(nil)
   317  		if !bytes.Equal(got, exp) {
   318  			c <- fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   319  		}
   320  		close(c)
   321  	}()
   322  	select {
   323  	case <-timeout.C:
   324  		err = fmt.Errorf("BMT hash calculation timed out")
   325  	case err = <-c:
   326  	}
   327  	return err
   328  }
   329  
   330  func BenchmarkSHA3_4k(t *testing.B)   { benchmarkSHA3(4096, t) }
   331  func BenchmarkSHA3_2k(t *testing.B)   { benchmarkSHA3(4096/2, t) }
   332  func BenchmarkSHA3_1k(t *testing.B)   { benchmarkSHA3(4096/4, t) }
   333  func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) }
   334  func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) }
   335  func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) }
   336  
   337  func BenchmarkBMTBaseline_4k(t *testing.B)   { benchmarkBMTBaseline(4096, t) }
   338  func BenchmarkBMTBaseline_2k(t *testing.B)   { benchmarkBMTBaseline(4096/2, t) }
   339  func BenchmarkBMTBaseline_1k(t *testing.B)   { benchmarkBMTBaseline(4096/4, t) }
   340  func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) }
   341  func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) }
   342  func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) }
   343  
   344  func BenchmarkRefHasher_4k(t *testing.B)   { benchmarkRefHasher(4096, t) }
   345  func BenchmarkRefHasher_2k(t *testing.B)   { benchmarkRefHasher(4096/2, t) }
   346  func BenchmarkRefHasher_1k(t *testing.B)   { benchmarkRefHasher(4096/4, t) }
   347  func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) }
   348  func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) }
   349  func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) }
   350  
   351  func BenchmarkHasher_4k(t *testing.B)   { benchmarkHasher(4096, t) }
   352  func BenchmarkHasher_2k(t *testing.B)   { benchmarkHasher(4096/2, t) }
   353  func BenchmarkHasher_1k(t *testing.B)   { benchmarkHasher(4096/4, t) }
   354  func BenchmarkHasher_512b(t *testing.B) { benchmarkHasher(4096/8, t) }
   355  func BenchmarkHasher_256b(t *testing.B) { benchmarkHasher(4096/16, t) }
   356  func BenchmarkHasher_128b(t *testing.B) { benchmarkHasher(4096/32, t) }
   357  
   358  func BenchmarkHasherNoReuse_4k(t *testing.B)   { benchmarkHasherReuse(1, 4096, t) }
   359  func BenchmarkHasherNoReuse_2k(t *testing.B)   { benchmarkHasherReuse(1, 4096/2, t) }
   360  func BenchmarkHasherNoReuse_1k(t *testing.B)   { benchmarkHasherReuse(1, 4096/4, t) }
   361  func BenchmarkHasherNoReuse_512b(t *testing.B) { benchmarkHasherReuse(1, 4096/8, t) }
   362  func BenchmarkHasherNoReuse_256b(t *testing.B) { benchmarkHasherReuse(1, 4096/16, t) }
   363  func BenchmarkHasherNoReuse_128b(t *testing.B) { benchmarkHasherReuse(1, 4096/32, t) }
   364  
   365  func BenchmarkHasherReuse_4k(t *testing.B)   { benchmarkHasherReuse(16, 4096, t) }
   366  func BenchmarkHasherReuse_2k(t *testing.B)   { benchmarkHasherReuse(16, 4096/2, t) }
   367  func BenchmarkHasherReuse_1k(t *testing.B)   { benchmarkHasherReuse(16, 4096/4, t) }
   368  func BenchmarkHasherReuse_512b(t *testing.B) { benchmarkHasherReuse(16, 4096/8, t) }
   369  func BenchmarkHasherReuse_256b(t *testing.B) { benchmarkHasherReuse(16, 4096/16, t) }
   370  func BenchmarkHasherReuse_128b(t *testing.B) { benchmarkHasherReuse(16, 4096/32, t) }
   371  
   372  // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
   373  // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
   374  // doing it on n maxproccnt each reusing the base hasher
   375  // the premise is that this is the minimum computation needed for a BMT
   376  // therefore this serves as a theoretical optimum for concurrent implementations
   377  func benchmarkBMTBaseline(n int, t *testing.B) {
   378  	tdata := testDataReader(64)
   379  	data := make([]byte, 64)
   380  	tdata.Read(data)
   381  	hasher := sha3.NewKeccak256
   382  
   383  	t.ReportAllocs()
   384  	t.ResetTimer()
   385  	for i := 0; i < t.N; i++ {
   386  		count := int32((n-1)/hasher().Size() + 1)
   387  		wg := sync.WaitGroup{}
   388  		wg.Add(maxproccnt)
   389  		var i int32
   390  		for j := 0; j < maxproccnt; j++ {
   391  			go func() {
   392  				defer wg.Done()
   393  				h := hasher()
   394  				for atomic.AddInt32(&i, 1) < count {
   395  					h.Reset()
   396  					h.Write(data)
   397  					h.Sum(nil)
   398  				}
   399  			}()
   400  		}
   401  		wg.Wait()
   402  	}
   403  }
   404  
   405  func benchmarkHasher(n int, t *testing.B) {
   406  	tdata := testDataReader(n)
   407  	data := make([]byte, n)
   408  	tdata.Read(data)
   409  
   410  	size := 1
   411  	hasher := sha3.NewKeccak256
   412  	segmentCount := 128
   413  	pool := NewTreePool(hasher, segmentCount, size)
   414  	bmt := New(pool)
   415  
   416  	t.ReportAllocs()
   417  	t.ResetTimer()
   418  	for i := 0; i < t.N; i++ {
   419  		bmt.Reset()
   420  		bmt.Write(data)
   421  		bmt.Sum(nil)
   422  	}
   423  }
   424  
   425  func benchmarkHasherReuse(poolsize, n int, t *testing.B) {
   426  	tdata := testDataReader(n)
   427  	data := make([]byte, n)
   428  	tdata.Read(data)
   429  
   430  	hasher := sha3.NewKeccak256
   431  	segmentCount := 128
   432  	pool := NewTreePool(hasher, segmentCount, poolsize)
   433  	cycles := 200
   434  
   435  	t.ReportAllocs()
   436  	t.ResetTimer()
   437  	for i := 0; i < t.N; i++ {
   438  		wg := sync.WaitGroup{}
   439  		wg.Add(cycles)
   440  		for j := 0; j < cycles; j++ {
   441  			bmt := New(pool)
   442  			go func() {
   443  				defer wg.Done()
   444  				bmt.Reset()
   445  				bmt.Write(data)
   446  				bmt.Sum(nil)
   447  			}()
   448  		}
   449  		wg.Wait()
   450  	}
   451  }
   452  
   453  func benchmarkSHA3(n int, t *testing.B) {
   454  	data := make([]byte, n)
   455  	tdata := testDataReader(n)
   456  	tdata.Read(data)
   457  	hasher := sha3.NewKeccak256
   458  	h := hasher()
   459  
   460  	t.ReportAllocs()
   461  	t.ResetTimer()
   462  	for i := 0; i < t.N; i++ {
   463  		h.Reset()
   464  		h.Write(data)
   465  		h.Sum(nil)
   466  	}
   467  }
   468  
   469  func benchmarkRefHasher(n int, t *testing.B) {
   470  	data := make([]byte, n)
   471  	tdata := testDataReader(n)
   472  	tdata.Read(data)
   473  	hasher := sha3.NewKeccak256
   474  	rbmt := NewRefHasher(hasher, 128)
   475  
   476  	t.ReportAllocs()
   477  	t.ResetTimer()
   478  	for i := 0; i < t.N; i++ {
   479  		rbmt.Hash(data)
   480  	}
   481  }