github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/core/connection/manager.go (about)

     1  /*
     2   * Copyright (C) 2017 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package connection
    19  
    20  import (
    21  	"context"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"math/big"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/ethereum/go-ethereum/common"
    30  	"github.com/gofrs/uuid"
    31  	"github.com/rs/zerolog/log"
    32  
    33  	"github.com/mysteriumnetwork/node/config"
    34  	"github.com/mysteriumnetwork/node/core/connection/connectionstate"
    35  	"github.com/mysteriumnetwork/node/core/discovery/proposal"
    36  	"github.com/mysteriumnetwork/node/core/ip"
    37  	"github.com/mysteriumnetwork/node/core/location"
    38  	"github.com/mysteriumnetwork/node/core/quality"
    39  	"github.com/mysteriumnetwork/node/eventbus"
    40  	"github.com/mysteriumnetwork/node/firewall"
    41  	"github.com/mysteriumnetwork/node/identity"
    42  	"github.com/mysteriumnetwork/node/market"
    43  	"github.com/mysteriumnetwork/node/p2p"
    44  	"github.com/mysteriumnetwork/node/pb"
    45  	"github.com/mysteriumnetwork/node/session"
    46  	"github.com/mysteriumnetwork/node/session/connectivity"
    47  	"github.com/mysteriumnetwork/node/trace"
    48  )
    49  
    50  const (
    51  	p2pDialTimeout = 60 * time.Second
    52  )
    53  
    54  var (
    55  	// ErrNoConnection error indicates that action applied to manager expects active connection (i.e. disconnect)
    56  	ErrNoConnection = errors.New("no connection exists")
    57  	// ErrAlreadyExists error indicates that action applied to manager expects no active connection (i.e. connect)
    58  	ErrAlreadyExists = errors.New("connection already exists")
    59  	// ErrConnectionCancelled indicates that connection in progress was cancelled by request of api user
    60  	ErrConnectionCancelled = errors.New("connection was cancelled")
    61  	// ErrConnectionFailed indicates that Connect method didn't reach "Connected" phase due to connection error
    62  	ErrConnectionFailed = errors.New("connection has failed")
    63  	// ErrUnsupportedServiceType indicates that target proposal contains unsupported service type
    64  	ErrUnsupportedServiceType = errors.New("unsupported service type in proposal")
    65  	// ErrInsufficientBalance indicates consumer has insufficient balance to connect to selected proposal
    66  	ErrInsufficientBalance = errors.New("insufficient balance")
    67  	// ErrUnlockRequired indicates that the consumer identity has not been unlocked yet
    68  	ErrUnlockRequired = errors.New("unlock required")
    69  )
    70  
    71  // IPCheckConfig contains common params for connection ip check.
    72  type IPCheckConfig struct {
    73  	MaxAttempts             int
    74  	SleepDurationAfterCheck time.Duration
    75  }
    76  
    77  // KeepAliveConfig contains keep alive options.
    78  type KeepAliveConfig struct {
    79  	SendInterval    time.Duration
    80  	SendTimeout     time.Duration
    81  	MaxSendErrCount int
    82  }
    83  
    84  // Config contains common configuration options for connection manager.
    85  type Config struct {
    86  	IPCheck   IPCheckConfig
    87  	KeepAlive KeepAliveConfig
    88  }
    89  
    90  // DefaultConfig returns default params.
    91  func DefaultConfig() Config {
    92  	return Config{
    93  		IPCheck: IPCheckConfig{
    94  			MaxAttempts:             6,
    95  			SleepDurationAfterCheck: 3 * time.Second,
    96  		},
    97  		KeepAlive: KeepAliveConfig{
    98  			SendInterval:    5 * time.Second,
    99  			SendTimeout:     5 * time.Second,
   100  			MaxSendErrCount: 3,
   101  		},
   102  	}
   103  }
   104  
   105  // Creator creates new connection by given options and uses state channel to report state changes
   106  type Creator func(serviceType string) (Connection, error)
   107  
   108  // ConnectionStart start new connection with a given options.
   109  type ConnectionStart func(context.Context, ConnectOptions) error
   110  
   111  // PaymentIssuer handles the payments for service
   112  type PaymentIssuer interface {
   113  	Start() error
   114  	SetSessionID(string)
   115  	Stop()
   116  }
   117  
   118  // PriceGetter fetches the current price.
   119  type PriceGetter interface {
   120  	GetCurrentPrice(nodeType string, country string) (market.Price, error)
   121  }
   122  
   123  type validator interface {
   124  	Validate(chainID int64, consumerID identity.Identity, p market.Price) error
   125  }
   126  
   127  // TimeGetter function returns current time
   128  type TimeGetter func() time.Time
   129  
   130  // PaymentEngineFactory creates a new payment issuer from the given params
   131  type PaymentEngineFactory func(senderUUID string, channel p2p.Channel, consumer, provider identity.Identity, hermes common.Address, proposal proposal.PricedServiceProposal, price market.Price) (PaymentIssuer, error)
   132  
   133  // ProposalLookup returns a service proposal based on predefined conditions.
   134  type ProposalLookup func() (proposal *proposal.PricedServiceProposal, err error)
   135  
   136  type connectionManager struct {
   137  	// These are passed on creation.
   138  	paymentEngineFactory PaymentEngineFactory
   139  	newConnection        Creator
   140  	eventBus             eventbus.EventBus
   141  	ipResolver           ip.Resolver
   142  	locationResolver     location.OriginResolver
   143  	config               Config
   144  	statsReportInterval  time.Duration
   145  	validator            validator
   146  	p2pDialer            p2p.Dialer
   147  	timeGetter           TimeGetter
   148  
   149  	// These are populated by Connect at runtime.
   150  	ctx                    context.Context
   151  	ctxLock                sync.RWMutex
   152  	status                 connectionstate.Status
   153  	statusLock             sync.RWMutex
   154  	cleanupLock            sync.Mutex
   155  	cleanup                []func() error
   156  	cleanupAfterDisconnect []func() error
   157  	cleanupFinished        chan struct{}
   158  	cleanupFinishedLock    sync.Mutex
   159  	acknowledge            func()
   160  	cancel                 func()
   161  	channel                p2p.Channel
   162  
   163  	preReconnect  func()
   164  	postReconnect func()
   165  
   166  	discoLock      sync.Mutex
   167  	connectOptions ConnectOptions
   168  
   169  	activeConnection Connection
   170  	statsTracker     statsTracker
   171  
   172  	uuid string
   173  }
   174  
   175  // NewManager creates connection manager with given dependencies
   176  func NewManager(
   177  	paymentEngineFactory PaymentEngineFactory,
   178  	connectionCreator Creator,
   179  	eventBus eventbus.EventBus,
   180  	ipResolver ip.Resolver,
   181  	locationResolver location.OriginResolver,
   182  	config Config,
   183  	statsReportInterval time.Duration,
   184  	validator validator,
   185  	p2pDialer p2p.Dialer,
   186  	preReconnect, postReconnect func(),
   187  ) *connectionManager {
   188  	uuid, err := uuid.NewV4()
   189  	if err != nil {
   190  		panic(err) // This should never happen.
   191  	}
   192  
   193  	m := &connectionManager{
   194  		newConnection:        connectionCreator,
   195  		status:               connectionstate.Status{State: connectionstate.NotConnected},
   196  		eventBus:             eventBus,
   197  		paymentEngineFactory: paymentEngineFactory,
   198  		cleanup:              make([]func() error, 0),
   199  		cleanupFinished:      make(chan struct{}, 1),
   200  		ipResolver:           ipResolver,
   201  		locationResolver:     locationResolver,
   202  		config:               config,
   203  		statsReportInterval:  statsReportInterval,
   204  		validator:            validator,
   205  		p2pDialer:            p2pDialer,
   206  		timeGetter:           time.Now,
   207  		preReconnect:         preReconnect,
   208  		postReconnect:        postReconnect,
   209  		uuid:                 uuid.String(),
   210  	}
   211  
   212  	m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold)
   213  
   214  	return m
   215  }
   216  
   217  func (m *connectionManager) chainID() int64 {
   218  	return config.GetInt64(config.FlagChainID)
   219  }
   220  
   221  func (m *connectionManager) Connect(consumerID identity.Identity, hermesID common.Address, proposalLookup ProposalLookup, params ConnectParams) (err error) {
   222  	var sessionID session.ID
   223  
   224  	proposal, err := proposalLookup()
   225  	if err != nil {
   226  		return fmt.Errorf("failed to lookup proposal: %w", err)
   227  	}
   228  
   229  	tracer := trace.NewTracer("Consumer whole Connect")
   230  	defer func() {
   231  		traceResult := tracer.Finish(m.eventBus, string(sessionID))
   232  		log.Debug().Msgf("Consumer connection trace: %s", traceResult)
   233  	}()
   234  
   235  	// make sure cache is cleared when connect terminates at any stage as part of disconnect
   236  	// we assume that IPResolver might be used / cache IP before connect
   237  	m.addCleanup(func() error {
   238  		m.clearIPCache()
   239  		return nil
   240  	})
   241  
   242  	if m.Status().State != connectionstate.NotConnected {
   243  		return ErrAlreadyExists
   244  	}
   245  
   246  	prc := m.priceFromProposal(*proposal)
   247  
   248  	err = m.validator.Validate(m.chainID(), consumerID, prc)
   249  	if err != nil {
   250  		return err
   251  	}
   252  
   253  	m.ctxLock.Lock()
   254  	m.ctx, m.cancel = context.WithCancel(context.Background())
   255  	m.ctxLock.Unlock()
   256  
   257  	m.statusConnecting(consumerID, hermesID, *proposal)
   258  	defer func() {
   259  		if err != nil {
   260  			log.Err(err).Msg("Connect failed, disconnecting")
   261  			m.disconnect()
   262  		}
   263  	}()
   264  
   265  	m.connectOptions = ConnectOptions{
   266  		ConsumerID:     consumerID,
   267  		HermesID:       hermesID,
   268  		Proposal:       *proposal,
   269  		ProposalLookup: proposalLookup,
   270  		Params:         params,
   271  	}
   272  
   273  	m.activeConnection, err = m.newConnection(proposal.ServiceType)
   274  	if err != nil {
   275  		return err
   276  	}
   277  
   278  	sessionID, err = m.initSession(tracer, prc)
   279  	if err != nil {
   280  		return err
   281  	}
   282  
   283  	originalPublicIP := m.getPublicIP()
   284  
   285  	err = m.startConnection(m.currentCtx(), m.activeConnection, m.activeConnection.Start, m.connectOptions, tracer)
   286  	if err != nil {
   287  		return m.handleStartError(sessionID, err)
   288  	}
   289  
   290  	err = m.waitForConnectedState(m.activeConnection.State())
   291  	if err != nil {
   292  		return m.handleStartError(sessionID, err)
   293  	}
   294  
   295  	m.statsTracker = newStatsTracker(m.eventBus, m.statsReportInterval)
   296  	go m.statsTracker.start(m, m.activeConnection)
   297  	m.addCleanup(func() error {
   298  		log.Trace().Msg("Cleaning: stopping statistics publisher")
   299  		defer log.Trace().Msg("Cleaning: stopping statistics publisher DONE")
   300  		m.statsTracker.stop()
   301  		return nil
   302  	})
   303  
   304  	go m.consumeConnectionStates(m.activeConnection.State())
   305  	go m.checkSessionIP(m.channel, m.connectOptions.ConsumerID, m.connectOptions.SessionID, originalPublicIP)
   306  
   307  	return nil
   308  }
   309  
   310  func (m *connectionManager) autoReconnect() (err error) {
   311  	var sessionID session.ID
   312  
   313  	tracer := trace.NewTracer("Consumer whole autoReconnect")
   314  	defer func() {
   315  		traceResult := tracer.Finish(m.eventBus, string(sessionID))
   316  		log.Debug().Msgf("Consumer connection trace: %s", traceResult)
   317  	}()
   318  
   319  	proposal, err := m.connectOptions.ProposalLookup()
   320  	if err != nil {
   321  		return fmt.Errorf("failed to lookup proposal: %w", err)
   322  	}
   323  
   324  	m.connectOptions.Proposal = *proposal
   325  
   326  	sessionID, err = m.initSession(tracer, m.priceFromProposal(m.connectOptions.Proposal))
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	err = m.startConnection(m.currentCtx(), m.activeConnection, m.activeConnection.Reconnect, m.connectOptions, tracer)
   332  	if err != nil {
   333  		return m.handleStartError(sessionID, err)
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  func (m *connectionManager) priceFromProposal(proposal proposal.PricedServiceProposal) market.Price {
   340  	p := market.Price{
   341  		PricePerHour: proposal.Price.PricePerHour,
   342  		PricePerGiB:  proposal.Price.PricePerGiB,
   343  	}
   344  
   345  	if config.GetBool(config.FlagPaymentsDuringSessionDebug) {
   346  		log.Info().Msg("Payments debug bas been enabled, will use absurd amounts for the proposal price")
   347  		amount := config.GetUInt64(config.FlagPaymentsAmountDuringSessionDebug)
   348  		if amount == 0 {
   349  			amount = 5000000000000000000
   350  		}
   351  
   352  		p = market.Price{
   353  			PricePerHour: new(big.Int).SetUint64(amount),
   354  			PricePerGiB:  new(big.Int).SetUint64(amount),
   355  		}
   356  	}
   357  
   358  	return p
   359  }
   360  
   361  func (m *connectionManager) initSession(tracer *trace.Tracer, prc market.Price) (sessionID session.ID, err error) {
   362  	err = m.createP2PChannel(m.connectOptions, tracer)
   363  	if err != nil {
   364  		return sessionID, fmt.Errorf("could not create p2p channel during connect: %w", err)
   365  	}
   366  
   367  	m.connectOptions.ProviderNATConn = m.channel.ServiceConn()
   368  	m.connectOptions.ChannelConn = m.channel.Conn()
   369  
   370  	paymentSession, err := m.paymentLoop(m.connectOptions, prc)
   371  	if err != nil {
   372  		return sessionID, err
   373  	}
   374  
   375  	sessionDTO, err := m.createP2PSession(m.activeConnection, m.connectOptions, tracer, prc)
   376  	sessionID = session.ID(sessionDTO.GetID())
   377  	if err != nil {
   378  		m.sendSessionStatus(m.channel, m.connectOptions.ConsumerID, sessionID, connectivity.StatusSessionEstablishmentFailed, err)
   379  		return sessionID, err
   380  	}
   381  
   382  	traceStart := tracer.StartStage("Consumer session creation (start)")
   383  	go m.keepAliveLoop(m.channel, sessionID)
   384  	m.setStatus(func(status *connectionstate.Status) {
   385  		status.SessionID = sessionID
   386  	})
   387  	m.publishSessionCreate(sessionID)
   388  	paymentSession.SetSessionID(string(sessionID))
   389  	tracer.EndStage(traceStart)
   390  
   391  	m.connectOptions.SessionID = sessionID
   392  	m.connectOptions.SessionConfig = sessionDTO.GetConfig()
   393  
   394  	return sessionID, nil
   395  }
   396  
   397  func (m *connectionManager) handleStartError(sessionID session.ID, err error) error {
   398  	if errors.Is(err, context.Canceled) {
   399  		return ErrConnectionCancelled
   400  	}
   401  	m.addCleanupAfterDisconnect(func() error {
   402  		return m.sendSessionStatus(m.channel, m.connectOptions.ConsumerID, sessionID, connectivity.StatusConnectionFailed, err)
   403  	})
   404  	m.publishStateEvent(connectionstate.StateConnectionFailed)
   405  
   406  	log.Info().Err(err).Msg("Cancelling connection initiation: ")
   407  	m.Cancel()
   408  	return err
   409  }
   410  
   411  func (m *connectionManager) clearIPCache() {
   412  	if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) {
   413  		return
   414  	}
   415  
   416  	if cr, ok := m.ipResolver.(*ip.CachedResolver); ok {
   417  		cr.ClearCache()
   418  	}
   419  }
   420  
   421  // checkSessionIP checks if IP has changed after connection was established.
   422  func (m *connectionManager) checkSessionIP(channel p2p.Channel, consumerID identity.Identity, sessionID session.ID, originalPublicIP string) {
   423  	if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) {
   424  		return
   425  	}
   426  
   427  	for i := 1; i <= m.config.IPCheck.MaxAttempts; i++ {
   428  		// Skip check if not connected. This may happen when context was canceled via Disconnect.
   429  		if m.Status().State != connectionstate.Connected {
   430  			return
   431  		}
   432  
   433  		newPublicIP := m.getPublicIP()
   434  		// If ip is changed notify peer that connection is successful.
   435  		if originalPublicIP != newPublicIP {
   436  			m.sendSessionStatus(channel, consumerID, sessionID, connectivity.StatusConnectionOk, nil)
   437  			return
   438  		}
   439  
   440  		// Notify peer and quality oracle that ip is not changed after tunnel connection was established.
   441  		if i == m.config.IPCheck.MaxAttempts {
   442  			m.sendSessionStatus(channel, consumerID, sessionID, connectivity.StatusSessionIPNotChanged, nil)
   443  			m.publishStateEvent(connectionstate.StateIPNotChanged)
   444  			return
   445  		}
   446  
   447  		time.Sleep(m.config.IPCheck.SleepDurationAfterCheck)
   448  	}
   449  }
   450  
   451  // sendSessionStatus sends session connectivity status to other peer.
   452  func (m *connectionManager) sendSessionStatus(channel p2p.ChannelSender, consumerID identity.Identity, sessionID session.ID, code connectivity.StatusCode, errDetails error) error {
   453  	var errDetailsMsg string
   454  	if errDetails != nil {
   455  		errDetailsMsg = errDetails.Error()
   456  	}
   457  
   458  	sessionStatus := &pb.SessionStatus{
   459  		ConsumerID: consumerID.Address,
   460  		SessionID:  string(sessionID),
   461  		Code:       uint32(code),
   462  		Message:    errDetailsMsg,
   463  	}
   464  
   465  	log.Debug().Msgf("Sending session status P2P message to %q: %s", p2p.TopicSessionStatus, sessionStatus.String())
   466  
   467  	ctx, cancel := context.WithTimeout(m.currentCtx(), 20*time.Second)
   468  	defer cancel()
   469  	_, err := channel.Send(ctx, p2p.TopicSessionStatus, p2p.ProtoMessage(sessionStatus))
   470  	if err != nil {
   471  		return fmt.Errorf("could not send p2p session status message: %w", err)
   472  	}
   473  
   474  	return nil
   475  }
   476  
   477  func (m *connectionManager) getPublicIP() string {
   478  	currentPublicIP, err := m.ipResolver.GetPublicIP()
   479  	if err != nil {
   480  		log.Error().Err(err).Msg("Could not get current public IP")
   481  		return ""
   482  	}
   483  	return currentPublicIP
   484  }
   485  
   486  func (m *connectionManager) paymentLoop(opts ConnectOptions, price market.Price) (PaymentIssuer, error) {
   487  	payments, err := m.paymentEngineFactory(m.uuid, m.channel, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.HermesID, opts.Proposal, price)
   488  	if err != nil {
   489  		return nil, err
   490  	}
   491  	m.addCleanup(func() error {
   492  		log.Trace().Msg("Cleaning: payments")
   493  		defer log.Trace().Msg("Cleaning: payments DONE")
   494  		payments.Stop()
   495  		return nil
   496  	})
   497  
   498  	go func() {
   499  		err := payments.Start()
   500  		if err != nil {
   501  			log.Error().Err(err).Msg("Payment error")
   502  
   503  			if config.GetBool(config.FlagKeepConnectedOnFail) {
   504  				m.statusOnHold()
   505  			} else {
   506  				err = m.Disconnect()
   507  				if err != nil {
   508  					log.Error().Err(err).Msg("Could not disconnect gracefully")
   509  				}
   510  			}
   511  		}
   512  	}()
   513  	return payments, nil
   514  }
   515  
   516  func (m *connectionManager) cleanConnection() {
   517  	m.cleanupLock.Lock()
   518  	defer m.cleanupLock.Unlock()
   519  
   520  	for i := len(m.cleanup) - 1; i >= 0; i-- {
   521  		log.Trace().Msgf("Connection cleaning up: (%v/%v)", i+1, len(m.cleanup))
   522  		err := m.cleanup[i]()
   523  		if err != nil {
   524  			log.Warn().Err(err).Msg("Cleanup error")
   525  		}
   526  	}
   527  	m.cleanup = nil
   528  }
   529  
   530  func (m *connectionManager) cleanAfterDisconnect() {
   531  	m.cleanupLock.Lock()
   532  	defer m.cleanupLock.Unlock()
   533  
   534  	for i := len(m.cleanupAfterDisconnect) - 1; i >= 0; i-- {
   535  		log.Trace().Msgf("Connection cleaning up (after disconnect): (%v/%v)", i+1, len(m.cleanupAfterDisconnect))
   536  		err := m.cleanupAfterDisconnect[i]()
   537  		if err != nil {
   538  			log.Warn().Err(err).Msg("Cleanup error")
   539  		}
   540  	}
   541  	m.cleanupAfterDisconnect = nil
   542  }
   543  
   544  func (m *connectionManager) createP2PChannel(opts ConnectOptions, tracer *trace.Tracer) error {
   545  	trace := tracer.StartStage("Consumer P2P channel creation")
   546  	defer tracer.EndStage(trace)
   547  
   548  	contactDef, err := p2p.ParseContact(opts.Proposal.Contacts)
   549  	if err != nil {
   550  		return fmt.Errorf("provider does not support p2p communication: %w", err)
   551  	}
   552  
   553  	timeoutCtx, cancel := context.WithTimeout(m.currentCtx(), p2pDialTimeout)
   554  	defer cancel()
   555  
   556  	// TODO register all handlers before channel read/write loops
   557  	channel, err := m.p2pDialer.Dial(timeoutCtx, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.Proposal.ServiceType, contactDef, tracer)
   558  	if err != nil {
   559  		return fmt.Errorf("p2p dialer failed: %w", err)
   560  	}
   561  	m.addCleanupAfterDisconnect(func() error {
   562  		log.Trace().Msg("Cleaning: closing P2P communication channel")
   563  		defer log.Trace().Msg("Cleaning: P2P communication channel DONE")
   564  
   565  		return channel.Close()
   566  	})
   567  
   568  	m.channel = channel
   569  	return nil
   570  }
   571  
   572  func (m *connectionManager) addCleanupAfterDisconnect(fn func() error) {
   573  	m.cleanupLock.Lock()
   574  	defer m.cleanupLock.Unlock()
   575  	m.cleanupAfterDisconnect = append(m.cleanupAfterDisconnect, fn)
   576  }
   577  
   578  func (m *connectionManager) addCleanup(fn func() error) {
   579  	m.cleanupLock.Lock()
   580  	defer m.cleanupLock.Unlock()
   581  	m.cleanup = append(m.cleanup, fn)
   582  }
   583  
   584  func (m *connectionManager) createP2PSession(c Connection, opts ConnectOptions, tracer *trace.Tracer, requestedPrice market.Price) (*pb.SessionResponse, error) {
   585  	trace := tracer.StartStage("Consumer session creation")
   586  	defer tracer.EndStage(trace)
   587  
   588  	sessionCreateConfig, err := c.GetConfig()
   589  	if err != nil {
   590  		return nil, fmt.Errorf("could not get session config: %w", err)
   591  	}
   592  
   593  	config, err := json.Marshal(sessionCreateConfig)
   594  	if err != nil {
   595  		return nil, fmt.Errorf("could not marshal session config: %w", err)
   596  	}
   597  
   598  	sessionRequest := &pb.SessionRequest{
   599  		Consumer: &pb.ConsumerInfo{
   600  			Id:             opts.ConsumerID.Address,
   601  			HermesID:       opts.HermesID.Hex(),
   602  			PaymentVersion: "v3",
   603  			Location: &pb.LocationInfo{
   604  				Country: m.Status().ConsumerLocation.Country,
   605  			},
   606  			Pricing: &pb.Pricing{
   607  				PerGib:  requestedPrice.PricePerGiB.Bytes(),
   608  				PerHour: requestedPrice.PricePerHour.Bytes(),
   609  			},
   610  		},
   611  		ProposalID: opts.Proposal.ID,
   612  		Config:     config,
   613  	}
   614  	log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionCreate, sessionRequest.String())
   615  	ctx, cancel := context.WithTimeout(m.currentCtx(), 20*time.Second)
   616  	defer cancel()
   617  	res, err := m.channel.Send(ctx, p2p.TopicSessionCreate, p2p.ProtoMessage(sessionRequest))
   618  	if err != nil {
   619  		return nil, fmt.Errorf("could not send p2p session create request: %w", err)
   620  	}
   621  
   622  	var sessionResponse pb.SessionResponse
   623  	err = res.UnmarshalProto(&sessionResponse)
   624  	if err != nil {
   625  		return nil, fmt.Errorf("could not unmarshal session reply to proto: %w", err)
   626  	}
   627  
   628  	channel := m.channel
   629  	m.acknowledge = func() {
   630  		pc := &pb.SessionInfo{
   631  			ConsumerID: opts.ConsumerID.Address,
   632  			SessionID:  sessionResponse.GetID(),
   633  		}
   634  		log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionAcknowledge, pc.String())
   635  		ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
   636  		defer cancel()
   637  		_, err := channel.Send(ctx, p2p.TopicSessionAcknowledge, p2p.ProtoMessage(pc))
   638  		if err != nil {
   639  			log.Warn().Err(err).Msg("Acknowledge failed")
   640  		}
   641  	}
   642  	m.addCleanupAfterDisconnect(func() error {
   643  		log.Trace().Msg("Cleaning: requesting session destroy")
   644  		defer log.Trace().Msg("Cleaning: requesting session destroy DONE")
   645  
   646  		sessionDestroy := &pb.SessionInfo{
   647  			ConsumerID: opts.ConsumerID.Address,
   648  			SessionID:  sessionResponse.GetID(),
   649  		}
   650  
   651  		log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionDestroy, sessionDestroy.String())
   652  		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   653  		defer cancel()
   654  		_, err := m.channel.Send(ctx, p2p.TopicSessionDestroy, p2p.ProtoMessage(sessionDestroy))
   655  		if err != nil {
   656  			return fmt.Errorf("could not send session destroy request: %w", err)
   657  		}
   658  
   659  		return nil
   660  	})
   661  
   662  	return &sessionResponse, nil
   663  }
   664  
   665  func (m *connectionManager) publishSessionCreate(sessionID session.ID) {
   666  	sessionInfo := m.Status()
   667  	// avoid printing IP address in logs
   668  	sessionInfo.ConsumerLocation.IP = ""
   669  
   670  	m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{
   671  		Status:      connectionstate.SessionCreatedStatus,
   672  		SessionInfo: sessionInfo,
   673  	})
   674  
   675  	m.addCleanup(func() error {
   676  		log.Trace().Msg("Cleaning: publishing session ended status")
   677  		defer log.Trace().Msg("Cleaning: publishing session ended status DONE")
   678  
   679  		sessionInfo := m.Status()
   680  		// avoid printing IP address in logs
   681  		sessionInfo.ConsumerLocation.IP = ""
   682  
   683  		m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{
   684  			Status:      connectionstate.SessionEndedStatus,
   685  			SessionInfo: sessionInfo,
   686  		})
   687  		return nil
   688  	})
   689  }
   690  
   691  func (m *connectionManager) startConnection(ctx context.Context, conn Connection, start ConnectionStart, connectOptions ConnectOptions, tracer *trace.Tracer) (err error) {
   692  	trace := tracer.StartStage("Consumer start connection")
   693  	defer tracer.EndStage(trace)
   694  
   695  	if err = start(ctx, connectOptions); err != nil {
   696  		return err
   697  	}
   698  	m.addCleanup(func() error {
   699  		log.Trace().Msg("Cleaning: stopping connection")
   700  		defer log.Trace().Msg("Cleaning: stopping connection DONE")
   701  		conn.Stop()
   702  		return nil
   703  	})
   704  
   705  	err = m.setupTrafficBlock(connectOptions.Params.DisableKillSwitch)
   706  	if err != nil {
   707  		return err
   708  	}
   709  
   710  	// Clear IP cache so session IP check can report that IP has really changed.
   711  	m.clearIPCache()
   712  
   713  	return nil
   714  }
   715  
   716  func (m *connectionManager) Status() connectionstate.Status {
   717  	m.statusLock.RLock()
   718  	defer m.statusLock.RUnlock()
   719  
   720  	return m.status
   721  }
   722  
   723  func (m *connectionManager) UUID() string {
   724  	m.statusLock.RLock()
   725  	defer m.statusLock.RUnlock()
   726  
   727  	return m.uuid
   728  }
   729  
   730  func (m *connectionManager) Stats() connectionstate.Statistics {
   731  	return m.statsTracker.stats()
   732  }
   733  
   734  func (m *connectionManager) setStatus(delta func(status *connectionstate.Status)) {
   735  	m.statusLock.Lock()
   736  	stateWas := m.status.State
   737  
   738  	delta(&m.status)
   739  
   740  	state := m.status.State
   741  	m.statusLock.Unlock()
   742  
   743  	if state != stateWas {
   744  		log.Info().Msgf("Connection state: %v -> %v", stateWas, state)
   745  		m.publishStateEvent(state)
   746  	}
   747  }
   748  
   749  func (m *connectionManager) statusConnecting(consumerID identity.Identity, accountantID common.Address, proposal proposal.PricedServiceProposal) {
   750  	m.setStatus(func(status *connectionstate.Status) {
   751  		*status = connectionstate.Status{
   752  			StartedAt:        m.timeGetter(),
   753  			ConsumerID:       consumerID,
   754  			ConsumerLocation: m.locationResolver.GetOrigin(),
   755  			HermesID:         accountantID,
   756  			Proposal:         proposal,
   757  			State:            connectionstate.Connecting,
   758  		}
   759  	})
   760  }
   761  
   762  func (m *connectionManager) statusConnected() {
   763  	m.setStatus(func(status *connectionstate.Status) {
   764  		status.State = connectionstate.Connected
   765  	})
   766  }
   767  
   768  func (m *connectionManager) statusReconnecting() {
   769  	m.setStatus(func(status *connectionstate.Status) {
   770  		status.State = connectionstate.Reconnecting
   771  	})
   772  }
   773  
   774  func (m *connectionManager) statusNotConnected() {
   775  	m.setStatus(func(status *connectionstate.Status) {
   776  		status.State = connectionstate.NotConnected
   777  	})
   778  }
   779  
   780  func (m *connectionManager) statusDisconnecting() {
   781  	m.setStatus(func(status *connectionstate.Status) {
   782  		status.State = connectionstate.Disconnecting
   783  	})
   784  }
   785  
   786  func (m *connectionManager) statusCanceled() {
   787  	m.setStatus(func(status *connectionstate.Status) {
   788  		status.State = connectionstate.Canceled
   789  	})
   790  }
   791  
   792  func (m *connectionManager) statusOnHold() {
   793  	m.setStatus(func(status *connectionstate.Status) {
   794  		status.State = connectionstate.StateOnHold
   795  	})
   796  }
   797  
   798  func (m *connectionManager) Cancel() {
   799  	m.statusCanceled()
   800  	logDisconnectError(m.Disconnect())
   801  }
   802  
   803  func (m *connectionManager) Disconnect() error {
   804  	if m.Status().State == connectionstate.NotConnected {
   805  		return ErrNoConnection
   806  	}
   807  
   808  	m.statusDisconnecting()
   809  	m.disconnect()
   810  
   811  	return nil
   812  }
   813  
   814  func (m *connectionManager) CheckChannel(ctx context.Context) error {
   815  	if err := m.sendKeepAlivePing(ctx, m.channel, m.Status().SessionID); err != nil {
   816  		return fmt.Errorf("keep alive ping failed: %w", err)
   817  	}
   818  	return nil
   819  }
   820  
   821  func (m *connectionManager) disconnect() {
   822  	m.discoLock.Lock()
   823  	defer m.discoLock.Unlock()
   824  
   825  	m.cleanupFinishedLock.Lock()
   826  	defer m.cleanupFinishedLock.Unlock()
   827  	m.cleanupFinished = make(chan struct{})
   828  	defer close(m.cleanupFinished)
   829  
   830  	m.ctxLock.Lock()
   831  	m.cancel()
   832  	m.ctxLock.Unlock()
   833  
   834  	m.cleanConnection()
   835  	m.statusNotConnected()
   836  
   837  	m.cleanAfterDisconnect()
   838  }
   839  
   840  func (m *connectionManager) waitForConnectedState(stateChannel <-chan connectionstate.State) error {
   841  	log.Debug().Msg("waiting for connected state")
   842  	for {
   843  		select {
   844  		case state, more := <-stateChannel:
   845  			if !more {
   846  				return ErrConnectionFailed
   847  			}
   848  
   849  			switch state {
   850  			case connectionstate.Connected:
   851  				log.Debug().Msg("Connected started event received")
   852  				if m.acknowledge != nil {
   853  					go m.acknowledge()
   854  				}
   855  				m.onStateChanged(state)
   856  				return nil
   857  			default:
   858  				m.onStateChanged(state)
   859  			}
   860  		case <-m.currentCtx().Done():
   861  			return m.currentCtx().Err()
   862  		}
   863  	}
   864  }
   865  
   866  func (m *connectionManager) consumeConnectionStates(stateChannel <-chan connectionstate.State) {
   867  	for state := range stateChannel {
   868  		m.onStateChanged(state)
   869  	}
   870  }
   871  
   872  func (m *connectionManager) onStateChanged(state connectionstate.State) {
   873  	log.Debug().Msgf("Connection state received: %s", state)
   874  
   875  	// React just to certain stains from connection. Because disconnect happens in connectionWaiter
   876  	switch state {
   877  	case connectionstate.Connected:
   878  		m.statusConnected()
   879  	case connectionstate.Reconnecting:
   880  		m.statusReconnecting()
   881  	}
   882  }
   883  
   884  func (m *connectionManager) setupTrafficBlock(disableKillSwitch bool) error {
   885  	if disableKillSwitch {
   886  		return nil
   887  	}
   888  
   889  	outboundIP, err := m.ipResolver.GetOutboundIP()
   890  	if err != nil {
   891  		return err
   892  	}
   893  
   894  	removeRule, err := firewall.BlockNonTunnelTraffic(firewall.Session, outboundIP)
   895  	if err != nil {
   896  		return err
   897  	}
   898  	m.addCleanup(func() error {
   899  		log.Trace().Msg("Cleaning: traffic block rule")
   900  		defer log.Trace().Msg("Cleaning: traffic block rule DONE")
   901  
   902  		removeRule()
   903  
   904  		return nil
   905  	})
   906  	return nil
   907  }
   908  
   909  func (m *connectionManager) reconnectOnHold(state connectionstate.AppEventConnectionState) {
   910  	if state.State != connectionstate.StateOnHold || !config.GetBool(config.FlagAutoReconnect) {
   911  		return
   912  	}
   913  
   914  	if m.channel != nil {
   915  		m.channel.Close()
   916  	}
   917  
   918  	m.preReconnect()
   919  	m.clearIPCache()
   920  
   921  	for err := m.autoReconnect(); err != nil; err = m.autoReconnect() {
   922  		select {
   923  		case <-m.currentCtx().Done():
   924  			log.Info().Err(m.currentCtx().Err()).Msg("Stopping reconnect")
   925  			return
   926  		default:
   927  			log.Error().Err(err).Msg("Failed to reconnect active session, will try again")
   928  		}
   929  	}
   930  	m.postReconnect()
   931  }
   932  
   933  func (m *connectionManager) publishStateEvent(state connectionstate.State) {
   934  	sessionInfo := m.Status()
   935  	// avoid printing IP address in logs
   936  	sessionInfo.ConsumerLocation.IP = ""
   937  
   938  	m.eventBus.Publish(connectionstate.AppTopicConnectionState, connectionstate.AppEventConnectionState{
   939  		UUID:        m.uuid,
   940  		State:       state,
   941  		SessionInfo: sessionInfo,
   942  	})
   943  }
   944  
   945  func (m *connectionManager) keepAliveLoop(channel p2p.Channel, sessionID session.ID) {
   946  	// Register handler for handling p2p keep alive pings from provider.
   947  	channel.Handle(p2p.TopicKeepAlive, func(c p2p.Context) error {
   948  		var ping pb.P2PKeepAlivePing
   949  		if err := c.Request().UnmarshalProto(&ping); err != nil {
   950  			return err
   951  		}
   952  
   953  		log.Debug().Msgf("Received p2p keepalive ping with SessionID=%s from %s", ping.SessionID, c.PeerID().ToCommonAddress())
   954  		return c.OK()
   955  	})
   956  
   957  	// Send pings to provider.
   958  	var errCount int
   959  	for {
   960  		select {
   961  		case <-m.currentCtx().Done():
   962  			log.Debug().Msgf("Stopping p2p keepalive: %v", m.currentCtx().Err())
   963  			return
   964  		case <-time.After(m.config.KeepAlive.SendInterval):
   965  			ctx, cancel := context.WithTimeout(context.Background(), m.config.KeepAlive.SendTimeout)
   966  			if err := m.sendKeepAlivePing(ctx, channel, sessionID); err != nil {
   967  				log.Err(err).Msgf("Failed to send p2p keepalive ping. SessionID=%s", sessionID)
   968  				errCount++
   969  				if errCount == m.config.KeepAlive.MaxSendErrCount {
   970  					log.Error().Msgf("Max p2p keepalive err count reached, disconnecting. SessionID=%s", sessionID)
   971  					if config.GetBool(config.FlagKeepConnectedOnFail) {
   972  						m.statusOnHold()
   973  					} else {
   974  						m.Disconnect()
   975  					}
   976  					cancel()
   977  					return
   978  				}
   979  			} else {
   980  				errCount = 0
   981  			}
   982  			cancel()
   983  		}
   984  	}
   985  }
   986  
   987  func (m *connectionManager) sendKeepAlivePing(ctx context.Context, channel p2p.Channel, sessionID session.ID) error {
   988  	msg := &pb.P2PKeepAlivePing{
   989  		SessionID: string(sessionID),
   990  	}
   991  
   992  	start := time.Now()
   993  	_, err := channel.Send(ctx, p2p.TopicKeepAlive, p2p.ProtoMessage(msg))
   994  	if err != nil {
   995  		return err
   996  	}
   997  
   998  	m.eventBus.Publish(quality.AppTopicConsumerPingP2P, quality.PingEvent{
   999  		SessionID: string(sessionID),
  1000  		Duration:  time.Since(start),
  1001  	})
  1002  
  1003  	return nil
  1004  }
  1005  
  1006  func (m *connectionManager) currentCtx() context.Context {
  1007  	m.ctxLock.RLock()
  1008  	defer m.ctxLock.RUnlock()
  1009  
  1010  	return m.ctx
  1011  }
  1012  
  1013  func (m *connectionManager) Reconnect() {
  1014  	err := m.Disconnect()
  1015  	if err != nil {
  1016  		log.Error().Err(err).Msgf("Failed to disconnect stale session")
  1017  	}
  1018  	log.Info().Msg("Waiting for previous session to cleanup")
  1019  
  1020  	m.cleanupFinishedLock.Lock()
  1021  	defer m.cleanupFinishedLock.Unlock()
  1022  	<-m.cleanupFinished
  1023  	err = m.Connect(m.connectOptions.ConsumerID, m.connectOptions.HermesID, m.connectOptions.ProposalLookup, m.connectOptions.Params)
  1024  	if err != nil {
  1025  		log.Error().Err(err).Msgf("Failed to reconnect")
  1026  	}
  1027  }
  1028  
  1029  func logDisconnectError(err error) {
  1030  	if err != nil && err != ErrNoConnection {
  1031  		log.Error().Err(err).Msg("Disconnect error")
  1032  	}
  1033  }