github.com/ethersphere/bee/v2@v2.2.0/pkg/bmt/bmt_test.go (about)

     1  // Copyright 2021 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bmt_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"math/rand"
    12  	"sort"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/ethersphere/bee/v2/pkg/bmt"
    17  	"github.com/ethersphere/bee/v2/pkg/bmt/reference"
    18  	"github.com/ethersphere/bee/v2/pkg/swarm"
    19  	"github.com/ethersphere/bee/v2/pkg/util/testutil"
    20  	"golang.org/x/sync/errgroup"
    21  )
    22  
    23  const (
    24  	// testPoolSize is the number of bmt trees the pool keeps when
    25  	testPoolSize = 16
    26  	// segmentCount is the maximum number of segments of the underlying chunk
    27  	// Should be equal to max-chunk-data-size / hash-size
    28  	// Currently set to 128 == 4096 (default chunk size) / 32 (sha3.keccak256 size)
    29  	testSegmentCount = 128
    30  )
    31  
    32  var (
    33  	testSegmentCounts = []int{1, 2, 3, 4, 5, 8, 9, 15, 16, 17, 32, 37, 42, 53, 63, 64, 65, 111, 127, 128}
    34  	hashSize          = swarm.NewHasher().Size()
    35  	seed              = time.Now().Unix()
    36  )
    37  
    38  func refHash(count int, data []byte) ([]byte, error) {
    39  	rbmt := reference.NewRefHasher(swarm.NewHasher(), count)
    40  	refNoMetaHash, err := rbmt.Hash(data)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	return bmt.Sha3hash(bmt.LengthToSpan(int64(len(data))), refNoMetaHash)
    45  }
    46  
    47  // syncHash hashes the data and the span using the bmt hasher
    48  func syncHash(h *bmt.Hasher, data []byte) ([]byte, error) {
    49  	h.Reset()
    50  	h.SetHeaderInt64(int64(len(data)))
    51  	_, err := h.Write(data)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	return h.Hash(nil)
    56  }
    57  
    58  // tests if hasher responds with correct hash comparing the reference implementation return value
    59  func TestHasherEmptyData(t *testing.T) {
    60  	t.Parallel()
    61  
    62  	for _, count := range testSegmentCounts {
    63  		count := count
    64  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
    65  			t.Parallel()
    66  
    67  			expHash, err := refHash(count, nil)
    68  			if err != nil {
    69  				t.Fatal(err)
    70  			}
    71  			pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, 1))
    72  			h := pool.Get()
    73  			resHash, err := syncHash(h, nil)
    74  			if err != nil {
    75  				t.Fatal(err)
    76  			}
    77  			pool.Put(h)
    78  			if !bytes.Equal(expHash, resHash) {
    79  				t.Fatalf("hash mismatch with reference. expected %x, got %x", expHash, resHash)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  // tests sequential write with entire max size written in one go
    86  func TestSyncHasherCorrectness(t *testing.T) {
    87  	t.Parallel()
    88  	testData := testutil.RandBytesWithSeed(t, 4096, seed)
    89  
    90  	for _, count := range testSegmentCounts {
    91  		count := count
    92  		t.Run(fmt.Sprintf("segments_%v", count), func(t *testing.T) {
    93  			t.Parallel()
    94  			max := count * hashSize
    95  			var incr int
    96  			capacity := 1
    97  			pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, capacity))
    98  			for n := 0; n <= max; n += incr {
    99  				h := pool.Get()
   100  				incr = 1 + rand.Intn(5)
   101  				err := testHasherCorrectness(h, testData, n, count)
   102  				if err != nil {
   103  					t.Fatalf("seed %d: %v", seed, err)
   104  				}
   105  				pool.Put(h)
   106  			}
   107  		})
   108  	}
   109  }
   110  
   111  // tests that the BMT hasher can be synchronously reused with poolsizes 1 and testPoolSize
   112  func TestHasherReuse(t *testing.T) {
   113  	t.Parallel()
   114  
   115  	t.Run(fmt.Sprintf("poolsize_%d", 1), func(t *testing.T) {
   116  		t.Parallel()
   117  		testHasherReuse(t, 1)
   118  	})
   119  
   120  	t.Run(fmt.Sprintf("poolsize_%d", testPoolSize), func(t *testing.T) {
   121  		t.Parallel()
   122  		testHasherReuse(t, testPoolSize)
   123  	})
   124  }
   125  
   126  // tests if bmt reuse is not corrupting result
   127  func testHasherReuse(t *testing.T, poolsize int) {
   128  	t.Helper()
   129  
   130  	pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, poolsize))
   131  	h := pool.Get()
   132  	defer pool.Put(h)
   133  
   134  	for i := 0; i < 100; i++ {
   135  		seed := int64(i)
   136  		testData := testutil.RandBytesWithSeed(t, 4096, seed)
   137  		n := rand.Intn(h.Capacity())
   138  		err := testHasherCorrectness(h, testData, n, testSegmentCount)
   139  		if err != nil {
   140  			t.Fatalf("seed %d: %v", seed, err)
   141  		}
   142  	}
   143  }
   144  
   145  // tests if pool can be cleanly reused even in concurrent use by several hashers
   146  func TestBMTConcurrentUse(t *testing.T) {
   147  	t.Parallel()
   148  
   149  	testData := testutil.RandBytesWithSeed(t, 4096, seed)
   150  	pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize))
   151  	cycles := 100
   152  
   153  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   154  	defer cancel()
   155  	eg, ectx := errgroup.WithContext(ctx)
   156  	for i := 0; i < cycles; i++ {
   157  		eg.Go(func() error {
   158  			select {
   159  			case <-ectx.Done():
   160  				return ectx.Err()
   161  			default:
   162  			}
   163  			h := pool.Get()
   164  			defer pool.Put(h)
   165  
   166  			n := rand.Intn(h.Capacity())
   167  			return testHasherCorrectness(h, testData, n, testSegmentCount)
   168  		})
   169  	}
   170  	if err := eg.Wait(); err != nil {
   171  		t.Fatalf("seed %d: %v", seed, err)
   172  	}
   173  }
   174  
   175  // tests BMT Hasher io.Writer interface is working correctly even with random short writes
   176  func TestBMTWriterBuffers(t *testing.T) {
   177  	t.Parallel()
   178  
   179  	for i, count := range testSegmentCounts {
   180  		i, count := i, count
   181  
   182  		t.Run(fmt.Sprintf("%d_segments", count), func(t *testing.T) {
   183  			t.Parallel()
   184  
   185  			pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, count, testPoolSize))
   186  			h := pool.Get()
   187  			defer pool.Put(h)
   188  
   189  			size := h.Capacity()
   190  			seed := int64(i)
   191  			testData := testutil.RandBytesWithSeed(t, 4096, seed)
   192  
   193  			resHash, err := syncHash(h, testData[:size])
   194  			if err != nil {
   195  				t.Fatal(err)
   196  			}
   197  			expHash, err := refHash(count, testData[:size])
   198  			if err != nil {
   199  				t.Fatal(err)
   200  			}
   201  			if !bytes.Equal(resHash, expHash) {
   202  				t.Fatalf("single write :hash mismatch with reference. expected %x, got %x", expHash, resHash)
   203  			}
   204  			attempts := 10
   205  			f := func() error {
   206  				h := pool.Get()
   207  				defer pool.Put(h)
   208  
   209  				reads := rand.Intn(count*2-1) + 1
   210  				offsets := make([]int, reads+1)
   211  				for i := 0; i < reads; i++ {
   212  					offsets[i] = rand.Intn(size) + 1
   213  				}
   214  				offsets[reads] = size
   215  				from := 0
   216  				sort.Ints(offsets)
   217  				for _, to := range offsets {
   218  					if from < to {
   219  						read, err := h.Write(testData[from:to])
   220  						if err != nil {
   221  							return err
   222  						}
   223  						if read != to-from {
   224  							return fmt.Errorf("incorrect read. expected %v bytes, got %v", to-from, read)
   225  						}
   226  						from = to
   227  					}
   228  				}
   229  				h.SetHeaderInt64(int64(size))
   230  				resHash, err := h.Hash(nil)
   231  				if err != nil {
   232  					return err
   233  				}
   234  				if !bytes.Equal(resHash, expHash) {
   235  					return fmt.Errorf("hash mismatch on %v. expected %x, got %x", offsets, expHash, resHash)
   236  				}
   237  				return nil
   238  			}
   239  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   240  			defer cancel()
   241  			eg, ectx := errgroup.WithContext(ctx)
   242  			for i := 0; i < attempts; i++ {
   243  				eg.Go(func() error {
   244  					select {
   245  					case <-ectx.Done():
   246  						return ectx.Err()
   247  					default:
   248  					}
   249  					return f()
   250  				})
   251  			}
   252  			if err := eg.Wait(); err != nil {
   253  				t.Fatalf("seed %d: %v", seed, err)
   254  			}
   255  		})
   256  	}
   257  }
   258  
   259  // helper function that compares reference and optimised implementations for correctness
   260  func testHasherCorrectness(h *bmt.Hasher, data []byte, n, count int) (err error) {
   261  	if len(data) < n {
   262  		n = len(data)
   263  	}
   264  	exp, err := refHash(count, data[:n])
   265  	if err != nil {
   266  		return err
   267  	}
   268  	got, err := syncHash(h, data[:n])
   269  	if err != nil {
   270  		return err
   271  	}
   272  	if !bytes.Equal(got, exp) {
   273  		return fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
   274  	}
   275  	return nil
   276  }
   277  
   278  // verifies that the bmt.Hasher can be used with the hash.Hash interface
   279  func TestUseSyncAsOrdinaryHasher(t *testing.T) {
   280  	t.Parallel()
   281  
   282  	pool := bmt.NewPool(bmt.NewConf(swarm.NewHasher, testSegmentCount, testPoolSize))
   283  	h := pool.Get()
   284  	defer pool.Put(h)
   285  	data := []byte("moodbytesmoodbytesmoodbytesmoodbytes")
   286  	expHash, err := refHash(128, data)
   287  	if err != nil {
   288  		t.Fatal(err)
   289  	}
   290  	h.SetHeaderInt64(int64(len(data)))
   291  	_, err = h.Write(data)
   292  	if err != nil {
   293  		t.Fatal(err)
   294  	}
   295  	resHash := h.Sum(nil)
   296  	if !bytes.Equal(expHash, resHash) {
   297  		t.Fatalf("normalhash; expected %x, got %x", expHash, resHash)
   298  	}
   299  }