github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/core/service/session_manager_test.go (about)

     1  /*
     2   * Copyright (C) 2017 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package service
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"math/big"
    25  	"net"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/mysteriumnetwork/node/core/policy/localcopy"
    32  
    33  	"github.com/ethereum/go-ethereum/common"
    34  	"github.com/stretchr/testify/assert"
    35  
    36  	"github.com/mysteriumnetwork/node/core/service/servicestate"
    37  	"github.com/mysteriumnetwork/node/identity"
    38  	"github.com/mysteriumnetwork/node/market"
    39  	"github.com/mysteriumnetwork/node/mocks"
    40  	"github.com/mysteriumnetwork/node/p2p"
    41  	"github.com/mysteriumnetwork/node/pb"
    42  	sessionEvent "github.com/mysteriumnetwork/node/session/event"
    43  	"github.com/mysteriumnetwork/node/trace"
    44  	"github.com/mysteriumnetwork/node/utils/reftracker"
    45  	"github.com/mysteriumnetwork/payments/crypto"
    46  )
    47  
    48  var (
    49  	currentProposalID = 68
    50  	currentProposal   = market.NewProposal("0x1", "mockservice", market.NewProposalOpts{})
    51  
    52  	mockTrustOracle = httptest.NewServer(
    53  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
    54  	)
    55  	currentService = NewInstance(
    56  		identity.FromAddress(currentProposal.ProviderID),
    57  		currentProposal.ServiceType,
    58  		struct{}{},
    59  		currentProposal,
    60  		servicestate.Running,
    61  		&mockService{},
    62  		localcopy.NewRepository(),
    63  		&mockDiscovery{},
    64  	)
    65  	consumerID = identity.FromAddress("deadbeef")
    66  	hermesID   = common.HexToAddress("0x1")
    67  )
    68  
    69  type mockBalanceTracker struct {
    70  	paymentError      error
    71  	firstPaymentError error
    72  }
    73  
    74  func (m mockBalanceTracker) Start() error {
    75  	return m.paymentError
    76  }
    77  
    78  func (m mockBalanceTracker) Stop() {
    79  }
    80  
    81  func (m mockBalanceTracker) WaitFirstInvoice(time.Duration) error {
    82  	return m.firstPaymentError
    83  }
    84  
    85  type mockP2PChannel struct {
    86  	tracer *trace.Tracer
    87  }
    88  
    89  func (m *mockP2PChannel) Send(_ context.Context, _ string, _ *p2p.Message) (*p2p.Message, error) {
    90  	return nil, nil
    91  }
    92  
    93  func (m *mockP2PChannel) Handle(topic string, handler p2p.HandlerFunc) {
    94  }
    95  
    96  func (m *mockP2PChannel) Tracer() *trace.Tracer {
    97  	return m.tracer
    98  }
    99  
   100  func (m *mockP2PChannel) ServiceConn() *net.UDPConn { return nil }
   101  
   102  func (m *mockP2PChannel) Conn() *net.UDPConn { return nil }
   103  
   104  func (m *mockP2PChannel) Close() error { return nil }
   105  
   106  func (m *mockP2PChannel) ID() string { return fmt.Sprintf("%p", m) }
   107  
   108  func TestManager_Start_StoresSession(t *testing.T) {
   109  	publisher := mocks.NewEventBus()
   110  	sessionStore := NewSessionPool(publisher)
   111  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, true)
   112  
   113  	_, err := manager.Start(&pb.SessionRequest{
   114  		Consumer: &pb.ConsumerInfo{
   115  			Id:       consumerID.Address,
   116  			HermesID: hermesID.String(),
   117  			Pricing: &pb.Pricing{
   118  				PerGib:  big.NewInt(1).Bytes(),
   119  				PerHour: big.NewInt(1).Bytes(),
   120  			},
   121  		},
   122  		ProposalID: int64(currentProposalID),
   123  	})
   124  	assert.NoError(t, err)
   125  
   126  	session := sessionStore.GetAll()[0]
   127  	assert.Equal(t, consumerID, session.ConsumerID)
   128  
   129  	assert.Eventually(t, func() bool {
   130  		history := publisher.GetEventHistory()
   131  		if len(history) != 7 {
   132  			return false
   133  		}
   134  
   135  		startEvent := appTopicSession(history, sessionEvent.CreatedStatus)
   136  		assert.Equal(t, sessionEvent.CreatedStatus, startEvent.Status)
   137  		assert.Equal(t, consumerID, startEvent.Session.ConsumerID)
   138  		assert.Equal(t, hermesID, startEvent.Session.HermesID)
   139  		assert.Equal(t, currentProposal, startEvent.Session.Proposal)
   140  
   141  		for _, key := range []string{
   142  			"Provider connect",
   143  			"Provider session create",
   144  			"Session validation",
   145  			"Provider session create (start)",
   146  			"Provider session create (payment)",
   147  			"Provider session create (configure)",
   148  		} {
   149  			e := appTopicTraceEvent(history, key)
   150  			assert.Equal(t, key, e.Key)
   151  		}
   152  
   153  		return true
   154  	}, 2*time.Second, 10*time.Millisecond)
   155  }
   156  
   157  func TestManager_Start_DisconnectsOnPaymentError(t *testing.T) {
   158  	publisher := mocks.NewEventBus()
   159  	sessionStore := NewSessionPool(publisher)
   160  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{
   161  		firstPaymentError: errors.New("sorry, your money ended"),
   162  	}, true)
   163  
   164  	_, err := manager.Start(&pb.SessionRequest{
   165  		Consumer: &pb.ConsumerInfo{
   166  			Id:       consumerID.Address,
   167  			HermesID: hermesID.String(),
   168  			Pricing: &pb.Pricing{
   169  				PerGib:  big.NewInt(1).Bytes(),
   170  				PerHour: big.NewInt(1).Bytes(),
   171  			},
   172  		},
   173  		ProposalID: int64(currentProposalID),
   174  	})
   175  	assert.EqualError(t, err, "first invoice was not paid: sorry, your money ended")
   176  	assert.Eventually(t, func() bool {
   177  		history := publisher.GetEventHistory()
   178  		if len(history) != 7 {
   179  			return false
   180  		}
   181  
   182  		startEvent := appTopicSession(history, sessionEvent.CreatedStatus)
   183  		assert.Equal(t, sessionEvent.CreatedStatus, startEvent.Status)
   184  		assert.Equal(t, consumerID, startEvent.Session.ConsumerID)
   185  		assert.Equal(t, hermesID, startEvent.Session.HermesID)
   186  		assert.Equal(t, currentProposal, startEvent.Session.Proposal)
   187  
   188  		for _, key := range []string{
   189  			"Provider connect",
   190  			"Provider session create",
   191  			"Session validation",
   192  			"Provider session create (start)",
   193  			"Provider session create (payment)",
   194  		} {
   195  			e := appTopicTraceEvent(history, key)
   196  			assert.Equal(t, key, e.Key)
   197  		}
   198  
   199  		closeEvent := appTopicSession(history, sessionEvent.RemovedStatus)
   200  		assert.Equal(t, sessionEvent.RemovedStatus, closeEvent.Status)
   201  		assert.Equal(t, consumerID, closeEvent.Session.ConsumerID)
   202  		assert.Equal(t, hermesID, closeEvent.Session.HermesID)
   203  		assert.Equal(t, currentProposal, closeEvent.Session.Proposal)
   204  
   205  		return true
   206  	}, 2*time.Second, 10*time.Millisecond)
   207  }
   208  
   209  func appTopicSession(history []mocks.EventBusEntry, status sessionEvent.Status) sessionEvent.AppEventSession {
   210  	for _, h := range history {
   211  		if h.Topic == sessionEvent.AppTopicSession {
   212  			e := h.Event.(sessionEvent.AppEventSession)
   213  			if e.Status == status {
   214  				return e
   215  			}
   216  		}
   217  	}
   218  	return sessionEvent.AppEventSession{}
   219  }
   220  
   221  func appTopicTraceEvent(history []mocks.EventBusEntry, key string) trace.Event {
   222  	for _, h := range history {
   223  		if h.Topic == trace.AppTopicTraceEvent {
   224  			e := h.Event.(trace.Event)
   225  			if e.Key == key {
   226  				return e
   227  			}
   228  		}
   229  	}
   230  	return trace.Event{}
   231  }
   232  
   233  func TestManager_Start_Second_Session_Destroy_Stale_Session(t *testing.T) {
   234  	sessionRequest := &pb.SessionRequest{
   235  		Consumer: &pb.ConsumerInfo{
   236  			Id:       consumerID.Address,
   237  			HermesID: hermesID.String(),
   238  			Pricing: &pb.Pricing{
   239  				PerGib:  big.NewInt(1).Bytes(),
   240  				PerHour: big.NewInt(1).Bytes(),
   241  			},
   242  		},
   243  		ProposalID: int64(currentProposalID),
   244  	}
   245  
   246  	publisher := mocks.NewEventBus()
   247  	sessionStore := NewSessionPool(publisher)
   248  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, true)
   249  
   250  	_, err := manager.Start(sessionRequest)
   251  	assert.NoError(t, err)
   252  
   253  	sessionOld := sessionStore.GetAll()[0]
   254  	assert.Equal(t, consumerID, sessionOld.ConsumerID)
   255  
   256  	_, err = manager.Start(sessionRequest)
   257  	assert.NoError(t, err)
   258  
   259  	assert.NoError(t, err)
   260  	assert.Eventuallyf(t, func() bool {
   261  		_, found := sessionStore.Find(sessionOld.ID)
   262  		return !found
   263  	}, 2*time.Second, 10*time.Millisecond, "Waiting for session destroy")
   264  }
   265  
   266  func TestManager_AcknowledgeSession_RejectsUnknown(t *testing.T) {
   267  	publisher := mocks.NewEventBus()
   268  	sessionStore := NewSessionPool(publisher)
   269  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, true)
   270  
   271  	err := manager.Acknowledge(consumerID, "")
   272  	assert.Exactly(t, err, ErrorSessionNotExists)
   273  }
   274  
   275  func TestManager_AcknowledgeSession_RejectsBadClient(t *testing.T) {
   276  	publisher := mocks.NewEventBus()
   277  	sessionStore := NewSessionPool(mocks.NewEventBus())
   278  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, true)
   279  
   280  	session, err := manager.Start(&pb.SessionRequest{
   281  		Consumer: &pb.ConsumerInfo{
   282  			Id:       consumerID.Address,
   283  			HermesID: hermesID.String(),
   284  			Pricing: &pb.Pricing{
   285  				PerGib:  big.NewInt(1).Bytes(),
   286  				PerHour: big.NewInt(1).Bytes(),
   287  			},
   288  		},
   289  		ProposalID: int64(currentProposalID),
   290  	})
   291  	assert.Nil(t, err)
   292  
   293  	err = manager.Acknowledge(identity.FromAddress("some other id"), string(session.ID))
   294  	assert.Exactly(t, ErrorWrongSessionOwner, err)
   295  }
   296  
   297  func TestManager_AcknowledgeSession_PublishesEvent(t *testing.T) {
   298  	publisher := mocks.NewEventBus()
   299  
   300  	sessionStore := NewSessionPool(publisher)
   301  	session, _ := NewSession(
   302  		currentService,
   303  		&pb.SessionRequest{Consumer: &pb.ConsumerInfo{Id: consumerID.Address}},
   304  		trace.NewTracer(""),
   305  	)
   306  	sessionStore.Add(session)
   307  
   308  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, true)
   309  
   310  	err := manager.Acknowledge(consumerID, string(session.ID))
   311  	assert.Nil(t, err)
   312  	assert.Eventually(t, func() bool {
   313  		// Check that state event with StateIPNotChanged status was called.
   314  		history := publisher.GetEventHistory()
   315  		for _, v := range history {
   316  			if v.Topic == sessionEvent.AppTopicSession && v.Event.(sessionEvent.AppEventSession).Status == sessionEvent.AcknowledgedStatus {
   317  				return true
   318  			}
   319  		}
   320  		return false
   321  	}, 2*time.Second, 10*time.Millisecond)
   322  }
   323  
   324  func newManager(service *Instance, sessions *SessionPool, publisher publisher, paymentEngine PaymentEngine, isPriceValid bool) *SessionManager {
   325  	ch := &mockP2PChannel{tracer: trace.NewTracer("Provider connect")}
   326  	m := NewSessionManager(
   327  		service,
   328  		sessions,
   329  		func(_, _ identity.Identity, _ int64, _ common.Address, _ string, _ chan crypto.ExchangeMessage, price market.Price) (PaymentEngine, error) {
   330  			return paymentEngine, nil
   331  		},
   332  		publisher,
   333  		ch,
   334  		DefaultConfig(),
   335  		&mockPriceValidator{
   336  			toReturn: isPriceValid,
   337  		},
   338  	)
   339  	reftracker.Singleton().Put("channel:"+ch.ID(), 10*time.Second, func() { ch.Close() })
   340  	return m
   341  }
   342  
   343  func TestManager_Start_RejectsInvalidPricing(t *testing.T) {
   344  	publisher := mocks.NewEventBus()
   345  	sessionStore := NewSessionPool(publisher)
   346  	manager := newManager(currentService, sessionStore, publisher, &mockBalanceTracker{}, false)
   347  
   348  	_, err := manager.Start(&pb.SessionRequest{
   349  		Consumer: &pb.ConsumerInfo{
   350  			Id:       consumerID.Address,
   351  			HermesID: hermesID.String(),
   352  			Pricing: &pb.Pricing{
   353  				PerGib:  big.NewInt(1).Bytes(),
   354  				PerHour: big.NewInt(1).Bytes(),
   355  			},
   356  		},
   357  		ProposalID: int64(currentProposalID),
   358  	})
   359  	assert.Error(t, err)
   360  	assert.Equal(t, "consumer asking for invalid price", err.Error())
   361  }
   362  
   363  type mockPriceValidator struct {
   364  	toReturn bool
   365  }
   366  
   367  func (mpv *mockPriceValidator) IsPriceValid(in market.Price, nodeType, country, ServiceType string) bool {
   368  	return mpv.toReturn
   369  }