github.com/ethersphere/bee/v2@v2.2.0/pkg/pullsync/pullsync_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package pullsync_test
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/ethersphere/bee/v2/pkg/log"
    15  	"github.com/ethersphere/bee/v2/pkg/p2p"
    16  	"github.com/ethersphere/bee/v2/pkg/p2p/streamtest"
    17  	"github.com/ethersphere/bee/v2/pkg/postage"
    18  	postagetesting "github.com/ethersphere/bee/v2/pkg/postage/testing"
    19  	"github.com/ethersphere/bee/v2/pkg/pullsync"
    20  	"github.com/ethersphere/bee/v2/pkg/storage"
    21  	testingc "github.com/ethersphere/bee/v2/pkg/storage/testing"
    22  	"github.com/ethersphere/bee/v2/pkg/storer"
    23  	mock "github.com/ethersphere/bee/v2/pkg/storer/mock"
    24  	"github.com/ethersphere/bee/v2/pkg/swarm"
    25  )
    26  
    27  var (
    28  	results []*storer.BinC
    29  	addrs   []swarm.Address
    30  	chunks  []swarm.Chunk
    31  )
    32  
    33  func someChunks(i ...int) (c []swarm.Chunk) {
    34  	for _, v := range i {
    35  		c = append(c, chunks[v])
    36  	}
    37  	return c
    38  }
    39  
    40  // nolint:gochecknoinits
    41  func init() {
    42  	n := 5
    43  	chunks = make([]swarm.Chunk, n)
    44  	addrs = make([]swarm.Address, n)
    45  	results = make([]*storer.BinC, n)
    46  	for i := 0; i < n; i++ {
    47  		chunks[i] = testingc.GenerateTestRandomChunk()
    48  		addrs[i] = chunks[i].Address()
    49  		stampHash, _ := chunks[i].Stamp().Hash()
    50  		results[i] = &storer.BinC{
    51  			Address:   addrs[i],
    52  			BatchID:   chunks[i].Stamp().BatchID(),
    53  			BinID:     uint64(i),
    54  			StampHash: stampHash,
    55  		}
    56  	}
    57  }
    58  
    59  func TestIncoming_WantNone(t *testing.T) {
    60  	t.Parallel()
    61  
    62  	var (
    63  		topMost            = uint64(4)
    64  		ps, _              = newPullSync(t, nil, 5, mock.WithSubscribeResp(results, nil), mock.WithChunks(chunks...))
    65  		recorder           = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
    66  		psClient, clientDb = newPullSync(t, recorder, 0, mock.WithChunks(chunks...))
    67  	)
    68  
    69  	topmost, _, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  
    74  	if topmost != topMost {
    75  		t.Fatalf("got offer topmost %d but want %d", topmost, topMost)
    76  	}
    77  	if clientDb.PutCalls() > 0 {
    78  		t.Fatal("too many puts")
    79  	}
    80  }
    81  
    82  func TestIncoming_ContextTimeout(t *testing.T) {
    83  	t.Parallel()
    84  
    85  	var (
    86  		ps, _       = newPullSync(t, nil, 0, mock.WithSubscribeResp(results, nil), mock.WithChunks(chunks...))
    87  		recorder    = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
    88  		psClient, _ = newPullSync(t, recorder, 0, mock.WithChunks(chunks...))
    89  	)
    90  
    91  	ctx, cancel := context.WithTimeout(context.Background(), 0)
    92  	cancel()
    93  	_, _, err := psClient.Sync(ctx, swarm.ZeroAddress, 0, 0)
    94  	if !errors.Is(err, context.DeadlineExceeded) {
    95  		t.Fatalf("wanted error %v, got %v", context.DeadlineExceeded, err)
    96  	}
    97  }
    98  
    99  func TestIncoming_WantOne(t *testing.T) {
   100  	t.Parallel()
   101  
   102  	var (
   103  		topMost            = uint64(4)
   104  		ps, _              = newPullSync(t, nil, 5, mock.WithSubscribeResp(results, nil), mock.WithChunks(chunks...))
   105  		recorder           = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   106  		psClient, clientDb = newPullSync(t, recorder, 0, mock.WithChunks(someChunks(1, 2, 3, 4)...))
   107  	)
   108  
   109  	topmost, _, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
   110  	if err != nil {
   111  		t.Fatal(err)
   112  	}
   113  
   114  	if topmost != topMost {
   115  		t.Fatalf("got offer topmost %d but want %d", topmost, topMost)
   116  	}
   117  
   118  	// should have all
   119  	haveChunks(t, clientDb, chunks...)
   120  	if clientDb.PutCalls() != 1 {
   121  		t.Fatalf("want 1 puts but got %d", clientDb.PutCalls())
   122  	}
   123  }
   124  
   125  func TestIncoming_WantAll(t *testing.T) {
   126  	t.Parallel()
   127  
   128  	var (
   129  		topMost            = uint64(4)
   130  		ps, _              = newPullSync(t, nil, 5, mock.WithSubscribeResp(results, nil), mock.WithChunks(chunks...))
   131  		recorder           = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   132  		psClient, clientDb = newPullSync(t, recorder, 0)
   133  	)
   134  
   135  	topmost, _, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  
   140  	if topmost != topMost {
   141  		t.Fatalf("got offer topmost %d but want %d", topmost, topMost)
   142  	}
   143  
   144  	// should have all
   145  	haveChunks(t, clientDb, chunks...)
   146  	if p := clientDb.PutCalls(); p != len(chunks) {
   147  		t.Fatalf("want %d puts but got %d", len(chunks), p)
   148  	}
   149  }
   150  
   151  func TestIncoming_WantErrors(t *testing.T) {
   152  	t.Parallel()
   153  
   154  	tChunks := testingc.GenerateTestRandomChunks(4)
   155  	// add same chunk with a different batch id
   156  	ch := swarm.NewChunk(tChunks[3].Address(), tChunks[3].Data()).WithStamp(postagetesting.MustNewStamp())
   157  	tChunks = append(tChunks, ch)
   158  	// add invalid chunk
   159  	tChunks = append(tChunks, testingc.GenerateTestRandomInvalidChunk())
   160  
   161  	tResults := make([]*storer.BinC, len(tChunks))
   162  	for i, c := range tChunks {
   163  		stampHash, err := c.Stamp().Hash()
   164  		if err != nil {
   165  			t.Fatal(err)
   166  		}
   167  		tResults[i] = &storer.BinC{
   168  			Address:   c.Address(),
   169  			BatchID:   c.Stamp().BatchID(),
   170  			BinID:     uint64(i + 5), // start from a higher bin id
   171  			StampHash: stampHash,
   172  		}
   173  	}
   174  
   175  	putHook := func(c swarm.Chunk) error {
   176  		if c.Address().Equal(tChunks[1].Address()) {
   177  			return storage.ErrOverwriteNewerChunk
   178  		}
   179  		return nil
   180  	}
   181  
   182  	validStampErr := errors.New("valid stamp error")
   183  	validStamp := func(c swarm.Chunk) (swarm.Chunk, error) {
   184  		if c.Address().Equal(tChunks[2].Address()) {
   185  			return nil, validStampErr
   186  		}
   187  		return c, nil
   188  	}
   189  
   190  	var (
   191  		topMost            = uint64(10)
   192  		ps, _              = newPullSync(t, nil, 20, mock.WithSubscribeResp(tResults, nil), mock.WithChunks(tChunks...))
   193  		recorder           = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   194  		psClient, clientDb = newPullSyncWithStamperValidator(t, recorder, 0, validStamp, mock.WithPutHook(putHook))
   195  	)
   196  
   197  	topmost, count, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
   198  	for _, e := range []error{storage.ErrOverwriteNewerChunk, validStampErr, swarm.ErrInvalidChunk} {
   199  		if !errors.Is(err, e) {
   200  			t.Fatalf("expected error %v", err)
   201  		}
   202  	}
   203  
   204  	if count != 3 {
   205  		t.Fatalf("got %d chunks but want %d", count, 3)
   206  	}
   207  
   208  	if topmost != topMost {
   209  		t.Fatalf("got offer topmost %d but want %d", topmost, topMost)
   210  	}
   211  
   212  	haveChunks(t, clientDb, append(tChunks[:1], tChunks[3:5]...)...)
   213  	if p := clientDb.PutCalls(); p != len(chunks)-1 {
   214  		t.Fatalf("want %d puts but got %d", len(chunks), p)
   215  	}
   216  }
   217  
   218  func TestIncoming_UnsolicitedChunk(t *testing.T) {
   219  	t.Parallel()
   220  
   221  	evilAddr := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000666")
   222  	evilData := []byte{0x66, 0x66, 0x66}
   223  	stamp := postagetesting.MustNewStamp()
   224  	evil := swarm.NewChunk(evilAddr, evilData).WithStamp(stamp)
   225  
   226  	var (
   227  		ps, _       = newPullSync(t, nil, 5, mock.WithSubscribeResp(results, nil), mock.WithChunks(chunks...), mock.WithEvilChunk(addrs[4], evil))
   228  		recorder    = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   229  		psClient, _ = newPullSync(t, recorder, 0)
   230  	)
   231  
   232  	_, _, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
   233  	if !errors.Is(err, pullsync.ErrUnsolicitedChunk) {
   234  		t.Fatalf("expected err %v but got %v", pullsync.ErrUnsolicitedChunk, err)
   235  	}
   236  }
   237  
   238  func TestMissingChunk(t *testing.T) {
   239  	t.Parallel()
   240  
   241  	var (
   242  		zeroChunk   = swarm.NewChunk(swarm.ZeroAddress, nil)
   243  		topMost     = uint64(4)
   244  		ps, _       = newPullSync(t, nil, 5, mock.WithSubscribeResp(results, nil), mock.WithChunks([]swarm.Chunk{zeroChunk}...))
   245  		recorder    = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   246  		psClient, _ = newPullSync(t, recorder, 0)
   247  	)
   248  
   249  	topmost, count, err := psClient.Sync(context.Background(), swarm.ZeroAddress, 0, 0)
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  
   254  	if topmost != topMost {
   255  		t.Fatalf("got offer topmost %d but want %d", topmost, topMost)
   256  	}
   257  	if count != 0 {
   258  		t.Fatalf("got count %d but want %d", count, 0)
   259  	}
   260  }
   261  
   262  func TestGetCursors(t *testing.T) {
   263  	t.Parallel()
   264  
   265  	var (
   266  		epochTs     = uint64(time.Now().Unix())
   267  		mockCursors = []uint64{100, 101, 102, 103}
   268  		ps, _       = newPullSync(t, nil, 0, mock.WithCursors(mockCursors, epochTs))
   269  		recorder    = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   270  		psClient, _ = newPullSync(t, recorder, 0)
   271  	)
   272  
   273  	curs, epoch, err := psClient.GetCursors(context.Background(), swarm.ZeroAddress)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  
   278  	if len(curs) != len(mockCursors) {
   279  		t.Fatalf("length mismatch got %d want %d", len(curs), len(mockCursors))
   280  	}
   281  
   282  	if epochTs != epoch {
   283  		t.Fatalf("epochs do not match got %d want %d", epoch, epochTs)
   284  	}
   285  
   286  	for i, v := range mockCursors {
   287  		if curs[i] != v {
   288  			t.Errorf("cursor mismatch. index %d want %d got %d", i, v, curs[i])
   289  		}
   290  	}
   291  }
   292  
   293  func TestGetCursorsError(t *testing.T) {
   294  	t.Parallel()
   295  
   296  	var (
   297  		e           = errors.New("erring")
   298  		ps, _       = newPullSync(t, nil, 0, mock.WithCursorsErr(e))
   299  		recorder    = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
   300  		psClient, _ = newPullSync(t, recorder, 0)
   301  	)
   302  
   303  	_, _, err := psClient.GetCursors(context.Background(), swarm.ZeroAddress)
   304  	if err == nil {
   305  		t.Fatal("expected error but got none")
   306  	}
   307  	if !errors.Is(err, io.EOF) {
   308  		t.Fatalf("expect error '%v' but got '%v'", e, err)
   309  	}
   310  }
   311  
   312  func haveChunks(t *testing.T, s *mock.ReserveStore, chunks ...swarm.Chunk) {
   313  	t.Helper()
   314  	for _, c := range chunks {
   315  		stampHash, err := c.Stamp().Hash()
   316  		if err != nil {
   317  			t.Fatal(err)
   318  		}
   319  		have, err := s.ReserveHas(c.Address(), c.Stamp().BatchID(), stampHash)
   320  		if err != nil {
   321  			t.Fatal(err)
   322  		}
   323  		if !have {
   324  			t.Errorf("storage does not have chunk %s", c.Address())
   325  		}
   326  	}
   327  }
   328  
   329  func newPullSync(
   330  	t *testing.T,
   331  	s p2p.Streamer,
   332  	maxPage uint64,
   333  	o ...mock.Option,
   334  ) (*pullsync.Syncer, *mock.ReserveStore) {
   335  	t.Helper()
   336  
   337  	validStamp := func(ch swarm.Chunk) (swarm.Chunk, error) {
   338  		return ch, nil
   339  	}
   340  
   341  	return newPullSyncWithStamperValidator(t, s, maxPage, validStamp, o...)
   342  }
   343  
   344  func newPullSyncWithStamperValidator(
   345  	t *testing.T,
   346  	s p2p.Streamer,
   347  	maxPage uint64,
   348  	validStamp postage.ValidStampFn,
   349  	o ...mock.Option,
   350  ) (*pullsync.Syncer, *mock.ReserveStore) {
   351  	t.Helper()
   352  
   353  	storage := mock.NewReserve(o...)
   354  	logger := log.Noop
   355  	unwrap := func(swarm.Chunk) {}
   356  	ps := pullsync.New(
   357  		s,
   358  		storage,
   359  		unwrap,
   360  		validStamp,
   361  		logger,
   362  		maxPage,
   363  	)
   364  
   365  	t.Cleanup(func() {
   366  		err := ps.Close()
   367  		if err != nil {
   368  			t.Errorf("failed closing pullsync: %v", err)
   369  		}
   370  	})
   371  	return ps, storage
   372  }