github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/session/pingpong/invoice_payer.go (about)

     1  /*
     2   * Copyright (C) 2019 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package pingpong
    19  
    20  import (
    21  	"fmt"
    22  	"math/big"
    23  	"strings"
    24  	"sync"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/accounts"
    28  	"github.com/ethereum/go-ethereum/common"
    29  	"github.com/gofrs/uuid"
    30  	"github.com/pkg/errors"
    31  	"github.com/rs/zerolog/log"
    32  
    33  	"github.com/mysteriumnetwork/node/core/connection/connectionstate"
    34  	"github.com/mysteriumnetwork/node/datasize"
    35  	"github.com/mysteriumnetwork/node/eventbus"
    36  	"github.com/mysteriumnetwork/node/identity"
    37  	"github.com/mysteriumnetwork/node/market"
    38  	"github.com/mysteriumnetwork/node/session/pingpong/event"
    39  	"github.com/mysteriumnetwork/payments/crypto"
    40  )
    41  
    42  // ErrWrongProvider represents an issue where the wrong provider is supplied.
    43  var ErrWrongProvider = errors.New("wrong provider supplied")
    44  
    45  // ErrProviderOvercharge represents an issue where the provider is trying to overcharge us.
    46  var ErrProviderOvercharge = errors.New("provider is overcharging")
    47  
    48  // consumerInvoiceBasicTolerance provider traffic amount compensation due to:
    49  //   - different MTU sizes
    50  //   - measurement timing inaccuracies
    51  //   - possible in-transit packet fragmentation
    52  //   - non-agreed traffic: traffic blocked / dropped / not reachable / failed retransmits on provider
    53  const consumerInvoiceBasicTolerance = 1.11
    54  
    55  // PeerExchangeMessageSender allows for sending of exchange messages.
    56  type PeerExchangeMessageSender interface {
    57  	Send(crypto.ExchangeMessage) error
    58  }
    59  
    60  type consumerTotalsStorage interface {
    61  	Store(chainID int64, id identity.Identity, hermesID common.Address, amount *big.Int) error
    62  	Get(chainID int64, id identity.Identity, hermesID common.Address) (*big.Int, error)
    63  	Add(chainID int64, id identity.Identity, hermesID common.Address, amount *big.Int) error
    64  }
    65  
    66  type timeTracker interface {
    67  	StartTracking()
    68  	Elapsed() time.Duration
    69  }
    70  
    71  type channelAddressCalculator interface {
    72  	GetChannelAddress(id identity.Identity) (common.Address, error)
    73  }
    74  
    75  // InvoicePayer keeps track of exchange messages and sends them to the provider.
    76  type InvoicePayer struct {
    77  	stop           chan struct{}
    78  	once           sync.Once
    79  	channelAddress identity.Identity
    80  
    81  	lastInvoice crypto.Invoice
    82  	deps        InvoicePayerDeps
    83  
    84  	dataTransferred     DataTransferred
    85  	dataTransferredLock sync.Mutex
    86  
    87  	sessionIDLock sync.Mutex
    88  }
    89  
    90  type hashSigner interface {
    91  	SignHash(a accounts.Account, hash []byte) ([]byte, error)
    92  }
    93  
    94  // InvoicePayerDeps contains all the dependencies for the exchange message tracker.
    95  type InvoicePayerDeps struct {
    96  	InvoiceChan               chan crypto.Invoice
    97  	PeerExchangeMessageSender PeerExchangeMessageSender
    98  	ConsumerTotalsStorage     consumerTotalsStorage
    99  	TimeTracker               timeTracker
   100  	Ks                        hashSigner
   101  	Identity, Peer            identity.Identity
   102  	AgreedPrice               market.Price
   103  	SenderUUID                string
   104  	SessionID                 string
   105  	AddressProvider           addressProvider
   106  	EventBus                  eventbus.EventBus
   107  	HermesAddress             common.Address
   108  	DataLeeway                datasize.BitSize
   109  	ChainID                   int64
   110  }
   111  
   112  // NewInvoicePayer returns a new instance of exchange message tracker.
   113  func NewInvoicePayer(ipd InvoicePayerDeps) *InvoicePayer {
   114  	return &InvoicePayer{
   115  		stop: make(chan struct{}),
   116  		deps: ipd,
   117  		lastInvoice: crypto.Invoice{
   118  			AgreementID:    new(big.Int),
   119  			AgreementTotal: new(big.Int),
   120  			TransactorFee:  new(big.Int),
   121  		},
   122  	}
   123  }
   124  
   125  // ErrInvoiceMissmatch represents an error that occurs when invoices do not match.
   126  var ErrInvoiceMissmatch = errors.New("invoice mismatch")
   127  
   128  // Start starts the message exchange tracker. Blocks.
   129  func (ip *InvoicePayer) Start() error {
   130  	log.Debug().Msg("Starting...")
   131  	addr, err := ip.deps.AddressProvider.GetActiveChannelAddress(ip.deps.ChainID, ip.deps.Identity.ToCommonAddress())
   132  	if err != nil {
   133  		return errors.Wrap(err, "could not generate channel address")
   134  	}
   135  	ip.channelAddress = identity.FromAddress(addr.Hex())
   136  
   137  	ip.deps.TimeTracker.StartTracking()
   138  
   139  	uid, err := uuid.NewV4()
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	err = ip.deps.EventBus.SubscribeWithUID(connectionstate.AppTopicConnectionStatistics, uid.String(), ip.consumeDataTransferredEvent)
   145  	if err != nil {
   146  		return errors.Wrap(err, "could not subscribe to data transfer events")
   147  	}
   148  
   149  	for {
   150  		select {
   151  		case <-ip.stop:
   152  			_ = ip.deps.EventBus.UnsubscribeWithUID(connectionstate.AppTopicConnectionStatistics, uid.String(), ip.consumeDataTransferredEvent)
   153  
   154  			return nil
   155  		case invoice := <-ip.deps.InvoiceChan:
   156  			log.Debug().Msgf("Invoice received: %v", invoice)
   157  			err := ip.isInvoiceOK(invoice)
   158  			if err != nil {
   159  				return errors.Wrap(err, "invoice not valid")
   160  			}
   161  
   162  			err = ip.issueExchangeMessage(invoice)
   163  			if err != nil {
   164  				return err
   165  			}
   166  
   167  			ip.lastInvoice = invoice
   168  		}
   169  	}
   170  }
   171  
   172  func (ip *InvoicePayer) incrementGrandTotalPromised(amount big.Int) error {
   173  	return ip.deps.ConsumerTotalsStorage.Add(ip.chainID(), ip.deps.Identity, ip.deps.HermesAddress, &amount)
   174  }
   175  
   176  func (ip *InvoicePayer) isInvoiceOK(invoice crypto.Invoice) error {
   177  	if !strings.EqualFold(invoice.Provider, ip.deps.Peer.Address) {
   178  		return ErrWrongProvider
   179  	}
   180  
   181  	transferred := ip.getDataTransferred()
   182  	transferred.Up += ip.deps.DataLeeway.Bytes()
   183  
   184  	shouldBe := CalculatePaymentAmount(ip.deps.TimeTracker.Elapsed(), transferred, ip.deps.AgreedPrice)
   185  	estimatedTolerance := estimateInvoiceTolerance(ip.deps.TimeTracker.Elapsed(), transferred)
   186  
   187  	upperBound, _ := new(big.Float).Mul(new(big.Float).SetInt(shouldBe), big.NewFloat(estimatedTolerance)).Int(nil)
   188  
   189  	log.Debug().Msgf("Estimated tolerance %.4v, upper bound %v", estimatedTolerance, upperBound)
   190  
   191  	if invoice.AgreementTotal.Cmp(upperBound) == 1 {
   192  		log.Warn().Msg("Provider trying to overcharge")
   193  		return ErrProviderOvercharge
   194  	}
   195  
   196  	return nil
   197  }
   198  
   199  func estimateInvoiceTolerance(elapsed time.Duration, transferred DataTransferred) float64 {
   200  	if elapsed.Seconds() < 1 {
   201  		return 3
   202  	}
   203  
   204  	totalMiBytesTransferred := float64(transferred.sum()) / (1024 * 1024)
   205  	avgSpeedInMiBits := totalMiBytesTransferred / elapsed.Seconds() * 8
   206  
   207  	// correction calculation based on total session duration.
   208  	durInMinutes := elapsed.Minutes()
   209  
   210  	if elapsed.Minutes() < 1 {
   211  		durInMinutes = 1
   212  	}
   213  
   214  	durationComponent := 1 - durInMinutes/(1+durInMinutes)
   215  
   216  	// correction calculation based on average session speed.
   217  	if avgSpeedInMiBits == 0 {
   218  		avgSpeedInMiBits = 1
   219  	}
   220  
   221  	avgSpeedComponent := 1 - 1/(1+avgSpeedInMiBits/1024)
   222  
   223  	return durationComponent + avgSpeedComponent + consumerInvoiceBasicTolerance
   224  }
   225  
   226  func (ip *InvoicePayer) calculateAmountToPromise(invoice crypto.Invoice) (toPromise *big.Int, diff *big.Int, err error) {
   227  	diff = safeSub(invoice.AgreementTotal, ip.lastInvoice.AgreementTotal)
   228  	totalPromised, err := ip.deps.ConsumerTotalsStorage.Get(ip.chainID(), ip.deps.Identity, ip.deps.HermesAddress)
   229  	if err != nil {
   230  		if err != ErrNotFound {
   231  			return new(big.Int), new(big.Int), fmt.Errorf("could not get previous grand total: %w", err)
   232  		}
   233  		log.Debug().Msg("No previous promised total, assuming 0")
   234  		totalPromised = new(big.Int)
   235  	}
   236  
   237  	// This is a new agreement, we need to take in the agreement total and just add it to total promised
   238  	if ip.lastInvoice.AgreementID.Cmp(invoice.AgreementID) != 0 {
   239  		diff = invoice.AgreementTotal
   240  	}
   241  
   242  	log.Debug().Msgf("Loaded previous state: already promised: %v", totalPromised)
   243  	log.Debug().Msgf("Incrementing promised amount by %v", diff)
   244  	amountToPromise := new(big.Int).Add(totalPromised, diff)
   245  	return amountToPromise, diff, nil
   246  }
   247  
   248  func (ip *InvoicePayer) chainID() int64 {
   249  	return ip.deps.ChainID
   250  }
   251  
   252  func (ip *InvoicePayer) issueExchangeMessage(invoice crypto.Invoice) error {
   253  	amountToPromise, diff, err := ip.calculateAmountToPromise(invoice)
   254  	if err != nil {
   255  		return errors.Wrap(err, "could not calculate amount to promise")
   256  	}
   257  
   258  	msg, err := crypto.CreateExchangeMessage(ip.chainID(), invoice, amountToPromise, ip.channelAddress.Address, ip.deps.HermesAddress.Hex(), ip.deps.Ks, common.HexToAddress(ip.deps.Identity.Address))
   259  	if err != nil {
   260  		return errors.Wrap(err, "could not create exchange message")
   261  	}
   262  
   263  	err = ip.deps.PeerExchangeMessageSender.Send(*msg)
   264  	if err != nil {
   265  		log.Warn().Err(err).Msg("Failed to send exchange message")
   266  	}
   267  
   268  	ip.publishInvoicePayedEvent(invoice)
   269  
   270  	// TODO: we'd probably want to check if we have enough balance here
   271  	err = ip.incrementGrandTotalPromised(*diff)
   272  	return errors.Wrap(err, "could not increment grand total")
   273  }
   274  
   275  func (ip *InvoicePayer) publishInvoicePayedEvent(invoice crypto.Invoice) {
   276  	ip.sessionIDLock.Lock()
   277  	defer ip.sessionIDLock.Unlock()
   278  
   279  	// session id might be set later than we start paying invoices, skip in that case.
   280  	if ip.deps.SessionID == "" {
   281  		return
   282  	}
   283  
   284  	ip.deps.EventBus.Publish(event.AppTopicInvoicePaid, event.AppEventInvoicePaid{
   285  		UUID:       ip.deps.SenderUUID,
   286  		ConsumerID: ip.deps.Identity,
   287  		SessionID:  ip.deps.SessionID,
   288  		Invoice:    invoice,
   289  	})
   290  }
   291  
   292  // Stop stops the message tracker.
   293  func (ip *InvoicePayer) Stop() {
   294  	ip.once.Do(func() {
   295  		log.Debug().Msg("Stopping...")
   296  		close(ip.stop)
   297  	})
   298  }
   299  
   300  func (ip *InvoicePayer) consumeDataTransferredEvent(e connectionstate.AppEventConnectionStatistics) {
   301  	// From a server perspective, bytes up are the actual bytes the client downloaded(aka the bytes we pushed to the consumer)
   302  	// To lessen the confusion, I suggest having the bytes reversed on the session instance.
   303  	// This way, the session will show that it downloaded the bytes in a manner that is easier to comprehend.
   304  	ip.updateDataTransfer(e.Stats.BytesSent, e.Stats.BytesReceived)
   305  }
   306  
   307  func (ip *InvoicePayer) updateDataTransfer(up, down uint64) {
   308  	ip.dataTransferredLock.Lock()
   309  	defer ip.dataTransferredLock.Unlock()
   310  
   311  	newUp := ip.dataTransferred.Up
   312  	if up > ip.dataTransferred.Up {
   313  		newUp = up
   314  	}
   315  
   316  	newDown := ip.dataTransferred.Down
   317  	if down > ip.dataTransferred.Down {
   318  		newDown = down
   319  	}
   320  
   321  	ip.dataTransferred = DataTransferred{
   322  		Up:   newUp,
   323  		Down: newDown,
   324  	}
   325  }
   326  
   327  func (ip *InvoicePayer) getDataTransferred() DataTransferred {
   328  	ip.dataTransferredLock.Lock()
   329  	defer ip.dataTransferredLock.Unlock()
   330  
   331  	return ip.dataTransferred
   332  }
   333  
   334  // SetSessionID updates invoice payer dependencies to set session ID once session established.
   335  func (ip *InvoicePayer) SetSessionID(sessionID string) {
   336  	ip.sessionIDLock.Lock()
   337  	defer ip.sessionIDLock.Unlock()
   338  	ip.deps.SessionID = sessionID
   339  }