github.com/decred/dcrlnd@v0.7.6/routing/mock_test.go (about)

     1  package routing
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
     8  	"github.com/decred/dcrlnd/channeldb"
     9  	"github.com/decred/dcrlnd/htlcswitch"
    10  	"github.com/decred/dcrlnd/lntypes"
    11  	"github.com/decred/dcrlnd/lnwire"
    12  	"github.com/decred/dcrlnd/routing/route"
    13  	"github.com/go-errors/errors"
    14  	"github.com/stretchr/testify/mock"
    15  )
    16  
    17  type mockPaymentAttemptDispatcherOld struct {
    18  	onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
    19  	results   map[uint64]*htlcswitch.PaymentResult
    20  
    21  	sync.Mutex
    22  }
    23  
    24  var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcherOld)(nil)
    25  
    26  func (m *mockPaymentAttemptDispatcherOld) SendHTLC(
    27  	firstHop lnwire.ShortChannelID, pid uint64,
    28  	_ *lnwire.UpdateAddHTLC) error {
    29  
    30  	if m.onPayment == nil {
    31  		return nil
    32  	}
    33  
    34  	var result *htlcswitch.PaymentResult
    35  	preimage, err := m.onPayment(firstHop)
    36  	if err != nil {
    37  		rtErr, ok := err.(htlcswitch.ClearTextError)
    38  		if !ok {
    39  			return err
    40  		}
    41  		result = &htlcswitch.PaymentResult{
    42  			Error: rtErr,
    43  		}
    44  	} else {
    45  		result = &htlcswitch.PaymentResult{Preimage: preimage}
    46  	}
    47  
    48  	m.Lock()
    49  	if m.results == nil {
    50  		m.results = make(map[uint64]*htlcswitch.PaymentResult)
    51  	}
    52  
    53  	m.results[pid] = result
    54  	m.Unlock()
    55  
    56  	return nil
    57  }
    58  
    59  func (m *mockPaymentAttemptDispatcherOld) GetPaymentResult(paymentID uint64,
    60  	_ lntypes.Hash, _ htlcswitch.ErrorDecrypter) (
    61  	<-chan *htlcswitch.PaymentResult, error) {
    62  
    63  	c := make(chan *htlcswitch.PaymentResult, 1)
    64  
    65  	m.Lock()
    66  	res, ok := m.results[paymentID]
    67  	m.Unlock()
    68  
    69  	if !ok {
    70  		return nil, htlcswitch.ErrPaymentIDNotFound
    71  	}
    72  	c <- res
    73  
    74  	return c, nil
    75  
    76  }
    77  func (m *mockPaymentAttemptDispatcherOld) CleanStore(
    78  	map[uint64]struct{}) error {
    79  
    80  	return nil
    81  }
    82  
    83  func (m *mockPaymentAttemptDispatcherOld) setPaymentResult(
    84  	f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) {
    85  
    86  	m.onPayment = f
    87  }
    88  
    89  type mockPaymentSessionSourceOld struct {
    90  	routes       []*route.Route
    91  	routeRelease chan struct{}
    92  }
    93  
    94  var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil)
    95  
    96  func (m *mockPaymentSessionSourceOld) NewPaymentSession(
    97  	_ *LightningPayment) (PaymentSession, error) {
    98  
    99  	return &mockPaymentSessionOld{
   100  		routes:  m.routes,
   101  		release: m.routeRelease,
   102  	}, nil
   103  }
   104  
   105  func (m *mockPaymentSessionSourceOld) NewPaymentSessionForRoute(
   106  	preBuiltRoute *route.Route) PaymentSession {
   107  	return nil
   108  }
   109  
   110  func (m *mockPaymentSessionSourceOld) NewPaymentSessionEmpty() PaymentSession {
   111  	return &mockPaymentSessionOld{}
   112  }
   113  
   114  type mockMissionControlOld struct {
   115  	MissionControl
   116  }
   117  
   118  var _ MissionController = (*mockMissionControlOld)(nil)
   119  
   120  func (m *mockMissionControlOld) ReportPaymentFail(
   121  	paymentID uint64, rt *route.Route,
   122  	failureSourceIdx *int, failure lnwire.FailureMessage) (
   123  	*channeldb.FailureReason, error) {
   124  
   125  	// Report a permanent failure if this is an error caused
   126  	// by incorrect details.
   127  	if failure.Code() == lnwire.CodeIncorrectOrUnknownPaymentDetails {
   128  		reason := channeldb.FailureReasonPaymentDetails
   129  		return &reason, nil
   130  	}
   131  
   132  	return nil, nil
   133  }
   134  
   135  func (m *mockMissionControlOld) ReportPaymentSuccess(paymentID uint64,
   136  	rt *route.Route) error {
   137  
   138  	return nil
   139  }
   140  
   141  func (m *mockMissionControlOld) GetProbability(fromNode, toNode route.Vertex,
   142  	amt lnwire.MilliAtom) float64 {
   143  
   144  	return 0
   145  }
   146  
   147  type mockPaymentSessionOld struct {
   148  	routes []*route.Route
   149  
   150  	// release is a channel that optionally blocks requesting a route
   151  	// from our mock payment channel. If this value is nil, we will just
   152  	// release the route automatically.
   153  	release chan struct{}
   154  }
   155  
   156  var _ PaymentSession = (*mockPaymentSessionOld)(nil)
   157  
   158  func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliAtom,
   159  	_, height uint32) (*route.Route, error) {
   160  
   161  	if m.release != nil {
   162  		m.release <- struct{}{}
   163  	}
   164  
   165  	if len(m.routes) == 0 {
   166  		return nil, errNoPathFound
   167  	}
   168  
   169  	r := m.routes[0]
   170  	m.routes = m.routes[1:]
   171  
   172  	return r, nil
   173  }
   174  
   175  func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate,
   176  	_ *secp256k1.PublicKey, _ *channeldb.CachedEdgePolicy) bool {
   177  
   178  	return false
   179  }
   180  
   181  func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *secp256k1.PublicKey,
   182  	_ uint64) *channeldb.CachedEdgePolicy {
   183  
   184  	return nil
   185  }
   186  
   187  type mockPayerOld struct {
   188  	sendResult    chan error
   189  	paymentResult chan *htlcswitch.PaymentResult
   190  	quit          chan struct{}
   191  }
   192  
   193  var _ PaymentAttemptDispatcher = (*mockPayerOld)(nil)
   194  
   195  func (m *mockPayerOld) SendHTLC(_ lnwire.ShortChannelID,
   196  	paymentID uint64,
   197  	_ *lnwire.UpdateAddHTLC) error {
   198  
   199  	select {
   200  	case res := <-m.sendResult:
   201  		return res
   202  	case <-m.quit:
   203  		return fmt.Errorf("test quitting")
   204  	}
   205  
   206  }
   207  
   208  func (m *mockPayerOld) GetPaymentResult(paymentID uint64, _ lntypes.Hash,
   209  	_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
   210  
   211  	select {
   212  	case res, ok := <-m.paymentResult:
   213  		resChan := make(chan *htlcswitch.PaymentResult, 1)
   214  		if !ok {
   215  			close(resChan)
   216  		} else {
   217  			resChan <- res
   218  		}
   219  
   220  		return resChan, nil
   221  
   222  	case <-m.quit:
   223  		return nil, fmt.Errorf("test quitting")
   224  	}
   225  }
   226  
   227  func (m *mockPayerOld) CleanStore(pids map[uint64]struct{}) error {
   228  	return nil
   229  }
   230  
   231  type initArgs struct {
   232  	c *channeldb.PaymentCreationInfo
   233  }
   234  
   235  type registerAttemptArgs struct {
   236  	a *channeldb.HTLCAttemptInfo
   237  }
   238  
   239  type settleAttemptArgs struct {
   240  	preimg lntypes.Preimage
   241  }
   242  
   243  type failAttemptArgs struct {
   244  	reason *channeldb.HTLCFailInfo
   245  }
   246  
   247  type failPaymentArgs struct {
   248  	reason channeldb.FailureReason
   249  }
   250  
   251  type testPayment struct {
   252  	info     channeldb.PaymentCreationInfo
   253  	attempts []channeldb.HTLCAttempt
   254  }
   255  
   256  type mockControlTowerOld struct {
   257  	payments   map[lntypes.Hash]*testPayment
   258  	successful map[lntypes.Hash]struct{}
   259  	failed     map[lntypes.Hash]channeldb.FailureReason
   260  
   261  	init            chan initArgs
   262  	registerAttempt chan registerAttemptArgs
   263  	settleAttempt   chan settleAttemptArgs
   264  	failAttempt     chan failAttemptArgs
   265  	failPayment     chan failPaymentArgs
   266  	fetchInFlight   chan struct{}
   267  
   268  	sync.Mutex
   269  }
   270  
   271  var _ ControlTower = (*mockControlTowerOld)(nil)
   272  
   273  func makeMockControlTower() *mockControlTowerOld {
   274  	return &mockControlTowerOld{
   275  		payments:   make(map[lntypes.Hash]*testPayment),
   276  		successful: make(map[lntypes.Hash]struct{}),
   277  		failed:     make(map[lntypes.Hash]channeldb.FailureReason),
   278  	}
   279  }
   280  
   281  func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash,
   282  	c *channeldb.PaymentCreationInfo) error {
   283  
   284  	if m.init != nil {
   285  		m.init <- initArgs{c}
   286  	}
   287  
   288  	m.Lock()
   289  	defer m.Unlock()
   290  
   291  	// Don't allow re-init a successful payment.
   292  	if _, ok := m.successful[phash]; ok {
   293  		return channeldb.ErrAlreadyPaid
   294  	}
   295  
   296  	_, failed := m.failed[phash]
   297  	_, ok := m.payments[phash]
   298  
   299  	// If the payment is known, only allow re-init if failed.
   300  	if ok && !failed {
   301  		return channeldb.ErrPaymentInFlight
   302  	}
   303  
   304  	delete(m.failed, phash)
   305  	m.payments[phash] = &testPayment{
   306  		info: *c,
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash,
   313  	a *channeldb.HTLCAttemptInfo) error {
   314  
   315  	if m.registerAttempt != nil {
   316  		m.registerAttempt <- registerAttemptArgs{a}
   317  	}
   318  
   319  	m.Lock()
   320  	defer m.Unlock()
   321  
   322  	// Lookup payment.
   323  	p, ok := m.payments[phash]
   324  	if !ok {
   325  		return channeldb.ErrPaymentNotInitiated
   326  	}
   327  
   328  	var inFlight bool
   329  	for _, a := range p.attempts {
   330  		if a.Settle != nil {
   331  			continue
   332  		}
   333  
   334  		if a.Failure != nil {
   335  			continue
   336  		}
   337  
   338  		inFlight = true
   339  	}
   340  
   341  	// Cannot register attempts for successful or failed payments.
   342  	_, settled := m.successful[phash]
   343  	_, failed := m.failed[phash]
   344  
   345  	if settled || failed {
   346  		return channeldb.ErrPaymentTerminal
   347  	}
   348  
   349  	if settled && !inFlight {
   350  		return channeldb.ErrPaymentAlreadySucceeded
   351  	}
   352  
   353  	if failed && !inFlight {
   354  		return channeldb.ErrPaymentAlreadyFailed
   355  	}
   356  
   357  	// Add attempt to payment.
   358  	p.attempts = append(p.attempts, channeldb.HTLCAttempt{
   359  		HTLCAttemptInfo: *a,
   360  	})
   361  	m.payments[phash] = p
   362  
   363  	return nil
   364  }
   365  
   366  func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash,
   367  	pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
   368  	*channeldb.HTLCAttempt, error) {
   369  
   370  	if m.settleAttempt != nil {
   371  		m.settleAttempt <- settleAttemptArgs{settleInfo.Preimage}
   372  	}
   373  
   374  	m.Lock()
   375  	defer m.Unlock()
   376  
   377  	// Only allow setting attempts if the payment is known.
   378  	p, ok := m.payments[phash]
   379  	if !ok {
   380  		return nil, channeldb.ErrPaymentNotInitiated
   381  	}
   382  
   383  	// Find the attempt with this pid, and set the settle info.
   384  	for i, a := range p.attempts {
   385  		if a.AttemptID != pid {
   386  			continue
   387  		}
   388  
   389  		if a.Settle != nil {
   390  			return nil, channeldb.ErrAttemptAlreadySettled
   391  		}
   392  		if a.Failure != nil {
   393  			return nil, channeldb.ErrAttemptAlreadyFailed
   394  		}
   395  
   396  		p.attempts[i].Settle = settleInfo
   397  
   398  		// Mark the payment successful on first settled attempt.
   399  		m.successful[phash] = struct{}{}
   400  		return &channeldb.HTLCAttempt{
   401  			Settle: settleInfo,
   402  		}, nil
   403  	}
   404  
   405  	return nil, fmt.Errorf("pid not found")
   406  }
   407  
   408  func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64,
   409  	failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
   410  
   411  	if m.failAttempt != nil {
   412  		m.failAttempt <- failAttemptArgs{failInfo}
   413  	}
   414  
   415  	m.Lock()
   416  	defer m.Unlock()
   417  
   418  	// Only allow failing attempts if the payment is known.
   419  	p, ok := m.payments[phash]
   420  	if !ok {
   421  		return nil, channeldb.ErrPaymentNotInitiated
   422  	}
   423  
   424  	// Find the attempt with this pid, and set the failure info.
   425  	for i, a := range p.attempts {
   426  		if a.AttemptID != pid {
   427  			continue
   428  		}
   429  
   430  		if a.Settle != nil {
   431  			return nil, channeldb.ErrAttemptAlreadySettled
   432  		}
   433  		if a.Failure != nil {
   434  			return nil, channeldb.ErrAttemptAlreadyFailed
   435  		}
   436  
   437  		p.attempts[i].Failure = failInfo
   438  		return &channeldb.HTLCAttempt{
   439  			Failure: failInfo,
   440  		}, nil
   441  	}
   442  
   443  	return nil, fmt.Errorf("pid not found")
   444  }
   445  
   446  func (m *mockControlTowerOld) Fail(phash lntypes.Hash,
   447  	reason channeldb.FailureReason) error {
   448  
   449  	m.Lock()
   450  	defer m.Unlock()
   451  
   452  	if m.failPayment != nil {
   453  		m.failPayment <- failPaymentArgs{reason}
   454  	}
   455  
   456  	// Payment must be known.
   457  	if _, ok := m.payments[phash]; !ok {
   458  		return channeldb.ErrPaymentNotInitiated
   459  	}
   460  
   461  	m.failed[phash] = reason
   462  
   463  	return nil
   464  }
   465  
   466  func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) (
   467  	*channeldb.MPPayment, error) {
   468  
   469  	m.Lock()
   470  	defer m.Unlock()
   471  
   472  	return m.fetchPayment(phash)
   473  }
   474  
   475  func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) (
   476  	*channeldb.MPPayment, error) {
   477  
   478  	p, ok := m.payments[phash]
   479  	if !ok {
   480  		return nil, channeldb.ErrPaymentNotInitiated
   481  	}
   482  
   483  	mp := &channeldb.MPPayment{
   484  		Info: &p.info,
   485  	}
   486  
   487  	reason, ok := m.failed[phash]
   488  	if ok {
   489  		mp.FailureReason = &reason
   490  	}
   491  
   492  	// Return a copy of the current attempts.
   493  	mp.HTLCs = append(mp.HTLCs, p.attempts...)
   494  	return mp, nil
   495  }
   496  
   497  func (m *mockControlTowerOld) FetchInFlightPayments() (
   498  	[]*channeldb.MPPayment, error) {
   499  
   500  	if m.fetchInFlight != nil {
   501  		m.fetchInFlight <- struct{}{}
   502  	}
   503  
   504  	m.Lock()
   505  	defer m.Unlock()
   506  
   507  	// In flight are all payments not successful or failed.
   508  	var fl []*channeldb.MPPayment
   509  	for hash := range m.payments {
   510  		if _, ok := m.successful[hash]; ok {
   511  			continue
   512  		}
   513  		if _, ok := m.failed[hash]; ok {
   514  			continue
   515  		}
   516  
   517  		mp, err := m.fetchPayment(hash)
   518  		if err != nil {
   519  			return nil, err
   520  		}
   521  
   522  		fl = append(fl, mp)
   523  	}
   524  
   525  	return fl, nil
   526  }
   527  
   528  func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
   529  	*ControlTowerSubscriber, error) {
   530  
   531  	return nil, errors.New("not implemented")
   532  }
   533  
   534  type mockPaymentAttemptDispatcher struct {
   535  	mock.Mock
   536  
   537  	resultChan chan *htlcswitch.PaymentResult
   538  }
   539  
   540  var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
   541  
   542  func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
   543  	pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error {
   544  
   545  	args := m.Called(firstHop, pid, htlcAdd)
   546  	return args.Error(0)
   547  }
   548  
   549  func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64,
   550  	paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
   551  	<-chan *htlcswitch.PaymentResult, error) {
   552  
   553  	m.Called(attemptID, paymentHash, deobfuscator)
   554  
   555  	// Instead of returning the mocked returned values, we need to return
   556  	// the chan resultChan so it can be converted into a read-only chan.
   557  	return m.resultChan, nil
   558  }
   559  
   560  func (m *mockPaymentAttemptDispatcher) CleanStore(
   561  	keepPids map[uint64]struct{}) error {
   562  
   563  	args := m.Called(keepPids)
   564  	return args.Error(0)
   565  }
   566  
   567  type mockPaymentSessionSource struct {
   568  	mock.Mock
   569  }
   570  
   571  var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
   572  
   573  func (m *mockPaymentSessionSource) NewPaymentSession(
   574  	payment *LightningPayment) (PaymentSession, error) {
   575  
   576  	args := m.Called(payment)
   577  	return args.Get(0).(PaymentSession), args.Error(1)
   578  }
   579  
   580  func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
   581  	preBuiltRoute *route.Route) PaymentSession {
   582  
   583  	args := m.Called(preBuiltRoute)
   584  	return args.Get(0).(PaymentSession)
   585  }
   586  
   587  func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
   588  	args := m.Called()
   589  	return args.Get(0).(PaymentSession)
   590  }
   591  
   592  type mockMissionControl struct {
   593  	mock.Mock
   594  }
   595  
   596  var _ MissionController = (*mockMissionControl)(nil)
   597  
   598  func (m *mockMissionControl) ReportPaymentFail(
   599  	paymentID uint64, rt *route.Route,
   600  	failureSourceIdx *int, failure lnwire.FailureMessage) (
   601  	*channeldb.FailureReason, error) {
   602  
   603  	args := m.Called(paymentID, rt, failureSourceIdx, failure)
   604  
   605  	// Type assertion on nil will fail, so we check and return here.
   606  	if args.Get(0) == nil {
   607  		return nil, args.Error(1)
   608  	}
   609  
   610  	return args.Get(0).(*channeldb.FailureReason), args.Error(1)
   611  }
   612  
   613  func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
   614  	rt *route.Route) error {
   615  
   616  	args := m.Called(paymentID, rt)
   617  	return args.Error(0)
   618  }
   619  
   620  func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
   621  	amt lnwire.MilliAtom) float64 {
   622  
   623  	args := m.Called(fromNode, toNode, amt)
   624  	return args.Get(0).(float64)
   625  }
   626  
   627  type mockPaymentSession struct {
   628  	mock.Mock
   629  }
   630  
   631  var _ PaymentSession = (*mockPaymentSession)(nil)
   632  
   633  func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliAtom,
   634  	activeShards, height uint32) (*route.Route, error) {
   635  	args := m.Called(maxAmt, feeLimit, activeShards, height)
   636  	return args.Get(0).(*route.Route), args.Error(1)
   637  }
   638  
   639  func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
   640  	pubKey *secp256k1.PublicKey, policy *channeldb.CachedEdgePolicy) bool {
   641  
   642  	args := m.Called(msg, pubKey, policy)
   643  	return args.Bool(0)
   644  }
   645  
   646  func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *secp256k1.PublicKey,
   647  	channelID uint64) *channeldb.CachedEdgePolicy {
   648  
   649  	args := m.Called(pubKey, channelID)
   650  	return args.Get(0).(*channeldb.CachedEdgePolicy)
   651  }
   652  
   653  type mockControlTower struct {
   654  	mock.Mock
   655  	sync.Mutex
   656  }
   657  
   658  var _ ControlTower = (*mockControlTower)(nil)
   659  
   660  func (m *mockControlTower) InitPayment(phash lntypes.Hash,
   661  	c *channeldb.PaymentCreationInfo) error {
   662  
   663  	args := m.Called(phash, c)
   664  	return args.Error(0)
   665  }
   666  
   667  func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
   668  	a *channeldb.HTLCAttemptInfo) error {
   669  
   670  	m.Lock()
   671  	defer m.Unlock()
   672  
   673  	args := m.Called(phash, a)
   674  	return args.Error(0)
   675  }
   676  
   677  func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
   678  	pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
   679  	*channeldb.HTLCAttempt, error) {
   680  
   681  	m.Lock()
   682  	defer m.Unlock()
   683  
   684  	args := m.Called(phash, pid, settleInfo)
   685  	return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
   686  }
   687  
   688  func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
   689  	failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {
   690  
   691  	m.Lock()
   692  	defer m.Unlock()
   693  
   694  	args := m.Called(phash, pid, failInfo)
   695  	return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
   696  }
   697  
   698  func (m *mockControlTower) Fail(phash lntypes.Hash,
   699  	reason channeldb.FailureReason) error {
   700  
   701  	m.Lock()
   702  	defer m.Unlock()
   703  
   704  	args := m.Called(phash, reason)
   705  	return args.Error(0)
   706  }
   707  
   708  func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
   709  	*channeldb.MPPayment, error) {
   710  
   711  	m.Lock()
   712  	defer m.Unlock()
   713  	args := m.Called(phash)
   714  
   715  	// Type assertion on nil will fail, so we check and return here.
   716  	if args.Get(0) == nil {
   717  		return nil, args.Error(1)
   718  	}
   719  
   720  	// Make a copy of the payment here to avoid data race.
   721  	p := args.Get(0).(*channeldb.MPPayment)
   722  	payment := &channeldb.MPPayment{
   723  		FailureReason: p.FailureReason,
   724  	}
   725  	payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs))
   726  	copy(payment.HTLCs, p.HTLCs)
   727  
   728  	return payment, args.Error(1)
   729  }
   730  
   731  func (m *mockControlTower) FetchInFlightPayments() (
   732  	[]*channeldb.MPPayment, error) {
   733  
   734  	args := m.Called()
   735  	return args.Get(0).([]*channeldb.MPPayment), args.Error(1)
   736  }
   737  
   738  func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
   739  	*ControlTowerSubscriber, error) {
   740  
   741  	args := m.Called(paymentHash)
   742  	return args.Get(0).(*ControlTowerSubscriber), args.Error(1)
   743  }
   744  
   745  type mockLink struct {
   746  	htlcswitch.ChannelLink
   747  	bandwidth         lnwire.MilliAtom
   748  	mayAddOutgoingErr error
   749  	ineligible        bool
   750  }
   751  
   752  // Bandwidth returns the bandwidth the mock was configured with.
   753  func (m *mockLink) Bandwidth() lnwire.MilliAtom {
   754  	return m.bandwidth
   755  }
   756  
   757  // EligibleToForward returns the mock's configured eligibility.
   758  func (m *mockLink) EligibleToForward() bool {
   759  	return !m.ineligible
   760  }
   761  
   762  // MayAddOutgoingHtlc returns the error configured in our mock.
   763  func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliAtom) error {
   764  	return m.mayAddOutgoingErr
   765  }