github.com/ethersphere/bee/v2@v2.2.0/pkg/retrieval/retrieval_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package retrieval_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"encoding/hex"
    11  	"errors"
    12  	"fmt"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/ethersphere/bee/v2/pkg/accounting"
    19  	accountingmock "github.com/ethersphere/bee/v2/pkg/accounting/mock"
    20  	"github.com/ethersphere/bee/v2/pkg/log"
    21  	"github.com/ethersphere/bee/v2/pkg/p2p"
    22  	"github.com/ethersphere/bee/v2/pkg/p2p/protobuf"
    23  	"github.com/ethersphere/bee/v2/pkg/p2p/streamtest"
    24  	"github.com/ethersphere/bee/v2/pkg/pricer"
    25  	pricermock "github.com/ethersphere/bee/v2/pkg/pricer/mock"
    26  	"github.com/ethersphere/bee/v2/pkg/retrieval"
    27  	pb "github.com/ethersphere/bee/v2/pkg/retrieval/pb"
    28  	"github.com/ethersphere/bee/v2/pkg/spinlock"
    29  	"github.com/ethersphere/bee/v2/pkg/storage"
    30  	"github.com/ethersphere/bee/v2/pkg/storage/inmemchunkstore"
    31  	testingc "github.com/ethersphere/bee/v2/pkg/storage/testing"
    32  	storemock "github.com/ethersphere/bee/v2/pkg/storer/mock"
    33  	"github.com/ethersphere/bee/v2/pkg/swarm"
    34  	"github.com/ethersphere/bee/v2/pkg/topology"
    35  	"github.com/ethersphere/bee/v2/pkg/tracing"
    36  
    37  	topologymock "github.com/ethersphere/bee/v2/pkg/topology/mock"
    38  )
    39  
    40  var (
    41  	testTimeout  = 5 * time.Second
    42  	defaultPrice = uint64(10)
    43  )
    44  
    45  type testStorer struct {
    46  	storage.ChunkStore
    47  }
    48  
    49  func (t *testStorer) Lookup() storage.Getter { return t.ChunkStore }
    50  
    51  func (t *testStorer) Cache() storage.Putter { return t.ChunkStore }
    52  
    53  // TestDelivery tests that a naive request -> delivery flow works.
    54  func TestDelivery(t *testing.T) {
    55  	t.Parallel()
    56  
    57  	var (
    58  		chunk                = testingc.FixtureChunk("0033")
    59  		logger               = log.Noop
    60  		mockStorer           = &testStorer{ChunkStore: inmemchunkstore.New()}
    61  		clientMockAccounting = accountingmock.NewAccounting()
    62  		serverMockAccounting = accountingmock.NewAccounting()
    63  		clientAddr           = swarm.MustParseHexAddress("9ee7add8")
    64  		serverAddr           = swarm.MustParseHexAddress("9ee7add7")
    65  
    66  		pricerMock = pricermock.NewMockService(defaultPrice, defaultPrice)
    67  	)
    68  	// put testdata in the mock store of the server
    69  	err := mockStorer.Put(context.Background(), chunk)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  
    74  	// create the server that will handle the request and will serve the response
    75  	server := createRetrieval(t, swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil, false)
    76  	recorder := streamtest.New(
    77  		streamtest.WithProtocols(server.Protocol()),
    78  		streamtest.WithBaseAddr(clientAddr),
    79  	)
    80  
    81  	// client mock storer does not store any data at this point
    82  	// but should be checked at at the end of the test for the
    83  	// presence of the chunk address key and value to ensure delivery
    84  	// was successful
    85  	clientMockStorer := &testStorer{ChunkStore: inmemchunkstore.New()}
    86  
    87  	mt := topologymock.NewTopologyDriver(topologymock.WithClosestPeer(serverAddr))
    88  
    89  	client := createRetrieval(t, clientAddr, clientMockStorer, recorder, mt, logger, clientMockAccounting, pricerMock, nil, false)
    90  	ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
    91  	defer cancel()
    92  	v, err := client.RetrieveChunk(ctx, chunk.Address(), swarm.ZeroAddress)
    93  	if err != nil {
    94  		t.Fatal(err)
    95  	}
    96  	if !bytes.Equal(v.Data(), chunk.Data()) {
    97  		t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
    98  	}
    99  	records, err := recorder.Records(serverAddr, "retrieval", "1.4.0", "retrieval")
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  	if l := len(records); l != 1 {
   104  		t.Fatalf("got %v records, want %v", l, 1)
   105  	}
   106  
   107  	record := records[0]
   108  
   109  	messages, err := protobuf.ReadMessages(
   110  		bytes.NewReader(record.In()),
   111  		func() protobuf.Message { return new(pb.Request) },
   112  	)
   113  	if err != nil {
   114  		t.Fatal(err)
   115  	}
   116  	reqs := make([]string, 0, len(messages))
   117  	for _, m := range messages {
   118  		reqs = append(reqs, hex.EncodeToString(m.(*pb.Request).Addr))
   119  	}
   120  
   121  	if len(reqs) != 1 {
   122  		t.Fatalf("got too many requests. want 1 got %d", len(reqs))
   123  	}
   124  
   125  	messages, err = protobuf.ReadMessages(
   126  		bytes.NewReader(record.Out()),
   127  		func() protobuf.Message { return new(pb.Delivery) },
   128  	)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	gotDeliveries := make([]string, 0, len(messages))
   134  	for _, m := range messages {
   135  		gotDeliveries = append(gotDeliveries, string(m.(*pb.Delivery).Data))
   136  	}
   137  
   138  	if len(gotDeliveries) != 1 {
   139  		t.Fatalf("got too many deliveries. want 1 got %d", len(gotDeliveries))
   140  	}
   141  
   142  	clientBalance, _ := clientMockAccounting.Balance(serverAddr)
   143  	if clientBalance.Int64() != -int64(defaultPrice) {
   144  		t.Fatalf("unexpected balance on client. want %d got %d", -defaultPrice, clientBalance)
   145  	}
   146  
   147  	serverBalance, _ := serverMockAccounting.Balance(clientAddr)
   148  	if serverBalance.Int64() != int64(defaultPrice) {
   149  		t.Fatalf("unexpected balance on server. want %d got %d", defaultPrice, serverBalance)
   150  	}
   151  }
   152  
   153  func TestWaitForInflight(t *testing.T) {
   154  	t.Parallel()
   155  
   156  	var (
   157  		chunk      = testingc.FixtureChunk("7000")
   158  		logger     = log.Noop
   159  		pricerMock = pricermock.NewMockService(defaultPrice, defaultPrice)
   160  
   161  		badMockStorer           = &testStorer{ChunkStore: inmemchunkstore.New()}
   162  		badServerMockAccounting = accountingmock.NewAccounting()
   163  		badServerAddr           = swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
   164  
   165  		mockStorer           = &testStorer{ChunkStore: inmemchunkstore.New()}
   166  		serverMockAccounting = accountingmock.NewAccounting()
   167  		serverAddr           = swarm.MustParseHexAddress("5000000000000000000000000000000000000000000000000000000000000000")
   168  
   169  		clientMockStorer     = &testStorer{ChunkStore: inmemchunkstore.New()}
   170  		clientMockAccounting = accountingmock.NewAccounting()
   171  		clientAddr           = swarm.MustParseHexAddress("9ee7add8")
   172  	)
   173  
   174  	// put testdata in the mock store of the server
   175  	err := mockStorer.Put(context.Background(), chunk)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  
   180  	// create the server that will handle the request and will serve the response
   181  	server := createRetrieval(t, serverAddr, mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil, false)
   182  
   183  	badServer := createRetrieval(t, badServerAddr, badMockStorer, nil, nil, logger, badServerMockAccounting, pricerMock, nil, false)
   184  
   185  	var fail = true
   186  	var lock sync.Mutex
   187  
   188  	recorder := streamtest.New(
   189  		streamtest.WithBaseAddr(clientAddr),
   190  		streamtest.WithProtocols(badServer.Protocol(), server.Protocol()),
   191  		streamtest.WithMiddlewares(func(h p2p.HandlerFunc) p2p.HandlerFunc {
   192  			return func(ctx context.Context, p p2p.Peer, s p2p.Stream) error {
   193  				lock.Lock()
   194  				defer lock.Unlock()
   195  
   196  				if fail {
   197  					fail = false
   198  					s.Close()
   199  					return errors.New("peer not reachable")
   200  				}
   201  
   202  				time.Sleep(time.Second * 2)
   203  
   204  				if err := h(ctx, p, s); err != nil {
   205  					return err
   206  				}
   207  				// close stream after all previous middlewares wrote to it
   208  				// so that the receiving peer can get all the post messages
   209  				return s.Close()
   210  			}
   211  		}),
   212  	)
   213  
   214  	mt := topologymock.NewTopologyDriver(topologymock.WithPeers(badServerAddr, serverAddr))
   215  
   216  	client := createRetrieval(t, clientAddr, clientMockStorer, recorder, mt, logger, clientMockAccounting, pricerMock, nil, false)
   217  
   218  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*30)
   219  	defer cancel()
   220  
   221  	v, err := client.RetrieveChunk(ctx, chunk.Address(), swarm.ZeroAddress)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  
   226  	if !bytes.Equal(v.Data(), chunk.Data()) {
   227  		t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
   228  	}
   229  }
   230  
   231  func TestRetrieveChunk(t *testing.T) {
   232  	t.Parallel()
   233  
   234  	var (
   235  		logger = log.Noop
   236  		pricer = pricermock.NewMockService(defaultPrice, defaultPrice)
   237  	)
   238  
   239  	// requesting a chunk from downstream peer is expected
   240  	t.Run("downstream", func(t *testing.T) {
   241  		t.Parallel()
   242  		t.Skip()
   243  
   244  		serverAddress := swarm.MustParseHexAddress("03")
   245  		clientAddress := swarm.MustParseHexAddress("01")
   246  		chunk := testingc.FixtureChunk("02c2")
   247  
   248  		serverStorer := &testStorer{ChunkStore: inmemchunkstore.New()}
   249  		err := serverStorer.Put(context.Background(), chunk)
   250  		if err != nil {
   251  			t.Fatal(err)
   252  		}
   253  
   254  		server := createRetrieval(t, serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil, false)
   255  		recorder := streamtest.New(streamtest.WithProtocols(server.Protocol()))
   256  
   257  		mt := topologymock.NewTopologyDriver(topologymock.WithClosestPeer(serverAddress))
   258  
   259  		client := createRetrieval(t, clientAddress, nil, recorder, mt, logger, accountingmock.NewAccounting(), pricer, nil, false)
   260  
   261  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   262  		if err != nil {
   263  			t.Fatal(err)
   264  		}
   265  		if !bytes.Equal(got.Data(), chunk.Data()) {
   266  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   267  		}
   268  	})
   269  
   270  	t.Run("forward", func(t *testing.T) {
   271  		t.Parallel()
   272  		t.Skip()
   273  
   274  		chunk := testingc.FixtureChunk("0025")
   275  
   276  		serverAddress := swarm.MustParseHexAddress("0100000000000000000000000000000000000000000000000000000000000000")
   277  		forwarderAddress := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
   278  		clientAddress := swarm.MustParseHexAddress("030000000000000000000000000000000000000000000000000000000000000000")
   279  
   280  		serverStorer := &testStorer{ChunkStore: inmemchunkstore.New()}
   281  		err := serverStorer.Put(context.Background(), chunk)
   282  		if err != nil {
   283  			t.Fatal(err)
   284  		}
   285  
   286  		server := createRetrieval(t,
   287  			serverAddress,
   288  			serverStorer, // chunk is in server's store
   289  			nil,
   290  			nil,
   291  			logger,
   292  			accountingmock.NewAccounting(),
   293  			pricer,
   294  			nil,
   295  			false,
   296  		)
   297  
   298  		forwarderStore := &testStorer{ChunkStore: inmemchunkstore.New()}
   299  
   300  		forwarder := createRetrieval(t,
   301  			forwarderAddress,
   302  			forwarderStore, // no chunk in forwarder's store
   303  			streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server
   304  			topologymock.NewTopologyDriver(topologymock.WithClosestPeer(serverAddress)),
   305  			logger,
   306  			accountingmock.NewAccounting(),
   307  			pricer,
   308  			nil,
   309  			true, // note explicit caching
   310  		)
   311  
   312  		client := createRetrieval(t,
   313  			clientAddress,
   314  			storemock.New(), // no chunk in clients's store
   315  			streamtest.New(streamtest.WithProtocols(forwarder.Protocol())), // connect to forwarder
   316  			topologymock.NewTopologyDriver(topologymock.WithClosestPeer(forwarderAddress)),
   317  			logger,
   318  			accountingmock.NewAccounting(),
   319  			pricer,
   320  			nil,
   321  			false,
   322  		)
   323  
   324  		if got, _ := forwarderStore.Has(context.Background(), chunk.Address()); got {
   325  			t.Fatalf("forwarder node already has chunk")
   326  		}
   327  
   328  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   329  		if err != nil {
   330  			t.Fatal(err)
   331  		}
   332  		if !bytes.Equal(got.Data(), chunk.Data()) {
   333  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   334  		}
   335  
   336  		err = spinlock.Wait(time.Second, func() bool {
   337  			gots, _ := forwarderStore.Has(context.Background(), chunk.Address())
   338  			return gots
   339  		})
   340  		if err != nil {
   341  			t.Fatalf("forwarder did not cache chunk")
   342  		}
   343  	})
   344  
   345  	t.Run("propagate error to origin", func(t *testing.T) {
   346  		t.Parallel()
   347  
   348  		chunk := testingc.FixtureChunk("0025")
   349  
   350  		serverAddress := swarm.MustParseHexAddress("0100000000000000000000000000000000000000000000000000000000000000")
   351  		forwarderAddress := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
   352  		clientAddress := swarm.MustParseHexAddress("030000000000000000000000000000000000000000000000000000000000000000")
   353  
   354  		buf := new(bytes.Buffer)
   355  		captureLogger := log.NewLogger("test", log.WithSink(buf))
   356  
   357  		server := createRetrieval(t,
   358  			serverAddress,
   359  			&testStorer{ChunkStore: inmemchunkstore.New()},
   360  			nil,
   361  			topologymock.NewTopologyDriver(),
   362  			logger,
   363  			accountingmock.NewAccounting(),
   364  			pricer,
   365  			nil,
   366  			false,
   367  		)
   368  
   369  		forwarderStore := &testStorer{ChunkStore: inmemchunkstore.New()}
   370  
   371  		forwarder := createRetrieval(t,
   372  			forwarderAddress,
   373  			forwarderStore, // no chunk in forwarder's store
   374  			streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server
   375  			topologymock.NewTopologyDriver(topologymock.WithClosestPeer(serverAddress)),
   376  			logger,
   377  			accountingmock.NewAccounting(),
   378  			pricer,
   379  			nil,
   380  			true, // note explicit caching
   381  		)
   382  
   383  		client := createRetrieval(t,
   384  			clientAddress,
   385  			storemock.New(), // no chunk in clients's store
   386  			streamtest.New(streamtest.WithProtocols(forwarder.Protocol())), // connect to forwarder
   387  			topologymock.NewTopologyDriver(topologymock.WithClosestPeer(forwarderAddress)),
   388  			captureLogger,
   389  			accountingmock.NewAccounting(),
   390  			pricer,
   391  			nil,
   392  			false,
   393  		)
   394  
   395  		_, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   396  		if err == nil {
   397  			t.Fatal("should have received an error")
   398  		}
   399  
   400  		want := p2p.NewChunkDeliveryError("retrieve chunk: no peer found")
   401  		if got := buf.String(); !strings.Contains(got, want.Error()) {
   402  			t.Fatalf("got log %s, want %s", got, want)
   403  		}
   404  	})
   405  }
   406  
   407  func TestRetrievePreemptiveRetry(t *testing.T) {
   408  	t.Parallel()
   409  
   410  	logger := log.Noop
   411  
   412  	chunk := testingc.FixtureChunk("0025")
   413  	someOtherChunk := testingc.FixtureChunk("0033")
   414  
   415  	pricerMock := pricermock.NewMockService(defaultPrice, defaultPrice)
   416  
   417  	clientAddress := swarm.MustParseHexAddress("1010")
   418  
   419  	serverAddress1 := swarm.MustParseHexAddress("1000000000000000000000000000000000000000000000000000000000000000")
   420  	serverAddress2 := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
   421  	peers := []swarm.Address{
   422  		serverAddress1,
   423  		serverAddress2,
   424  	}
   425  
   426  	serverStorer1 := &testStorer{ChunkStore: inmemchunkstore.New()}
   427  	serverStorer2 := &testStorer{ChunkStore: inmemchunkstore.New()}
   428  
   429  	// we put some other chunk on server 1
   430  	err := serverStorer1.Put(context.Background(), someOtherChunk)
   431  	if err != nil {
   432  		t.Fatal(err)
   433  	}
   434  	// we put chunk we need on server 2
   435  	err = serverStorer2.Put(context.Background(), chunk)
   436  	if err != nil {
   437  		t.Fatal(err)
   438  	}
   439  
   440  	noClosestPeer := topologymock.NewTopologyDriver()
   441  	closetPeers := topologymock.NewTopologyDriver(topologymock.WithPeers(peers...))
   442  
   443  	server1 := createRetrieval(t, serverAddress1, serverStorer1, nil, noClosestPeer, logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   444  	server2 := createRetrieval(t, serverAddress2, serverStorer2, nil, noClosestPeer, logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   445  
   446  	t.Run("peer not reachable", func(t *testing.T) {
   447  		t.Parallel()
   448  
   449  		ranOnce := true
   450  		ranMux := sync.Mutex{}
   451  		recorder := streamtest.New(
   452  			streamtest.WithProtocols(
   453  				server1.Protocol(),
   454  				server2.Protocol(),
   455  			),
   456  			streamtest.WithMiddlewares(
   457  				func(h p2p.HandlerFunc) p2p.HandlerFunc {
   458  					return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
   459  						ranMux.Lock()
   460  						defer ranMux.Unlock()
   461  						// NOTE: return error for peer1
   462  						if ranOnce {
   463  							ranOnce = false
   464  							return fmt.Errorf("peer not reachable: %s", peer.Address.String())
   465  						}
   466  
   467  						return server2.Handler(ctx, peer, stream)
   468  					}
   469  				},
   470  			),
   471  			streamtest.WithBaseAddr(clientAddress),
   472  		)
   473  
   474  		client := createRetrieval(t, clientAddress, nil, recorder, closetPeers, logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   475  
   476  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   477  		if err != nil {
   478  			t.Fatal(err)
   479  		}
   480  
   481  		if !bytes.Equal(got.Data(), chunk.Data()) {
   482  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   483  		}
   484  	})
   485  
   486  	t.Run("peer does not have chunk", func(t *testing.T) {
   487  		t.Parallel()
   488  
   489  		ranOnce := true
   490  		ranMux := sync.Mutex{}
   491  		recorder := streamtest.New(
   492  			streamtest.WithProtocols(
   493  				server1.Protocol(),
   494  				server2.Protocol(),
   495  			),
   496  			streamtest.WithMiddlewares(
   497  				func(h p2p.HandlerFunc) p2p.HandlerFunc {
   498  					return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
   499  						ranMux.Lock()
   500  						defer ranMux.Unlock()
   501  						if ranOnce {
   502  							ranOnce = false
   503  							return server1.Handler(ctx, peer, stream)
   504  						}
   505  
   506  						return server2.Handler(ctx, peer, stream)
   507  					}
   508  				},
   509  			),
   510  		)
   511  
   512  		client := createRetrieval(t, clientAddress, nil, recorder, closetPeers, logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   513  
   514  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   515  		if err != nil {
   516  			t.Fatal(err)
   517  		}
   518  
   519  		if !bytes.Equal(got.Data(), chunk.Data()) {
   520  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   521  		}
   522  	})
   523  
   524  	t.Run("one peer is slower", func(t *testing.T) {
   525  		t.Parallel()
   526  
   527  		serverStorer1 := &testStorer{ChunkStore: inmemchunkstore.New()}
   528  		serverStorer2 := &testStorer{ChunkStore: inmemchunkstore.New()}
   529  
   530  		// both peers have required chunk
   531  		err := serverStorer1.Put(context.Background(), chunk)
   532  		if err != nil {
   533  			t.Fatal(err)
   534  		}
   535  		err = serverStorer2.Put(context.Background(), chunk)
   536  		if err != nil {
   537  			t.Fatal(err)
   538  		}
   539  
   540  		server1MockAccounting := accountingmock.NewAccounting()
   541  		server2MockAccounting := accountingmock.NewAccounting()
   542  
   543  		server1 := createRetrieval(t, serverAddress1, serverStorer1, nil, noClosestPeer, logger, server1MockAccounting, pricerMock, nil, false)
   544  		server2 := createRetrieval(t, serverAddress2, serverStorer2, nil, noClosestPeer, logger, server2MockAccounting, pricerMock, nil, false)
   545  
   546  		// NOTE: must be more than retry duration
   547  		// (here one second more)
   548  		server1ResponseDelayDuration := 2 * time.Second
   549  
   550  		ranOnce := true
   551  		ranMux := sync.Mutex{}
   552  		recorder := streamtest.New(
   553  			streamtest.WithProtocols(
   554  				server1.Protocol(),
   555  				server2.Protocol(),
   556  			),
   557  			streamtest.WithMiddlewares(
   558  				func(h p2p.HandlerFunc) p2p.HandlerFunc {
   559  					return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
   560  						ranMux.Lock()
   561  						if ranOnce {
   562  							// NOTE: sleep time must be more than retry duration
   563  							ranOnce = false
   564  							ranMux.Unlock()
   565  							time.Sleep(server1ResponseDelayDuration)
   566  							// server2 is picked first because it's address is closer to the chunk than server1
   567  							return server2.Handler(ctx, peer, stream)
   568  						}
   569  						ranMux.Unlock()
   570  
   571  						return server1.Handler(ctx, peer, stream)
   572  					}
   573  				},
   574  			),
   575  		)
   576  
   577  		clientMockAccounting := accountingmock.NewAccounting()
   578  
   579  		client := createRetrieval(t, clientAddress, nil, recorder, closetPeers, logger, clientMockAccounting, pricerMock, nil, false)
   580  
   581  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   582  		if err != nil {
   583  			t.Fatal(err)
   584  		}
   585  
   586  		if !bytes.Equal(got.Data(), chunk.Data()) {
   587  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   588  		}
   589  
   590  		clientServer1Balance, _ := clientMockAccounting.Balance(serverAddress1)
   591  		if clientServer1Balance.Int64() != -int64(defaultPrice) {
   592  			t.Fatalf("unexpected balance on client. want %d got %d", -int64(defaultPrice), clientServer1Balance)
   593  		}
   594  
   595  		clientServer2Balance, _ := clientMockAccounting.Balance(serverAddress2)
   596  		if clientServer2Balance.Int64() != 0 {
   597  			t.Fatalf("unexpected balance on client. want %d got %d", 0, clientServer2Balance)
   598  		}
   599  
   600  		// wait and check balance again
   601  		// (yet one second more than before, minus original duration)
   602  		time.Sleep(2 * time.Second)
   603  
   604  		clientServer1Balance, _ = clientMockAccounting.Balance(serverAddress1)
   605  		if clientServer1Balance.Int64() != -int64(defaultPrice) {
   606  			t.Fatalf("unexpected balance on client. want %d got %d", -int64(defaultPrice), clientServer1Balance)
   607  		}
   608  
   609  		clientServer2Balance, _ = clientMockAccounting.Balance(serverAddress2)
   610  		if clientServer2Balance.Int64() != -int64(defaultPrice) {
   611  			t.Fatalf("unexpected balance on client. want %d got %d", -int64(defaultPrice), clientServer2Balance)
   612  		}
   613  	})
   614  
   615  	t.Run("peer forwards request", func(t *testing.T) {
   616  		t.Parallel()
   617  
   618  		// server 2 has the chunk
   619  		server2 := createRetrieval(t, serverAddress2, serverStorer2, nil, noClosestPeer, logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   620  
   621  		server1Recorder := streamtest.New(
   622  			streamtest.WithProtocols(server2.Protocol()),
   623  		)
   624  
   625  		// server 1 will forward request to server 2
   626  		server1 := createRetrieval(t, serverAddress1, serverStorer1, server1Recorder, topologymock.NewTopologyDriver(topologymock.WithPeers(serverAddress2)), logger, accountingmock.NewAccounting(), pricerMock, nil, true)
   627  
   628  		clientRecorder := streamtest.New(
   629  			streamtest.WithProtocols(server1.Protocol()),
   630  		)
   631  
   632  		// client only knows about server 1
   633  		client := createRetrieval(t, clientAddress, nil, clientRecorder, topologymock.NewTopologyDriver(topologymock.WithPeers(serverAddress1)), logger, accountingmock.NewAccounting(), pricerMock, nil, false)
   634  
   635  		if got, _ := serverStorer1.Has(context.Background(), chunk.Address()); got {
   636  			t.Fatalf("forwarder node already has chunk")
   637  		}
   638  
   639  		got, err := client.RetrieveChunk(context.Background(), chunk.Address(), swarm.ZeroAddress)
   640  		if err != nil {
   641  			t.Fatal(err)
   642  		}
   643  
   644  		if !bytes.Equal(got.Data(), chunk.Data()) {
   645  			t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
   646  		}
   647  		err = spinlock.Wait(time.Second, func() bool {
   648  			has, _ := serverStorer1.Has(context.Background(), chunk.Address())
   649  			return has
   650  		})
   651  		if err != nil {
   652  			t.Fatalf("forwarder node does not have chunk")
   653  		}
   654  	})
   655  }
   656  
   657  func TestClosestPeer(t *testing.T) {
   658  	t.Parallel()
   659  
   660  	srvAd := swarm.MustParseHexAddress("0100000000000000000000000000000000000000000000000000000000000000")
   661  
   662  	addr1 := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
   663  	addr2 := swarm.MustParseHexAddress("0300000000000000000000000000000000000000000000000000000000000000")
   664  	addr3 := swarm.MustParseHexAddress("0400000000000000000000000000000000000000000000000000000000000000")
   665  
   666  	ret := createRetrieval(t, srvAd, nil, nil, topologymock.NewTopologyDriver(topologymock.WithPeers(addr1, addr2, addr3)), log.Noop, nil, nil, nil, false)
   667  
   668  	t.Run("closest", func(t *testing.T) {
   669  		t.Parallel()
   670  
   671  		addr, err := ret.ClosestPeer(addr1, nil, false)
   672  		if err != nil {
   673  			t.Fatal("closest peer", err)
   674  		}
   675  		if !addr.Equal(addr1) {
   676  			t.Fatalf("want %s, got %s", addr1.String(), addr.String())
   677  		}
   678  	})
   679  
   680  	t.Run("second closest", func(t *testing.T) {
   681  		t.Parallel()
   682  
   683  		addr, err := ret.ClosestPeer(addr1, []swarm.Address{addr1}, false)
   684  		if err != nil {
   685  			t.Fatal("closest peer", err)
   686  		}
   687  		if !addr.Equal(addr2) {
   688  			t.Fatalf("want %s, got %s", addr2.String(), addr.String())
   689  		}
   690  	})
   691  
   692  	t.Run("closest is further than base addr", func(t *testing.T) {
   693  		t.Parallel()
   694  
   695  		_, err := ret.ClosestPeer(srvAd, nil, false)
   696  		if !errors.Is(err, topology.ErrNotFound) {
   697  			t.Fatal("closest peer", err)
   698  		}
   699  	})
   700  }
   701  
   702  func createRetrieval(
   703  	t *testing.T,
   704  	addr swarm.Address,
   705  	storer retrieval.Storer,
   706  	streamer p2p.Streamer,
   707  	chunkPeerer topology.ClosestPeerer,
   708  	logger log.Logger,
   709  	accounting accounting.Interface,
   710  	pricer pricer.Interface,
   711  	tracer *tracing.Tracer,
   712  	forwarderCaching bool,
   713  ) *retrieval.Service {
   714  	t.Helper()
   715  
   716  	radiusF := func() (uint8, error) { return swarm.MaxBins, nil }
   717  
   718  	ret := retrieval.New(addr, radiusF, storer, streamer, chunkPeerer, logger, accounting, pricer, tracer, forwarderCaching)
   719  	t.Cleanup(func() { ret.Close() })
   720  	return ret
   721  }