github.com/ethersphere/bee/v2@v2.2.0/pkg/pricing/pricing_test.go (about) 1 // Copyright 2020 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package pricing_test 6 7 import ( 8 "bytes" 9 "context" 10 "errors" 11 "math/big" 12 "testing" 13 14 "github.com/ethersphere/bee/v2/pkg/log" 15 "github.com/ethersphere/bee/v2/pkg/p2p" 16 "github.com/ethersphere/bee/v2/pkg/p2p/protobuf" 17 "github.com/ethersphere/bee/v2/pkg/p2p/streamtest" 18 "github.com/ethersphere/bee/v2/pkg/pricing" 19 "github.com/ethersphere/bee/v2/pkg/pricing/pb" 20 "github.com/ethersphere/bee/v2/pkg/swarm" 21 ) 22 23 type testThresholdObserver struct { 24 called bool 25 peer swarm.Address 26 paymentThreshold *big.Int 27 } 28 29 func (t *testThresholdObserver) NotifyPaymentThreshold(peerAddr swarm.Address, paymentThreshold *big.Int) error { 30 t.called = true 31 t.peer = peerAddr 32 t.paymentThreshold = paymentThreshold 33 return nil 34 } 35 36 func TestAnnouncePaymentThreshold(t *testing.T) { 37 t.Parallel() 38 39 logger := log.Noop 40 testThreshold := big.NewInt(100000) 41 testLightThreshold := big.NewInt(10000) 42 43 observer := &testThresholdObserver{} 44 45 recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 46 recipient.SetPaymentThresholdObserver(observer) 47 48 peerID := swarm.MustParseHexAddress("9ee7add7") 49 50 recorder := streamtest.New( 51 streamtest.WithProtocols(recipient.Protocol()), 52 streamtest.WithBaseAddr(peerID), 53 ) 54 55 payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 56 57 paymentThreshold := big.NewInt(100000) 58 59 err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold) 60 if err != nil { 61 t.Fatal(err) 62 } 63 64 records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing") 65 if err != nil { 66 t.Fatal(err) 67 } 68 69 if l := len(records); l != 1 { 70 t.Fatalf("got %v records, want %v", l, 1) 71 } 72 73 record := records[0] 74 75 messages, err := protobuf.ReadMessages( 76 bytes.NewReader(record.In()), 77 func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) }, 78 ) 79 if err != nil { 80 t.Fatal(err) 81 } 82 83 if len(messages) != 1 { 84 t.Fatalf("got %v messages, want %v", len(messages), 1) 85 } 86 87 sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold) 88 if sentPaymentThreshold.Cmp(paymentThreshold) != 0 { 89 t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, paymentThreshold) 90 } 91 92 if !observer.called { 93 t.Fatal("expected observer to be called") 94 } 95 96 if observer.paymentThreshold.Cmp(paymentThreshold) != 0 { 97 t.Fatalf("observer called with wrong paymentThreshold. got %v, want %v", observer.paymentThreshold, paymentThreshold) 98 } 99 100 if !observer.peer.Equal(peerID) { 101 t.Fatalf("observer called with wrong peer. got %v, want %v", observer.peer, peerID) 102 } 103 } 104 105 func TestAnnouncePaymentWithInsufficientThreshold(t *testing.T) { 106 t.Parallel() 107 108 logger := log.Noop 109 testThreshold := big.NewInt(100_000) 110 testLightThreshold := big.NewInt(10_000) 111 112 observer := &testThresholdObserver{} 113 114 minThreshold := big.NewInt(1_000_000) // above requested threshold 115 116 recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, minThreshold) 117 recipient.SetPaymentThresholdObserver(observer) 118 119 peerID := swarm.MustParseHexAddress("9ee7add7") 120 121 recorder := streamtest.New( 122 streamtest.WithProtocols(recipient.Protocol()), 123 streamtest.WithBaseAddr(peerID), 124 ) 125 126 payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, minThreshold) 127 128 paymentThreshold := big.NewInt(100_000) 129 130 err := payer.AnnouncePaymentThreshold(context.Background(), peerID, paymentThreshold) 131 if err != nil { 132 t.Fatal(err) 133 } 134 135 records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing") 136 if err != nil { 137 t.Fatal(err) 138 } 139 140 if l := len(records); l != 1 { 141 t.Fatalf("got %v records, want %v", l, 1) 142 } 143 144 record := records[0] 145 146 if record.Err() == nil { 147 t.Fatal("expected error") 148 } 149 150 disconnectErr := &p2p.DisconnectError{} 151 if !errors.As(record.Err(), &disconnectErr) { 152 t.Fatalf("wanted %v, got %v", disconnectErr, record.Err()) 153 } 154 155 if !errors.Is(record.Err(), pricing.ErrThresholdTooLow) { 156 t.Fatalf("wanted error %v, got %v", pricing.ErrThresholdTooLow, record.Err()) 157 } 158 159 if observer.called { 160 t.Fatal("unexpected call to the observer") 161 } 162 } 163 164 func TestInitialPaymentThreshold(t *testing.T) { 165 t.Parallel() 166 167 logger := log.Noop 168 testThreshold := big.NewInt(100000) 169 testLightThreshold := big.NewInt(10000) 170 171 observer := &testThresholdObserver{} 172 173 recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 174 recipient.SetPaymentThresholdObserver(observer) 175 176 peerID := swarm.MustParseHexAddress("9ee7add7") 177 peer := p2p.Peer{Address: peerID, FullNode: true} 178 179 recorder := streamtest.New( 180 streamtest.WithProtocols(recipient.Protocol()), 181 streamtest.WithBaseAddr(peerID), 182 ) 183 184 payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 185 186 err := payer.Init(context.Background(), peer) 187 if err != nil { 188 t.Fatal(err) 189 } 190 191 records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing") 192 if err != nil { 193 t.Fatal(err) 194 } 195 196 if l := len(records); l != 1 { 197 t.Fatalf("got %v records, want %v", l, 1) 198 } 199 200 record := records[0] 201 202 messages, err := protobuf.ReadMessages( 203 bytes.NewReader(record.In()), 204 func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) }, 205 ) 206 if err != nil { 207 t.Fatal(err) 208 } 209 210 if len(messages) != 1 { 211 t.Fatalf("got %v messages, want %v", len(messages), 1) 212 } 213 214 sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold) 215 if sentPaymentThreshold.Cmp(testThreshold) != 0 { 216 t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, testThreshold) 217 } 218 219 if !observer.called { 220 t.Fatal("expected observer to be called") 221 } 222 223 if observer.paymentThreshold.Cmp(testThreshold) != 0 { 224 t.Fatalf("observer called with wrong paymentThreshold, got %v, want %v", observer.paymentThreshold, testThreshold) 225 } 226 227 if !observer.peer.Equal(peerID) { 228 t.Fatalf("observer called with wrong peer, got %v, want %v", observer.peer, peerID) 229 } 230 } 231 232 func TestInitialPaymentThresholdLightNode(t *testing.T) { 233 t.Parallel() 234 235 logger := log.Noop 236 testThreshold := big.NewInt(100000) 237 testLightThreshold := big.NewInt(10000) 238 239 observer := &testThresholdObserver{} 240 241 recipient := pricing.New(nil, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 242 recipient.SetPaymentThresholdObserver(observer) 243 244 peerID := swarm.MustParseHexAddress("9ee7add7") 245 peer := p2p.Peer{Address: peerID, FullNode: false} 246 247 recorder := streamtest.New( 248 streamtest.WithProtocols(recipient.Protocol()), 249 streamtest.WithBaseAddr(peerID), 250 ) 251 252 payer := pricing.New(recorder, logger, testThreshold, testLightThreshold, big.NewInt(1000)) 253 254 err := payer.Init(context.Background(), peer) 255 if err != nil { 256 t.Fatal(err) 257 } 258 259 records, err := recorder.Records(peerID, "pricing", "1.0.0", "pricing") 260 if err != nil { 261 t.Fatal(err) 262 } 263 264 if l := len(records); l != 1 { 265 t.Fatalf("got %v records, want %v", l, 1) 266 } 267 268 record := records[0] 269 270 messages, err := protobuf.ReadMessages( 271 bytes.NewReader(record.In()), 272 func() protobuf.Message { return new(pb.AnnouncePaymentThreshold) }, 273 ) 274 if err != nil { 275 t.Fatal(err) 276 } 277 278 if len(messages) != 1 { 279 t.Fatalf("got %v messages, want %v", len(messages), 1) 280 } 281 282 sentPaymentThreshold := big.NewInt(0).SetBytes(messages[0].(*pb.AnnouncePaymentThreshold).PaymentThreshold) 283 if sentPaymentThreshold.Cmp(testLightThreshold) != 0 { 284 t.Fatalf("got message with amount %v, want %v", sentPaymentThreshold, testLightThreshold) 285 } 286 287 if !observer.called { 288 t.Fatal("expected observer to be called") 289 } 290 291 if observer.paymentThreshold.Cmp(testLightThreshold) != 0 { 292 t.Fatalf("observer called with wrong paymentThreshold, got %v, want %v", observer.paymentThreshold, testLightThreshold) 293 } 294 295 if !observer.peer.Equal(peerID) { 296 t.Fatalf("observer called with wrong peer, got %v, want %v", observer.peer, peerID) 297 } 298 }