github.com/ethersphere/bee/v2@v2.2.0/pkg/replicas/getter_test.go (about)

     1  // Copyright 2023 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package replicas_test
     6  
     7  import (
     8  	"context"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/ethersphere/bee/v2/pkg/cac"
    18  	"github.com/ethersphere/bee/v2/pkg/file/redundancy"
    19  	"github.com/ethersphere/bee/v2/pkg/replicas"
    20  	"github.com/ethersphere/bee/v2/pkg/soc"
    21  	"github.com/ethersphere/bee/v2/pkg/storage"
    22  	"github.com/ethersphere/bee/v2/pkg/swarm"
    23  )
    24  
    25  type testGetter struct {
    26  	ch         swarm.Chunk
    27  	now        time.Time
    28  	origCalled chan struct{}
    29  	origIndex  int
    30  	errf       func(int) chan struct{}
    31  	firstFound int32
    32  	attempts   atomic.Int32
    33  	cancelled  chan struct{}
    34  	addresses  [17]swarm.Address
    35  	latencies  [17]time.Duration
    36  }
    37  
    38  func (tg *testGetter) Get(ctx context.Context, addr swarm.Address) (ch swarm.Chunk, err error) {
    39  	i := tg.attempts.Add(1) - 1
    40  	tg.addresses[i] = addr
    41  	tg.latencies[i] = time.Since(tg.now)
    42  
    43  	if addr.Equal(tg.ch.Address()) {
    44  		tg.origIndex = int(i)
    45  		close(tg.origCalled)
    46  		ch = tg.ch
    47  	}
    48  
    49  	if i != tg.firstFound {
    50  		select {
    51  		case <-ctx.Done():
    52  			return nil, ctx.Err()
    53  		case <-tg.errf(int(i)):
    54  			return nil, storage.ErrNotFound
    55  		}
    56  	}
    57  	defer func() {
    58  		go func() {
    59  			select {
    60  			case <-time.After(100 * time.Millisecond):
    61  			case <-ctx.Done():
    62  				close(tg.cancelled)
    63  			}
    64  		}()
    65  	}()
    66  
    67  	if ch != nil {
    68  		return ch, nil
    69  	}
    70  	return soc.New(addr.Bytes(), tg.ch).Sign(replicas.Signer)
    71  }
    72  
    73  func newTestGetter(ch swarm.Chunk, firstFound int, errf func(int) chan struct{}) *testGetter {
    74  	return &testGetter{
    75  		ch:         ch,
    76  		errf:       errf,
    77  		firstFound: int32(firstFound),
    78  		cancelled:  make(chan struct{}),
    79  		origCalled: make(chan struct{}),
    80  	}
    81  }
    82  
    83  // Close implements the storage.Getter interface
    84  func (tg *testGetter) Close() error {
    85  	return nil
    86  }
    87  
    88  func TestGetter(t *testing.T) {
    89  	t.Parallel()
    90  	// failure is a struct that defines a failure scenario to test
    91  	type failure struct {
    92  		name string
    93  		err  error
    94  		errf func(int, int) func(int) chan struct{}
    95  	}
    96  	// failures is a list of failure scenarios to test
    97  	failures := []failure{
    98  		{
    99  			"timeout",
   100  			context.Canceled,
   101  			func(_, _ int) func(i int) chan struct{} {
   102  				return func(i int) chan struct{} {
   103  					return nil
   104  				}
   105  			},
   106  		},
   107  		{
   108  			"not found",
   109  			storage.ErrNotFound,
   110  			func(_, _ int) func(i int) chan struct{} {
   111  				c := make(chan struct{})
   112  				close(c)
   113  				return func(i int) chan struct{} {
   114  					return c
   115  				}
   116  			},
   117  		},
   118  	}
   119  	type test struct {
   120  		name    string
   121  		failure failure
   122  		level   int
   123  		count   int
   124  		found   int
   125  	}
   126  
   127  	var tests []test
   128  	for _, f := range failures {
   129  		for level, c := range redundancy.GetReplicaCounts() {
   130  			for j := 0; j <= c*2+1; j++ {
   131  				tests = append(tests, test{
   132  					name:    fmt.Sprintf("%s level %d count %d found %d", f.name, level, c, j),
   133  					failure: f,
   134  					level:   level,
   135  					count:   c,
   136  					found:   j,
   137  				})
   138  			}
   139  		}
   140  	}
   141  
   142  	// initialise the base chunk
   143  	chunkLen := 420
   144  	buf := make([]byte, chunkLen)
   145  	if _, err := io.ReadFull(rand.Reader, buf); err != nil {
   146  		t.Fatal(err)
   147  	}
   148  	ch, err := cac.New(buf)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	// reset retry interval to speed up tests
   153  	retryInterval := replicas.RetryInterval
   154  	defer func() { replicas.RetryInterval = retryInterval }()
   155  	replicas.RetryInterval = 100 * time.Millisecond
   156  
   157  	// run the tests
   158  	for _, tc := range tests {
   159  		t.Run(tc.name, func(t *testing.T) {
   160  			// initiate a chunk retrieval session using replicas.Getter
   161  			// embedding a testGetter that simulates the behaviour of a chunk store
   162  			store := newTestGetter(ch, tc.found, tc.failure.errf(tc.found, tc.count))
   163  			g := replicas.NewGetter(store, redundancy.Level(tc.level))
   164  			store.now = time.Now()
   165  			ctx, cancel := context.WithCancel(context.Background())
   166  			if tc.found > tc.count {
   167  				wait := replicas.RetryInterval / 2 * time.Duration(1+2*tc.level)
   168  				go func() {
   169  					time.Sleep(wait)
   170  					cancel()
   171  				}()
   172  			}
   173  			_, err := g.Get(ctx, ch.Address())
   174  			replicas.Wait(g)
   175  			cancel()
   176  
   177  			// test the returned error
   178  			if tc.found <= tc.count {
   179  				if err != nil {
   180  					t.Fatalf("expected no error. got %v", err)
   181  				}
   182  				// if j <= c, the original chunk should be retrieved and the context should be cancelled
   183  				t.Run("retrievals cancelled", func(t *testing.T) {
   184  
   185  					select {
   186  					case <-time.After(100 * time.Millisecond):
   187  						t.Fatal("timed out waiting for context to be cancelled")
   188  					case <-store.cancelled:
   189  					}
   190  				})
   191  
   192  			} else {
   193  				if err == nil {
   194  					t.Fatalf("expected error. got <nil>")
   195  				}
   196  
   197  				t.Run("returns correct error", func(t *testing.T) {
   198  					if !errors.Is(err, replicas.ErrSwarmageddon) {
   199  						t.Fatalf("incorrect error. want Swarmageddon. got %v", err)
   200  					}
   201  					if !errors.Is(err, tc.failure.err) {
   202  						t.Fatalf("incorrect error. want it to wrap %v. got %v", tc.failure.err, err)
   203  					}
   204  				})
   205  			}
   206  
   207  			attempts := int(store.attempts.Load())
   208  			// the original chunk should be among those attempted for retrieval
   209  			addresses := store.addresses[:attempts]
   210  			latencies := store.latencies[:attempts]
   211  			t.Run("original address called", func(t *testing.T) {
   212  				select {
   213  				case <-time.After(100 * time.Millisecond):
   214  					t.Fatal("timed out waiting form original address to be attempted for retrieval")
   215  				case <-store.origCalled:
   216  					i := store.origIndex
   217  					if i > 2 {
   218  						t.Fatalf("original address called too late. want at most 2 (preceding attempts). got %v (latency: %v)", i, latencies[i])
   219  					}
   220  					addresses = append(addresses[:i], addresses[i+1:]...)
   221  					latencies = append(latencies[:i], latencies[i+1:]...)
   222  					attempts--
   223  				}
   224  			})
   225  
   226  			t.Run("retrieved count", func(t *testing.T) {
   227  				if attempts > tc.count {
   228  					t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v.", tc.count, attempts)
   229  				}
   230  				if tc.found > tc.count {
   231  					if attempts < tc.count {
   232  						t.Fatalf("too few attempts to retrieve a replica: want at least %v. got %v.", tc.count, attempts)
   233  					}
   234  					return
   235  				}
   236  				max := 2
   237  				for i := 1; i < tc.level && max < tc.found; i++ {
   238  					max = max * 2
   239  				}
   240  				if attempts > max {
   241  					t.Fatalf("too many attempts to retrieve a replica: want at most %v. got %v. latencies %v", max, attempts, latencies)
   242  				}
   243  			})
   244  
   245  			t.Run("dispersion", func(t *testing.T) {
   246  
   247  				if err := dispersed(redundancy.Level(tc.level), ch, addresses); err != nil {
   248  					t.Fatalf("addresses are not dispersed: %v", err)
   249  				}
   250  			})
   251  
   252  			t.Run("latency", func(t *testing.T) {
   253  				counts := redundancy.GetReplicaCounts()
   254  				for i, latency := range latencies {
   255  					multiplier := latency / replicas.RetryInterval
   256  					if multiplier > 0 && i < counts[multiplier-1] {
   257  						t.Fatalf("incorrect latency for retrieving replica %d: %v", i, err)
   258  					}
   259  				}
   260  			})
   261  		})
   262  	}
   263  }