github.com/decred/dcrlnd@v0.7.6/htlcswitch/mock.go (about)

     1  package htlcswitch
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"os"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    18  	"github.com/decred/dcrd/dcrutil/v4"
    19  	"github.com/decred/dcrd/wire"
    20  	"github.com/decred/dcrlnd/chainntnfs"
    21  	"github.com/decred/dcrlnd/channeldb"
    22  	"github.com/decred/dcrlnd/clock"
    23  	"github.com/decred/dcrlnd/contractcourt"
    24  	"github.com/decred/dcrlnd/htlcswitch/hop"
    25  	"github.com/decred/dcrlnd/invoices"
    26  	"github.com/decred/dcrlnd/lnpeer"
    27  	"github.com/decred/dcrlnd/lntest/mock"
    28  	"github.com/decred/dcrlnd/lntypes"
    29  	"github.com/decred/dcrlnd/lnwallet/chainfee"
    30  	"github.com/decred/dcrlnd/lnwire"
    31  	"github.com/decred/dcrlnd/ticker"
    32  	sphinx "github.com/decred/lightning-onion/v4"
    33  	"github.com/go-errors/errors"
    34  )
    35  
    36  type mockPreimageCache struct {
    37  	sync.Mutex
    38  	preimageMap map[lntypes.Hash]lntypes.Preimage
    39  }
    40  
    41  func newMockPreimageCache() *mockPreimageCache {
    42  	return &mockPreimageCache{
    43  		preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
    44  	}
    45  }
    46  
    47  func (m *mockPreimageCache) LookupPreimage(
    48  	hash lntypes.Hash) (lntypes.Preimage, bool) {
    49  
    50  	m.Lock()
    51  	defer m.Unlock()
    52  
    53  	p, ok := m.preimageMap[hash]
    54  	return p, ok
    55  }
    56  
    57  func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
    58  	m.Lock()
    59  	defer m.Unlock()
    60  
    61  	for _, preimage := range preimages {
    62  		m.preimageMap[preimage.Hash()] = preimage
    63  	}
    64  
    65  	return nil
    66  }
    67  
    68  func (m *mockPreimageCache) SubscribeUpdates() *contractcourt.WitnessSubscription {
    69  	return nil
    70  }
    71  
    72  type mockFeeEstimator struct {
    73  	byteFeeIn chan chainfee.AtomPerKByte
    74  	relayFee  chan chainfee.AtomPerKByte
    75  
    76  	quit chan struct{}
    77  }
    78  
    79  func newMockFeeEstimator() *mockFeeEstimator {
    80  	return &mockFeeEstimator{
    81  		byteFeeIn: make(chan chainfee.AtomPerKByte),
    82  		relayFee:  make(chan chainfee.AtomPerKByte),
    83  		quit:      make(chan struct{}),
    84  	}
    85  }
    86  
    87  func (m *mockFeeEstimator) EstimateFeePerKB(
    88  	numBlocks uint32) (chainfee.AtomPerKByte, error) {
    89  
    90  	select {
    91  	case feeRate := <-m.byteFeeIn:
    92  		return feeRate, nil
    93  	case <-m.quit:
    94  		return 0, fmt.Errorf("exiting")
    95  	}
    96  }
    97  
    98  func (m *mockFeeEstimator) RelayFeePerKB() chainfee.AtomPerKByte {
    99  	select {
   100  	case feeRate := <-m.relayFee:
   101  		return feeRate
   102  	case <-m.quit:
   103  		return 0
   104  	}
   105  }
   106  
   107  func (m *mockFeeEstimator) Start() error {
   108  	return nil
   109  }
   110  func (m *mockFeeEstimator) Stop() error {
   111  	close(m.quit)
   112  	return nil
   113  }
   114  
   115  var _ chainfee.Estimator = (*mockFeeEstimator)(nil)
   116  
   117  type mockForwardingLog struct {
   118  	sync.Mutex
   119  
   120  	events map[time.Time]channeldb.ForwardingEvent
   121  }
   122  
   123  func (m *mockForwardingLog) AddForwardingEvents(events []channeldb.ForwardingEvent) error {
   124  	m.Lock()
   125  	defer m.Unlock()
   126  
   127  	for _, event := range events {
   128  		m.events[event.Timestamp] = event
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  type mockServer struct {
   135  	started  int32 // To be used atomically.
   136  	shutdown int32 // To be used atomically.
   137  	wg       sync.WaitGroup
   138  	quit     chan struct{}
   139  
   140  	t testing.TB
   141  
   142  	name     string
   143  	messages chan lnwire.Message
   144  
   145  	id         [33]byte
   146  	htlcSwitch *Switch
   147  
   148  	registry         *mockInvoiceRegistry
   149  	pCache           *mockPreimageCache
   150  	interceptorFuncs []messageInterceptor
   151  }
   152  
   153  var _ lnpeer.Peer = (*mockServer)(nil)
   154  
   155  func initDB() (*channeldb.DB, error) {
   156  	tempPath, err := ioutil.TempDir("", "switchdb")
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	db, err := channeldb.Open(tempPath)
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	return db, err
   167  }
   168  
   169  func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) {
   170  	var err error
   171  
   172  	if db == nil {
   173  		db, err = initDB()
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  	}
   178  
   179  	cfg := Config{
   180  		DB:                   db,
   181  		FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
   182  		FetchClosedChannels:  db.ChannelStateDB().FetchClosedChannels,
   183  		SwitchPackager:       channeldb.NewSwitchPackager(),
   184  		FwdingLog: &mockForwardingLog{
   185  			events: make(map[time.Time]channeldb.ForwardingEvent),
   186  		},
   187  		FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
   188  			return nil, nil
   189  		},
   190  		Notifier: &mock.ChainNotifier{
   191  			SpendChan: make(chan *chainntnfs.SpendDetail),
   192  			EpochChan: make(chan *chainntnfs.BlockEpoch),
   193  			ConfChan:  make(chan *chainntnfs.TxConfirmation),
   194  		},
   195  		FwdEventTicker: ticker.NewForce(DefaultFwdEventInterval),
   196  		LogEventTicker: ticker.NewForce(DefaultLogInterval),
   197  		AckEventTicker: ticker.NewForce(DefaultAckInterval),
   198  		HtlcNotifier:   &mockHTLCNotifier{},
   199  		Clock:          clock.NewDefaultClock(),
   200  		HTLCExpiry:     time.Hour,
   201  		DustThreshold:  DefaultDustThreshold,
   202  	}
   203  
   204  	return New(cfg, startingHeight)
   205  }
   206  
   207  func newMockServer(t testing.TB, name string, startingHeight uint32,
   208  	db *channeldb.DB, defaultDelta uint32) (*mockServer, error) {
   209  
   210  	var id [33]byte
   211  	h := sha256.Sum256([]byte(name))
   212  	copy(id[:], h[:])
   213  
   214  	pCache := newMockPreimageCache()
   215  
   216  	htlcSwitch, err := initSwitchWithDB(startingHeight, db)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  
   221  	registry := newMockRegistry(defaultDelta)
   222  
   223  	return &mockServer{
   224  		t:                t,
   225  		id:               id,
   226  		name:             name,
   227  		messages:         make(chan lnwire.Message, 3000),
   228  		quit:             make(chan struct{}),
   229  		registry:         registry,
   230  		htlcSwitch:       htlcSwitch,
   231  		pCache:           pCache,
   232  		interceptorFuncs: make([]messageInterceptor, 0),
   233  	}, nil
   234  }
   235  
   236  func (s *mockServer) Start() error {
   237  	if !atomic.CompareAndSwapInt32(&s.started, 0, 1) {
   238  		return errors.New("mock server already started")
   239  	}
   240  
   241  	if err := s.htlcSwitch.Start(); err != nil {
   242  		return err
   243  	}
   244  
   245  	s.wg.Add(1)
   246  	go func() {
   247  		defer s.wg.Done()
   248  
   249  		defer func() {
   250  			s.htlcSwitch.Stop()
   251  		}()
   252  
   253  		for {
   254  			select {
   255  			case msg := <-s.messages:
   256  				var shouldSkip bool
   257  
   258  				for _, interceptor := range s.interceptorFuncs {
   259  					skip, err := interceptor(msg)
   260  					if err != nil {
   261  						s.t.Fatalf("%v: error in the "+
   262  							"interceptor: %v", s.name, err)
   263  						return
   264  					}
   265  					shouldSkip = shouldSkip || skip
   266  				}
   267  
   268  				if shouldSkip {
   269  					continue
   270  				}
   271  
   272  				if err := s.readHandler(msg); err != nil {
   273  					s.t.Fatal(err)
   274  					return
   275  				}
   276  			case <-s.quit:
   277  				return
   278  			}
   279  		}
   280  	}()
   281  
   282  	return nil
   283  }
   284  
   285  func (s *mockServer) QuitSignal() <-chan struct{} {
   286  	return s.quit
   287  }
   288  
   289  // mockHopIterator represents the test version of hop iterator which instead
   290  // of encrypting the path in onion blob just stores the path as a list of hops.
   291  type mockHopIterator struct {
   292  	hops []*hop.Payload
   293  }
   294  
   295  func newMockHopIterator(hops ...*hop.Payload) hop.Iterator {
   296  	return &mockHopIterator{hops: hops}
   297  }
   298  
   299  func (r *mockHopIterator) HopPayload() (*hop.Payload, error) {
   300  	h := r.hops[0]
   301  	r.hops = r.hops[1:]
   302  	return h, nil
   303  }
   304  
   305  func (r *mockHopIterator) ExtraOnionBlob() []byte {
   306  	return nil
   307  }
   308  
   309  func (r *mockHopIterator) ExtractErrorEncrypter(
   310  	extracter hop.ErrorEncrypterExtracter) (hop.ErrorEncrypter,
   311  	lnwire.FailCode) {
   312  
   313  	return extracter(nil)
   314  }
   315  
   316  func (r *mockHopIterator) EncodeNextHop(w io.Writer) error {
   317  	var hopLength [4]byte
   318  	binary.BigEndian.PutUint32(hopLength[:], uint32(len(r.hops)))
   319  
   320  	if _, err := w.Write(hopLength[:]); err != nil {
   321  		return err
   322  	}
   323  
   324  	for _, hop := range r.hops {
   325  		fwdInfo := hop.ForwardingInfo()
   326  		if err := encodeFwdInfo(w, &fwdInfo); err != nil {
   327  			return err
   328  		}
   329  	}
   330  
   331  	return nil
   332  }
   333  
   334  func encodeFwdInfo(w io.Writer, f *hop.ForwardingInfo) error {
   335  	if _, err := w.Write([]byte{byte(f.Network)}); err != nil {
   336  		return err
   337  	}
   338  
   339  	if err := binary.Write(w, binary.BigEndian, f.NextHop); err != nil {
   340  		return err
   341  	}
   342  
   343  	if err := binary.Write(w, binary.BigEndian, f.AmountToForward); err != nil {
   344  		return err
   345  	}
   346  
   347  	if err := binary.Write(w, binary.BigEndian, f.OutgoingCTLV); err != nil {
   348  		return err
   349  	}
   350  
   351  	return nil
   352  }
   353  
   354  var _ hop.Iterator = (*mockHopIterator)(nil)
   355  
   356  // mockObfuscator mock implementation of the failure obfuscator which only
   357  // encodes the failure and do not makes any onion obfuscation.
   358  type mockObfuscator struct {
   359  	ogPacket *sphinx.OnionPacket
   360  	failure  lnwire.FailureMessage
   361  }
   362  
   363  // NewMockObfuscator initializes a dummy mockObfuscator used for testing.
   364  func NewMockObfuscator() hop.ErrorEncrypter {
   365  	return &mockObfuscator{}
   366  }
   367  
   368  func (o *mockObfuscator) OnionPacket() *sphinx.OnionPacket {
   369  	return o.ogPacket
   370  }
   371  
   372  func (o *mockObfuscator) Type() hop.EncrypterType {
   373  	return hop.EncrypterTypeMock
   374  }
   375  
   376  func (o *mockObfuscator) Encode(w io.Writer) error {
   377  	return nil
   378  }
   379  
   380  func (o *mockObfuscator) Decode(r io.Reader) error {
   381  	return nil
   382  }
   383  
   384  func (o *mockObfuscator) Reextract(
   385  	extracter hop.ErrorEncrypterExtracter) error {
   386  
   387  	return nil
   388  }
   389  
   390  func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
   391  	lnwire.OpaqueReason, error) {
   392  
   393  	o.failure = failure
   394  
   395  	var b bytes.Buffer
   396  	if err := lnwire.EncodeFailure(&b, failure, 0); err != nil {
   397  		return nil, err
   398  	}
   399  	return b.Bytes(), nil
   400  }
   401  
   402  func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
   403  	return reason
   404  }
   405  
   406  func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason {
   407  	return reason
   408  }
   409  
   410  // mockDeobfuscator mock implementation of the failure deobfuscator which
   411  // only decodes the failure do not makes any onion obfuscation.
   412  type mockDeobfuscator struct{}
   413  
   414  func newMockDeobfuscator() ErrorDecrypter {
   415  	return &mockDeobfuscator{}
   416  }
   417  
   418  func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) (*ForwardingError, error) {
   419  
   420  	r := bytes.NewReader(reason)
   421  	failure, err := lnwire.DecodeFailure(r, 0)
   422  	if err != nil {
   423  		return nil, err
   424  	}
   425  
   426  	return NewForwardingError(failure, 1), nil
   427  }
   428  
   429  var _ ErrorDecrypter = (*mockDeobfuscator)(nil)
   430  
   431  // mockIteratorDecoder test version of hop iterator decoder which decodes the
   432  // encoded array of hops.
   433  type mockIteratorDecoder struct {
   434  	mu sync.RWMutex
   435  
   436  	responses map[[32]byte][]hop.DecodeHopIteratorResponse
   437  
   438  	decodeFail bool
   439  }
   440  
   441  func newMockIteratorDecoder() *mockIteratorDecoder {
   442  	return &mockIteratorDecoder{
   443  		responses: make(map[[32]byte][]hop.DecodeHopIteratorResponse),
   444  	}
   445  }
   446  
   447  func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
   448  	cltv uint32) (hop.Iterator, lnwire.FailCode) {
   449  
   450  	var b [4]byte
   451  	_, err := r.Read(b[:])
   452  	if err != nil {
   453  		return nil, lnwire.CodeTemporaryChannelFailure
   454  	}
   455  	hopLength := binary.BigEndian.Uint32(b[:])
   456  
   457  	hops := make([]*hop.Payload, hopLength)
   458  	for i := uint32(0); i < hopLength; i++ {
   459  		var f hop.ForwardingInfo
   460  		if err := decodeFwdInfo(r, &f); err != nil {
   461  			return nil, lnwire.CodeTemporaryChannelFailure
   462  		}
   463  
   464  		var nextHopBytes [8]byte
   465  		binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64())
   466  
   467  		hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
   468  			Realm:         [1]byte{}, // hop.BitcoinNetwork
   469  			NextAddress:   nextHopBytes,
   470  			ForwardAmount: uint64(f.AmountToForward),
   471  			OutgoingCltv:  f.OutgoingCTLV,
   472  		})
   473  	}
   474  
   475  	return newMockHopIterator(hops...), lnwire.CodeNone
   476  }
   477  
   478  func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
   479  	reqs []hop.DecodeHopIteratorRequest) (
   480  	[]hop.DecodeHopIteratorResponse, error) {
   481  
   482  	idHash := sha256.Sum256(id)
   483  
   484  	p.mu.RLock()
   485  	if resps, ok := p.responses[idHash]; ok {
   486  		p.mu.RUnlock()
   487  		return resps, nil
   488  	}
   489  	p.mu.RUnlock()
   490  
   491  	batchSize := len(reqs)
   492  
   493  	resps := make([]hop.DecodeHopIteratorResponse, 0, batchSize)
   494  	for _, req := range reqs {
   495  		iterator, failcode := p.DecodeHopIterator(
   496  			req.OnionReader, req.RHash, req.IncomingCltv,
   497  		)
   498  
   499  		if p.decodeFail {
   500  			failcode = lnwire.CodeTemporaryChannelFailure
   501  		}
   502  
   503  		resp := hop.DecodeHopIteratorResponse{
   504  			HopIterator: iterator,
   505  			FailCode:    failcode,
   506  		}
   507  		resps = append(resps, resp)
   508  	}
   509  
   510  	p.mu.Lock()
   511  	p.responses[idHash] = resps
   512  	p.mu.Unlock()
   513  
   514  	return resps, nil
   515  }
   516  
   517  func decodeFwdInfo(r io.Reader, f *hop.ForwardingInfo) error {
   518  	var net [1]byte
   519  	if _, err := r.Read(net[:]); err != nil {
   520  		return err
   521  	}
   522  	f.Network = hop.Network(net[0])
   523  
   524  	if err := binary.Read(r, binary.BigEndian, &f.NextHop); err != nil {
   525  		return err
   526  	}
   527  
   528  	if err := binary.Read(r, binary.BigEndian, &f.AmountToForward); err != nil {
   529  		return err
   530  	}
   531  
   532  	if err := binary.Read(r, binary.BigEndian, &f.OutgoingCTLV); err != nil {
   533  		return err
   534  	}
   535  
   536  	return nil
   537  }
   538  
   539  // messageInterceptor is function that handles the incoming peer messages and
   540  // may decide should the peer skip the message or not.
   541  type messageInterceptor func(m lnwire.Message) (bool, error)
   542  
   543  // Record is used to set the function which will be triggered when new
   544  // lnwire message was received.
   545  func (s *mockServer) intersect(f messageInterceptor) {
   546  	s.interceptorFuncs = append(s.interceptorFuncs, f)
   547  }
   548  
   549  func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error {
   550  
   551  	for _, msg := range msgs {
   552  		select {
   553  		case s.messages <- msg:
   554  		case <-s.quit:
   555  			return errors.New("server is stopped")
   556  		}
   557  	}
   558  
   559  	return nil
   560  }
   561  
   562  func (s *mockServer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error {
   563  	panic("not implemented")
   564  }
   565  
   566  func (s *mockServer) readHandler(message lnwire.Message) error {
   567  	var targetChan lnwire.ChannelID
   568  
   569  	switch msg := message.(type) {
   570  	case *lnwire.UpdateAddHTLC:
   571  		targetChan = msg.ChanID
   572  	case *lnwire.UpdateFulfillHTLC:
   573  		targetChan = msg.ChanID
   574  	case *lnwire.UpdateFailHTLC:
   575  		targetChan = msg.ChanID
   576  	case *lnwire.UpdateFailMalformedHTLC:
   577  		targetChan = msg.ChanID
   578  	case *lnwire.RevokeAndAck:
   579  		targetChan = msg.ChanID
   580  	case *lnwire.CommitSig:
   581  		targetChan = msg.ChanID
   582  	case *lnwire.FundingLocked:
   583  		// Ignore
   584  		return nil
   585  	case *lnwire.ChannelReestablish:
   586  		targetChan = msg.ChanID
   587  	case *lnwire.UpdateFee:
   588  		targetChan = msg.ChanID
   589  	default:
   590  		return fmt.Errorf("unknown message type: %T", msg)
   591  	}
   592  
   593  	// Dispatch the commitment update message to the proper channel link
   594  	// dedicated to this channel. If the link is not found, we will discard
   595  	// the message.
   596  	link, err := s.htlcSwitch.GetLink(targetChan)
   597  	if err != nil {
   598  		return nil
   599  	}
   600  
   601  	// Create goroutine for this, in order to be able to properly stop
   602  	// the server when handler stacked (server unavailable)
   603  	link.HandleChannelUpdate(message)
   604  
   605  	return nil
   606  }
   607  
   608  func (s *mockServer) PubKey() [33]byte {
   609  	return s.id
   610  }
   611  
   612  func (s *mockServer) IdentityKey() *secp256k1.PublicKey {
   613  	pubkey, _ := secp256k1.ParsePubKey(s.id[:])
   614  	return pubkey
   615  }
   616  
   617  func (s *mockServer) Address() net.Addr {
   618  	return nil
   619  }
   620  
   621  func (s *mockServer) Inbound() bool {
   622  	return false
   623  }
   624  
   625  func (s *mockServer) AddNewChannel(channel *channeldb.OpenChannel,
   626  	cancel <-chan struct{}) error {
   627  
   628  	return nil
   629  }
   630  
   631  func (s *mockServer) WipeChannel(*wire.OutPoint) {}
   632  
   633  func (s *mockServer) LocalFeatures() *lnwire.FeatureVector {
   634  	return nil
   635  }
   636  
   637  func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector {
   638  	return nil
   639  }
   640  
   641  func (s *mockServer) Stop() error {
   642  	if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) {
   643  		return nil
   644  	}
   645  
   646  	close(s.quit)
   647  	s.wg.Wait()
   648  
   649  	return nil
   650  }
   651  
   652  func (s *mockServer) String() string {
   653  	return s.name
   654  }
   655  
   656  type mockChannelLink struct {
   657  	htlcSwitch *Switch
   658  
   659  	shortChanID lnwire.ShortChannelID
   660  
   661  	chanID lnwire.ChannelID
   662  
   663  	peer lnpeer.Peer
   664  
   665  	mailBox MailBox
   666  
   667  	packets chan *htlcPacket
   668  
   669  	eligible bool
   670  
   671  	htlcID uint64
   672  
   673  	checkHtlcTransitResult *LinkError
   674  
   675  	checkHtlcForwardResult *LinkError
   676  }
   677  
   678  // completeCircuit is a helper method for adding the finalized payment circuit
   679  // to the switch's circuit map. In testing, this should be executed after
   680  // receiving an htlc from the downstream packets channel.
   681  func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error {
   682  	switch htlc := pkt.htlc.(type) {
   683  	case *lnwire.UpdateAddHTLC:
   684  		pkt.outgoingChanID = f.shortChanID
   685  		pkt.outgoingHTLCID = f.htlcID
   686  		htlc.ID = f.htlcID
   687  
   688  		keystone := Keystone{pkt.inKey(), pkt.outKey()}
   689  		err := f.htlcSwitch.circuits.OpenCircuits(keystone)
   690  		if err != nil {
   691  			return err
   692  		}
   693  
   694  		f.htlcID++
   695  
   696  	case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
   697  		err := f.htlcSwitch.teardownCircuit(pkt)
   698  		if err != nil {
   699  			return err
   700  		}
   701  	}
   702  
   703  	f.mailBox.AckPacket(pkt.inKey())
   704  
   705  	return nil
   706  }
   707  
   708  func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error {
   709  	return f.htlcSwitch.circuits.DeleteCircuits(pkt.inKey())
   710  }
   711  
   712  func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
   713  	shortChanID lnwire.ShortChannelID, peer lnpeer.Peer, eligible bool,
   714  ) *mockChannelLink {
   715  
   716  	return &mockChannelLink{
   717  		htlcSwitch:  htlcSwitch,
   718  		chanID:      chanID,
   719  		shortChanID: shortChanID,
   720  		peer:        peer,
   721  		eligible:    eligible,
   722  	}
   723  }
   724  
   725  func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error {
   726  	f.mailBox.AddPacket(pkt)
   727  	return nil
   728  }
   729  
   730  func (f *mockChannelLink) handleLocalAddPacket(pkt *htlcPacket) error {
   731  	_ = f.mailBox.AddPacket(pkt)
   732  	return nil
   733  }
   734  
   735  func (f *mockChannelLink) getDustSum(remote bool) lnwire.MilliAtom {
   736  	return 0
   737  }
   738  
   739  func (f *mockChannelLink) getFeeRate() chainfee.AtomPerKByte {
   740  	return 0
   741  }
   742  
   743  func (f *mockChannelLink) getDustClosure() dustClosure {
   744  	dustLimit := dcrutil.Amount(6030)
   745  	return dustHelper(
   746  		channeldb.SingleFunderTweaklessBit, dustLimit, dustLimit,
   747  	)
   748  }
   749  
   750  func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
   751  }
   752  
   753  func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) {
   754  }
   755  func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliAtom,
   756  	lnwire.MilliAtom, uint32, uint32, uint32) *LinkError {
   757  
   758  	return f.checkHtlcForwardResult
   759  }
   760  
   761  func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
   762  	amt lnwire.MilliAtom, timeout uint32,
   763  	heightNow uint32) *LinkError {
   764  
   765  	return f.checkHtlcTransitResult
   766  }
   767  
   768  func (f *mockChannelLink) Stats() (uint64, lnwire.MilliAtom, lnwire.MilliAtom) {
   769  	return 0, 0, 0
   770  }
   771  
   772  func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
   773  	f.mailBox = mailBox
   774  	f.packets = mailBox.PacketOutBox()
   775  	mailBox.SetDustClosure(f.getDustClosure())
   776  }
   777  
   778  func (f *mockChannelLink) Start() error {
   779  	f.mailBox.ResetMessages()
   780  	f.mailBox.ResetPackets()
   781  	return nil
   782  }
   783  
   784  func (f *mockChannelLink) ChanID() lnwire.ChannelID                     { return f.chanID }
   785  func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID           { return f.shortChanID }
   786  func (f *mockChannelLink) Bandwidth() lnwire.MilliAtom                  { return 99999999 }
   787  func (f *mockChannelLink) Peer() lnpeer.Peer                            { return f.peer }
   788  func (f *mockChannelLink) ChannelPoint() *wire.OutPoint                 { return &wire.OutPoint{} }
   789  func (f *mockChannelLink) Stop()                                        {}
   790  func (f *mockChannelLink) EligibleToForward() bool                      { return f.eligible }
   791  func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliAtom) error    { return nil }
   792  func (f *mockChannelLink) ShutdownIfChannelClean() error                { return nil }
   793  func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
   794  func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
   795  	f.eligible = true
   796  	return f.shortChanID, nil
   797  }
   798  
   799  var _ ChannelLink = (*mockChannelLink)(nil)
   800  
   801  func newDB() (*channeldb.DB, func(), error) {
   802  	// First, create a temporary directory to be used for the duration of
   803  	// this test.
   804  	tempDirName, err := ioutil.TempDir("", "channeldb")
   805  	if err != nil {
   806  		return nil, nil, err
   807  	}
   808  
   809  	// Next, create channeldb for the first time.
   810  	cdb, err := channeldb.Open(tempDirName)
   811  	if err != nil {
   812  		os.RemoveAll(tempDirName)
   813  		return nil, nil, err
   814  	}
   815  
   816  	cleanUp := func() {
   817  		cdb.Close()
   818  		os.RemoveAll(tempDirName)
   819  	}
   820  
   821  	return cdb, cleanUp, nil
   822  }
   823  
   824  const testInvoiceCltvExpiry = 6
   825  
   826  type mockInvoiceRegistry struct {
   827  	settleChan chan lntypes.Hash
   828  
   829  	registry *invoices.InvoiceRegistry
   830  
   831  	cleanup func()
   832  }
   833  
   834  type mockChainNotifier struct {
   835  	chainntnfs.ChainNotifier
   836  }
   837  
   838  // RegisterBlockEpochNtfn mocks a successful call to register block
   839  // notifications.
   840  func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
   841  	*chainntnfs.BlockEpochEvent, error) {
   842  
   843  	return &chainntnfs.BlockEpochEvent{
   844  		Cancel: func() {},
   845  	}, nil
   846  }
   847  
   848  func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
   849  	cdb, cleanup, err := newDB()
   850  	if err != nil {
   851  		panic(err)
   852  	}
   853  
   854  	registry := invoices.NewRegistry(
   855  		cdb,
   856  		invoices.NewInvoiceExpiryWatcher(
   857  			clock.NewDefaultClock(), 0, 0, nil,
   858  			&mockChainNotifier{},
   859  		),
   860  		&invoices.RegistryConfig{
   861  			FinalCltvRejectDelta: 5,
   862  		},
   863  	)
   864  	registry.Start()
   865  
   866  	return &mockInvoiceRegistry{
   867  		registry: registry,
   868  		cleanup:  cleanup,
   869  	}
   870  }
   871  
   872  func (i *mockInvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (
   873  	channeldb.Invoice, error) {
   874  
   875  	return i.registry.LookupInvoice(rHash)
   876  }
   877  
   878  func (i *mockInvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
   879  	return i.registry.SettleHodlInvoice(preimage)
   880  }
   881  
   882  func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash,
   883  	amt lnwire.MilliAtom, expiry uint32, currentHeight int32,
   884  	circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
   885  	payload invoices.Payload) (invoices.HtlcResolution, error) {
   886  
   887  	event, err := i.registry.NotifyExitHopHtlc(
   888  		rhash, amt, expiry, currentHeight, circuitKey, hodlChan,
   889  		payload,
   890  	)
   891  	if err != nil {
   892  		return nil, err
   893  	}
   894  	if i.settleChan != nil {
   895  		i.settleChan <- rhash
   896  	}
   897  
   898  	return event, nil
   899  }
   900  
   901  func (i *mockInvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error {
   902  	return i.registry.CancelInvoice(payHash)
   903  }
   904  
   905  func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice,
   906  	paymentHash lntypes.Hash) error {
   907  
   908  	_, err := i.registry.AddInvoice(&invoice, paymentHash)
   909  	return err
   910  }
   911  
   912  func (i *mockInvoiceRegistry) HodlUnsubscribeAll(subscriber chan<- interface{}) {
   913  	i.registry.HodlUnsubscribeAll(subscriber)
   914  }
   915  
   916  var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil)
   917  
   918  type mockCircuitMap struct {
   919  	lookup chan *PaymentCircuit
   920  }
   921  
   922  var _ CircuitMap = (*mockCircuitMap)(nil)
   923  
   924  func (m *mockCircuitMap) OpenCircuits(...Keystone) error {
   925  	return nil
   926  }
   927  
   928  func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
   929  	start uint64) error {
   930  	return nil
   931  }
   932  
   933  func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
   934  	return nil
   935  }
   936  
   937  func (m *mockCircuitMap) CommitCircuits(
   938  	circuit ...*PaymentCircuit) (*CircuitFwdActions, error) {
   939  
   940  	return nil, nil
   941  }
   942  
   943  func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit,
   944  	error) {
   945  	return nil, nil
   946  }
   947  
   948  func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit,
   949  	error) {
   950  	return nil, nil
   951  }
   952  
   953  func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit {
   954  	return <-m.lookup
   955  }
   956  
   957  func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
   958  	return nil
   959  }
   960  
   961  func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
   962  	return nil
   963  }
   964  
   965  func (m *mockCircuitMap) NumPending() int {
   966  	return 0
   967  }
   968  
   969  func (m *mockCircuitMap) NumOpen() int {
   970  	return 0
   971  }
   972  
   973  type mockOnionErrorDecryptor struct {
   974  	sourceIdx int
   975  	message   []byte
   976  	err       error
   977  }
   978  
   979  func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) (
   980  	*sphinx.DecryptedError, error) {
   981  
   982  	return &sphinx.DecryptedError{
   983  		SenderIdx: m.sourceIdx,
   984  		Message:   m.message,
   985  	}, m.err
   986  }
   987  
   988  var _ htlcNotifier = (*mockHTLCNotifier)(nil)
   989  
   990  type mockHTLCNotifier struct{}
   991  
   992  func (h *mockHTLCNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo,
   993  	eventType HtlcEventType) {
   994  }
   995  
   996  func (h *mockHTLCNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo,
   997  	eventType HtlcEventType, linkErr *LinkError, incoming bool) {
   998  }
   999  
  1000  func (h *mockHTLCNotifier) NotifyForwardingFailEvent(key HtlcKey,
  1001  	eventType HtlcEventType) {
  1002  }
  1003  
  1004  func (h *mockHTLCNotifier) NotifySettleEvent(key HtlcKey,
  1005  	preimage lntypes.Preimage, eventType HtlcEventType) {
  1006  }