github.com/ethersphere/bee/v2@v2.2.0/pkg/file/redundancy/redundancy_test.go (about)

     1  // Copyright 2023 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package redundancy_test
     6  
     7  import (
     8  	"crypto/rand"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  	"testing"
    13  
    14  	"github.com/ethersphere/bee/v2/pkg/file/pipeline"
    15  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/bmt"
    16  	"github.com/ethersphere/bee/v2/pkg/file/redundancy"
    17  	"github.com/ethersphere/bee/v2/pkg/swarm"
    18  )
    19  
    20  type mockEncoder struct {
    21  	shards, parities int
    22  }
    23  
    24  func newMockEncoder(shards, parities int) (redundancy.ErasureEncoder, error) {
    25  	return &mockEncoder{
    26  		shards:   shards,
    27  		parities: parities,
    28  	}, nil
    29  }
    30  
    31  // Encode makes MSB of span equal to data
    32  func (m *mockEncoder) Encode(buffer [][]byte) error {
    33  	// writes parity data
    34  	indicatedValue := 0
    35  	for i := m.shards; i < m.shards+m.parities; i++ {
    36  		data := make([]byte, 32)
    37  		data[swarm.SpanSize-1], data[swarm.SpanSize] = uint8(indicatedValue), uint8(indicatedValue)
    38  		buffer[i] = data
    39  		indicatedValue++
    40  	}
    41  	return nil
    42  }
    43  
    44  type ParityChainWriter struct {
    45  	sync.Mutex
    46  	chainWriteCalls int
    47  	sumCalls        int
    48  	validCalls      []bool
    49  }
    50  
    51  func NewParityChainWriter() *ParityChainWriter {
    52  	return &ParityChainWriter{}
    53  }
    54  
    55  func (c *ParityChainWriter) ChainWriteCalls() int {
    56  	c.Lock()
    57  	defer c.Unlock()
    58  	return c.chainWriteCalls
    59  }
    60  func (c *ParityChainWriter) SumCalls() int { c.Lock(); defer c.Unlock(); return c.sumCalls }
    61  
    62  func (c *ParityChainWriter) ChainWrite(args *pipeline.PipeWriteArgs) error {
    63  	c.Lock()
    64  	defer c.Unlock()
    65  	valid := args.Span[len(args.Span)-1] == args.Data[len(args.Span)] && args.Data[len(args.Span)] == byte(c.chainWriteCalls)
    66  	c.chainWriteCalls++
    67  	c.validCalls = append(c.validCalls, valid)
    68  	return nil
    69  }
    70  func (c *ParityChainWriter) Sum() ([]byte, error) {
    71  	c.Lock()
    72  	defer c.Unlock()
    73  	c.sumCalls++
    74  	return nil, nil
    75  }
    76  
    77  func TestEncode(t *testing.T) {
    78  	t.Parallel()
    79  	// initializes mockEncoder -> creates shard chunks -> redundancy.chunkWrites -> call encode
    80  	erasureEncoder := redundancy.GetErasureEncoder()
    81  	defer func() {
    82  		redundancy.SetErasureEncoder(erasureEncoder)
    83  	}()
    84  	redundancy.SetErasureEncoder(newMockEncoder)
    85  
    86  	// test on the data level
    87  	for _, level := range []redundancy.Level{redundancy.MEDIUM, redundancy.STRONG, redundancy.INSANE, redundancy.PARANOID} {
    88  		for _, encrypted := range []bool{false, true} {
    89  			maxShards := level.GetMaxShards()
    90  			if encrypted {
    91  				maxShards = level.GetMaxEncShards()
    92  			}
    93  			for shardCount := 1; shardCount <= maxShards; shardCount++ {
    94  				t.Run(fmt.Sprintf("redundancy level %d is checked with %d shards", level, shardCount), func(t *testing.T) {
    95  					parityChainWriter := NewParityChainWriter()
    96  					ppf := func() pipeline.ChainWriter {
    97  						return bmt.NewBmtWriter(parityChainWriter)
    98  					}
    99  					params := redundancy.New(level, encrypted, ppf)
   100  					// checks parity pipelinecalls are valid
   101  
   102  					parityCount := 0
   103  					parityCallback := func(level int, span, address []byte) error {
   104  						parityCount++
   105  						return nil
   106  					}
   107  
   108  					for i := 0; i < shardCount; i++ {
   109  						buffer := make([]byte, 32)
   110  						_, err := io.ReadFull(rand.Reader, buffer)
   111  						if err != nil {
   112  							t.Fatal(err)
   113  						}
   114  						err = params.ChunkWrite(0, buffer, parityCallback)
   115  						if err != nil {
   116  							t.Fatal(err)
   117  						}
   118  					}
   119  					if shardCount != maxShards {
   120  						// encode should be called automatically when reaching maxshards
   121  						err := params.Encode(0, parityCallback)
   122  						if err != nil {
   123  							t.Fatal(err)
   124  						}
   125  					}
   126  
   127  					// CHECKS
   128  
   129  					if parityCount != parityChainWriter.chainWriteCalls {
   130  						t.Fatalf("parity callback was called %d times meanwhile chainwrite was called %d times", parityCount, parityChainWriter.chainWriteCalls)
   131  					}
   132  
   133  					expectedParityCount := params.Level().GetParities(shardCount)
   134  					if encrypted {
   135  						expectedParityCount = params.Level().GetEncParities(shardCount)
   136  					}
   137  					if parityCount != expectedParityCount {
   138  						t.Fatalf("parity callback was called %d times meanwhile expected parity number should be %d", parityCount, expectedParityCount)
   139  					}
   140  
   141  					for i, validCall := range parityChainWriter.validCalls {
   142  						if !validCall {
   143  							t.Fatalf("parity chunk data is wrong at parity index %d", i)
   144  						}
   145  					}
   146  				})
   147  			}
   148  		}
   149  	}
   150  }