github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/session/pingpong/invoice_tracker.go (about) 1 /* 2 * Copyright (C) 2019 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 package pingpong 19 20 import ( 21 "bytes" 22 crand "crypto/rand" 23 "encoding/hex" 24 stdErr "errors" 25 "fmt" 26 "math" 27 "math/big" 28 "strings" 29 "sync" 30 "time" 31 32 "github.com/ethereum/go-ethereum/common" 33 "github.com/pkg/errors" 34 "github.com/rs/zerolog/log" 35 36 "github.com/mysteriumnetwork/node/config" 37 "github.com/mysteriumnetwork/node/eventbus" 38 "github.com/mysteriumnetwork/node/identity" 39 "github.com/mysteriumnetwork/node/market" 40 "github.com/mysteriumnetwork/node/p2p" 41 sessionEvent "github.com/mysteriumnetwork/node/session/event" 42 "github.com/mysteriumnetwork/node/session/pingpong/event" 43 "github.com/mysteriumnetwork/payments/crypto" 44 ) 45 46 // ErrConsumerPromiseValidationFailed represents an error where consumer tries to cheat us with incorrect promises. 47 var ErrConsumerPromiseValidationFailed = errors.New("consumer failed to issue promise for the correct amount") 48 49 // ErrHermesFeeTooLarge indicates that we do not allow hermess with such high fees 50 var ErrHermesFeeTooLarge = errors.New("hermes fee exceeds predefined limits") 51 52 // ErrHermesInactive indicates that the chosen hermes is not active 53 var ErrHermesInactive = errors.New("hermes is not active") 54 55 // ErrInvoiceExpired shows that the given invoice has already expired 56 var ErrInvoiceExpired = errors.New("invoice expired") 57 58 // ErrExchangeWaitTimeout indicates that we did not get an exchange message in time. 59 var ErrExchangeWaitTimeout = errors.New("did not get a new exchange message") 60 61 // ErrInvoiceSendMaxFailCountReached indicates that we did not sent an exchange message in time. 62 var ErrInvoiceSendMaxFailCountReached = errors.New("did not sent a new exchange message") 63 64 // ErrExchangeValidationFailed indicates that there was an error with the exchange signature. 65 var ErrExchangeValidationFailed = errors.New("exchange validation failed") 66 67 // ErrConsumerNotRegistered represents the error that the consumer is not registered 68 var ErrConsumerNotRegistered = errors.New("consumer not registered") 69 70 var providerFirstInvoiceValue = big.NewInt(1) 71 72 // PeerInvoiceSender allows to send invoices. 73 type PeerInvoiceSender interface { 74 Send(crypto.Invoice) error 75 } 76 77 type hermesStatusChecker interface { 78 GetHermesStatus(chainID int64, registryAddress common.Address, hermesID common.Address) (HermesStatus, error) 79 } 80 81 type providerInvoiceStorage interface { 82 Get(providerIdentity, consumerIdentity identity.Identity) (crypto.Invoice, error) 83 Store(providerIdentity, consumerIdentity identity.Identity, invoice crypto.Invoice) error 84 StoreR(providerIdentity identity.Identity, agreementID *big.Int, r string) error 85 GetR(providerID identity.Identity, agreementID *big.Int) (string, error) 86 } 87 88 type promiseHandler interface { 89 RequestPromise(r []byte, em crypto.ExchangeMessage, providerID identity.Identity, sessionID string) <-chan error 90 } 91 92 type sentInvoice struct { 93 invoice crypto.Invoice 94 r []byte 95 isCritical bool 96 } 97 98 // DataTransferred represents the data transferred in a session. 99 type DataTransferred struct { 100 Up, Down uint64 101 } 102 103 func (dt DataTransferred) sum() uint64 { 104 return dt.Up + dt.Down 105 } 106 107 // InvoiceTracker keeps tab of invoices and sends them to the consumer. 108 type InvoiceTracker struct { 109 stop chan struct{} 110 promiseErrors chan error 111 invoiceChannel chan bool 112 hermesFailureCount uint64 113 hermesFailureCountLock sync.Mutex 114 115 notReceivedExchangeMessageCount uint64 116 notSentExchangeMessageCount uint64 117 exchangeMessageCountLock sync.Mutex 118 119 maxNotReceivedExchangeMessages uint64 120 maxNotSentExchangeMessages uint64 121 once sync.Once 122 agreementID *big.Int 123 firstInvoicePaid bool 124 invoicesSent map[string]sentInvoice 125 invoiceLock sync.Mutex 126 deps InvoiceTrackerDeps 127 128 dataTransferred DataTransferred 129 dataTransferredLock sync.Mutex 130 131 criticalInvoiceErrors chan error 132 lastInvoiceSent time.Duration 133 invoiceDebounceRate time.Duration 134 135 lastExchangeMessage crypto.ExchangeMessage 136 lastExchangeMessageLock sync.Mutex 137 } 138 139 // InvoiceTrackerDeps contains all the deps needed for invoice tracker. 140 type InvoiceTrackerDeps struct { 141 AgreedPrice market.Price 142 Peer identity.Identity 143 PeerInvoiceSender PeerInvoiceSender 144 InvoiceStorage providerInvoiceStorage 145 TimeTracker timeTracker 146 ChargePeriodLeeway time.Duration 147 ExchangeMessageChan chan crypto.ExchangeMessage 148 ExchangeMessageWaitTimeout time.Duration 149 ProviderID identity.Identity 150 ConsumersHermesID common.Address 151 AddressProvider addressProvider 152 MaxHermesFailureCount uint64 153 MaxAllowedHermesFee uint16 154 HermesStatusChecker hermesStatusChecker 155 EventBus eventbus.EventBus 156 SessionID string 157 PromiseHandler promiseHandler 158 ChainID int64 159 ChargePeriod time.Duration 160 LimitChargePeriod time.Duration 161 LimitNotPaidInvoice *big.Int 162 MaxNotPaidInvoice *big.Int 163 Observer observerApi 164 } 165 166 // NewInvoiceTracker creates a new instance of invoice tracker. 167 func NewInvoiceTracker( 168 itd InvoiceTrackerDeps, 169 ) *InvoiceTracker { 170 return &InvoiceTracker{ 171 lastExchangeMessage: crypto.ExchangeMessage{ 172 Promise: crypto.Promise{ 173 Amount: new(big.Int), 174 Fee: new(big.Int), 175 }, 176 AgreementID: new(big.Int), 177 AgreementTotal: new(big.Int), 178 }, 179 stop: make(chan struct{}), 180 deps: itd, 181 maxNotReceivedExchangeMessages: calculateMaxNotReceivedExchangeMessageCount(itd.ChargePeriodLeeway, itd.ChargePeriod), 182 maxNotSentExchangeMessages: calculateMaxNotSentExchangeMessageCount(itd.ChargePeriodLeeway, itd.ChargePeriod), 183 invoicesSent: make(map[string]sentInvoice), 184 promiseErrors: make(chan error), 185 criticalInvoiceErrors: make(chan error), 186 invoiceChannel: make(chan bool), 187 invoiceDebounceRate: time.Second * 5, 188 } 189 } 190 191 func calculateMaxNotReceivedExchangeMessageCount(chargeLeeway, chargePeriod time.Duration) uint64 { 192 return uint64(math.Round(float64(chargeLeeway) / float64(chargePeriod))) 193 } 194 195 func calculateMaxNotSentExchangeMessageCount(chargeLeeway, chargePeriod time.Duration) uint64 { 196 return uint64(math.Round(float64(chargeLeeway) / float64(chargePeriod))) 197 } 198 199 func (it *InvoiceTracker) markInvoiceSent(invoice sentInvoice) { 200 it.invoiceLock.Lock() 201 defer it.invoiceLock.Unlock() 202 203 it.invoicesSent[invoice.invoice.Hashlock] = invoice 204 } 205 206 func (it *InvoiceTracker) markInvoicePaid(hashlock []byte) { 207 it.invoiceLock.Lock() 208 defer it.invoiceLock.Unlock() 209 210 if !it.firstInvoicePaid { 211 it.firstInvoicePaid = true 212 } 213 214 delete(it.invoicesSent, hex.EncodeToString(hashlock)) 215 } 216 217 func (it *InvoiceTracker) getMarkedInvoice(hashlock []byte) (invoice sentInvoice, ok bool) { 218 it.invoiceLock.Lock() 219 defer it.invoiceLock.Unlock() 220 in, ok := it.invoicesSent[hex.EncodeToString(hashlock)] 221 return in, ok 222 } 223 224 func (it *InvoiceTracker) listenForExchangeMessages() error { 225 for { 226 select { 227 case pm := <-it.deps.ExchangeMessageChan: 228 err := it.handleExchangeMessage(pm) 229 if err != nil && err != ErrInvoiceExpired { 230 return err 231 } 232 case <-it.stop: 233 return nil 234 } 235 } 236 } 237 238 func (it *InvoiceTracker) generateAgreementID() { 239 agreementID := make([]byte, 32) 240 _, err := crand.Read(agreementID) 241 if err != nil { 242 panic(err) 243 } 244 it.agreementID = new(big.Int).SetBytes(agreementID) 245 } 246 247 func (it *InvoiceTracker) handleExchangeMessage(em crypto.ExchangeMessage) error { 248 invoice, ok := it.getMarkedInvoice(em.Promise.Hashlock) 249 if !ok { 250 log.Debug().Msgf("consumer sent exchange message with missing expired hashlock %s, skipping", invoice.invoice.Hashlock) 251 return ErrInvoiceExpired 252 } 253 254 err := it.validateExchangeMessage(em) 255 if err != nil { 256 return err 257 } 258 259 it.saveLastExchangeMessage(em) 260 it.markInvoicePaid(em.Promise.Hashlock) 261 it.resetNotReceivedExchangeMessageCount() 262 it.resetNotSentExchangeMessageCount() 263 264 // incase of zero payment, we'll just skip going to the hermes 265 if it.deps.AgreedPrice.IsFree() { 266 return nil 267 } 268 269 err = it.deps.InvoiceStorage.StoreR(it.deps.ProviderID, it.agreementID, hex.EncodeToString(invoice.r)) 270 if err != nil { 271 return errors.Wrap(err, fmt.Sprintf("could not store r: %s", hex.EncodeToString(invoice.r))) 272 } 273 errChan := it.deps.PromiseHandler.RequestPromise(invoice.r, em, it.deps.ProviderID, it.deps.SessionID) 274 go it.handlePromiseErrors(errChan) 275 return nil 276 } 277 278 // Start stars the invoice tracker 279 func (it *InvoiceTracker) Start() error { 280 log.Debug().Msgf("Starting invoice tracker for session %s", it.deps.SessionID) 281 it.deps.TimeTracker.StartTracking() 282 283 if err := it.deps.EventBus.SubscribeWithUID(sessionEvent.AppTopicDataTransferred, it.deps.SessionID, it.consumeDataTransferredEvent); err != nil { 284 return err 285 } 286 287 registry, err := it.deps.AddressProvider.GetRegistryAddress(it.deps.ChainID) 288 if err != nil { 289 return err 290 } 291 292 status, err := it.deps.HermesStatusChecker.GetHermesStatus(it.deps.ChainID, registry, it.deps.ConsumersHermesID) 293 if err != nil { 294 return fmt.Errorf("could not check hermes status: %w", err) 295 } 296 297 if !status.IsActive { 298 log.Error().Msgf("Hermes(%v) is inactive", it.deps.ConsumersHermesID.Hex()) 299 return ErrHermesInactive 300 } 301 302 if status.Fee > it.deps.MaxAllowedHermesFee { 303 log.Error().Msgf("Hermes fee too large, asking for %v where %v is the limit", status.Fee, it.deps.MaxAllowedHermesFee) 304 return ErrHermesFeeTooLarge 305 } 306 307 it.generateAgreementID() 308 309 emErrors := make(chan error) 310 go func() { 311 emErrors <- it.listenForExchangeMessages() 312 }() 313 314 err = it.sendInvoice(true) 315 if err != nil { 316 return fmt.Errorf("could not send first invoice: %w", err) 317 } 318 319 go it.sendInvoicesWhenNeeded(time.Second) 320 for { 321 select { 322 case <-it.stop: 323 return nil 324 case critical := <-it.invoiceChannel: 325 err := it.sendInvoice(critical) 326 if err != nil { 327 if stdErr.Is(err, p2p.ErrSendTimeout) { 328 log.Warn().Err(err).Msg("Marking invoice as not sent") 329 it.markExchangeMessageNotSent() 330 } else { 331 return fmt.Errorf("sending of invoice failed: %w", err) 332 } 333 } 334 case err := <-it.criticalInvoiceErrors: 335 return err 336 case emErr := <-emErrors: 337 if emErr != nil { 338 return errors.Wrap(emErr, "failed to get exchange message") 339 } 340 case pErr := <-it.promiseErrors: 341 err := it.handleHermesError(pErr) 342 if err != nil { 343 return fmt.Errorf("could not request promise: %w", err) 344 } 345 } 346 } 347 } 348 349 func (it *InvoiceTracker) sendInvoicesWhenNeeded(interval time.Duration) { 350 it.lastInvoiceSent = it.deps.TimeTracker.Elapsed() 351 for { 352 select { 353 case <-it.stop: 354 return 355 case <-time.After(interval): 356 currentlyElapsed := it.deps.TimeTracker.Elapsed() 357 shouldBe := CalculatePaymentAmount(currentlyElapsed, it.getDataTransferred(), it.deps.AgreedPrice) 358 lastEM := it.getLastExchangeMessage() 359 diff := safeSub(shouldBe, lastEM.AgreementTotal) 360 if diff.Cmp(it.deps.MaxNotPaidInvoice) >= 0 && currentlyElapsed-it.lastInvoiceSent > it.invoiceDebounceRate { 361 it.lastInvoiceSent = it.deps.TimeTracker.Elapsed() 362 it.invoiceChannel <- true 363 364 it.updateMaxUnpaid() 365 } else if currentlyElapsed-it.lastInvoiceSent > it.deps.ChargePeriod { 366 it.lastInvoiceSent = it.deps.TimeTracker.Elapsed() 367 it.invoiceChannel <- false 368 369 it.updateTimer() 370 } 371 } 372 } 373 } 374 375 const sessionInvoiceIncreaseSlope = 3 376 377 func (it *InvoiceTracker) updateMaxUnpaid() { 378 limit := it.deps.LimitNotPaidInvoice 379 if limit == nil || it.deps.MaxNotPaidInvoice.Cmp(limit) >= 0 { 380 return 381 } 382 383 add := new(big.Int).Div(it.deps.MaxNotPaidInvoice, new(big.Int).SetInt64(sessionInvoiceIncreaseSlope)) 384 bigger := new(big.Int).Add(it.deps.MaxNotPaidInvoice, add) 385 if bigger.Cmp(limit) > 0 { 386 bigger = limit 387 } 388 389 it.deps.MaxNotPaidInvoice = bigger 390 log.Debug().Str("invoice_amount", it.deps.MaxNotPaidInvoice.String()).Msg("Max invoice amount increased") 391 } 392 393 func (it *InvoiceTracker) updateTimer() { 394 maxTime := it.deps.LimitChargePeriod 395 if it.deps.ChargePeriod >= maxTime { 396 return 397 } 398 399 newMaxTime := it.deps.ChargePeriod/sessionInvoiceIncreaseSlope + it.deps.ChargePeriod 400 if newMaxTime > maxTime { 401 newMaxTime = maxTime 402 } 403 it.deps.ChargePeriod = newMaxTime 404 log.Debug().Int64("change_period (ms)", it.deps.ChargePeriod.Milliseconds()).Msg("Max charge period increased") 405 } 406 407 // WaitFirstInvoice waits for a first invoice to be paid. 408 func (it *InvoiceTracker) WaitFirstInvoice(wait time.Duration) error { 409 timeout := time.After(wait) 410 411 for { 412 select { 413 case <-time.After(10 * time.Millisecond): 414 it.invoiceLock.Lock() 415 paid := it.firstInvoicePaid 416 it.invoiceLock.Unlock() 417 if paid { 418 return nil 419 } 420 case <-timeout: 421 return fmt.Errorf("failed waiting for first invoice") 422 case <-it.stop: 423 return nil 424 } 425 } 426 } 427 428 func (it *InvoiceTracker) handlePromiseErrors(ch <-chan error) { 429 for err := range ch { 430 it.promiseErrors <- err 431 } 432 } 433 434 func (it *InvoiceTracker) markExchangeMessageNotReceived() { 435 it.exchangeMessageCountLock.Lock() 436 defer it.exchangeMessageCountLock.Unlock() 437 it.notReceivedExchangeMessageCount++ 438 } 439 440 func (it *InvoiceTracker) markExchangeMessageNotSent() { 441 it.exchangeMessageCountLock.Lock() 442 defer it.exchangeMessageCountLock.Unlock() 443 it.notSentExchangeMessageCount++ 444 } 445 446 func (it *InvoiceTracker) resetNotReceivedExchangeMessageCount() { 447 it.exchangeMessageCountLock.Lock() 448 defer it.exchangeMessageCountLock.Unlock() 449 it.notReceivedExchangeMessageCount = 0 450 } 451 452 func (it *InvoiceTracker) resetNotSentExchangeMessageCount() { 453 it.exchangeMessageCountLock.Lock() 454 defer it.exchangeMessageCountLock.Unlock() 455 it.notSentExchangeMessageCount = 0 456 } 457 458 func (it *InvoiceTracker) getNotReceivedExchangeMessageCount() uint64 { 459 it.exchangeMessageCountLock.Lock() 460 defer it.exchangeMessageCountLock.Unlock() 461 return it.notReceivedExchangeMessageCount 462 } 463 464 func (it *InvoiceTracker) getNotSentExchangeMessageCount() uint64 { 465 it.exchangeMessageCountLock.Lock() 466 defer it.exchangeMessageCountLock.Unlock() 467 return it.notSentExchangeMessageCount 468 } 469 470 func (it *InvoiceTracker) saveLastExchangeMessage(em crypto.ExchangeMessage) { 471 it.lastExchangeMessageLock.Lock() 472 defer it.lastExchangeMessageLock.Unlock() 473 it.lastExchangeMessage = em 474 } 475 476 func (it *InvoiceTracker) getLastExchangeMessage() crypto.ExchangeMessage { 477 it.lastExchangeMessageLock.Lock() 478 defer it.lastExchangeMessageLock.Unlock() 479 return it.lastExchangeMessage 480 } 481 482 func (it *InvoiceTracker) chainID() int64 { 483 return config.GetInt64(config.FlagChainID) 484 } 485 486 func (it *InvoiceTracker) sendInvoice(isCritical bool) error { 487 if it.getNotSentExchangeMessageCount() >= it.maxNotSentExchangeMessages { 488 return ErrInvoiceSendMaxFailCountReached 489 } 490 491 if it.getNotReceivedExchangeMessageCount() >= it.maxNotReceivedExchangeMessages { 492 return ErrExchangeWaitTimeout 493 } 494 495 shouldBe := CalculatePaymentAmount(it.deps.TimeTracker.Elapsed(), it.getDataTransferred(), it.deps.AgreedPrice) 496 497 lastEm := it.getLastExchangeMessage() 498 if lastEm.AgreementTotal.Cmp(big.NewInt(0)) == 0 && shouldBe.Cmp(big.NewInt(0)) == 1 { 499 // The first invoice should have minimal static value. 500 shouldBe = providerFirstInvoiceValue 501 log.Debug().Msgf("Being lenient for the first payment, asking for %v", shouldBe) 502 } 503 504 r, err := crypto.GenerateR() 505 if err != nil { 506 return fmt.Errorf("failed to generate R: %w", err) 507 } 508 invoice, err := crypto.CreateInvoice(it.agreementID, shouldBe, new(big.Int), r, it.chainID()) 509 if err != nil { 510 return fmt.Errorf("failed to create invoice: %w", err) 511 } 512 513 invoice.Provider = it.deps.ProviderID.Address 514 err = it.deps.PeerInvoiceSender.Send(invoice) 515 if err != nil { 516 return err 517 } 518 519 it.markInvoiceSent(sentInvoice{ 520 invoice: invoice, 521 r: r, 522 isCritical: isCritical, 523 }) 524 525 hlock, err := hex.DecodeString(invoice.Hashlock) 526 if err != nil { 527 return err 528 } 529 530 go it.waitForInvoicePayment(hlock) 531 532 err = it.deps.InvoiceStorage.Store(it.deps.ProviderID, it.deps.Peer, invoice) 533 return errors.Wrap(err, "could not store invoice") 534 } 535 536 func (it *InvoiceTracker) waitForInvoicePayment(hlock []byte) { 537 select { 538 case <-time.After(it.deps.ExchangeMessageWaitTimeout): 539 inv, ok := it.getMarkedInvoice(hlock) 540 if !ok { 541 return 542 } 543 544 if inv.isCritical { 545 log.Info().Msgf("did not get paid for invoice with hashlock %v, invoice is critical. Aborting.", inv.invoice.Hashlock) 546 it.criticalInvoiceErrors <- fmt.Errorf("did not get paid for critical invoice with hashlock %v", inv.invoice.Hashlock) 547 return 548 } 549 550 log.Info().Msgf("did not get paid for invoice with hashlock %v, incrementing failure count", inv.invoice.Hashlock) 551 it.markInvoicePaid(hlock) 552 it.markExchangeMessageNotReceived() 553 case <-it.stop: 554 return 555 } 556 } 557 558 func (it *InvoiceTracker) handleHermesError(err error) error { 559 if err == nil { 560 it.resetHermesFailureCount() 561 return nil 562 } 563 564 switch { 565 case 566 stdErr.Is(err, ErrHermesHashlockMissmatch), 567 stdErr.Is(err, ErrHermesPreviousRNotRevealed), 568 stdErr.Is(err, ErrHermesInternal), 569 stdErr.Is(err, ErrHermesNotFound), 570 stdErr.Is(err, ErrHermesMalformedJSON), 571 stdErr.Is(err, ErrTooManyRequests): 572 // these are ignorable, we'll eventually fail 573 if it.incrementHermesFailureCount() > it.deps.MaxHermesFailureCount { 574 return err 575 } 576 log.Warn().Err(err).Msg("hermes error, will retry") 577 return nil 578 case 579 stdErr.Is(err, ErrHermesInvalidSignature), 580 stdErr.Is(err, ErrHermesPaymentValueTooLow), 581 stdErr.Is(err, ErrHermesPromiseValueTooLow), 582 stdErr.Is(err, ErrHermesOverspend), 583 stdErr.Is(err, ErrConsumerUnregistered): 584 // these are critical, return and cancel session 585 return err 586 // under normal use, this should not occur. If it does, we should drop sessions until we settle because we're not getting paid. 587 case stdErr.Is(err, ErrHermesProviderBalanceExhausted): 588 hermes, err := it.deps.AddressProvider.GetActiveHermes(it.chainID()) 589 if err != nil { 590 return err 591 } 592 it.deps.EventBus.Publish( 593 event.AppTopicSettlementRequest, 594 event.AppEventSettlementRequest{ 595 ChainID: it.chainID(), 596 HermesID: hermes, 597 ProviderID: it.deps.ProviderID, 598 }, 599 ) 600 return err 601 default: 602 if it.incrementHermesFailureCount() > it.deps.MaxHermesFailureCount { 603 return err 604 } 605 log.Warn().Err(err).Msg("unknown hermes error encountered, will retry") 606 return nil 607 } 608 } 609 610 func (it *InvoiceTracker) incrementHermesFailureCount() uint64 { 611 it.hermesFailureCountLock.Lock() 612 defer it.hermesFailureCountLock.Unlock() 613 it.hermesFailureCount++ 614 log.Trace().Msgf("hermes error count %v/%v", it.hermesFailureCount, it.deps.MaxHermesFailureCount) 615 return it.hermesFailureCount 616 } 617 618 func (it *InvoiceTracker) resetHermesFailureCount() { 619 it.hermesFailureCountLock.Lock() 620 defer it.hermesFailureCountLock.Unlock() 621 it.hermesFailureCount = 0 622 } 623 624 func (it *InvoiceTracker) validateExchangeMessage(em crypto.ExchangeMessage) error { 625 peerAddr := common.HexToAddress(it.deps.Peer.Address) 626 if res := em.IsMessageValid(peerAddr); !res { 627 return ErrExchangeValidationFailed 628 } 629 630 if em.ChainID != it.chainID() { 631 return fmt.Errorf("invalid chain id in exchange message: expected %v, got %v", it.chainID(), em.ChainID) 632 } 633 634 signer, err := em.Promise.RecoverSigner() 635 if err != nil { 636 return errors.Wrap(err, "could not recover promise signature") 637 } 638 639 if signer.Hex() != peerAddr.Hex() { 640 return errors.New("identity missmatch") 641 } 642 643 lastEm := it.getLastExchangeMessage() 644 if em.Promise.Amount.Cmp(lastEm.Promise.Amount) == -1 { 645 log.Warn().Msgf("Consumer sent an invalid amount. Expected < %v, got %v", lastEm.Promise.Amount, em.Promise.Amount) 646 return errors.Wrap(ErrConsumerPromiseValidationFailed, "invalid amount") 647 } 648 649 registry, err := it.deps.AddressProvider.GetRegistryAddress(em.ChainID) 650 if err != nil { 651 return errors.Wrap(err, "could not get registry address") 652 } 653 654 hermesId := common.HexToAddress(em.HermesID) 655 chimp, err := it.deps.AddressProvider.GetChannelImplementationForHermes(em.ChainID, hermesId) 656 if err != nil { 657 log.Err(err).Msgf("Failed to get channel implementation for hermes %s, using fallback", em.HermesID) 658 hermesData, err := it.deps.Observer.GetHermesData(em.ChainID, hermesId) 659 if err != nil { 660 return errors.Wrap(err, "could not get channel implementation") 661 } 662 chimp = hermesData.ChannelImpl 663 } 664 665 addr, err := it.deps.AddressProvider.GetArbitraryChannelAddress(common.HexToAddress(em.HermesID), registry, chimp, it.deps.Peer.ToCommonAddress()) 666 if err != nil { 667 return errors.Wrap(err, "could not generate channel address") 668 } 669 670 expectedChannel, err := hex.DecodeString(strings.TrimPrefix(addr.Hex(), "0x")) 671 if err != nil { 672 return errors.Wrap(err, "could not decode expected chanel") 673 } 674 675 if !bytes.Equal(expectedChannel, em.Promise.ChannelID) { 676 log.Warn().Msgf("Consumer sent an invalid channel address. Expected %q, got %q", addr.Hex(), hex.EncodeToString(em.Promise.ChannelID)) 677 return errors.Wrap(ErrConsumerPromiseValidationFailed, "invalid channel address") 678 } 679 return nil 680 } 681 682 // Stop stops the invoice tracker. 683 func (it *InvoiceTracker) Stop() { 684 it.once.Do(func() { 685 log.Debug().Msgf("Stopping invoice tracker for session %s", it.deps.SessionID) 686 _ = it.deps.EventBus.UnsubscribeWithUID(sessionEvent.AppTopicDataTransferred, it.deps.SessionID, it.consumeDataTransferredEvent) 687 close(it.stop) 688 }) 689 } 690 691 func (it *InvoiceTracker) consumeDataTransferredEvent(e sessionEvent.AppEventDataTransferred) { 692 // skip irrelevant sessions 693 if !strings.EqualFold(e.ID, it.deps.SessionID) { 694 return 695 } 696 697 // From a server perspective, bytes up are the actual bytes the client downloaded(aka the bytes we pushed to the consumer) 698 // To lessen the confusion, I suggest having the bytes reversed on the session instance. 699 // This way, the session will show that it downloaded the bytes in a manner that is easier to comprehend. 700 it.updateDataTransfer(e.Down, e.Up) 701 } 702 703 func (it *InvoiceTracker) updateDataTransfer(up, down uint64) { 704 it.dataTransferredLock.Lock() 705 defer it.dataTransferredLock.Unlock() 706 707 newUp := it.dataTransferred.Up 708 if up > it.dataTransferred.Up { 709 newUp = up 710 } 711 712 newDown := it.dataTransferred.Down 713 if down > it.dataTransferred.Down { 714 newDown = down 715 } 716 717 it.dataTransferred = DataTransferred{ 718 Up: newUp, 719 Down: newDown, 720 } 721 } 722 723 func (it *InvoiceTracker) getDataTransferred() DataTransferred { 724 it.dataTransferredLock.Lock() 725 defer it.dataTransferredLock.Unlock() 726 727 return it.dataTransferred 728 }