github.com/ethersphere/bee/v2@v2.2.0/pkg/pricing/pricing_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package pricing_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"math/big"
    12  	"testing"
    13  
    14  	"github.com/ethersphere/bee/v2/pkg/log"
    15  	"github.com/ethersphere/bee/v2/pkg/p2p"
    16  	"github.com/ethersphere/bee/v2/pkg/p2p/protobuf"
    17  	"github.com/ethersphere/bee/v2/pkg/p2p/streamtest"
    18  	"github.com/ethersphere/bee/v2/pkg/pricing"
    19  	"github.com/ethersphere/bee/v2/pkg/pricing/pb"
    20  	"github.com/ethersphere/bee/v2/pkg/swarm"
    21  )
    22  
    23  type testThresholdObserver struct {
    24  	called           bool
    25  	peer             swarm.Address
    26  	paymentThreshold *big.Int
    27  }
    28  
    29  func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error {
    30  	t.called = true
    31  	t.peer = peerAddr
    32  	t.paymentThreshold = paymentThreshold
    33  	return nil
    34  }
    35  
    36  func TestAnnouncePaymentThreshold(t *testing.T) {
    37  	t.Parallel()
    38  
    39  	logger := log.Noop
    40  	testThreshold := big.NewInt(100000)
    41  	testLightThreshold := big.NewInt(10000)
    42  
    43  	observer := &testThresholdObserver{}
    44  
    45  	recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000))
    46  	recipient.SetPaymentThresholdObserver(observer)
    47  
    48  	peerID := swarm.MustParseHexAddress("9ee7add7")
    49  
    50  	recorder := streamtest.New(
    51  		streamtest.WithProtocols(recipient.Protocol()),
    52  		streamtest.WithBaseAddr(peerID),
    53  	)
    54  
    55  	payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000))
    56  
    57  	paymentThreshold := big.NewInt(100000)
    58  
    59  	err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  
    64  	records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  
    69  	if l := len(records); l != 1 {
    70  		t.Fatalf("got %v records, want %v", l, 1)
    71  	}
    72  
    73  	record := records[0]
    74  
    75  	messages, err := protobuf.ReadMessages(
    76  		bytes.NewReader(record.In()),
    77  		func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
    78  	)
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  
    83  	if len(messages) != 1 {
    84  		t.Fatalf("got %v messages, want %v", len(messages), 1)
    85  	}
    86  
    87  	sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
    88  	if sentPaymentThreshold.Cmp(paymentThreshold) != 0 {
    89  		t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, paymentThreshold)
    90  	}
    91  
    92  	if !observer.called {
    93  		t.Fatal("expected observer to be called")
    94  	}
    95  
    96  	if observer.paymentThreshold.Cmp(paymentThreshold) != 0 {
    97  		t.Fatalf("observer called with wrong paymentThreshold. got %v, want %v", observer.paymentThreshold, paymentThreshold)
    98  	}
    99  
   100  	if !observer.peer.Equal(peerID) {
   101  		t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID)
   102  	}
   103  }
   104  
   105  func TestAnnouncePaymentWithInsufficientThreshold(t *testing.T) {
   106  	t.Parallel()
   107  
   108  	logger := log.Noop
   109  	testThreshold := big.NewInt(100_000)
   110  	testLightThreshold := big.NewInt(10_000)
   111  
   112  	observer := &testThresholdObserver{}
   113  
   114  	minThreshold := big.NewInt(1_000_000) // above requested threshold
   115  
   116  	recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, minThreshold)
   117  	recipient.SetPaymentThresholdObserver(observer)
   118  
   119  	peerID := swarm.MustParseHexAddress("9ee7add7")
   120  
   121  	recorder := streamtest.New(
   122  		streamtest.WithProtocols(recipient.Protocol()),
   123  		streamtest.WithBaseAddr(peerID),
   124  	)
   125  
   126  	payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, minThreshold)
   127  
   128  	paymentThreshold := big.NewInt(100_000)
   129  
   130  	err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold)
   131  	if err != nil {
   132  		t.Fatal(err)
   133  	}
   134  
   135  	records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  
   140  	if l := len(records); l != 1 {
   141  		t.Fatalf("got %v records, want %v", l, 1)
   142  	}
   143  
   144  	record := records[0]
   145  
   146  	if record.Err() == nil {
   147  		t.Fatal("expected error")
   148  	}
   149  
   150  	disconnectErr := &p2p.DisconnectError{}
   151  	if !errors.As(record.Err(), &disconnectErr) {
   152  		t.Fatalf("wanted %v, got %v", disconnectErr, record.Err())
   153  	}
   154  
   155  	if !errors.Is(record.Err(), pricing.ErrThresholdTooLow) {
   156  		t.Fatalf("wanted error %v, got %v", pricing.ErrThresholdTooLow, record.Err())
   157  	}
   158  
   159  	if observer.called {
   160  		t.Fatal("unexpected call to the observer")
   161  	}
   162  }
   163  
   164  func TestInitialPaymentThreshold(t *testing.T) {
   165  	t.Parallel()
   166  
   167  	logger := log.Noop
   168  	testThreshold := big.NewInt(100000)
   169  	testLightThreshold := big.NewInt(10000)
   170  
   171  	observer := &testThresholdObserver{}
   172  
   173  	recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000))
   174  	recipient.SetPaymentThresholdObserver(observer)
   175  
   176  	peerID := swarm.MustParseHexAddress("9ee7add7")
   177  	peer := p2p.Peer{Address: peerID, FullNode: true}
   178  
   179  	recorder := streamtest.New(
   180  		streamtest.WithProtocols(recipient.Protocol()),
   181  		streamtest.WithBaseAddr(peerID),
   182  	)
   183  
   184  	payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000))
   185  
   186  	err := payer.Init(context.Background(), peer)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	if l := len(records); l != 1 {
   197  		t.Fatalf("got %v records, want %v", l, 1)
   198  	}
   199  
   200  	record := records[0]
   201  
   202  	messages, err := protobuf.ReadMessages(
   203  		bytes.NewReader(record.In()),
   204  		func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
   205  	)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	if len(messages) != 1 {
   211  		t.Fatalf("got %v messages, want %v", len(messages), 1)
   212  	}
   213  
   214  	sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
   215  	if sentPaymentThreshold.Cmp(testThreshold) != 0 {
   216  		t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, testThreshold)
   217  	}
   218  
   219  	if !observer.called {
   220  		t.Fatal("expected observer to be called")
   221  	}
   222  
   223  	if observer.paymentThreshold.Cmp(testThreshold) != 0 {
   224  		t.Fatalf("observer called with wrong paymentThreshold, got %v, want %v", observer.paymentThreshold, testThreshold)
   225  	}
   226  
   227  	if !observer.peer.Equal(peerID) {
   228  		t.Fatalf("observer called with wrong peer, got %v, want %v", observer.peer, peerID)
   229  	}
   230  }
   231  
   232  func TestInitialPaymentThresholdLightNode(t *testing.T) {
   233  	t.Parallel()
   234  
   235  	logger := log.Noop
   236  	testThreshold := big.NewInt(100000)
   237  	testLightThreshold := big.NewInt(10000)
   238  
   239  	observer := &testThresholdObserver{}
   240  
   241  	recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000))
   242  	recipient.SetPaymentThresholdObserver(observer)
   243  
   244  	peerID := swarm.MustParseHexAddress("9ee7add7")
   245  	peer := p2p.Peer{Address: peerID, FullNode: false}
   246  
   247  	recorder := streamtest.New(
   248  		streamtest.WithProtocols(recipient.Protocol()),
   249  		streamtest.WithBaseAddr(peerID),
   250  	)
   251  
   252  	payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000))
   253  
   254  	err := payer.Init(context.Background(), peer)
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  
   259  	records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing")
   260  	if err != nil {
   261  		t.Fatal(err)
   262  	}
   263  
   264  	if l := len(records); l != 1 {
   265  		t.Fatalf("got %v records, want %v", l, 1)
   266  	}
   267  
   268  	record := records[0]
   269  
   270  	messages, err := protobuf.ReadMessages(
   271  		bytes.NewReader(record.In()),
   272  		func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) },
   273  	)
   274  	if err != nil {
   275  		t.Fatal(err)
   276  	}
   277  
   278  	if len(messages) != 1 {
   279  		t.Fatalf("got %v messages, want %v", len(messages), 1)
   280  	}
   281  
   282  	sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold)
   283  	if sentPaymentThreshold.Cmp(testLightThreshold) != 0 {
   284  		t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, testLightThreshold)
   285  	}
   286  
   287  	if !observer.called {
   288  		t.Fatal("expected observer to be called")
   289  	}
   290  
   291  	if observer.paymentThreshold.Cmp(testLightThreshold) != 0 {
   292  		t.Fatalf("observer called with wrong paymentThreshold, got %v, want %v", observer.paymentThreshold, testLightThreshold)
   293  	}
   294  
   295  	if !observer.peer.Equal(peerID) {
   296  		t.Fatalf("observer called with wrong peer, got %v, want %v", observer.peer, peerID)
   297  	}
   298  }