github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/session/pingpong/invoice_payer.go (about) 1 /* 2 * Copyright (C) 2019 The "MysteriumNetwork/node" Authors. 3 * 4 * This program is free software: you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation, either version 3 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 */ 17 18 package pingpong 19 20 import ( 21 "fmt" 22 "math/big" 23 "strings" 24 "sync" 25 "time" 26 27 "github.com/ethereum/go-ethereum/accounts" 28 "github.com/ethereum/go-ethereum/common" 29 "github.com/gofrs/uuid" 30 "github.com/pkg/errors" 31 "github.com/rs/zerolog/log" 32 33 "github.com/mysteriumnetwork/node/core/connection/connectionstate" 34 "github.com/mysteriumnetwork/node/datasize" 35 "github.com/mysteriumnetwork/node/eventbus" 36 "github.com/mysteriumnetwork/node/identity" 37 "github.com/mysteriumnetwork/node/market" 38 "github.com/mysteriumnetwork/node/session/pingpong/event" 39 "github.com/mysteriumnetwork/payments/crypto" 40 ) 41 42 // ErrWrongProvider represents an issue where the wrong provider is supplied. 43 var ErrWrongProvider = errors.New("wrong provider supplied") 44 45 // ErrProviderOvercharge represents an issue where the provider is trying to overcharge us. 46 var ErrProviderOvercharge = errors.New("provider is overcharging") 47 48 // consumerInvoiceBasicTolerance provider traffic amount compensation due to: 49 // - different MTU sizes 50 // - measurement timing inaccuracies 51 // - possible in-transit packet fragmentation 52 // - non-agreed traffic: traffic blocked / dropped / not reachable / failed retransmits on provider 53 const consumerInvoiceBasicTolerance = 1.11 54 55 // PeerExchangeMessageSender allows for sending of exchange messages. 56 type PeerExchangeMessageSender interface { 57 Send(crypto.ExchangeMessage) error 58 } 59 60 type consumerTotalsStorage interface { 61 Store(chainID int64, id identity.Identity, hermesID common.Address, amount *big.Int) error 62 Get(chainID int64, id identity.Identity, hermesID common.Address) (*big.Int, error) 63 Add(chainID int64, id identity.Identity, hermesID common.Address, amount *big.Int) error 64 } 65 66 type timeTracker interface { 67 StartTracking() 68 Elapsed() time.Duration 69 } 70 71 type channelAddressCalculator interface { 72 GetChannelAddress(id identity.Identity) (common.Address, error) 73 } 74 75 // InvoicePayer keeps track of exchange messages and sends them to the provider. 76 type InvoicePayer struct { 77 stop chan struct{} 78 once sync.Once 79 channelAddress identity.Identity 80 81 lastInvoice crypto.Invoice 82 deps InvoicePayerDeps 83 84 dataTransferred DataTransferred 85 dataTransferredLock sync.Mutex 86 87 sessionIDLock sync.Mutex 88 } 89 90 type hashSigner interface { 91 SignHash(a accounts.Account, hash []byte) ([]byte, error) 92 } 93 94 // InvoicePayerDeps contains all the dependencies for the exchange message tracker. 95 type InvoicePayerDeps struct { 96 InvoiceChan chan crypto.Invoice 97 PeerExchangeMessageSender PeerExchangeMessageSender 98 ConsumerTotalsStorage consumerTotalsStorage 99 TimeTracker timeTracker 100 Ks hashSigner 101 Identity, Peer identity.Identity 102 AgreedPrice market.Price 103 SenderUUID string 104 SessionID string 105 AddressProvider addressProvider 106 EventBus eventbus.EventBus 107 HermesAddress common.Address 108 DataLeeway datasize.BitSize 109 ChainID int64 110 } 111 112 // NewInvoicePayer returns a new instance of exchange message tracker. 113 func NewInvoicePayer(ipd InvoicePayerDeps) *InvoicePayer { 114 return &InvoicePayer{ 115 stop: make(chan struct{}), 116 deps: ipd, 117 lastInvoice: crypto.Invoice{ 118 AgreementID: new(big.Int), 119 AgreementTotal: new(big.Int), 120 TransactorFee: new(big.Int), 121 }, 122 } 123 } 124 125 // ErrInvoiceMissmatch represents an error that occurs when invoices do not match. 126 var ErrInvoiceMissmatch = errors.New("invoice mismatch") 127 128 // Start starts the message exchange tracker. Blocks. 129 func (ip *InvoicePayer) Start() error { 130 log.Debug().Msg("Starting...") 131 addr, err := ip.deps.AddressProvider.GetActiveChannelAddress(ip.deps.ChainID, ip.deps.Identity.ToCommonAddress()) 132 if err != nil { 133 return errors.Wrap(err, "could not generate channel address") 134 } 135 ip.channelAddress = identity.FromAddress(addr.Hex()) 136 137 ip.deps.TimeTracker.StartTracking() 138 139 uid, err := uuid.NewV4() 140 if err != nil { 141 return err 142 } 143 144 err = ip.deps.EventBus.SubscribeWithUID(connectionstate.AppTopicConnectionStatistics, uid.String(), ip.consumeDataTransferredEvent) 145 if err != nil { 146 return errors.Wrap(err, "could not subscribe to data transfer events") 147 } 148 149 for { 150 select { 151 case <-ip.stop: 152 _ = ip.deps.EventBus.UnsubscribeWithUID(connectionstate.AppTopicConnectionStatistics, uid.String(), ip.consumeDataTransferredEvent) 153 154 return nil 155 case invoice := <-ip.deps.InvoiceChan: 156 log.Debug().Msgf("Invoice received: %v", invoice) 157 err := ip.isInvoiceOK(invoice) 158 if err != nil { 159 return errors.Wrap(err, "invoice not valid") 160 } 161 162 err = ip.issueExchangeMessage(invoice) 163 if err != nil { 164 return err 165 } 166 167 ip.lastInvoice = invoice 168 } 169 } 170 } 171 172 func (ip *InvoicePayer) incrementGrandTotalPromised(amount big.Int) error { 173 return ip.deps.ConsumerTotalsStorage.Add(ip.chainID(), ip.deps.Identity, ip.deps.HermesAddress, &amount) 174 } 175 176 func (ip *InvoicePayer) isInvoiceOK(invoice crypto.Invoice) error { 177 if !strings.EqualFold(invoice.Provider, ip.deps.Peer.Address) { 178 return ErrWrongProvider 179 } 180 181 transferred := ip.getDataTransferred() 182 transferred.Up += ip.deps.DataLeeway.Bytes() 183 184 shouldBe := CalculatePaymentAmount(ip.deps.TimeTracker.Elapsed(), transferred, ip.deps.AgreedPrice) 185 estimatedTolerance := estimateInvoiceTolerance(ip.deps.TimeTracker.Elapsed(), transferred) 186 187 upperBound, _ := new(big.Float).Mul(new(big.Float).SetInt(shouldBe), big.NewFloat(estimatedTolerance)).Int(nil) 188 189 log.Debug().Msgf("Estimated tolerance %.4v, upper bound %v", estimatedTolerance, upperBound) 190 191 if invoice.AgreementTotal.Cmp(upperBound) == 1 { 192 log.Warn().Msg("Provider trying to overcharge") 193 return ErrProviderOvercharge 194 } 195 196 return nil 197 } 198 199 func estimateInvoiceTolerance(elapsed time.Duration, transferred DataTransferred) float64 { 200 if elapsed.Seconds() < 1 { 201 return 3 202 } 203 204 totalMiBytesTransferred := float64(transferred.sum()) / (1024 * 1024) 205 avgSpeedInMiBits := totalMiBytesTransferred / elapsed.Seconds() * 8 206 207 // correction calculation based on total session duration. 208 durInMinutes := elapsed.Minutes() 209 210 if elapsed.Minutes() < 1 { 211 durInMinutes = 1 212 } 213 214 durationComponent := 1 - durInMinutes/(1+durInMinutes) 215 216 // correction calculation based on average session speed. 217 if avgSpeedInMiBits == 0 { 218 avgSpeedInMiBits = 1 219 } 220 221 avgSpeedComponent := 1 - 1/(1+avgSpeedInMiBits/1024) 222 223 return durationComponent + avgSpeedComponent + consumerInvoiceBasicTolerance 224 } 225 226 func (ip *InvoicePayer) calculateAmountToPromise(invoice crypto.Invoice) (toPromise *big.Int, diff *big.Int, err error) { 227 diff = safeSub(invoice.AgreementTotal, ip.lastInvoice.AgreementTotal) 228 totalPromised, err := ip.deps.ConsumerTotalsStorage.Get(ip.chainID(), ip.deps.Identity, ip.deps.HermesAddress) 229 if err != nil { 230 if err != ErrNotFound { 231 return new(big.Int), new(big.Int), fmt.Errorf("could not get previous grand total: %w", err) 232 } 233 log.Debug().Msg("No previous promised total, assuming 0") 234 totalPromised = new(big.Int) 235 } 236 237 // This is a new agreement, we need to take in the agreement total and just add it to total promised 238 if ip.lastInvoice.AgreementID.Cmp(invoice.AgreementID) != 0 { 239 diff = invoice.AgreementTotal 240 } 241 242 log.Debug().Msgf("Loaded previous state: already promised: %v", totalPromised) 243 log.Debug().Msgf("Incrementing promised amount by %v", diff) 244 amountToPromise := new(big.Int).Add(totalPromised, diff) 245 return amountToPromise, diff, nil 246 } 247 248 func (ip *InvoicePayer) chainID() int64 { 249 return ip.deps.ChainID 250 } 251 252 func (ip *InvoicePayer) issueExchangeMessage(invoice crypto.Invoice) error { 253 amountToPromise, diff, err := ip.calculateAmountToPromise(invoice) 254 if err != nil { 255 return errors.Wrap(err, "could not calculate amount to promise") 256 } 257 258 msg, err := crypto.CreateExchangeMessage(ip.chainID(), invoice, amountToPromise, ip.channelAddress.Address, ip.deps.HermesAddress.Hex(), ip.deps.Ks, common.HexToAddress(ip.deps.Identity.Address)) 259 if err != nil { 260 return errors.Wrap(err, "could not create exchange message") 261 } 262 263 err = ip.deps.PeerExchangeMessageSender.Send(*msg) 264 if err != nil { 265 log.Warn().Err(err).Msg("Failed to send exchange message") 266 } 267 268 ip.publishInvoicePayedEvent(invoice) 269 270 // TODO: we'd probably want to check if we have enough balance here 271 err = ip.incrementGrandTotalPromised(*diff) 272 return errors.Wrap(err, "could not increment grand total") 273 } 274 275 func (ip *InvoicePayer) publishInvoicePayedEvent(invoice crypto.Invoice) { 276 ip.sessionIDLock.Lock() 277 defer ip.sessionIDLock.Unlock() 278 279 // session id might be set later than we start paying invoices, skip in that case. 280 if ip.deps.SessionID == "" { 281 return 282 } 283 284 ip.deps.EventBus.Publish(event.AppTopicInvoicePaid, event.AppEventInvoicePaid{ 285 UUID: ip.deps.SenderUUID, 286 ConsumerID: ip.deps.Identity, 287 SessionID: ip.deps.SessionID, 288 Invoice: invoice, 289 }) 290 } 291 292 // Stop stops the message tracker. 293 func (ip *InvoicePayer) Stop() { 294 ip.once.Do(func() { 295 log.Debug().Msg("Stopping...") 296 close(ip.stop) 297 }) 298 } 299 300 func (ip *InvoicePayer) consumeDataTransferredEvent(e connectionstate.AppEventConnectionStatistics) { 301 // From a server perspective, bytes up are the actual bytes the client downloaded(aka the bytes we pushed to the consumer) 302 // To lessen the confusion, I suggest having the bytes reversed on the session instance. 303 // This way, the session will show that it downloaded the bytes in a manner that is easier to comprehend. 304 ip.updateDataTransfer(e.Stats.BytesSent, e.Stats.BytesReceived) 305 } 306 307 func (ip *InvoicePayer) updateDataTransfer(up, down uint64) { 308 ip.dataTransferredLock.Lock() 309 defer ip.dataTransferredLock.Unlock() 310 311 newUp := ip.dataTransferred.Up 312 if up > ip.dataTransferred.Up { 313 newUp = up 314 } 315 316 newDown := ip.dataTransferred.Down 317 if down > ip.dataTransferred.Down { 318 newDown = down 319 } 320 321 ip.dataTransferred = DataTransferred{ 322 Up: newUp, 323 Down: newDown, 324 } 325 } 326 327 func (ip *InvoicePayer) getDataTransferred() DataTransferred { 328 ip.dataTransferredLock.Lock() 329 defer ip.dataTransferredLock.Unlock() 330 331 return ip.dataTransferred 332 } 333 334 // SetSessionID updates invoice payer dependencies to set session ID once session established. 335 func (ip *InvoicePayer) SetSessionID(sessionID string) { 336 ip.sessionIDLock.Lock() 337 defer ip.sessionIDLock.Unlock() 338 ip.deps.SessionID = sessionID 339 }