github.com/ethersphere/bee/v2@v2.2.0/pkg/file/pipeline/hashtrie/hashtrie_test.go (about)

     1  // Copyright 2021 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hashtrie_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"encoding/binary"
    11  	"errors"
    12  	"sync/atomic"
    13  	"testing"
    14  
    15  	bmtUtils "github.com/ethersphere/bee/v2/pkg/bmt"
    16  	"github.com/ethersphere/bee/v2/pkg/cac"
    17  	"github.com/ethersphere/bee/v2/pkg/encryption"
    18  	dec "github.com/ethersphere/bee/v2/pkg/encryption/store"
    19  	"github.com/ethersphere/bee/v2/pkg/file"
    20  	"github.com/ethersphere/bee/v2/pkg/file/pipeline"
    21  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/bmt"
    22  	enc "github.com/ethersphere/bee/v2/pkg/file/pipeline/encryption"
    23  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/hashtrie"
    24  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/mock"
    25  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/store"
    26  	"github.com/ethersphere/bee/v2/pkg/file/redundancy"
    27  	"github.com/ethersphere/bee/v2/pkg/storage"
    28  	"github.com/ethersphere/bee/v2/pkg/storage/inmemchunkstore"
    29  	"github.com/ethersphere/bee/v2/pkg/swarm"
    30  )
    31  
    32  var (
    33  	addr swarm.Address
    34  	span []byte
    35  	ctx  = context.Background()
    36  )
    37  
    38  // nolint:gochecknoinits
    39  func init() {
    40  	b := make([]byte, 32)
    41  	b[31] = 0x01
    42  	addr = swarm.NewAddress(b)
    43  
    44  	span = make([]byte, 8)
    45  	binary.LittleEndian.PutUint64(span, 1)
    46  }
    47  
    48  // newErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline
    49  // which are using simple BMT and StoreWriter pipelines for chunk writes
    50  func newErasureHashTrieWriter(
    51  	ctx context.Context,
    52  	s storage.Putter,
    53  	rLevel redundancy.Level,
    54  	encryptChunks bool,
    55  	intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter,
    56  	replicaPutter storage.Putter,
    57  ) (redundancy.RedundancyParams, pipeline.ChainWriter) {
    58  	pf := func() pipeline.ChainWriter {
    59  		lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline)
    60  		return bmt.NewBmtWriter(lsw)
    61  	}
    62  	if encryptChunks {
    63  		pf = func() pipeline.ChainWriter {
    64  			lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline)
    65  			b := bmt.NewBmtWriter(lsw)
    66  			return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b)
    67  		}
    68  	}
    69  	ppf := func() pipeline.ChainWriter {
    70  		lsw := store.NewStoreWriter(ctx, s, parityChunkPipeline)
    71  		return bmt.NewBmtWriter(lsw)
    72  	}
    73  
    74  	hashSize := swarm.HashSize
    75  	if encryptChunks {
    76  		hashSize *= 2
    77  	}
    78  
    79  	r := redundancy.New(rLevel, encryptChunks, ppf)
    80  	ht := hashtrie.NewHashTrieWriter(ctx, hashSize, r, pf, replicaPutter)
    81  	return r, ht
    82  }
    83  
    84  func TestLevels(t *testing.T) {
    85  	t.Parallel()
    86  
    87  	var (
    88  		hashSize = 32
    89  	)
    90  
    91  	// to create a level wrap we need to do branching^(level-1) writes
    92  	for _, tc := range []struct {
    93  		desc   string
    94  		writes int
    95  	}{
    96  		{
    97  			desc:   "2 at L1",
    98  			writes: 2,
    99  		},
   100  		{
   101  			desc:   "1 at L2, 1 at L1", // dangling chunk
   102  			writes: 16 + 1,
   103  		},
   104  		{
   105  			desc:   "1 at L3, 1 at L2, 1 at L1",
   106  			writes: 64 + 16 + 1,
   107  		},
   108  		{
   109  			desc:   "1 at L3, 2 at L2, 1 at L1",
   110  			writes: 64 + 16 + 16 + 1,
   111  		},
   112  		{
   113  			desc:   "1 at L5, 1 at L1",
   114  			writes: 1024 + 1,
   115  		},
   116  		{
   117  			desc:   "1 at L5, 1 at L3",
   118  			writes: 1024 + 1,
   119  		},
   120  		{
   121  			desc:   "2 at L5, 1 at L1",
   122  			writes: 1024 + 1024 + 1,
   123  		},
   124  		{
   125  			desc:   "3 at L5, 2 at L3, 1 at L1",
   126  			writes: 1024 + 1024 + 1024 + 64 + 64 + 1,
   127  		},
   128  		{
   129  			desc:   "1 at L7, 1 at L1",
   130  			writes: 4096 + 1,
   131  		},
   132  		{
   133  			desc:   "1 at L8", // balanced trie - all good
   134  			writes: 16384,
   135  		},
   136  	} {
   137  
   138  		tc := tc
   139  		t.Run(tc.desc, func(t *testing.T) {
   140  			t.Parallel()
   141  
   142  			s := inmemchunkstore.New()
   143  			pf := func() pipeline.ChainWriter {
   144  				lsw := store.NewStoreWriter(ctx, s, nil)
   145  				return bmt.NewBmtWriter(lsw)
   146  			}
   147  
   148  			ht := hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s)
   149  
   150  			for i := 0; i < tc.writes; i++ {
   151  				a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span}
   152  				err := ht.ChainWrite(a)
   153  				if err != nil {
   154  					t.Fatal(err)
   155  				}
   156  			}
   157  
   158  			ref, err := ht.Sum()
   159  			if err != nil {
   160  				t.Fatal(err)
   161  			}
   162  
   163  			rootch, err := s.Get(ctx, swarm.NewAddress(ref))
   164  			if err != nil {
   165  				t.Fatal(err)
   166  			}
   167  
   168  			//check the span. since write spans are 1 value 1, then expected span == tc.writes
   169  			sp := binary.LittleEndian.Uint64(rootch.Data()[:swarm.SpanSize])
   170  			if sp != uint64(tc.writes) {
   171  				t.Fatalf("want span %d got %d", tc.writes, sp)
   172  			}
   173  		})
   174  	}
   175  }
   176  
   177  type redundancyMock struct {
   178  	redundancy.Params
   179  }
   180  
   181  func (r redundancyMock) MaxShards() int {
   182  	return 4
   183  }
   184  
   185  func TestLevels_TrieFull(t *testing.T) {
   186  	t.Parallel()
   187  
   188  	var (
   189  		hashSize = 32
   190  		writes   = 16384 // this is to get a balanced trie
   191  		s        = inmemchunkstore.New()
   192  		pf       = func() pipeline.ChainWriter {
   193  			lsw := store.NewStoreWriter(ctx, s, nil)
   194  			return bmt.NewBmtWriter(lsw)
   195  		}
   196  		r     = redundancy.New(0, false, pf)
   197  		rMock = &redundancyMock{
   198  			Params: *r,
   199  		}
   200  
   201  		ht = hashtrie.NewHashTrieWriter(ctx, hashSize, rMock, pf, s)
   202  	)
   203  
   204  	// to create a level wrap we need to do branching^(level-1) writes
   205  	for i := 0; i < writes; i++ {
   206  		a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span}
   207  		err := ht.ChainWrite(a)
   208  		if err != nil {
   209  			t.Fatal(err)
   210  		}
   211  	}
   212  
   213  	a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span}
   214  	err := ht.ChainWrite(a)
   215  	if !errors.Is(err, hashtrie.ErrTrieFull) {
   216  		t.Fatal(err)
   217  	}
   218  
   219  	// it is questionable whether the writer should go into some
   220  	// corrupt state after the last write which causes the trie full
   221  	// error, in which case we would return an error on Sum()
   222  	_, err = ht.Sum()
   223  	if err != nil {
   224  		t.Fatal(err)
   225  	}
   226  }
   227  
   228  // TestRegression is a regression test for the bug
   229  // described in https://github.com/ethersphere/bee/issues/1175
   230  func TestRegression(t *testing.T) {
   231  	t.Parallel()
   232  
   233  	var (
   234  		hashSize = 32
   235  		writes   = 67100000 / 4096
   236  		span     = make([]byte, 8)
   237  		s        = inmemchunkstore.New()
   238  		pf       = func() pipeline.ChainWriter {
   239  			lsw := store.NewStoreWriter(ctx, s, nil)
   240  			return bmt.NewBmtWriter(lsw)
   241  		}
   242  		ht = hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s)
   243  	)
   244  	binary.LittleEndian.PutUint64(span, 4096)
   245  
   246  	for i := 0; i < writes; i++ {
   247  		a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span}
   248  		err := ht.ChainWrite(a)
   249  		if err != nil {
   250  			t.Fatal(err)
   251  		}
   252  	}
   253  
   254  	ref, err := ht.Sum()
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  
   259  	rootch, err := s.Get(ctx, swarm.NewAddress(ref))
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  
   264  	sp := binary.LittleEndian.Uint64(rootch.Data()[:swarm.SpanSize])
   265  	if sp != uint64(writes*4096) {
   266  		t.Fatalf("want span %d got %d", writes*4096, sp)
   267  	}
   268  }
   269  
   270  type replicaPutter struct {
   271  	storage.Putter
   272  	replicaCount atomic.Uint32
   273  }
   274  
   275  func (r *replicaPutter) Put(ctx context.Context, chunk swarm.Chunk) error {
   276  	r.replicaCount.Add(1)
   277  	return r.Putter.Put(ctx, chunk)
   278  }
   279  
   280  // TestRedundancy using erasure coding library and checks carrierChunk function and modified span in intermediate chunk
   281  func TestRedundancy(t *testing.T) {
   282  	t.Parallel()
   283  	// chunks need to have data so that it will not throw error on redundancy caching
   284  	ch, err := cac.New(make([]byte, swarm.ChunkSize))
   285  	if err != nil {
   286  		t.Fatal(err)
   287  	}
   288  	chData := ch.Data()
   289  	chSpan := chData[:swarm.SpanSize]
   290  	chAddr := ch.Address().Bytes()
   291  
   292  	// test logic assumes a simple 2 level chunk tree with carrier chunk
   293  	for _, tc := range []struct {
   294  		desc       string
   295  		level      redundancy.Level
   296  		encryption bool
   297  		writes     int
   298  		parities   int
   299  	}{
   300  		{
   301  			desc:       "redundancy write for not encrypted data",
   302  			level:      redundancy.INSANE,
   303  			encryption: false,
   304  			writes:     98, // 97 chunk references fit into one chunk + 1 carrier
   305  			parities:   37, // 31 (full ch) + 6 (2 ref)
   306  		},
   307  		{
   308  			desc:       "redundancy write for encrypted data",
   309  			level:      redundancy.PARANOID,
   310  			encryption: true,
   311  			writes:     21,  // 21 encrypted chunk references fit into one chunk + 1 carrier
   312  			parities:   116, // // 87 (full ch) + 29 (2 ref)
   313  		},
   314  	} {
   315  		tc := tc
   316  		t.Run(tc.desc, func(t *testing.T) {
   317  			t.Parallel()
   318  			subCtx := redundancy.SetLevelInContext(ctx, tc.level)
   319  
   320  			s := inmemchunkstore.New()
   321  			intermediateChunkCounter := mock.NewChainWriter()
   322  			parityChunkCounter := mock.NewChainWriter()
   323  			replicaChunkCounter := &replicaPutter{Putter: s}
   324  
   325  			r, ht := newErasureHashTrieWriter(subCtx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter, replicaChunkCounter)
   326  
   327  			// write data to the hashTrie
   328  			var key []byte
   329  			if tc.encryption {
   330  				key = make([]byte, swarm.HashSize)
   331  			}
   332  			for i := 0; i < tc.writes; i++ {
   333  				a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key}
   334  				err := ht.ChainWrite(a)
   335  				if err != nil {
   336  					t.Fatal(err)
   337  				}
   338  			}
   339  
   340  			ref, err := ht.Sum()
   341  			if err != nil {
   342  				t.Fatal(err)
   343  			}
   344  
   345  			// sanity check for the test samples
   346  			if tc.parities != parityChunkCounter.ChainWriteCalls() {
   347  				t.Errorf("generated parities should be %d. Got: %d", tc.parities, parityChunkCounter.ChainWriteCalls())
   348  			}
   349  			if intermediateChunkCounter.ChainWriteCalls() != 2 { // root chunk and the chunk which was written before carrierChunk movement
   350  				t.Errorf("effective chunks should be %d. Got: %d", tc.writes, intermediateChunkCounter.ChainWriteCalls())
   351  			}
   352  
   353  			rootch, err := s.Get(subCtx, swarm.NewAddress(ref[:swarm.HashSize]))
   354  			if err != nil {
   355  				t.Fatal(err)
   356  			}
   357  			chData := rootch.Data()
   358  			if tc.encryption {
   359  				chData, err = dec.DecryptChunkData(chData, ref[swarm.HashSize:])
   360  				if err != nil {
   361  					t.Fatal(err)
   362  				}
   363  			}
   364  
   365  			// span check
   366  			level, sp := redundancy.DecodeSpan(chData[:swarm.SpanSize])
   367  			expectedSpan := bmtUtils.LengthToSpan(int64(tc.writes * swarm.ChunkSize))
   368  			if !bytes.Equal(expectedSpan, sp) {
   369  				t.Fatalf("want span %d got %d", expectedSpan, span)
   370  			}
   371  			if level != tc.level {
   372  				t.Fatalf("encoded level differs from the uploaded one %d. Got: %d", tc.level, level)
   373  			}
   374  			expectedParities := tc.parities - r.Parities(r.MaxShards())
   375  			_, parity := file.ReferenceCount(bmtUtils.LengthFromSpan(sp), level, tc.encryption)
   376  			if expectedParities != parity {
   377  				t.Fatalf("want parity %d got %d", expectedParities, parity)
   378  			}
   379  			if tc.level.GetReplicaCount() != int(replicaChunkCounter.replicaCount.Load()) {
   380  				t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount.Load()))
   381  			}
   382  		})
   383  	}
   384  }