github.com/ethersphere/bee/v2@v2.2.0/pkg/file/pipeline/hashtrie/hashtrie_test.go (about) 1 // Copyright 2021 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package hashtrie_test 6 7 import ( 8 "bytes" 9 "context" 10 "encoding/binary" 11 "errors" 12 "sync/atomic" 13 "testing" 14 15 bmtUtils "github.com/ethersphere/bee/v2/pkg/bmt" 16 "github.com/ethersphere/bee/v2/pkg/cac" 17 "github.com/ethersphere/bee/v2/pkg/encryption" 18 dec "github.com/ethersphere/bee/v2/pkg/encryption/store" 19 "github.com/ethersphere/bee/v2/pkg/file" 20 "github.com/ethersphere/bee/v2/pkg/file/pipeline" 21 "github.com/ethersphere/bee/v2/pkg/file/pipeline/bmt" 22 enc "github.com/ethersphere/bee/v2/pkg/file/pipeline/encryption" 23 "github.com/ethersphere/bee/v2/pkg/file/pipeline/hashtrie" 24 "github.com/ethersphere/bee/v2/pkg/file/pipeline/mock" 25 "github.com/ethersphere/bee/v2/pkg/file/pipeline/store" 26 "github.com/ethersphere/bee/v2/pkg/file/redundancy" 27 "github.com/ethersphere/bee/v2/pkg/storage" 28 "github.com/ethersphere/bee/v2/pkg/storage/inmemchunkstore" 29 "github.com/ethersphere/bee/v2/pkg/swarm" 30 ) 31 32 var ( 33 addr swarm.Address 34 span []byte 35 ctx = context.Background() 36 ) 37 38 // nolint:gochecknoinits 39 func init() { 40 b := make([]byte, 32) 41 b[31] = 0x01 42 addr = swarm.NewAddress(b) 43 44 span = make([]byte, 8) 45 binary.LittleEndian.PutUint64(span, 1) 46 } 47 48 // newErasureHashTrieWriter returns back an redundancy param and a HastTrieWriter pipeline 49 // which are using simple BMT and StoreWriter pipelines for chunk writes 50 func newErasureHashTrieWriter( 51 ctx context.Context, 52 s storage.Putter, 53 rLevel redundancy.Level, 54 encryptChunks bool, 55 intermediateChunkPipeline, parityChunkPipeline pipeline.ChainWriter, 56 replicaPutter storage.Putter, 57 ) (redundancy.RedundancyParams, pipeline.ChainWriter) { 58 pf := func() pipeline.ChainWriter { 59 lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) 60 return bmt.NewBmtWriter(lsw) 61 } 62 if encryptChunks { 63 pf = func() pipeline.ChainWriter { 64 lsw := store.NewStoreWriter(ctx, s, intermediateChunkPipeline) 65 b := bmt.NewBmtWriter(lsw) 66 return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b) 67 } 68 } 69 ppf := func() pipeline.ChainWriter { 70 lsw := store.NewStoreWriter(ctx, s, parityChunkPipeline) 71 return bmt.NewBmtWriter(lsw) 72 } 73 74 hashSize := swarm.HashSize 75 if encryptChunks { 76 hashSize *= 2 77 } 78 79 r := redundancy.New(rLevel, encryptChunks, ppf) 80 ht := hashtrie.NewHashTrieWriter(ctx, hashSize, r, pf, replicaPutter) 81 return r, ht 82 } 83 84 func TestLevels(t *testing.T) { 85 t.Parallel() 86 87 var ( 88 hashSize = 32 89 ) 90 91 // to create a level wrap we need to do branching^(level-1) writes 92 for _, tc := range []struct { 93 desc string 94 writes int 95 }{ 96 { 97 desc: "2 at L1", 98 writes: 2, 99 }, 100 { 101 desc: "1 at L2, 1 at L1", // dangling chunk 102 writes: 16 + 1, 103 }, 104 { 105 desc: "1 at L3, 1 at L2, 1 at L1", 106 writes: 64 + 16 + 1, 107 }, 108 { 109 desc: "1 at L3, 2 at L2, 1 at L1", 110 writes: 64 + 16 + 16 + 1, 111 }, 112 { 113 desc: "1 at L5, 1 at L1", 114 writes: 1024 + 1, 115 }, 116 { 117 desc: "1 at L5, 1 at L3", 118 writes: 1024 + 1, 119 }, 120 { 121 desc: "2 at L5, 1 at L1", 122 writes: 1024 + 1024 + 1, 123 }, 124 { 125 desc: "3 at L5, 2 at L3, 1 at L1", 126 writes: 1024 + 1024 + 1024 + 64 + 64 + 1, 127 }, 128 { 129 desc: "1 at L7, 1 at L1", 130 writes: 4096 + 1, 131 }, 132 { 133 desc: "1 at L8", // balanced trie - all good 134 writes: 16384, 135 }, 136 } { 137 138 tc := tc 139 t.Run(tc.desc, func(t *testing.T) { 140 t.Parallel() 141 142 s := inmemchunkstore.New() 143 pf := func() pipeline.ChainWriter { 144 lsw := store.NewStoreWriter(ctx, s, nil) 145 return bmt.NewBmtWriter(lsw) 146 } 147 148 ht := hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) 149 150 for i := 0; i < tc.writes; i++ { 151 a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} 152 err := ht.ChainWrite(a) 153 if err != nil { 154 t.Fatal(err) 155 } 156 } 157 158 ref, err := ht.Sum() 159 if err != nil { 160 t.Fatal(err) 161 } 162 163 rootch, err := s.Get(ctx, swarm.NewAddress(ref)) 164 if err != nil { 165 t.Fatal(err) 166 } 167 168 //check the span. since write spans are 1 value 1, then expected span == tc.writes 169 sp := binary.LittleEndian.Uint64(rootch.Data()[:swarm.SpanSize]) 170 if sp != uint64(tc.writes) { 171 t.Fatalf("want span %d got %d", tc.writes, sp) 172 } 173 }) 174 } 175 } 176 177 type redundancyMock struct { 178 redundancy.Params 179 } 180 181 func (r redundancyMock) MaxShards() int { 182 return 4 183 } 184 185 func TestLevels_TrieFull(t *testing.T) { 186 t.Parallel() 187 188 var ( 189 hashSize = 32 190 writes = 16384 // this is to get a balanced trie 191 s = inmemchunkstore.New() 192 pf = func() pipeline.ChainWriter { 193 lsw := store.NewStoreWriter(ctx, s, nil) 194 return bmt.NewBmtWriter(lsw) 195 } 196 r = redundancy.New(0, false, pf) 197 rMock = &redundancyMock{ 198 Params: *r, 199 } 200 201 ht = hashtrie.NewHashTrieWriter(ctx, hashSize, rMock, pf, s) 202 ) 203 204 // to create a level wrap we need to do branching^(level-1) writes 205 for i := 0; i < writes; i++ { 206 a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} 207 err := ht.ChainWrite(a) 208 if err != nil { 209 t.Fatal(err) 210 } 211 } 212 213 a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} 214 err := ht.ChainWrite(a) 215 if !errors.Is(err, hashtrie.ErrTrieFull) { 216 t.Fatal(err) 217 } 218 219 // it is questionable whether the writer should go into some 220 // corrupt state after the last write which causes the trie full 221 // error, in which case we would return an error on Sum() 222 _, err = ht.Sum() 223 if err != nil { 224 t.Fatal(err) 225 } 226 } 227 228 // TestRegression is a regression test for the bug 229 // described in https://github.com/ethersphere/bee/issues/1175 230 func TestRegression(t *testing.T) { 231 t.Parallel() 232 233 var ( 234 hashSize = 32 235 writes = 67100000 / 4096 236 span = make([]byte, 8) 237 s = inmemchunkstore.New() 238 pf = func() pipeline.ChainWriter { 239 lsw := store.NewStoreWriter(ctx, s, nil) 240 return bmt.NewBmtWriter(lsw) 241 } 242 ht = hashtrie.NewHashTrieWriter(ctx, hashSize, redundancy.New(0, false, pf), pf, s) 243 ) 244 binary.LittleEndian.PutUint64(span, 4096) 245 246 for i := 0; i < writes; i++ { 247 a := &pipeline.PipeWriteArgs{Ref: addr.Bytes(), Span: span} 248 err := ht.ChainWrite(a) 249 if err != nil { 250 t.Fatal(err) 251 } 252 } 253 254 ref, err := ht.Sum() 255 if err != nil { 256 t.Fatal(err) 257 } 258 259 rootch, err := s.Get(ctx, swarm.NewAddress(ref)) 260 if err != nil { 261 t.Fatal(err) 262 } 263 264 sp := binary.LittleEndian.Uint64(rootch.Data()[:swarm.SpanSize]) 265 if sp != uint64(writes*4096) { 266 t.Fatalf("want span %d got %d", writes*4096, sp) 267 } 268 } 269 270 type replicaPutter struct { 271 storage.Putter 272 replicaCount atomic.Uint32 273 } 274 275 func (r *replicaPutter) Put(ctx context.Context, chunk swarm.Chunk) error { 276 r.replicaCount.Add(1) 277 return r.Putter.Put(ctx, chunk) 278 } 279 280 // TestRedundancy using erasure coding library and checks carrierChunk function and modified span in intermediate chunk 281 func TestRedundancy(t *testing.T) { 282 t.Parallel() 283 // chunks need to have data so that it will not throw error on redundancy caching 284 ch, err := cac.New(make([]byte, swarm.ChunkSize)) 285 if err != nil { 286 t.Fatal(err) 287 } 288 chData := ch.Data() 289 chSpan := chData[:swarm.SpanSize] 290 chAddr := ch.Address().Bytes() 291 292 // test logic assumes a simple 2 level chunk tree with carrier chunk 293 for _, tc := range []struct { 294 desc string 295 level redundancy.Level 296 encryption bool 297 writes int 298 parities int 299 }{ 300 { 301 desc: "redundancy write for not encrypted data", 302 level: redundancy.INSANE, 303 encryption: false, 304 writes: 98, // 97 chunk references fit into one chunk + 1 carrier 305 parities: 37, // 31 (full ch) + 6 (2 ref) 306 }, 307 { 308 desc: "redundancy write for encrypted data", 309 level: redundancy.PARANOID, 310 encryption: true, 311 writes: 21, // 21 encrypted chunk references fit into one chunk + 1 carrier 312 parities: 116, // // 87 (full ch) + 29 (2 ref) 313 }, 314 } { 315 tc := tc 316 t.Run(tc.desc, func(t *testing.T) { 317 t.Parallel() 318 subCtx := redundancy.SetLevelInContext(ctx, tc.level) 319 320 s := inmemchunkstore.New() 321 intermediateChunkCounter := mock.NewChainWriter() 322 parityChunkCounter := mock.NewChainWriter() 323 replicaChunkCounter := &replicaPutter{Putter: s} 324 325 r, ht := newErasureHashTrieWriter(subCtx, s, tc.level, tc.encryption, intermediateChunkCounter, parityChunkCounter, replicaChunkCounter) 326 327 // write data to the hashTrie 328 var key []byte 329 if tc.encryption { 330 key = make([]byte, swarm.HashSize) 331 } 332 for i := 0; i < tc.writes; i++ { 333 a := &pipeline.PipeWriteArgs{Data: chData, Span: chSpan, Ref: chAddr, Key: key} 334 err := ht.ChainWrite(a) 335 if err != nil { 336 t.Fatal(err) 337 } 338 } 339 340 ref, err := ht.Sum() 341 if err != nil { 342 t.Fatal(err) 343 } 344 345 // sanity check for the test samples 346 if tc.parities != parityChunkCounter.ChainWriteCalls() { 347 t.Errorf("generated parities should be %d. Got: %d", tc.parities, parityChunkCounter.ChainWriteCalls()) 348 } 349 if intermediateChunkCounter.ChainWriteCalls() != 2 { // root chunk and the chunk which was written before carrierChunk movement 350 t.Errorf("effective chunks should be %d. Got: %d", tc.writes, intermediateChunkCounter.ChainWriteCalls()) 351 } 352 353 rootch, err := s.Get(subCtx, swarm.NewAddress(ref[:swarm.HashSize])) 354 if err != nil { 355 t.Fatal(err) 356 } 357 chData := rootch.Data() 358 if tc.encryption { 359 chData, err = dec.DecryptChunkData(chData, ref[swarm.HashSize:]) 360 if err != nil { 361 t.Fatal(err) 362 } 363 } 364 365 // span check 366 level, sp := redundancy.DecodeSpan(chData[:swarm.SpanSize]) 367 expectedSpan := bmtUtils.LengthToSpan(int64(tc.writes * swarm.ChunkSize)) 368 if !bytes.Equal(expectedSpan, sp) { 369 t.Fatalf("want span %d got %d", expectedSpan, span) 370 } 371 if level != tc.level { 372 t.Fatalf("encoded level differs from the uploaded one %d. Got: %d", tc.level, level) 373 } 374 expectedParities := tc.parities - r.Parities(r.MaxShards()) 375 _, parity := file.ReferenceCount(bmtUtils.LengthFromSpan(sp), level, tc.encryption) 376 if expectedParities != parity { 377 t.Fatalf("want parity %d got %d", expectedParities, parity) 378 } 379 if tc.level.GetReplicaCount() != int(replicaChunkCounter.replicaCount.Load()) { 380 t.Fatalf("unexpected number of replicas: want %d. Got: %d", tc.level.GetReplicaCount(), int(replicaChunkCounter.replicaCount.Load())) 381 } 382 }) 383 } 384 }