github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/session/pingpong/hermes_promise_handler.go (about)

     1  /*
     2   * Copyright (C) 2019 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package pingpong
    19  
    20  import (
    21  	"encoding/hex"
    22  	"encoding/json"
    23  	stdErr "errors"
    24  	"fmt"
    25  	"math/big"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/ethereum/go-ethereum/common"
    30  	"github.com/mysteriumnetwork/node/config"
    31  	"github.com/mysteriumnetwork/node/core/node/event"
    32  	"github.com/mysteriumnetwork/node/eventbus"
    33  	"github.com/mysteriumnetwork/node/identity"
    34  	"github.com/mysteriumnetwork/node/identity/registry"
    35  	sessionEvent "github.com/mysteriumnetwork/node/session/event"
    36  	pinge "github.com/mysteriumnetwork/node/session/pingpong/event"
    37  	"github.com/mysteriumnetwork/payments/crypto"
    38  	"github.com/pkg/errors"
    39  	"github.com/rs/zerolog/log"
    40  )
    41  
    42  type hermesPromiseStorage interface {
    43  	Store(promise HermesPromise) error
    44  	Get(chainID int64, channelID string) (HermesPromise, error)
    45  }
    46  
    47  type feeProvider interface {
    48  	FetchSettleFees(chainID int64) (registry.FeesResponse, error)
    49  }
    50  
    51  // HermesHTTPRequester represents HTTP requests to Hermes.
    52  type HermesHTTPRequester interface {
    53  	PayAndSettle(rp RequestPromise) (crypto.Promise, error)
    54  	RequestPromise(rp RequestPromise) (crypto.Promise, error)
    55  	RevealR(r string, provider string, agreementID *big.Int) error
    56  	UpdatePromiseFee(promise crypto.Promise, newFee *big.Int) (crypto.Promise, error)
    57  	GetConsumerData(chainID int64, id string, cacheTime time.Duration) (HermesUserInfo, error)
    58  	GetProviderData(chainID int64, id string) (HermesUserInfo, error)
    59  	SyncProviderPromise(promise crypto.Promise, signer identity.Signer) error
    60  }
    61  
    62  type encryption interface {
    63  	Decrypt(addr common.Address, encrypted []byte) ([]byte, error)
    64  	Encrypt(addr common.Address, plaintext []byte) ([]byte, error)
    65  }
    66  
    67  // HermesCallerFactory represents Hermes caller factory.
    68  type HermesCallerFactory func(url string) HermesHTTPRequester
    69  
    70  // HermesPromiseHandlerDeps represents the HermesPromiseHandler dependencies.
    71  type HermesPromiseHandlerDeps struct {
    72  	HermesPromiseStorage hermesPromiseStorage
    73  	FeeProvider          feeProvider
    74  	Encryption           encryption
    75  	EventBus             eventbus.Publisher
    76  	HermesURLGetter      hermesURLGetter
    77  	HermesCallerFactory  HermesCallerFactory
    78  	Signer               identity.SignerFactory
    79  	Chains               []int64
    80  }
    81  
    82  // HermesPromiseHandler handles the hermes promises for ongoing sessions.
    83  type HermesPromiseHandler struct {
    84  	deps           HermesPromiseHandlerDeps
    85  	queue          chan enqueuedRequest
    86  	stop           chan struct{}
    87  	stopOnce       sync.Once
    88  	startOnce      sync.Once
    89  	transactorFees map[int64]registry.FeesResponse
    90  }
    91  
    92  // NewHermesPromiseHandler returns a new instance of hermes promise handler.
    93  func NewHermesPromiseHandler(deps HermesPromiseHandlerDeps) *HermesPromiseHandler {
    94  	if len(deps.Chains) == 0 {
    95  		deps.Chains = []int64{config.GetInt64(config.FlagChain1ChainID), config.GetInt64(config.FlagChain2ChainID)}
    96  	}
    97  
    98  	return &HermesPromiseHandler{
    99  		deps:           deps,
   100  		queue:          make(chan enqueuedRequest, 100),
   101  		stop:           make(chan struct{}),
   102  		transactorFees: make(map[int64]registry.FeesResponse),
   103  	}
   104  }
   105  
   106  type enqueuedRequest struct {
   107  	errChan     chan error
   108  	r           []byte
   109  	em          crypto.ExchangeMessage
   110  	providerID  identity.Identity
   111  	requestFunc func(rp RequestPromise) (crypto.Promise, error)
   112  	sessionID   string
   113  }
   114  
   115  type hermesURLGetter interface {
   116  	GetHermesURL(chainID int64, address common.Address) (string, error)
   117  }
   118  
   119  // RequestPromise adds the request to the queue.
   120  func (aph *HermesPromiseHandler) RequestPromise(r []byte, em crypto.ExchangeMessage, providerID identity.Identity, sessionID string) <-chan error {
   121  	er := enqueuedRequest{
   122  		r:          r,
   123  		em:         em,
   124  		providerID: providerID,
   125  		errChan:    make(chan error),
   126  		sessionID:  sessionID,
   127  	}
   128  
   129  	hermesID := common.HexToAddress(em.HermesID)
   130  	hermesCaller, err := aph.getHermesCaller(em.ChainID, hermesID)
   131  	if err != nil {
   132  		go func() {
   133  			er.errChan <- fmt.Errorf("could not get hermes caller: %w", err)
   134  		}()
   135  		return er.errChan
   136  	}
   137  
   138  	er.requestFunc = aph.makeRequestPromiseFunc(providerID, hermesCaller)
   139  
   140  	aph.queue <- er
   141  	return er.errChan
   142  }
   143  
   144  func (aph *HermesPromiseHandler) makeRequestPromiseFunc(providerID identity.Identity, caller HermesHTTPRequester) func(rp RequestPromise) (crypto.Promise, error) {
   145  	return func(rp RequestPromise) (crypto.Promise, error) {
   146  		p, err := caller.RequestPromise(rp)
   147  		if err == nil {
   148  			return p, nil
   149  		}
   150  
   151  		if !stdErr.Is(err, ErrInvalidPreviuosLatestPromise) {
   152  			// We can only really handle the previuos promise is invalid error.
   153  			return crypto.Promise{}, err
   154  		}
   155  
   156  		chid, err := crypto.GenerateProviderChannelID(providerID.Address, rp.ExchangeMessage.HermesID)
   157  		if err != nil {
   158  			return crypto.Promise{}, fmt.Errorf("failed to generate provider ID in promise sync: %w", err)
   159  		}
   160  
   161  		stored, err := aph.deps.HermesPromiseStorage.Get(rp.ExchangeMessage.ChainID, chid)
   162  		if err != nil {
   163  			return crypto.Promise{}, fmt.Errorf("failed to get last known promise from bolt: %w", err)
   164  		}
   165  
   166  		signer := aph.deps.Signer(providerID)
   167  		if err := caller.SyncProviderPromise(stored.Promise, signer); err != nil {
   168  			return crypto.Promise{}, fmt.Errorf("failed to sync to last known promise: %w", err)
   169  		}
   170  
   171  		if err := caller.RevealR(stored.R, providerID.Address, stored.AgreementID); err != nil {
   172  			return crypto.Promise{}, fmt.Errorf("failed to reveal R after sync: %w", err)
   173  		}
   174  
   175  		return caller.RequestPromise(rp)
   176  	}
   177  }
   178  
   179  // PayAndSettle adds the request to the queue.
   180  func (aph *HermesPromiseHandler) PayAndSettle(r []byte, em crypto.ExchangeMessage, providerID identity.Identity, sessionID string) <-chan error {
   181  	er := enqueuedRequest{
   182  		r:          r,
   183  		em:         em,
   184  		providerID: providerID,
   185  		errChan:    make(chan error),
   186  		sessionID:  sessionID,
   187  	}
   188  
   189  	hermesID := common.HexToAddress(em.HermesID)
   190  	hermesCaller, err := aph.getHermesCaller(em.ChainID, hermesID)
   191  	if err != nil {
   192  		go func() {
   193  			er.errChan <- fmt.Errorf("could not get hermes caller: %w", err)
   194  		}()
   195  		return er.errChan
   196  	}
   197  	er.requestFunc = hermesCaller.PayAndSettle
   198  
   199  	aph.queue <- er
   200  	return er.errChan
   201  }
   202  
   203  func (aph *HermesPromiseHandler) getFees(chainID int64) (*big.Int, error) {
   204  	fee, ok := aph.transactorFees[chainID]
   205  	if ok && fee.IsValid() {
   206  		return fee.Fee, nil
   207  	}
   208  
   209  	if err := aph.updateFees(chainID); err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	if updatedFee, ok := aph.transactorFees[chainID]; ok {
   214  		return updatedFee.Fee, nil
   215  	}
   216  
   217  	return nil, errors.New("failed to fetch fees")
   218  }
   219  
   220  func (aph *HermesPromiseHandler) updateFees(chainID int64) error {
   221  	fees, err := aph.deps.FeeProvider.FetchSettleFees(chainID)
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	aph.transactorFees[chainID] = fees
   227  	return nil
   228  }
   229  
   230  func (aph *HermesPromiseHandler) handleRequests() {
   231  	log.Debug().Msgf("hermes promise handler started")
   232  	defer log.Debug().Msgf("hermes promise handler stopped")
   233  	for {
   234  		select {
   235  		case <-aph.stop:
   236  			return
   237  		case entry := <-aph.queue:
   238  			aph.requestPromise(entry)
   239  		}
   240  	}
   241  }
   242  
   243  // Subscribe subscribes HermesPromiseHandler to relevant events.
   244  func (aph *HermesPromiseHandler) Subscribe(bus eventbus.Subscriber) error {
   245  	err := bus.SubscribeAsync(event.AppTopicNode, aph.handleNodeEvents)
   246  	if err != nil {
   247  		return fmt.Errorf("could not subscribe to node events: %w", err)
   248  	}
   249  
   250  	return nil
   251  }
   252  
   253  func (aph *HermesPromiseHandler) doStop() {
   254  	aph.stopOnce.Do(func() {
   255  		close(aph.stop)
   256  	})
   257  }
   258  
   259  func (aph *HermesPromiseHandler) handleNodeEvents(e event.Payload) {
   260  	if e.Status == event.StatusStopped {
   261  		aph.doStop()
   262  		return
   263  	}
   264  	if e.Status == event.StatusStarted {
   265  		aph.startOnce.Do(
   266  			func() {
   267  				for _, c := range aph.deps.Chains {
   268  					if err := aph.updateFees(c); err != nil {
   269  						log.Warn().Err(err).Msg("could not fetch fees")
   270  					}
   271  				}
   272  
   273  				aph.handleRequests()
   274  			},
   275  		)
   276  		return
   277  	}
   278  }
   279  
   280  func (aph *HermesPromiseHandler) requestPromise(er enqueuedRequest) {
   281  	defer close(er.errChan)
   282  
   283  	providerID := er.providerID
   284  	hermesID := common.HexToAddress(er.em.HermesID)
   285  	fee, err := aph.getFees(er.em.ChainID)
   286  	if err != nil {
   287  		er.errChan <- fmt.Errorf("no fees for chain %v: %w", er.em.ChainID, err)
   288  		return
   289  	}
   290  
   291  	details := rRecoveryDetails{
   292  		R:           hex.EncodeToString(er.r),
   293  		AgreementID: er.em.AgreementID,
   294  	}
   295  
   296  	bytes, err := json.Marshal(details)
   297  	if err != nil {
   298  		er.errChan <- fmt.Errorf("could not marshal R recovery details: %w", err)
   299  		return
   300  	}
   301  
   302  	encrypted, err := aph.deps.Encryption.Encrypt(providerID.ToCommonAddress(), bytes)
   303  	if err != nil {
   304  		er.errChan <- fmt.Errorf("could not encrypt R: %w", err)
   305  		return
   306  	}
   307  
   308  	request := RequestPromise{
   309  		ExchangeMessage: er.em,
   310  		TransactorFee:   fee,
   311  		RRecoveryData:   hex.EncodeToString(encrypted),
   312  	}
   313  
   314  	promise, err := er.requestFunc(request)
   315  	err = aph.handleHermesError(err, providerID, er.em.ChainID, hermesID)
   316  	if err != nil {
   317  		if !errors.Is(err, errRrecovered) {
   318  			er.errChan <- fmt.Errorf("hermes request promise error: %w", err)
   319  			return
   320  		}
   321  		log.Info().Msgf("r recovered, will request again")
   322  
   323  		promise, err = er.requestFunc(request)
   324  		if err != nil {
   325  			er.errChan <- fmt.Errorf("attempted to request promise again and got an error: %w", err)
   326  			return
   327  		}
   328  	}
   329  
   330  	if promise.ChainID != request.ExchangeMessage.ChainID {
   331  		log.Debug().Msgf("Received promise with wrong chain id from hermes. Expected %v, got %v", request.ExchangeMessage.ChainID, promise.ChainID)
   332  	}
   333  
   334  	ap := HermesPromise{
   335  		ChannelID:   aph.normalizeChannelID(promise.ChannelID),
   336  		Identity:    providerID,
   337  		HermesID:    hermesID,
   338  		Promise:     promise,
   339  		R:           hex.EncodeToString(er.r),
   340  		Revealed:    false,
   341  		AgreementID: er.em.AgreementID,
   342  	}
   343  
   344  	err = aph.deps.HermesPromiseStorage.Store(ap)
   345  	if err != nil && !stdErr.Is(err, ErrAttemptToOverwrite) {
   346  		er.errChan <- fmt.Errorf("could not store hermes promise: %w", err)
   347  		return
   348  	}
   349  
   350  	aph.deps.EventBus.Publish(pinge.AppTopicHermesPromise, pinge.AppEventHermesPromise{
   351  		Promise:    promise,
   352  		HermesID:   hermesID,
   353  		ProviderID: providerID,
   354  	})
   355  	aph.deps.EventBus.Publish(sessionEvent.AppTopicTokensEarned, sessionEvent.AppEventTokensEarned{
   356  		ProviderID: providerID,
   357  		SessionID:  er.sessionID,
   358  		Total:      er.em.AgreementTotal,
   359  	})
   360  
   361  	err = aph.revealR(ap)
   362  	err = aph.handleHermesError(err, providerID, ap.Promise.ChainID, hermesID)
   363  	if err != nil {
   364  		if errors.Is(err, errRrecovered) {
   365  			log.Info().Msgf("r recovered")
   366  			return
   367  		}
   368  		er.errChan <- fmt.Errorf("hermes reveal r error: %w", err)
   369  		return
   370  	}
   371  }
   372  
   373  func (aph *HermesPromiseHandler) normalizeChannelID(chid []byte) string {
   374  	hexStr := common.Bytes2Hex(chid)
   375  	return "0x" + hexStr
   376  }
   377  
   378  func (aph *HermesPromiseHandler) getHermesCaller(chainID int64, hermesID common.Address) (HermesHTTPRequester, error) {
   379  	addr, err := aph.deps.HermesURLGetter.GetHermesURL(chainID, hermesID)
   380  	if err != nil {
   381  		return nil, fmt.Errorf("could not get hermes URL: %w", err)
   382  	}
   383  	return aph.deps.HermesCallerFactory(addr), nil
   384  }
   385  
   386  func (aph *HermesPromiseHandler) revealR(hermesPromise HermesPromise) error {
   387  	if hermesPromise.Revealed {
   388  		return nil
   389  	}
   390  
   391  	hermesCaller, err := aph.getHermesCaller(hermesPromise.Promise.ChainID, hermesPromise.HermesID)
   392  	if err != nil {
   393  		return fmt.Errorf("could not get hermes caller: %w", err)
   394  	}
   395  
   396  	err = hermesCaller.RevealR(hermesPromise.R, hermesPromise.Identity.Address, hermesPromise.AgreementID)
   397  	handledErr := aph.handleHermesError(err, hermesPromise.Identity, hermesPromise.Promise.ChainID, hermesPromise.HermesID)
   398  	if handledErr != nil {
   399  		if errors.Is(err, errRrecovered) {
   400  			log.Info().Msgf("r recovered")
   401  			return nil
   402  		}
   403  		return fmt.Errorf("could not reveal R: %w", err)
   404  	}
   405  
   406  	hermesPromise.Revealed = true
   407  	err = aph.deps.HermesPromiseStorage.Store(hermesPromise)
   408  	if err != nil && !stdErr.Is(err, ErrAttemptToOverwrite) {
   409  		return fmt.Errorf("could not store hermes promise: %w", err)
   410  	}
   411  
   412  	return nil
   413  }
   414  
   415  var errRrecovered = errors.New("R recovered")
   416  var errPreviuosPromise = errors.New("action cannot be performed as previuos promise is invalid")
   417  
   418  func (aph *HermesPromiseHandler) handleHermesError(err error, providerID identity.Identity, chainID int64, hermesID common.Address) error {
   419  	if err == nil {
   420  		return nil
   421  	}
   422  
   423  	switch {
   424  	case stdErr.Is(err, ErrNeedsRRecovery):
   425  		var aer HermesErrorResponse
   426  		ok := stdErr.As(err, &aer)
   427  		if !ok {
   428  			return errors.New("could not cast errNeedsRecovery to hermesError")
   429  		}
   430  		recoveryErr := aph.recoverR(aer, providerID, chainID, hermesID)
   431  		if recoveryErr != nil {
   432  			return recoveryErr
   433  		}
   434  		return errRrecovered
   435  	case stdErr.Is(err, ErrHermesNoPreviousPromise):
   436  		log.Info().Msg("no previous promise on hermes, will mark R as revealed")
   437  		return nil
   438  	default:
   439  		return err
   440  	}
   441  }
   442  
   443  func (aph *HermesPromiseHandler) recoverR(aerr hermesError, providerID identity.Identity, chainID int64, hermesID common.Address) error {
   444  	log.Info().Msg("Recovering R...")
   445  	decoded, err := hex.DecodeString(aerr.Data())
   446  	if err != nil {
   447  		return fmt.Errorf("could not decode R recovery details: %w", err)
   448  	}
   449  
   450  	decrypted, err := aph.deps.Encryption.Decrypt(providerID.ToCommonAddress(), decoded)
   451  	if err != nil {
   452  		return fmt.Errorf("could not decrypt R details: %w", err)
   453  	}
   454  
   455  	res := rRecoveryDetails{}
   456  	err = json.Unmarshal(decrypted, &res)
   457  	if err != nil {
   458  		return fmt.Errorf("could not unmarshal R details: %w", err)
   459  	}
   460  
   461  	log.Info().Msg("R recovered, will reveal...")
   462  	hermesCaller, err := aph.getHermesCaller(chainID, hermesID)
   463  	if err != nil {
   464  		return fmt.Errorf("could not get hermes caller: %w", err)
   465  	}
   466  
   467  	err = hermesCaller.RevealR(res.R, providerID.Address, res.AgreementID)
   468  	if err != nil {
   469  		return fmt.Errorf("could not reveal R: %w", err)
   470  	}
   471  
   472  	log.Info().Msg("R recovered successfully")
   473  	return nil
   474  }