github.com/gobitfly/go-ethereum@v1.8.12/swarm/bmt/bmt_test.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package bmt
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"encoding/binary"
    23  	"fmt"
    24  	"io"
    25  	"math/rand"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/ethereum/go-ethereum/crypto/sha3"
    32  )
    33  
    34  // the actual data length generated (could be longer than max datalength of the BMT)
    35  const BufferSize = 4128
    36  
    37  func sha3hash(data ...[]byte) []byte {
    38  	h := sha3.NewKeccak256()
    39  	for _, v := range data {
    40  		h.Write(v)
    41  	}
    42  	return h.Sum(nil)
    43  }
    44  
    45  // TestRefHasher tests that the RefHasher computes the expected BMT hash for
    46  // all data lengths between 0 and 256 bytes
    47  func TestRefHasher(t *testing.T) {
    48  
    49  	// the test struct is used to specify the expected BMT hash for
    50  	// segment counts between from and to and lengths from 1 to datalength
    51  	type test struct {
    52  		from     int
    53  		to       int
    54  		expected func([]byte) []byte
    55  	}
    56  
    57  	var tests []*test
    58  	// all lengths in [0,64] should be:
    59  	//
    60  	//   sha3hash(data)
    61  	//
    62  	tests = append(tests, &test{
    63  		from: 1,
    64  		to:   2,
    65  		expected: func(d []byte) []byte {
    66  			data := make([]byte, 64)
    67  			copy(data, d)
    68  			return sha3hash(data)
    69  		},
    70  	})
    71  
    72  	// all lengths in [3,4] should be:
    73  	//
    74  	//   sha3hash(
    75  	//     sha3hash(data[:64])
    76  	//     sha3hash(data[64:])
    77  	//   )
    78  	//
    79  	tests = append(tests, &test{
    80  		from: 3,
    81  		to:   4,
    82  		expected: func(d []byte) []byte {
    83  			data := make([]byte, 128)
    84  			copy(data, d)
    85  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
    86  		},
    87  	})
    88  
    89  	// all segmentCounts in [5,8] should be:
    90  	//
    91  	//   sha3hash(
    92  	//     sha3hash(
    93  	//       sha3hash(data[:64])
    94  	//       sha3hash(data[64:128])
    95  	//     )
    96  	//     sha3hash(
    97  	//       sha3hash(data[128:192])
    98  	//       sha3hash(data[192:])
    99  	//     )
   100  	//   )
   101  	//
   102  	tests = append(tests, &test{
   103  		from: 5,
   104  		to:   8,
   105  		expected: func(d []byte) []byte {
   106  			data := make([]byte, 256)
   107  			copy(data, d)
   108  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   109  		},
   110  	})
   111  
   112  	// run the tests
   113  	for _, x := range tests {
   114  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   115  			for length := 1; length <= segmentCount*32; length++ {
   116  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   117  					data := make([]byte, length)
   118  					if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
   119  						t.Fatal(err)
   120  					}
   121  					expected := x.expected(data)
   122  					actual := NewRefHasher(sha3.NewKeccak256, segmentCount).Hash(data)
   123  					if !bytes.Equal(actual, expected) {
   124  						t.Fatalf("expected %x, got %x", expected, actual)
   125  					}
   126  				})
   127  			}
   128  		}
   129  	}
   130  }
   131  
   132  func TestHasherCorrectness(t *testing.T) {
   133  	err := testHasher(testBaseHasher)
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  }
   138  
   139  func testHasher(f func(BaseHasherFunc, []byte, int, int) error) error {
   140  	data := newData(BufferSize)
   141  	hasher := sha3.NewKeccak256
   142  	size := hasher().Size()
   143  	counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128}
   144  
   145  	var err error
   146  	for _, count := range counts {
   147  		max := count * size
   148  		incr := 1
   149  		for n := 1; n <= max; n += incr {
   150  			err = f(hasher, data, n, count)
   151  			if err != nil {
   152  				return err
   153  			}
   154  		}
   155  	}
   156  	return nil
   157  }
   158  
   159  // Tests that the BMT hasher can be synchronously reused with poolsizes 1 and PoolSize
   160  func TestHasherReuse(t *testing.T) {
   161  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   162  		testHasherReuse(1, t)
   163  	})
   164  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   165  		testHasherReuse(PoolSize, t)
   166  	})
   167  }
   168  
   169  func testHasherReuse(poolsize int, t *testing.T) {
   170  	hasher := sha3.NewKeccak256
   171  	pool := NewTreePool(hasher, SegmentCount, poolsize)
   172  	defer pool.Drain(0)
   173  	bmt := New(pool)
   174  
   175  	for i := 0; i < 100; i++ {
   176  		data := newData(BufferSize)
   177  		n := rand.Intn(bmt.DataLength())
   178  		err := testHasherCorrectness(bmt, hasher, data, n, SegmentCount)
   179  		if err != nil {
   180  			t.Fatal(err)
   181  		}
   182  	}
   183  }
   184  
   185  // Tests if pool can be cleanly reused even in concurrent use
   186  func TestBMTHasherConcurrentUse(t *testing.T) {
   187  	hasher := sha3.NewKeccak256
   188  	pool := NewTreePool(hasher, SegmentCount, PoolSize)
   189  	defer pool.Drain(0)
   190  	cycles := 100
   191  	errc := make(chan error)
   192  
   193  	for i := 0; i < cycles; i++ {
   194  		go func() {
   195  			bmt := New(pool)
   196  			data := newData(BufferSize)
   197  			n := rand.Intn(bmt.DataLength())
   198  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   199  		}()
   200  	}
   201  LOOP:
   202  	for {
   203  		select {
   204  		case <-time.NewTimer(5 * time.Second).C:
   205  			t.Fatal("timed out")
   206  		case err := <-errc:
   207  			if err != nil {
   208  				t.Fatal(err)
   209  			}
   210  			cycles--
   211  			if cycles == 0 {
   212  				break LOOP
   213  			}
   214  		}
   215  	}
   216  }
   217  
   218  // helper function that creates  a tree pool
   219  func testBaseHasher(hasher BaseHasherFunc, d []byte, n, count int) error {
   220  	pool := NewTreePool(hasher, count, 1)
   221  	defer pool.Drain(0)
   222  	bmt := New(pool)
   223  	return testHasherCorrectness(bmt, hasher, d, n, count)
   224  }
   225  
   226  // helper function that compares reference and optimised implementations on
   227  // correctness
   228  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   229  	span := make([]byte, 8)
   230  	if len(d) < n {
   231  		n = len(d)
   232  	}
   233  	binary.BigEndian.PutUint64(span, uint64(n))
   234  	data := d[:n]
   235  	rbmt := NewRefHasher(hasher, count)
   236  	exp := sha3hash(span, rbmt.Hash(data))
   237  	got := Hash(bmt, span, data)
   238  	if !bytes.Equal(got, exp) {
   239  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   240  	}
   241  	return err
   242  }
   243  
   244  func BenchmarkSHA3_4k(t *testing.B)   { benchmarkSHA3(4096, t) }
   245  func BenchmarkSHA3_2k(t *testing.B)   { benchmarkSHA3(4096/2, t) }
   246  func BenchmarkSHA3_1k(t *testing.B)   { benchmarkSHA3(4096/4, t) }
   247  func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) }
   248  func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) }
   249  func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) }
   250  
   251  func BenchmarkBMTBaseline_4k(t *testing.B)   { benchmarkBMTBaseline(4096, t) }
   252  func BenchmarkBMTBaseline_2k(t *testing.B)   { benchmarkBMTBaseline(4096/2, t) }
   253  func BenchmarkBMTBaseline_1k(t *testing.B)   { benchmarkBMTBaseline(4096/4, t) }
   254  func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) }
   255  func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) }
   256  func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) }
   257  
   258  func BenchmarkRefHasher_4k(t *testing.B)   { benchmarkRefHasher(4096, t) }
   259  func BenchmarkRefHasher_2k(t *testing.B)   { benchmarkRefHasher(4096/2, t) }
   260  func BenchmarkRefHasher_1k(t *testing.B)   { benchmarkRefHasher(4096/4, t) }
   261  func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) }
   262  func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) }
   263  func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) }
   264  
   265  func BenchmarkBMTHasher_4k(t *testing.B)   { benchmarkBMTHasher(4096, t) }
   266  func BenchmarkBMTHasher_2k(t *testing.B)   { benchmarkBMTHasher(4096/2, t) }
   267  func BenchmarkBMTHasher_1k(t *testing.B)   { benchmarkBMTHasher(4096/4, t) }
   268  func BenchmarkBMTHasher_512b(t *testing.B) { benchmarkBMTHasher(4096/8, t) }
   269  func BenchmarkBMTHasher_256b(t *testing.B) { benchmarkBMTHasher(4096/16, t) }
   270  func BenchmarkBMTHasher_128b(t *testing.B) { benchmarkBMTHasher(4096/32, t) }
   271  
   272  func BenchmarkBMTHasherNoPool_4k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096, t) }
   273  func BenchmarkBMTHasherNoPool_2k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096/2, t) }
   274  func BenchmarkBMTHasherNoPool_1k(t *testing.B)   { benchmarkBMTHasherPool(1, 4096/4, t) }
   275  func BenchmarkBMTHasherNoPool_512b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/8, t) }
   276  func BenchmarkBMTHasherNoPool_256b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/16, t) }
   277  func BenchmarkBMTHasherNoPool_128b(t *testing.B) { benchmarkBMTHasherPool(1, 4096/32, t) }
   278  
   279  func BenchmarkBMTHasherPool_4k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096, t) }
   280  func BenchmarkBMTHasherPool_2k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096/2, t) }
   281  func BenchmarkBMTHasherPool_1k(t *testing.B)   { benchmarkBMTHasherPool(PoolSize, 4096/4, t) }
   282  func BenchmarkBMTHasherPool_512b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/8, t) }
   283  func BenchmarkBMTHasherPool_256b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/16, t) }
   284  func BenchmarkBMTHasherPool_128b(t *testing.B) { benchmarkBMTHasherPool(PoolSize, 4096/32, t) }
   285  
   286  // benchmarks simple sha3 hash on chunks
   287  func benchmarkSHA3(n int, t *testing.B) {
   288  	data := newData(n)
   289  	hasher := sha3.NewKeccak256
   290  	h := hasher()
   291  
   292  	t.ReportAllocs()
   293  	t.ResetTimer()
   294  	for i := 0; i < t.N; i++ {
   295  		h.Reset()
   296  		h.Write(data)
   297  		h.Sum(nil)
   298  	}
   299  }
   300  
   301  // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
   302  // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
   303  // doing it on n PoolSize each reusing the base hasher
   304  // the premise is that this is the minimum computation needed for a BMT
   305  // therefore this serves as a theoretical optimum for concurrent implementations
   306  func benchmarkBMTBaseline(n int, t *testing.B) {
   307  	hasher := sha3.NewKeccak256
   308  	hashSize := hasher().Size()
   309  	data := newData(hashSize)
   310  
   311  	t.ReportAllocs()
   312  	t.ResetTimer()
   313  	for i := 0; i < t.N; i++ {
   314  		count := int32((n-1)/hashSize + 1)
   315  		wg := sync.WaitGroup{}
   316  		wg.Add(PoolSize)
   317  		var i int32
   318  		for j := 0; j < PoolSize; j++ {
   319  			go func() {
   320  				defer wg.Done()
   321  				h := hasher()
   322  				for atomic.AddInt32(&i, 1) < count {
   323  					h.Reset()
   324  					h.Write(data)
   325  					h.Sum(nil)
   326  				}
   327  			}()
   328  		}
   329  		wg.Wait()
   330  	}
   331  }
   332  
   333  // benchmarks BMT Hasher
   334  func benchmarkBMTHasher(n int, t *testing.B) {
   335  	data := newData(n)
   336  	hasher := sha3.NewKeccak256
   337  	pool := NewTreePool(hasher, SegmentCount, PoolSize)
   338  
   339  	t.ReportAllocs()
   340  	t.ResetTimer()
   341  	for i := 0; i < t.N; i++ {
   342  		bmt := New(pool)
   343  		Hash(bmt, nil, data)
   344  	}
   345  }
   346  
   347  // benchmarks 100 concurrent bmt hashes with pool capacity
   348  func benchmarkBMTHasherPool(poolsize, n int, t *testing.B) {
   349  	data := newData(n)
   350  	hasher := sha3.NewKeccak256
   351  	pool := NewTreePool(hasher, SegmentCount, poolsize)
   352  	cycles := 100
   353  
   354  	t.ReportAllocs()
   355  	t.ResetTimer()
   356  	wg := sync.WaitGroup{}
   357  	for i := 0; i < t.N; i++ {
   358  		wg.Add(cycles)
   359  		for j := 0; j < cycles; j++ {
   360  			go func() {
   361  				defer wg.Done()
   362  				bmt := New(pool)
   363  				Hash(bmt, nil, data)
   364  			}()
   365  		}
   366  		wg.Wait()
   367  	}
   368  }
   369  
   370  // benchmarks the reference hasher
   371  func benchmarkRefHasher(n int, t *testing.B) {
   372  	data := newData(n)
   373  	hasher := sha3.NewKeccak256
   374  	rbmt := NewRefHasher(hasher, 128)
   375  
   376  	t.ReportAllocs()
   377  	t.ResetTimer()
   378  	for i := 0; i < t.N; i++ {
   379  		rbmt.Hash(data)
   380  	}
   381  }
   382  
   383  func newData(bufferSize int) []byte {
   384  	data := make([]byte, bufferSize)
   385  	_, err := io.ReadFull(crand.Reader, data)
   386  	if err != nil {
   387  		panic(err.Error())
   388  	}
   389  	return data
   390  }