github.com/linapex/ethereum-go-chinese@v0.0.0-20190316121929-f8b7a73c3fa1/swarm/bmt/bmt_test.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 19:16:43</date>
    10  //</624450112671715328>
    11  
    12  
    13  package bmt
    14  
    15  import (
    16  	"bytes"
    17  	"encoding/binary"
    18  	"fmt"
    19  	"math/rand"
    20  	"sync"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/ethereum/go-ethereum/swarm/testutil"
    26  	"golang.org/x/crypto/sha3"
    27  )
    28  
    29  //生成的实际数据长度(可能长于BMT的最大数据长度)
    30  const BufferSize = 4128
    31  
    32  const (
    33  //SegmentCount是基础块的最大段数
    34  //应等于最大块数据大小/哈希大小
    35  //当前设置为128==4096(默认块大小)/32(sha3.keccak256大小)
    36  	segmentCount = 128
    37  )
    38  
    39  var counts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    40  
    41  //计算数据的keccak256 sha3哈希
    42  func sha3hash(data ...[]byte) []byte {
    43  	h := sha3.NewLegacyKeccak256()
    44  	return doSum(h, nil, data...)
    45  }
    46  
    47  //testrefhasher测试refhasher为其计算预期的bmt哈希
    48  //一些小数据长度
    49  func TestRefHasher(t *testing.T) {
    50  //测试结构用于指定预期的BMT哈希
    51  //从到之间的段计数和从1到数据长度的段计数
    52  	type test struct {
    53  		from     int
    54  		to       int
    55  		expected func([]byte) []byte
    56  	}
    57  
    58  	var tests []*test
    59  //[0,64]中的所有长度应为:
    60  //
    61  //SH3HASH(数据)
    62  //
    63  	tests = append(tests, &test{
    64  		from: 1,
    65  		to:   2,
    66  		expected: func(d []byte) []byte {
    67  			data := make([]byte, 64)
    68  			copy(data, d)
    69  			return sha3hash(data)
    70  		},
    71  	})
    72  
    73  //[3,4]中的所有长度应为:
    74  //
    75  //SH3HASH
    76  //sha3hash(数据[:64])
    77  //sha3hash(数据[64:]
    78  //)
    79  //
    80  	tests = append(tests, &test{
    81  		from: 3,
    82  		to:   4,
    83  		expected: func(d []byte) []byte {
    84  			data := make([]byte, 128)
    85  			copy(data, d)
    86  			return sha3hash(sha3hash(data[:64]), sha3hash(data[64:]))
    87  		},
    88  	})
    89  
    90  //[5,8]中的所有分段计数应为:
    91  //
    92  //SH3HASH
    93  //SH3HASH
    94  //sha3hash(数据[:64])
    95  //sha3hash(数据[64:128])
    96  //)
    97  //SH3HASH
    98  //sha3hash(数据[128:192])
    99  //sha3hash(数据[192:]
   100  //)
   101  //)
   102  //
   103  	tests = append(tests, &test{
   104  		from: 5,
   105  		to:   8,
   106  		expected: func(d []byte) []byte {
   107  			data := make([]byte, 256)
   108  			copy(data, d)
   109  			return sha3hash(sha3hash(sha3hash(data[:64]), sha3hash(data[64:128])), sha3hash(sha3hash(data[128:192]), sha3hash(data[192:])))
   110  		},
   111  	})
   112  
   113  //运行测试
   114  	for i, x := range tests {
   115  		for segmentCount := x.from; segmentCount <= x.to; segmentCount++ {
   116  			for length := 1; length <= segmentCount*32; length++ {
   117  				t.Run(fmt.Sprintf("%d_segments_%d_bytes", segmentCount, length), func(t *testing.T) {
   118  					data := testutil.RandomBytes(i, length)
   119  					expected := x.expected(data)
   120  					actual := NewRefHasher(sha3.NewLegacyKeccak256, segmentCount).Hash(data)
   121  					if !bytes.Equal(actual, expected) {
   122  						t.Fatalf("expected %x, got %x", expected, actual)
   123  					}
   124  				})
   125  			}
   126  		}
   127  	}
   128  }
   129  
   130  //测试哈希程序是否响应正确的哈希,比较引用实现返回值
   131  func TestHasherEmptyData(t *testing.T) {
   132  	hasher := sha3.NewLegacyKeccak256
   133  	var data []byte
   134  	for _, count := range counts {
   135  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   136  			pool := NewTreePool(hasher, count, PoolSize)
   137  			defer pool.Drain(0)
   138  			bmt := New(pool)
   139  			rbmt := NewRefHasher(hasher, count)
   140  			refHash := rbmt.Hash(data)
   141  			expHash := syncHash(bmt, nil, data)
   142  			if !bytes.Equal(expHash, refHash) {
   143  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   144  			}
   145  		})
   146  	}
   147  }
   148  
   149  //用一次性写入的整个最大大小测试顺序写入
   150  func TestSyncHasherCorrectness(t *testing.T) {
   151  	data := testutil.RandomBytes(1, BufferSize)
   152  	hasher := sha3.NewLegacyKeccak256
   153  	size := hasher().Size()
   154  
   155  	var err error
   156  	for _, count := range counts {
   157  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
   158  			max := count * size
   159  			var incr int
   160  			capacity := 1
   161  			pool := NewTreePool(hasher, count, capacity)
   162  			defer pool.Drain(0)
   163  			for n := 0; n <= max; n += incr {
   164  				incr = 1 + rand.Intn(5)
   165  				bmt := New(pool)
   166  				err = testHasherCorrectness(bmt, hasher, data, n, count)
   167  				if err != nil {
   168  					t.Fatal(err)
   169  				}
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  //测试顺序为非特定并发写入,一次写入整个最大大小
   176  func TestAsyncCorrectness(t *testing.T) {
   177  	data := testutil.RandomBytes(1, BufferSize)
   178  	hasher := sha3.NewLegacyKeccak256
   179  	size := hasher().Size()
   180  	whs := []whenHash{first, last, random}
   181  
   182  	for _, double := range []bool{false, true} {
   183  		for _, wh := range whs {
   184  			for _, count := range counts {
   185  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_segments_%v", double, wh, count), func(t *testing.T) {
   186  					max := count * size
   187  					var incr int
   188  					capacity := 1
   189  					pool := NewTreePool(hasher, count, capacity)
   190  					defer pool.Drain(0)
   191  					for n := 1; n <= max; n += incr {
   192  						incr = 1 + rand.Intn(5)
   193  						bmt := New(pool)
   194  						d := data[:n]
   195  						rbmt := NewRefHasher(hasher, count)
   196  						exp := rbmt.Hash(d)
   197  						got := syncHash(bmt, nil, d)
   198  						if !bytes.Equal(got, exp) {
   199  							t.Fatalf("wrong sync hash for datalength %v: expected %x (ref), got %x", n, exp, got)
   200  						}
   201  						sw := bmt.NewAsyncWriter(double)
   202  						got = asyncHashRandom(sw, nil, d, wh)
   203  						if !bytes.Equal(got, exp) {
   204  							t.Fatalf("wrong async hash for datalength %v: expected %x, got %x", n, exp, got)
   205  						}
   206  					}
   207  				})
   208  			}
   209  		}
   210  	}
   211  }
   212  
   213  //测试bmt散列器是否可以与poolsizes 1和poolsize同步重用
   214  func TestHasherReuse(t *testing.T) {
   215  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   216  		testHasherReuse(1, t)
   217  	})
   218  	t.Run(fmt.Sprintf("poolsize_%d", PoolSize), func(t *testing.T) {
   219  		testHasherReuse(PoolSize, t)
   220  	})
   221  }
   222  
   223  //测试BMT重用是否不会破坏结果
   224  func testHasherReuse(poolsize int, t *testing.T) {
   225  	hasher := sha3.NewLegacyKeccak256
   226  	pool := NewTreePool(hasher, segmentCount, poolsize)
   227  	defer pool.Drain(0)
   228  	bmt := New(pool)
   229  
   230  	for i := 0; i < 100; i++ {
   231  		data := testutil.RandomBytes(1, BufferSize)
   232  		n := rand.Intn(bmt.Size())
   233  		err := testHasherCorrectness(bmt, hasher, data, n, segmentCount)
   234  		if err != nil {
   235  			t.Fatal(err)
   236  		}
   237  	}
   238  }
   239  
   240  //测试池是否可以被干净地重用,即使在多个哈希程序同时使用时也是如此
   241  func TestBMTConcurrentUse(t *testing.T) {
   242  	hasher := sha3.NewLegacyKeccak256
   243  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   244  	defer pool.Drain(0)
   245  	cycles := 100
   246  	errc := make(chan error)
   247  
   248  	for i := 0; i < cycles; i++ {
   249  		go func() {
   250  			bmt := New(pool)
   251  			data := testutil.RandomBytes(1, BufferSize)
   252  			n := rand.Intn(bmt.Size())
   253  			errc <- testHasherCorrectness(bmt, hasher, data, n, 128)
   254  		}()
   255  	}
   256  LOOP:
   257  	for {
   258  		select {
   259  		case <-time.NewTimer(5 * time.Second).C:
   260  			t.Fatal("timed out")
   261  		case err := <-errc:
   262  			if err != nil {
   263  				t.Fatal(err)
   264  			}
   265  			cycles--
   266  			if cycles == 0 {
   267  				break LOOP
   268  			}
   269  		}
   270  	}
   271  }
   272  
   273  //测试bmt hasher io.writer接口是否正常工作
   274  //甚至多个短随机写缓冲区
   275  func TestBMTWriterBuffers(t *testing.T) {
   276  	hasher := sha3.NewLegacyKeccak256
   277  
   278  	for _, count := range counts {
   279  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   280  			errc := make(chan error)
   281  			pool := NewTreePool(hasher, count, PoolSize)
   282  			defer pool.Drain(0)
   283  			n := count * 32
   284  			bmt := New(pool)
   285  			data := testutil.RandomBytes(1, n)
   286  			rbmt := NewRefHasher(hasher, count)
   287  			refHash := rbmt.Hash(data)
   288  			expHash := syncHash(bmt, nil, data)
   289  			if !bytes.Equal(expHash, refHash) {
   290  				t.Fatalf("hash mismatch with reference. expected %x, got %x", refHash, expHash)
   291  			}
   292  			attempts := 10
   293  			f := func() error {
   294  				bmt := New(pool)
   295  				bmt.Reset()
   296  				var buflen int
   297  				for offset := 0; offset < n; offset += buflen {
   298  					buflen = rand.Intn(n-offset) + 1
   299  					read, err := bmt.Write(data[offset : offset+buflen])
   300  					if err != nil {
   301  						return err
   302  					}
   303  					if read != buflen {
   304  						return fmt.Errorf("incorrect read. expected %v bytes, got %v", buflen, read)
   305  					}
   306  				}
   307  				hash := bmt.Sum(nil)
   308  				if !bytes.Equal(hash, expHash) {
   309  					return fmt.Errorf("hash mismatch. expected %x, got %x", hash, expHash)
   310  				}
   311  				return nil
   312  			}
   313  
   314  			for j := 0; j < attempts; j++ {
   315  				go func() {
   316  					errc <- f()
   317  				}()
   318  			}
   319  			timeout := time.NewTimer(2 * time.Second)
   320  			for {
   321  				select {
   322  				case err := <-errc:
   323  					if err != nil {
   324  						t.Fatal(err)
   325  					}
   326  					attempts--
   327  					if attempts == 0 {
   328  						return
   329  					}
   330  				case <-timeout.C:
   331  					t.Fatalf("timeout")
   332  				}
   333  			}
   334  		})
   335  	}
   336  }
   337  
   338  //比较引用和优化实现的helper函数
   339  //正确性
   340  func testHasherCorrectness(bmt *Hasher, hasher BaseHasherFunc, d []byte, n, count int) (err error) {
   341  	span := make([]byte, 8)
   342  	if len(d) < n {
   343  		n = len(d)
   344  	}
   345  	binary.BigEndian.PutUint64(span, uint64(n))
   346  	data := d[:n]
   347  	rbmt := NewRefHasher(hasher, count)
   348  	exp := sha3hash(span, rbmt.Hash(data))
   349  	got := syncHash(bmt, span, data)
   350  	if !bytes.Equal(got, exp) {
   351  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   352  	}
   353  	return err
   354  }
   355  
   356  //
   357  func BenchmarkBMT(t *testing.B) {
   358  	for size := 4096; size >= 128; size /= 2 {
   359  		t.Run(fmt.Sprintf("%v_size_%v", "SHA3", size), func(t *testing.B) {
   360  			benchmarkSHA3(t, size)
   361  		})
   362  		t.Run(fmt.Sprintf("%v_size_%v", "Baseline", size), func(t *testing.B) {
   363  			benchmarkBMTBaseline(t, size)
   364  		})
   365  		t.Run(fmt.Sprintf("%v_size_%v", "REF", size), func(t *testing.B) {
   366  			benchmarkRefHasher(t, size)
   367  		})
   368  		t.Run(fmt.Sprintf("%v_size_%v", "BMT", size), func(t *testing.B) {
   369  			benchmarkBMT(t, size)
   370  		})
   371  	}
   372  }
   373  
   374  type whenHash = int
   375  
   376  const (
   377  	first whenHash = iota
   378  	last
   379  	random
   380  )
   381  
   382  func BenchmarkBMTAsync(t *testing.B) {
   383  	whs := []whenHash{first, last, random}
   384  	for size := 4096; size >= 128; size /= 2 {
   385  		for _, wh := range whs {
   386  			for _, double := range []bool{false, true} {
   387  				t.Run(fmt.Sprintf("double_%v_hash_when_%v_size_%v", double, wh, size), func(t *testing.B) {
   388  					benchmarkBMTAsync(t, size, wh, double)
   389  				})
   390  			}
   391  		}
   392  	}
   393  }
   394  
   395  func BenchmarkPool(t *testing.B) {
   396  	caps := []int{1, PoolSize}
   397  	for size := 4096; size >= 128; size /= 2 {
   398  		for _, c := range caps {
   399  			t.Run(fmt.Sprintf("poolsize_%v_size_%v", c, size), func(t *testing.B) {
   400  				benchmarkPool(t, c, size)
   401  			})
   402  		}
   403  	}
   404  }
   405  
   406  //块上的简单sha3哈希基准
   407  func benchmarkSHA3(t *testing.B, n int) {
   408  	data := testutil.RandomBytes(1, n)
   409  	hasher := sha3.NewLegacyKeccak256
   410  	h := hasher()
   411  
   412  	t.ReportAllocs()
   413  	t.ResetTimer()
   414  	for i := 0; i < t.N; i++ {
   415  		doSum(h, nil, data)
   416  	}
   417  }
   418  
   419  //为平衡(为了简单起见)BMT设定最小哈希时间基准
   420  //通过计算/分段大小,并行散列2*分段大小字节
   421  //在n池上执行此操作,使每个池都重新使用基哈希表
   422  //前提是这是BMT所需的最小计算量
   423  //因此,这在理论上是并发实现的最佳选择。
   424  func benchmarkBMTBaseline(t *testing.B, n int) {
   425  	hasher := sha3.NewLegacyKeccak256
   426  	hashSize := hasher().Size()
   427  	data := testutil.RandomBytes(1, hashSize)
   428  
   429  	t.ReportAllocs()
   430  	t.ResetTimer()
   431  	for i := 0; i < t.N; i++ {
   432  		count := int32((n-1)/hashSize + 1)
   433  		wg := sync.WaitGroup{}
   434  		wg.Add(PoolSize)
   435  		var i int32
   436  		for j := 0; j < PoolSize; j++ {
   437  			go func() {
   438  				defer wg.Done()
   439  				h := hasher()
   440  				for atomic.AddInt32(&i, 1) < count {
   441  					doSum(h, nil, data)
   442  				}
   443  			}()
   444  		}
   445  		wg.Wait()
   446  	}
   447  }
   448  
   449  //基准BMT哈希
   450  func benchmarkBMT(t *testing.B, n int) {
   451  	data := testutil.RandomBytes(1, n)
   452  	hasher := sha3.NewLegacyKeccak256
   453  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   454  	bmt := New(pool)
   455  
   456  	t.ReportAllocs()
   457  	t.ResetTimer()
   458  	for i := 0; i < t.N; i++ {
   459  		syncHash(bmt, nil, data)
   460  	}
   461  }
   462  
   463  //具有异步并发段/节写入的基准BMT哈希器
   464  func benchmarkBMTAsync(t *testing.B, n int, wh whenHash, double bool) {
   465  	data := testutil.RandomBytes(1, n)
   466  	hasher := sha3.NewLegacyKeccak256
   467  	pool := NewTreePool(hasher, segmentCount, PoolSize)
   468  	bmt := New(pool).NewAsyncWriter(double)
   469  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   470  	rand.Shuffle(len(idxs), func(i int, j int) {
   471  		idxs[i], idxs[j] = idxs[j], idxs[i]
   472  	})
   473  
   474  	t.ReportAllocs()
   475  	t.ResetTimer()
   476  	for i := 0; i < t.N; i++ {
   477  		asyncHash(bmt, nil, n, wh, idxs, segments)
   478  	}
   479  }
   480  
   481  //基准100个具有池容量的并发BMT哈希
   482  func benchmarkPool(t *testing.B, poolsize, n int) {
   483  	data := testutil.RandomBytes(1, n)
   484  	hasher := sha3.NewLegacyKeccak256
   485  	pool := NewTreePool(hasher, segmentCount, poolsize)
   486  	cycles := 100
   487  
   488  	t.ReportAllocs()
   489  	t.ResetTimer()
   490  	wg := sync.WaitGroup{}
   491  	for i := 0; i < t.N; i++ {
   492  		wg.Add(cycles)
   493  		for j := 0; j < cycles; j++ {
   494  			go func() {
   495  				defer wg.Done()
   496  				bmt := New(pool)
   497  				syncHash(bmt, nil, data)
   498  			}()
   499  		}
   500  		wg.Wait()
   501  	}
   502  }
   503  
   504  //基准参考哈希
   505  func benchmarkRefHasher(t *testing.B, n int) {
   506  	data := testutil.RandomBytes(1, n)
   507  	hasher := sha3.NewLegacyKeccak256
   508  	rbmt := NewRefHasher(hasher, 128)
   509  
   510  	t.ReportAllocs()
   511  	t.ResetTimer()
   512  	for i := 0; i < t.N; i++ {
   513  		rbmt.Hash(data)
   514  	}
   515  }
   516  
   517  //散列使用bmt散列器散列数据和范围
   518  func syncHash(h *Hasher, span, data []byte) []byte {
   519  	h.ResetWithLength(span)
   520  	h.Write(data)
   521  	return h.Sum(nil)
   522  }
   523  
   524  func splitAndShuffle(secsize int, data []byte) (idxs []int, segments [][]byte) {
   525  	l := len(data)
   526  	n := l / secsize
   527  	if l%secsize > 0 {
   528  		n++
   529  	}
   530  	for i := 0; i < n; i++ {
   531  		idxs = append(idxs, i)
   532  		end := (i + 1) * secsize
   533  		if end > l {
   534  			end = l
   535  		}
   536  		section := data[i*secsize : end]
   537  		segments = append(segments, section)
   538  	}
   539  	rand.Shuffle(n, func(i int, j int) {
   540  		idxs[i], idxs[j] = idxs[j], idxs[i]
   541  	})
   542  	return idxs, segments
   543  }
   544  
   545  //拆分输入数据执行随机无序移动以模拟异步节写入
   546  func asyncHashRandom(bmt SectionWriter, span []byte, data []byte, wh whenHash) (s []byte) {
   547  	idxs, segments := splitAndShuffle(bmt.SectionSize(), data)
   548  	return asyncHash(bmt, span, len(data), wh, idxs, segments)
   549  }
   550  
   551  //模拟bmt sectionwriter的异步节写入
   552  //需要对段的所有索引的列表进行排列(随机无序排列)
   553  //把它们写在适当的部分
   554  //根据wh参数调用sum函数(第一个、最后一个、随机[相对于段写入])
   555  func asyncHash(bmt SectionWriter, span []byte, l int, wh whenHash, idxs []int, segments [][]byte) (s []byte) {
   556  	bmt.Reset()
   557  	if l == 0 {
   558  		return bmt.Sum(nil, l, span)
   559  	}
   560  	c := make(chan []byte, 1)
   561  	hashf := func() {
   562  		c <- bmt.Sum(nil, l, span)
   563  	}
   564  	maxsize := len(idxs)
   565  	var r int
   566  	if wh == random {
   567  		r = rand.Intn(maxsize)
   568  	}
   569  	for i, idx := range idxs {
   570  		bmt.Write(idx, segments[idx])
   571  		if (wh == first || wh == random) && i == r {
   572  			go hashf()
   573  		}
   574  	}
   575  	if wh == last {
   576  		return bmt.Sum(nil, l, span)
   577  	}
   578  	return <-c
   579  }
   580