github.com/ethersphere/bee/v2@v2.2.0/pkg/file/redundancy/redundancy_test.go (about) 1 // Copyright 2023 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package redundancy_test 6 7 import ( 8 "crypto/rand" 9 "fmt" 10 "io" 11 "sync" 12 "testing" 13 14 "github.com/ethersphere/bee/v2/pkg/file/pipeline" 15 "github.com/ethersphere/bee/v2/pkg/file/pipeline/bmt" 16 "github.com/ethersphere/bee/v2/pkg/file/redundancy" 17 "github.com/ethersphere/bee/v2/pkg/swarm" 18 ) 19 20 type mockEncoder struct { 21 shards, parities int 22 } 23 24 func newMockEncoder(shards, parities int) (redundancy.ErasureEncoder, error) { 25 return &mockEncoder{ 26 shards: shards, 27 parities: parities, 28 }, nil 29 } 30 31 // Encode makes MSB of span equal to data 32 func (m *mockEncoder) Encode(buffer [][]byte) error { 33 // writes parity data 34 indicatedValue := 0 35 for i := m.shards; i < m.shards+m.parities; i++ { 36 data := make([]byte, 32) 37 data[swarm.SpanSize-1], data[swarm.SpanSize] = uint8(indicatedValue), uint8(indicatedValue) 38 buffer[i] = data 39 indicatedValue++ 40 } 41 return nil 42 } 43 44 type ParityChainWriter struct { 45 sync.Mutex 46 chainWriteCalls int 47 sumCalls int 48 validCalls []bool 49 } 50 51 func NewParityChainWriter() *ParityChainWriter { 52 return &ParityChainWriter{} 53 } 54 55 func (c *ParityChainWriter) ChainWriteCalls() int { 56 c.Lock() 57 defer c.Unlock() 58 return c.chainWriteCalls 59 } 60 func (c *ParityChainWriter) SumCalls() int { c.Lock(); defer c.Unlock(); return c.sumCalls } 61 62 func (c *ParityChainWriter) ChainWrite(args *pipeline.PipeWriteArgs) error { 63 c.Lock() 64 defer c.Unlock() 65 valid := args.Span[len(args.Span)-1] == args.Data[len(args.Span)] && args.Data[len(args.Span)] == byte(c.chainWriteCalls) 66 c.chainWriteCalls++ 67 c.validCalls = append(c.validCalls, valid) 68 return nil 69 } 70 func (c *ParityChainWriter) Sum() ([]byte, error) { 71 c.Lock() 72 defer c.Unlock() 73 c.sumCalls++ 74 return nil, nil 75 } 76 77 func TestEncode(t *testing.T) { 78 t.Parallel() 79 // initializes mockEncoder -> creates shard chunks -> redundancy.chunkWrites -> call encode 80 erasureEncoder := redundancy.GetErasureEncoder() 81 defer func() { 82 redundancy.SetErasureEncoder(erasureEncoder) 83 }() 84 redundancy.SetErasureEncoder(newMockEncoder) 85 86 // test on the data level 87 for _, level := range []redundancy.Level{redundancy.MEDIUM, redundancy.STRONG, redundancy.INSANE, redundancy.PARANOID} { 88 for _, encrypted := range []bool{false, true} { 89 maxShards := level.GetMaxShards() 90 if encrypted { 91 maxShards = level.GetMaxEncShards() 92 } 93 for shardCount := 1; shardCount <= maxShards; shardCount++ { 94 t.Run(fmt.Sprintf("redundancy level %d is checked with %d shards", level, shardCount), func(t *testing.T) { 95 parityChainWriter := NewParityChainWriter() 96 ppf := func() pipeline.ChainWriter { 97 return bmt.NewBmtWriter(parityChainWriter) 98 } 99 params := redundancy.New(level, encrypted, ppf) 100 // checks parity pipelinecalls are valid 101 102 parityCount := 0 103 parityCallback := func(level int, span, address []byte) error { 104 parityCount++ 105 return nil 106 } 107 108 for i := 0; i < shardCount; i++ { 109 buffer := make([]byte, 32) 110 _, err := io.ReadFull(rand.Reader, buffer) 111 if err != nil { 112 t.Fatal(err) 113 } 114 err = params.ChunkWrite(0, buffer, parityCallback) 115 if err != nil { 116 t.Fatal(err) 117 } 118 } 119 if shardCount != maxShards { 120 // encode should be called automatically when reaching maxshards 121 err := params.Encode(0, parityCallback) 122 if err != nil { 123 t.Fatal(err) 124 } 125 } 126 127 // CHECKS 128 129 if parityCount != parityChainWriter.chainWriteCalls { 130 t.Fatalf("parity callback was called %d times meanwhile chainwrite was called %d times", parityCount, parityChainWriter.chainWriteCalls) 131 } 132 133 expectedParityCount := params.Level().GetParities(shardCount) 134 if encrypted { 135 expectedParityCount = params.Level().GetEncParities(shardCount) 136 } 137 if parityCount != expectedParityCount { 138 t.Fatalf("parity callback was called %d times meanwhile expected parity number should be %d", parityCount, expectedParityCount) 139 } 140 141 for i, validCall := range parityChainWriter.validCalls { 142 if !validCall { 143 t.Fatalf("parity chunk data is wrong at parity index %d", i) 144 } 145 } 146 }) 147 } 148 } 149 } 150 }