github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/core/service/session_manager.go (about)

     1  /*
     2   * Copyright (C) 2017 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package service
    19  
    20  import (
    21  	"context"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"math/big"
    26  	"net"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/ethereum/go-ethereum/common"
    31  	"github.com/rs/zerolog/log"
    32  
    33  	"github.com/mysteriumnetwork/node/config"
    34  	"github.com/mysteriumnetwork/node/core/quality"
    35  	"github.com/mysteriumnetwork/node/identity"
    36  	"github.com/mysteriumnetwork/node/market"
    37  	"github.com/mysteriumnetwork/node/nat/event"
    38  	"github.com/mysteriumnetwork/node/p2p"
    39  	"github.com/mysteriumnetwork/node/pb"
    40  	"github.com/mysteriumnetwork/node/session"
    41  	sevent "github.com/mysteriumnetwork/node/session/event"
    42  	"github.com/mysteriumnetwork/node/utils/reftracker"
    43  	"github.com/mysteriumnetwork/payments/crypto"
    44  )
    45  
    46  var (
    47  	// ErrorInvalidProposal is validation error then invalid proposal requested for session creation
    48  	ErrorInvalidProposal = errors.New("proposal does not exist")
    49  	// ErrorSessionNotExists returned when consumer tries to destroy session that does not exists
    50  	ErrorSessionNotExists = errors.New("session does not exists")
    51  	// ErrorWrongSessionOwner returned when consumer tries to destroy session that does not belongs to him
    52  	ErrorWrongSessionOwner = errors.New("wrong session owner")
    53  )
    54  
    55  // IDGenerator defines method for session id generation
    56  type IDGenerator func() (session.ID, error)
    57  
    58  // ConfigParams session configuration parameters
    59  type ConfigParams struct {
    60  	SessionServiceConfig   ServiceConfiguration
    61  	SessionDestroyCallback DestroyCallback
    62  }
    63  
    64  // ServiceConfiguration defines service configuration from underlying transport mechanism to be passed to remote party
    65  // should be serializable to json format.
    66  type ServiceConfiguration interface{}
    67  
    68  type publisher interface {
    69  	Publish(topic string, data interface{})
    70  }
    71  
    72  // KeepAliveConfig contains keep alive options.
    73  type KeepAliveConfig struct {
    74  	SendInterval    time.Duration
    75  	SendTimeout     time.Duration
    76  	MaxSendErrCount int
    77  }
    78  
    79  // Config contains common configuration options for session manager.
    80  type Config struct {
    81  	KeepAlive KeepAliveConfig
    82  }
    83  
    84  // DefaultConfig returns default params.
    85  func DefaultConfig() Config {
    86  	return Config{
    87  		KeepAlive: KeepAliveConfig{
    88  			SendInterval:    14 * time.Second,
    89  			SendTimeout:     5 * time.Second,
    90  			MaxSendErrCount: 5,
    91  		},
    92  	}
    93  }
    94  
    95  // ConfigProvider is able to handle config negotiations
    96  type ConfigProvider interface {
    97  	ProvideConfig(sessionID string, sessionConfig json.RawMessage, conn *net.UDPConn) (*ConfigParams, error)
    98  }
    99  
   100  // DestroyCallback cleanups session
   101  type DestroyCallback func()
   102  
   103  // PromiseProcessor processes promises at provider side.
   104  // Provider checks promises from consumer and signs them also.
   105  // Provider clears promises from consumer.
   106  type PromiseProcessor interface {
   107  	Start(proposal market.ServiceProposal) error
   108  	Stop() error
   109  }
   110  
   111  // PaymentEngineFactory creates a new instance of payment engine
   112  type PaymentEngineFactory func(providerID, consumerID identity.Identity, chainID int64, hermesID common.Address, sessionID string, exchangeChan chan crypto.ExchangeMessage, price market.Price) (PaymentEngine, error)
   113  
   114  // PriceValidator allows to validate prices against those in discovery.
   115  type PriceValidator interface {
   116  	IsPriceValid(in market.Price, nodeType string, country string, serviceType string) bool
   117  }
   118  
   119  // PaymentEngine is responsible for interacting with the consumer in regard to payments.
   120  type PaymentEngine interface {
   121  	Start() error
   122  	WaitFirstInvoice(time.Duration) error
   123  	Stop()
   124  }
   125  
   126  // NATEventGetter lets us access the last known traversal event
   127  type NATEventGetter interface {
   128  	LastEvent() *event.Event
   129  }
   130  
   131  // NewSessionManager returns new session SessionManager
   132  func NewSessionManager(
   133  	service *Instance,
   134  	sessionStorage *SessionPool,
   135  	paymentEngineFactory PaymentEngineFactory,
   136  	publisher publisher,
   137  	channel p2p.Channel,
   138  	config Config,
   139  	priceValidator PriceValidator,
   140  ) *SessionManager {
   141  	return &SessionManager{
   142  		service:              service,
   143  		sessionStorage:       sessionStorage,
   144  		publisher:            publisher,
   145  		paymentEngineFactory: paymentEngineFactory,
   146  		paymentEngineChan:    make(chan crypto.ExchangeMessage, 1),
   147  		channel:              channel,
   148  		config:               config,
   149  		priceValidator:       priceValidator,
   150  	}
   151  }
   152  
   153  // SessionManager knows how to start and provision session
   154  type SessionManager struct {
   155  	service              *Instance
   156  	sessionStorage       *SessionPool
   157  	paymentEngineFactory PaymentEngineFactory
   158  	paymentEngineChan    chan crypto.ExchangeMessage
   159  	publisher            publisher
   160  	channel              p2p.Channel
   161  	config               Config
   162  	priceValidator       PriceValidator
   163  }
   164  
   165  // Start starts a session on the provider side for the given consumer.
   166  // Multiple sessions per peerID is possible in case different services are used
   167  func (manager *SessionManager) Start(request *pb.SessionRequest) (_ pb.SessionResponse, err error) {
   168  	session, err := NewSession(manager.service, request, manager.channel.Tracer())
   169  	if err != nil {
   170  		return pb.SessionResponse{}, fmt.Errorf("cannot create new session: %w", err)
   171  	}
   172  
   173  	prices := manager.remapPricing(request.Consumer.Pricing)
   174  
   175  	var validationError error
   176  	validationWG := sync.WaitGroup{}
   177  	validationWG.Add(1)
   178  	go func() {
   179  		trace := session.tracer.StartStage("Session validation")
   180  		validationError = manager.validateSession(session, prices)
   181  		session.tracer.EndStage(trace)
   182  		validationWG.Done()
   183  	}()
   184  
   185  	rt := reftracker.Singleton()
   186  	chID := "channel:" + manager.channel.ID()
   187  
   188  	if rt.Incr(chID) != nil {
   189  		return pb.SessionResponse{}, fmt.Errorf("unable to hold the channel: %w", err)
   190  	}
   191  	log.Info().Msgf("session ref incr for %q", chID)
   192  
   193  	session.addCleanup(func() error {
   194  		log.Info().Msgf("session ref decr for %q", chID)
   195  		return rt.Decr(chID)
   196  	})
   197  
   198  	defer func() {
   199  		if err != nil {
   200  			log.Err(err).Msg("Session failed, disconnecting")
   201  			session.Close()
   202  		}
   203  	}()
   204  
   205  	trace := session.tracer.StartStage("Provider session create")
   206  	defer func() {
   207  		session.tracer.EndStage(trace)
   208  		traceResult := session.tracer.Finish(manager.publisher, string(session.ID))
   209  		log.Debug().Msgf("Provider connection trace: %s", traceResult)
   210  	}()
   211  
   212  	validationWG.Wait()
   213  	if validationError != nil {
   214  		return pb.SessionResponse{}, validationError
   215  	}
   216  
   217  	if err = manager.startSession(session, prices); err != nil {
   218  		return pb.SessionResponse{}, err
   219  	}
   220  
   221  	if err = manager.paymentLoop(session, prices); err != nil {
   222  		return pb.SessionResponse{}, err
   223  	}
   224  
   225  	return manager.providerService(session, manager.channel)
   226  }
   227  
   228  func (manager *SessionManager) validatePrice(in market.Price, nodeType, country, serviceType string) error {
   229  	if !manager.priceValidator.IsPriceValid(in, nodeType, country, serviceType) {
   230  		return errors.New("consumer asking for invalid price")
   231  	}
   232  
   233  	return nil
   234  }
   235  
   236  func (manager *SessionManager) remapPricing(in *pb.Pricing) market.Price {
   237  	// This prevents panics in case of malicious consumers.
   238  	if in == nil || in.PerGib == nil || in.PerHour == nil {
   239  		return market.Price{
   240  			PricePerHour: big.NewInt(0),
   241  			PricePerGiB:  big.NewInt(0),
   242  		}
   243  	}
   244  
   245  	return market.Price{
   246  		PricePerHour: big.NewInt(0).SetBytes(in.PerHour),
   247  		PricePerGiB:  big.NewInt(0).SetBytes(in.PerGib),
   248  	}
   249  }
   250  
   251  // Acknowledge marks the session as successfully established as far as the consumer is concerned.
   252  func (manager *SessionManager) Acknowledge(consumerID identity.Identity, sessionID string) error {
   253  	session, found := manager.sessionStorage.Find(session.ID(sessionID))
   254  	if !found {
   255  		return ErrorSessionNotExists
   256  	}
   257  	if session.ConsumerID != consumerID {
   258  		return ErrorWrongSessionOwner
   259  	}
   260  
   261  	manager.publisher.Publish(sevent.AppTopicSession, session.toEvent(sevent.AcknowledgedStatus))
   262  	return nil
   263  }
   264  
   265  func (manager *SessionManager) startSession(session *Session, prices market.Price) error {
   266  	trace := session.tracer.StartStage("Provider session create (start)")
   267  	defer session.tracer.EndStage(trace)
   268  
   269  	manager.clearStaleSession(session.ConsumerID, manager.service.Type)
   270  
   271  	manager.sessionStorage.Add(session)
   272  	session.addCleanup(func() error {
   273  		manager.sessionStorage.Remove(session.ID)
   274  		return nil
   275  	})
   276  
   277  	go manager.keepAliveLoop(session, manager.channel)
   278  
   279  	return nil
   280  }
   281  
   282  func (manager *SessionManager) validateSession(session *Session, prices market.Price) error {
   283  	if !manager.service.PolicyProvider().IsIdentityAllowed(session.ConsumerID) {
   284  		return fmt.Errorf("consumer identity is not allowed: %s", session.ConsumerID.Address)
   285  	}
   286  
   287  	return manager.validatePrice(prices, manager.service.Proposal.Location.IPType, manager.service.Proposal.Location.Country, manager.service.Proposal.ServiceType)
   288  }
   289  
   290  func (manager *SessionManager) clearStaleSession(consumerID identity.Identity, serviceType string) {
   291  	// Reading stale session before starting the clean up in goroutine.
   292  	// This is required to make sure we are not cleaning the newly created session.
   293  	for _, session := range manager.sessionStorage.GetAll() {
   294  		if consumerID != session.ConsumerID {
   295  			continue
   296  		}
   297  		if serviceType != session.Proposal.ServiceType {
   298  			continue
   299  		}
   300  		log.Info().Msgf("Cleaning stale session %s for %s consumer", session.ID, consumerID.Address)
   301  		go session.Close()
   302  	}
   303  }
   304  
   305  // Destroy destroys session by given sessionID
   306  func (manager *SessionManager) Destroy(consumerID identity.Identity, sessionID string) error {
   307  	session, found := manager.sessionStorage.Find(session.ID(sessionID))
   308  	if !found {
   309  		return ErrorSessionNotExists
   310  	}
   311  	if session.ConsumerID != consumerID {
   312  		return ErrorWrongSessionOwner
   313  	}
   314  
   315  	session.Close()
   316  	return nil
   317  }
   318  
   319  func (manager *SessionManager) paymentLoop(session *Session, price market.Price) error {
   320  	trace := session.tracer.StartStage("Provider session create (payment)")
   321  	defer session.tracer.EndStage(trace)
   322  
   323  	log.Info().Msg("Using new payments")
   324  
   325  	chainID := config.GetInt64(config.FlagChainID)
   326  	engine, err := manager.paymentEngineFactory(manager.service.ProviderID, session.ConsumerID, chainID, session.HermesID, string(session.ID), manager.paymentEngineChan, price)
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	// stop the balance tracker once the session is finished
   332  	session.addCleanup(func() error {
   333  		engine.Stop()
   334  		return nil
   335  	})
   336  
   337  	go func() {
   338  		err := engine.Start()
   339  		if err != nil {
   340  			log.Error().Err(err).Msg("Payment engine error")
   341  			session.Close()
   342  		}
   343  	}()
   344  
   345  	log.Info().Msg("Waiting for a first invoice to be paid")
   346  	if err := engine.WaitFirstInvoice(30 * time.Second); err != nil {
   347  		return fmt.Errorf("first invoice was not paid: %w", err)
   348  	}
   349  
   350  	return nil
   351  }
   352  
   353  func (manager *SessionManager) providerService(session *Session, channel p2p.Channel) (pb.SessionResponse, error) {
   354  	trace := session.tracer.StartStage("Provider session create (configure)")
   355  	defer session.tracer.EndStage(trace)
   356  
   357  	config, err := manager.service.Service().ProvideConfig(string(session.ID), session.request.GetConfig(), channel.ServiceConn())
   358  	if err != nil {
   359  		return pb.SessionResponse{}, fmt.Errorf("cannot get provider config for session %s: %w", string(session.ID), err)
   360  	}
   361  
   362  	if config.SessionDestroyCallback != nil {
   363  		session.addCleanup(func() error {
   364  			config.SessionDestroyCallback()
   365  			return nil
   366  		})
   367  	}
   368  
   369  	data, err := json.Marshal(config.SessionServiceConfig)
   370  	if err != nil {
   371  		return pb.SessionResponse{}, fmt.Errorf("cannot pack session %s service config: %w", string(session.ID), err)
   372  	}
   373  
   374  	return pb.SessionResponse{
   375  		ID:          string(session.ID),
   376  		PaymentInfo: "v3",
   377  		Config:      data,
   378  	}, nil
   379  }
   380  
   381  func (manager *SessionManager) keepAliveLoop(sess *Session, channel p2p.Channel) {
   382  	// Register handler for handling p2p keep alive pings from consumer.
   383  	channel.Handle(p2p.TopicKeepAlive, func(c p2p.Context) error {
   384  		var ping pb.P2PKeepAlivePing
   385  		if err := c.Request().UnmarshalProto(&ping); err != nil {
   386  			return err
   387  		}
   388  
   389  		log.Debug().Msgf("Received p2p keepalive ping with SessionID=%s from %s", ping.SessionID, c.PeerID().ToCommonAddress())
   390  		return c.OK()
   391  	})
   392  
   393  	// Send pings to consumer.
   394  	var errCount int
   395  	for {
   396  		select {
   397  		case <-sess.Done():
   398  			return
   399  		case <-time.After(manager.config.KeepAlive.SendInterval):
   400  			if err := manager.sendKeepAlivePing(channel, sess.ID); err != nil {
   401  				log.Err(err).Msgf("Failed to send p2p keepalive ping. SessionID=%s", sess.ID)
   402  				errCount++
   403  				if errCount == manager.config.KeepAlive.MaxSendErrCount {
   404  					log.Error().Msgf("Max p2p keepalive err count reached, closing SessionID=%s", sess.ID)
   405  					sess.Close()
   406  					return
   407  				}
   408  			} else {
   409  				errCount = 0
   410  			}
   411  		}
   412  	}
   413  }
   414  
   415  func (manager *SessionManager) sendKeepAlivePing(channel p2p.Channel, sessionID session.ID) error {
   416  	ctx, cancel := context.WithTimeout(context.Background(), manager.config.KeepAlive.SendTimeout)
   417  	defer cancel()
   418  	msg := &pb.P2PKeepAlivePing{
   419  		SessionID: string(sessionID),
   420  	}
   421  
   422  	start := time.Now()
   423  	_, err := channel.Send(ctx, p2p.TopicKeepAlive, p2p.ProtoMessage(msg))
   424  	manager.publisher.Publish(quality.AppTopicProviderPingP2P, quality.PingEvent{
   425  		SessionID: string(sessionID),
   426  		Duration:  time.Since(start),
   427  	})
   428  
   429  	return err
   430  }