github.com/ethersphere/bee/v2@v2.2.0/pkg/steward/steward_test.go (about)

     1  // Copyright 2021 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package steward_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/rand"
    11  	"errors"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/ethersphere/bee/v2/pkg/file/pipeline/builder"
    18  	"github.com/ethersphere/bee/v2/pkg/file/redundancy"
    19  	postagetesting "github.com/ethersphere/bee/v2/pkg/postage/mock"
    20  	"github.com/ethersphere/bee/v2/pkg/steward"
    21  	storage "github.com/ethersphere/bee/v2/pkg/storage"
    22  	"github.com/ethersphere/bee/v2/pkg/storage/inmemchunkstore"
    23  	mockstorer "github.com/ethersphere/bee/v2/pkg/storer/mock"
    24  	"github.com/ethersphere/bee/v2/pkg/swarm"
    25  )
    26  
    27  type counter struct {
    28  	storage.ChunkStore
    29  	count atomic.Int32
    30  }
    31  
    32  func (c *counter) Put(ctx context.Context, ch swarm.Chunk) (err error) {
    33  	c.count.Add(1)
    34  	return c.ChunkStore.Put(ctx, ch)
    35  }
    36  
    37  func TestSteward(t *testing.T) {
    38  	t.Parallel()
    39  	inmem := &counter{ChunkStore: inmemchunkstore.New()}
    40  
    41  	var (
    42  		ctx            = context.Background()
    43  		chunks         = 1000
    44  		data           = make([]byte, chunks*4096) //1k chunks
    45  		chunkStore     = inmem
    46  		store          = mockstorer.NewWithChunkStore(chunkStore)
    47  		localRetrieval = &localRetriever{ChunkStore: chunkStore}
    48  		s              = steward.New(store, localRetrieval, inmem)
    49  		stamper        = postagetesting.NewStamper()
    50  	)
    51  	ctx = redundancy.SetLevelInContext(ctx, redundancy.NONE)
    52  
    53  	n, err := rand.Read(data)
    54  	if n != cap(data) {
    55  		t.Fatal("short read")
    56  	}
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  
    61  	pipe := builder.NewPipelineBuilder(ctx, chunkStore, false, 0)
    62  	addr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	chunkCount := int(inmem.count.Load())
    68  	done := make(chan struct{})
    69  	errc := make(chan error, 1)
    70  	go func() {
    71  		defer close(done)
    72  		count := 0
    73  		for op := range store.PusherFeed() {
    74  			has, err := chunkStore.Has(ctx, op.Chunk.Address())
    75  			if err != nil || !has {
    76  				if !has {
    77  					err = errors.New("chunk not found")
    78  				}
    79  				select {
    80  				case errc <- err:
    81  				default:
    82  				}
    83  				return
    84  			}
    85  			count++
    86  			if count == chunkCount {
    87  				return
    88  			}
    89  		}
    90  	}()
    91  
    92  	err = s.Reupload(ctx, addr, stamper)
    93  	if err != nil {
    94  		t.Fatal(err)
    95  	}
    96  
    97  	select {
    98  	case <-done:
    99  	case <-time.After(3 * time.Second):
   100  		t.Fatal("took too long to finish")
   101  	}
   102  
   103  	select {
   104  	case err := <-errc:
   105  		t.Fatalf("unexpected error: %v", err)
   106  	default:
   107  	}
   108  
   109  	isRetrievable, err := s.IsRetrievable(ctx, addr)
   110  	if err != nil {
   111  		t.Fatal(err)
   112  	}
   113  	if !isRetrievable {
   114  		t.Fatalf("re-uploaded content on %q should be retrievable", addr)
   115  	}
   116  
   117  	count := len(localRetrieval.retrievedChunks)
   118  	if count != chunkCount {
   119  		t.Fatalf("unexpected no of unique chunks retrieved: want %d have %d", chunkCount, count)
   120  	}
   121  }
   122  
   123  type localRetriever struct {
   124  	storage.ChunkStore
   125  	mu              sync.Mutex
   126  	retrievedChunks map[string]struct{}
   127  }
   128  
   129  func (lr *localRetriever) RetrieveChunk(ctx context.Context, addr, sourceAddr swarm.Address) (chunk swarm.Chunk, err error) {
   130  	lr.mu.Lock()
   131  	defer lr.mu.Unlock()
   132  
   133  	if lr.retrievedChunks == nil {
   134  		lr.retrievedChunks = make(map[string]struct{})
   135  	}
   136  	lr.retrievedChunks[addr.String()] = struct{}{}
   137  	return lr.Get(ctx, addr)
   138  }