github.com/decred/dcrlnd@v0.7.6/htlcswitch/test_utils.go (about)

     1  package htlcswitch
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	"crypto/sha256"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net"
    12  	"os"
    13  	"runtime"
    14  	"runtime/pprof"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/decred/dcrd/chaincfg/chainhash"
    20  	"github.com/decred/dcrd/chaincfg/v3"
    21  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    22  	"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
    23  	"github.com/decred/dcrd/dcrutil/v4"
    24  	"github.com/decred/dcrd/wire"
    25  	"github.com/decred/dcrlnd/channeldb"
    26  	"github.com/decred/dcrlnd/contractcourt"
    27  	"github.com/decred/dcrlnd/htlcswitch/hop"
    28  	"github.com/decred/dcrlnd/input"
    29  	"github.com/decred/dcrlnd/keychain"
    30  	"github.com/decred/dcrlnd/kvdb"
    31  	"github.com/decred/dcrlnd/lnpeer"
    32  	"github.com/decred/dcrlnd/lntest/channels"
    33  	"github.com/decred/dcrlnd/lntest/mock"
    34  	"github.com/decred/dcrlnd/lntest/wait"
    35  	"github.com/decred/dcrlnd/lntypes"
    36  	"github.com/decred/dcrlnd/lnwallet"
    37  	"github.com/decred/dcrlnd/lnwallet/chainfee"
    38  	"github.com/decred/dcrlnd/lnwire"
    39  	"github.com/decred/dcrlnd/shachain"
    40  	"github.com/decred/dcrlnd/ticker"
    41  	sphinx "github.com/decred/lightning-onion/v4"
    42  	"github.com/go-errors/errors"
    43  )
    44  
    45  func modNScalar(b []byte) *secp256k1.ModNScalar {
    46  	var m secp256k1.ModNScalar
    47  	m.SetByteSlice(b)
    48  	return &m
    49  }
    50  
    51  var (
    52  	alicePrivKey = []byte("alice priv key")
    53  	bobPrivKey   = []byte("bob priv key")
    54  	carolPrivKey = []byte("carol priv key")
    55  
    56  	rBytes, _ = hex.DecodeString("6372440660162918006277497454296753625158993" +
    57  		"5445068131219452686511677818569431")
    58  	sBytes, _ = hex.DecodeString("1880105606924982582529128710493133386286603" +
    59  		"3135609736119018462340006816851118")
    60  	testSig    = ecdsa.NewSignature(modNScalar(rBytes), modNScalar(sBytes))
    61  	wireSig, _ = lnwire.NewSigFromSignature(testSig)
    62  
    63  	testBatchTimeout = 50 * time.Millisecond
    64  )
    65  
    66  var idSeqNum uint64
    67  
    68  // genID generates a unique tuple to identify a test channel.
    69  func genID() (lnwire.ChannelID, lnwire.ShortChannelID) {
    70  	id := atomic.AddUint64(&idSeqNum, 1)
    71  
    72  	var scratch [8]byte
    73  
    74  	binary.BigEndian.PutUint64(scratch[:], id)
    75  	hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
    76  
    77  	chanPoint1 := wire.NewOutPoint(hash1, uint32(id), wire.TxTreeRegular)
    78  	chanID1 := lnwire.NewChanIDFromOutPoint(chanPoint1)
    79  	aliceChanID := lnwire.NewShortChanIDFromInt(id)
    80  
    81  	return chanID1, aliceChanID
    82  }
    83  
    84  // genIDs generates ids for two test channels.
    85  func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID,
    86  	lnwire.ShortChannelID) {
    87  
    88  	chanID1, aliceChanID := genID()
    89  	chanID2, bobChanID := genID()
    90  
    91  	return chanID1, chanID2, aliceChanID, bobChanID
    92  }
    93  
    94  // mockGetChanUpdateMessage helper function which returns topology update of
    95  // the channel
    96  func mockGetChanUpdateMessage(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
    97  	return &lnwire.ChannelUpdate{
    98  		Signature: wireSig,
    99  	}, nil
   100  }
   101  
   102  // generateRandomBytes returns securely generated random bytes.
   103  // It will return an error if the system's secure random
   104  // number generator fails to function correctly, in which
   105  // case the caller should not continue.
   106  func generateRandomBytes(n int) ([]byte, error) {
   107  	b := make([]byte, n)
   108  
   109  	// TODO(roasbeef): should use counter in tests (atomic) rather than
   110  	// this
   111  
   112  	_, err := crand.Read(b)
   113  	// Note that Err == nil only if we read len(b) bytes.
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return b, nil
   119  }
   120  
   121  type testLightningChannel struct {
   122  	channel *lnwallet.LightningChannel
   123  	restore func() (*lnwallet.LightningChannel, error)
   124  }
   125  
   126  // createTestChannel creates the channel and returns our and remote channels
   127  // representations.
   128  //
   129  // TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase
   130  func createTestChannel(alicePrivKey, bobPrivKey []byte,
   131  	aliceAmount, bobAmount, aliceReserve, bobReserve dcrutil.Amount,
   132  	chanID lnwire.ShortChannelID) (*testLightningChannel,
   133  	*testLightningChannel, func(), error) {
   134  
   135  	netParams := chaincfg.RegNetParams()
   136  
   137  	aliceKeyPriv := secp256k1.PrivKeyFromBytes(alicePrivKey)
   138  	aliceKeyPub := aliceKeyPriv.PubKey()
   139  	bobKeyPriv := secp256k1.PrivKeyFromBytes(bobPrivKey)
   140  	bobKeyPub := bobKeyPriv.PubKey()
   141  
   142  	channelCapacity := aliceAmount + bobAmount
   143  	csvTimeoutAlice := uint32(5)
   144  	csvTimeoutBob := uint32(4)
   145  	isAliceInitiator := true
   146  
   147  	aliceConstraints := &channeldb.ChannelConstraints{
   148  		DustLimit: dcrutil.Amount(6030),
   149  		MaxPendingAmount: lnwire.NewMAtomsFromAtoms(
   150  			channelCapacity),
   151  		ChanReserve:      aliceReserve,
   152  		MinHTLC:          0,
   153  		MaxAcceptedHtlcs: input.MaxHTLCNumber / 2,
   154  		CsvDelay:         uint16(csvTimeoutAlice),
   155  	}
   156  
   157  	bobConstraints := &channeldb.ChannelConstraints{
   158  		DustLimit: dcrutil.Amount(12060),
   159  		MaxPendingAmount: lnwire.NewMAtomsFromAtoms(
   160  			channelCapacity),
   161  		ChanReserve:      bobReserve,
   162  		MinHTLC:          0,
   163  		MaxAcceptedHtlcs: input.MaxHTLCNumber / 2,
   164  		CsvDelay:         uint16(csvTimeoutBob),
   165  	}
   166  
   167  	var hash [chainhash.HashSize]byte
   168  	randomSeed, err := generateRandomBytes(chainhash.HashSize)
   169  	if err != nil {
   170  		return nil, nil, nil, err
   171  	}
   172  	copy(hash[:], randomSeed)
   173  
   174  	prevOut := &wire.OutPoint{
   175  		Hash:  chainhash.Hash(hash),
   176  		Index: 0,
   177  	}
   178  	fundingTxIn := wire.NewTxIn(prevOut, int64(channelCapacity), nil)
   179  
   180  	aliceCfg := channeldb.ChannelConfig{
   181  		ChannelConstraints: *aliceConstraints,
   182  		MultiSigKey: keychain.KeyDescriptor{
   183  			PubKey: aliceKeyPub,
   184  		},
   185  		RevocationBasePoint: keychain.KeyDescriptor{
   186  			PubKey: aliceKeyPub,
   187  		},
   188  		PaymentBasePoint: keychain.KeyDescriptor{
   189  			PubKey: aliceKeyPub,
   190  		},
   191  		DelayBasePoint: keychain.KeyDescriptor{
   192  			PubKey: aliceKeyPub,
   193  		},
   194  		HtlcBasePoint: keychain.KeyDescriptor{
   195  			PubKey: aliceKeyPub,
   196  		},
   197  	}
   198  	bobCfg := channeldb.ChannelConfig{
   199  		ChannelConstraints: *bobConstraints,
   200  		MultiSigKey: keychain.KeyDescriptor{
   201  			PubKey: bobKeyPub,
   202  		},
   203  		RevocationBasePoint: keychain.KeyDescriptor{
   204  			PubKey: bobKeyPub,
   205  		},
   206  		PaymentBasePoint: keychain.KeyDescriptor{
   207  			PubKey: bobKeyPub,
   208  		},
   209  		DelayBasePoint: keychain.KeyDescriptor{
   210  			PubKey: bobKeyPub,
   211  		},
   212  		HtlcBasePoint: keychain.KeyDescriptor{
   213  			PubKey: bobKeyPub,
   214  		},
   215  	}
   216  
   217  	bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
   218  	if err != nil {
   219  		return nil, nil, nil, err
   220  	}
   221  	bobPreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*bobRoot))
   222  	bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
   223  	if err != nil {
   224  		return nil, nil, nil, err
   225  	}
   226  	bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
   227  
   228  	aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
   229  	if err != nil {
   230  		return nil, nil, nil, err
   231  	}
   232  	alicePreimageProducer := shachain.NewRevocationProducer(shachain.ShaHash(*aliceRoot))
   233  	aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
   234  	if err != nil {
   235  		return nil, nil, nil, err
   236  	}
   237  	aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
   238  
   239  	aliceCommitTx, bobCommitTx, err := lnwallet.CreateCommitmentTxns(
   240  		aliceAmount, bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint,
   241  		bobCommitPoint, *fundingTxIn, channeldb.SingleFunderTweaklessBit,
   242  		isAliceInitiator, 0, netParams,
   243  	)
   244  	if err != nil {
   245  		return nil, nil, nil, err
   246  	}
   247  
   248  	alicePath, err := ioutil.TempDir("", "alicedb")
   249  	if err != nil {
   250  		return nil, nil, nil, err
   251  	}
   252  	dbAlice, err := channeldb.Open(alicePath)
   253  	if err != nil {
   254  		return nil, nil, nil, err
   255  	}
   256  
   257  	bobPath, err := ioutil.TempDir("", "bobdb")
   258  	if err != nil {
   259  		return nil, nil, nil, err
   260  	}
   261  	dbBob, err := channeldb.Open(bobPath)
   262  	if err != nil {
   263  		return nil, nil, nil, err
   264  	}
   265  
   266  	estimator := chainfee.NewStaticEstimator(1e4, 0)
   267  	feePerKB, err := estimator.EstimateFeePerKB(1)
   268  	if err != nil {
   269  		return nil, nil, nil, err
   270  	}
   271  	commitFee := feePerKB.FeeForSize(input.CommitmentTxSize)
   272  
   273  	const broadcastHeight = 1
   274  	bobAddr := &net.TCPAddr{
   275  		IP:   net.ParseIP("127.0.0.1"),
   276  		Port: 18555,
   277  	}
   278  
   279  	aliceAddr := &net.TCPAddr{
   280  		IP:   net.ParseIP("127.0.0.1"),
   281  		Port: 18556,
   282  	}
   283  
   284  	aliceCommit := channeldb.ChannelCommitment{
   285  		CommitHeight:  0,
   286  		LocalBalance:  lnwire.NewMAtomsFromAtoms(aliceAmount - commitFee),
   287  		RemoteBalance: lnwire.NewMAtomsFromAtoms(bobAmount),
   288  		CommitFee:     commitFee,
   289  		FeePerKB:      dcrutil.Amount(feePerKB),
   290  		CommitTx:      aliceCommitTx,
   291  		CommitSig:     bytes.Repeat([]byte{1}, 71),
   292  	}
   293  	bobCommit := channeldb.ChannelCommitment{
   294  		CommitHeight:  0,
   295  		LocalBalance:  lnwire.NewMAtomsFromAtoms(bobAmount),
   296  		RemoteBalance: lnwire.NewMAtomsFromAtoms(aliceAmount - commitFee),
   297  		CommitFee:     commitFee,
   298  		FeePerKB:      dcrutil.Amount(feePerKB),
   299  		CommitTx:      bobCommitTx,
   300  		CommitSig:     bytes.Repeat([]byte{1}, 71),
   301  	}
   302  
   303  	aliceChannelState := &channeldb.OpenChannel{
   304  		LocalChanCfg:            aliceCfg,
   305  		RemoteChanCfg:           bobCfg,
   306  		IdentityPub:             aliceKeyPub,
   307  		FundingOutpoint:         *prevOut,
   308  		ChanType:                channeldb.SingleFunderTweaklessBit,
   309  		IsInitiator:             isAliceInitiator,
   310  		Capacity:                channelCapacity,
   311  		RemoteCurrentRevocation: bobCommitPoint,
   312  		RevocationProducer:      alicePreimageProducer,
   313  		RevocationStore:         shachain.NewRevocationStore(),
   314  		LocalCommitment:         aliceCommit,
   315  		RemoteCommitment:        aliceCommit,
   316  		ShortChannelID:          chanID,
   317  		Db:                      dbAlice.ChannelStateDB(),
   318  		Packager:                channeldb.NewChannelPackager(chanID),
   319  		FundingTxn:              channels.TestFundingTx,
   320  	}
   321  
   322  	bobChannelState := &channeldb.OpenChannel{
   323  		LocalChanCfg:            bobCfg,
   324  		RemoteChanCfg:           aliceCfg,
   325  		IdentityPub:             bobKeyPub,
   326  		FundingOutpoint:         *prevOut,
   327  		ChanType:                channeldb.SingleFunderTweaklessBit,
   328  		IsInitiator:             !isAliceInitiator,
   329  		Capacity:                channelCapacity,
   330  		RemoteCurrentRevocation: aliceCommitPoint,
   331  		RevocationProducer:      bobPreimageProducer,
   332  		RevocationStore:         shachain.NewRevocationStore(),
   333  		LocalCommitment:         bobCommit,
   334  		RemoteCommitment:        bobCommit,
   335  		ShortChannelID:          chanID,
   336  		Db:                      dbBob.ChannelStateDB(),
   337  		Packager:                channeldb.NewChannelPackager(chanID),
   338  	}
   339  
   340  	if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil {
   341  		return nil, nil, nil, err
   342  	}
   343  
   344  	if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil {
   345  		return nil, nil, nil, err
   346  	}
   347  
   348  	cleanUpFunc := func() {
   349  		dbAlice.Close()
   350  		dbBob.Close()
   351  		os.RemoveAll(bobPath)
   352  		os.RemoveAll(alicePath)
   353  	}
   354  
   355  	aliceSigner := &mock.SingleSigner{Privkey: aliceKeyPriv}
   356  	bobSigner := &mock.SingleSigner{Privkey: bobKeyPriv}
   357  
   358  	alicePool := lnwallet.NewSigPool(runtime.NumCPU(), aliceSigner)
   359  	channelAlice, err := lnwallet.NewLightningChannel(
   360  		aliceSigner, aliceChannelState, alicePool, netParams,
   361  	)
   362  	if err != nil {
   363  		return nil, nil, nil, err
   364  	}
   365  	alicePool.Start()
   366  
   367  	bobPool := lnwallet.NewSigPool(runtime.NumCPU(), bobSigner)
   368  	channelBob, err := lnwallet.NewLightningChannel(
   369  		bobSigner, bobChannelState, bobPool, netParams,
   370  	)
   371  	if err != nil {
   372  		return nil, nil, nil, err
   373  	}
   374  	bobPool.Start()
   375  
   376  	// Now that the channel are open, simulate the start of a session by
   377  	// having Alice and Bob extend their revocation windows to each other.
   378  	aliceNextRevoke, err := channelAlice.NextRevocationKey()
   379  	if err != nil {
   380  		return nil, nil, nil, err
   381  	}
   382  	if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil {
   383  		return nil, nil, nil, err
   384  	}
   385  
   386  	bobNextRevoke, err := channelBob.NextRevocationKey()
   387  	if err != nil {
   388  		return nil, nil, nil, err
   389  	}
   390  	if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil {
   391  		return nil, nil, nil, err
   392  	}
   393  
   394  	restoreAlice := func() (*lnwallet.LightningChannel, error) {
   395  		aliceStoredChannels, err := dbAlice.ChannelStateDB().
   396  			FetchOpenChannels(aliceKeyPub)
   397  		switch err {
   398  		case nil:
   399  		case kvdb.ErrDatabaseNotOpen:
   400  			dbAlice, err = channeldb.Open(dbAlice.Path())
   401  			if err != nil {
   402  				return nil, errors.Errorf("unable to reopen alice "+
   403  					"db: %v", err)
   404  			}
   405  
   406  			aliceStoredChannels, err = dbAlice.ChannelStateDB().
   407  				FetchOpenChannels(aliceKeyPub)
   408  			if err != nil {
   409  				return nil, errors.Errorf("unable to fetch alice "+
   410  					"channel: %v", err)
   411  			}
   412  		default:
   413  			return nil, errors.Errorf("unable to fetch alice channel: "+
   414  				"%v", err)
   415  		}
   416  
   417  		var aliceStoredChannel *channeldb.OpenChannel
   418  		for _, channel := range aliceStoredChannels {
   419  			if channel.FundingOutpoint.String() == prevOut.String() {
   420  				aliceStoredChannel = channel
   421  				break
   422  			}
   423  		}
   424  
   425  		if aliceStoredChannel == nil {
   426  			return nil, errors.New("unable to find stored alice channel")
   427  		}
   428  
   429  		newAliceChannel, err := lnwallet.NewLightningChannel(
   430  			aliceSigner, aliceStoredChannel, alicePool, netParams,
   431  		)
   432  		if err != nil {
   433  			return nil, errors.Errorf("unable to create new channel: %v",
   434  				err)
   435  		}
   436  
   437  		return newAliceChannel, nil
   438  	}
   439  
   440  	restoreBob := func() (*lnwallet.LightningChannel, error) {
   441  		bobStoredChannels, err := dbBob.ChannelStateDB().
   442  			FetchOpenChannels(bobKeyPub)
   443  		switch err {
   444  		case nil:
   445  		case kvdb.ErrDatabaseNotOpen:
   446  			dbBob, err = channeldb.Open(dbBob.Path())
   447  			if err != nil {
   448  				return nil, errors.Errorf("unable to reopen bob "+
   449  					"db: %v", err)
   450  			}
   451  
   452  			bobStoredChannels, err = dbBob.ChannelStateDB().
   453  				FetchOpenChannels(bobKeyPub)
   454  			if err != nil {
   455  				return nil, errors.Errorf("unable to fetch bob "+
   456  					"channel: %v", err)
   457  			}
   458  		default:
   459  			return nil, errors.Errorf("unable to fetch bob channel: "+
   460  				"%v", err)
   461  		}
   462  
   463  		var bobStoredChannel *channeldb.OpenChannel
   464  		for _, channel := range bobStoredChannels {
   465  			if channel.FundingOutpoint.String() == prevOut.String() {
   466  				bobStoredChannel = channel
   467  				break
   468  			}
   469  		}
   470  
   471  		if bobStoredChannel == nil {
   472  			return nil, errors.New("unable to find stored bob channel")
   473  		}
   474  
   475  		newBobChannel, err := lnwallet.NewLightningChannel(
   476  			bobSigner, bobStoredChannel, bobPool, netParams,
   477  		)
   478  		if err != nil {
   479  			return nil, errors.Errorf("unable to create new channel: %v",
   480  				err)
   481  		}
   482  		return newBobChannel, nil
   483  	}
   484  
   485  	testLightningChannelAlice := &testLightningChannel{
   486  		channel: channelAlice,
   487  		restore: restoreAlice,
   488  	}
   489  
   490  	testLightningChannelBob := &testLightningChannel{
   491  		channel: channelBob,
   492  		restore: restoreBob,
   493  	}
   494  
   495  	return testLightningChannelAlice, testLightningChannelBob, cleanUpFunc,
   496  		nil
   497  }
   498  
   499  // getChanID retrieves the channel point from an lnnwire message.
   500  func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
   501  	var chanID lnwire.ChannelID
   502  	switch msg := msg.(type) {
   503  	case *lnwire.UpdateAddHTLC:
   504  		chanID = msg.ChanID
   505  	case *lnwire.UpdateFulfillHTLC:
   506  		chanID = msg.ChanID
   507  	case *lnwire.UpdateFailHTLC:
   508  		chanID = msg.ChanID
   509  	case *lnwire.RevokeAndAck:
   510  		chanID = msg.ChanID
   511  	case *lnwire.CommitSig:
   512  		chanID = msg.ChanID
   513  	case *lnwire.ChannelReestablish:
   514  		chanID = msg.ChanID
   515  	case *lnwire.FundingLocked:
   516  		chanID = msg.ChanID
   517  	case *lnwire.UpdateFee:
   518  		chanID = msg.ChanID
   519  	default:
   520  		return chanID, fmt.Errorf("unknown type: %T", msg)
   521  	}
   522  
   523  	return chanID, nil
   524  }
   525  
   526  // generateHoldPayment generates the htlc add request by given path blob and
   527  // invoice which should be added by destination peer.
   528  func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliAtom,
   529  	timelock uint32, blob [lnwire.OnionPacketSize]byte,
   530  	preimage *lntypes.Preimage, rhash, payAddr [32]byte) (
   531  	*channeldb.Invoice, *lnwire.UpdateAddHTLC, uint64, error) {
   532  
   533  	// Create the db invoice. Normally the payment requests needs to be set,
   534  	// because it is decoded in InvoiceRegistry to obtain the cltv expiry.
   535  	// But because the mock registry used in tests is mocking the decode
   536  	// step and always returning the value of testInvoiceCltvExpiry, we
   537  	// don't need to bother here with creating and signing a payment
   538  	// request.
   539  
   540  	invoice := &channeldb.Invoice{
   541  		CreationDate: time.Now(),
   542  		Terms: channeldb.ContractTerm{
   543  			FinalCltvDelta:  testInvoiceCltvExpiry,
   544  			Value:           invoiceAmt,
   545  			PaymentPreimage: preimage,
   546  			PaymentAddr:     payAddr,
   547  			Features: lnwire.NewFeatureVector(
   548  				nil, lnwire.Features,
   549  			),
   550  		},
   551  		HodlInvoice: preimage == nil,
   552  	}
   553  
   554  	htlc := &lnwire.UpdateAddHTLC{
   555  		PaymentHash: rhash,
   556  		Amount:      htlcAmt,
   557  		Expiry:      timelock,
   558  		OnionBlob:   blob,
   559  	}
   560  
   561  	pid, err := generateRandomBytes(8)
   562  	if err != nil {
   563  		return nil, nil, 0, err
   564  	}
   565  	paymentID := binary.BigEndian.Uint64(pid)
   566  
   567  	return invoice, htlc, paymentID, nil
   568  }
   569  
   570  // generatePayment generates the htlc add request by given path blob and
   571  // invoice which should be added by destination peer.
   572  func generatePayment(invoiceAmt, htlcAmt lnwire.MilliAtom, timelock uint32,
   573  	blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice,
   574  	*lnwire.UpdateAddHTLC, uint64, error) {
   575  
   576  	var preimage lntypes.Preimage
   577  	r, err := generateRandomBytes(lntypes.HashSize)
   578  	if err != nil {
   579  		return nil, nil, 0, err
   580  	}
   581  	copy(preimage[:], r)
   582  
   583  	rhash := preimage.Hash()
   584  
   585  	var payAddr [sha256.Size]byte
   586  	r, err = generateRandomBytes(sha256.Size)
   587  	if err != nil {
   588  		return nil, nil, 0, err
   589  	}
   590  	copy(payAddr[:], r)
   591  
   592  	return generatePaymentWithPreimage(
   593  		invoiceAmt, htlcAmt, timelock, blob, &preimage, rhash, payAddr,
   594  	)
   595  }
   596  
   597  // generateRoute generates the path blob by given array of peers.
   598  func generateRoute(hops ...*hop.Payload) (
   599  	[lnwire.OnionPacketSize]byte, error) {
   600  
   601  	var blob [lnwire.OnionPacketSize]byte
   602  	if len(hops) == 0 {
   603  		return blob, errors.New("empty path")
   604  	}
   605  
   606  	iterator := newMockHopIterator(hops...)
   607  
   608  	w := bytes.NewBuffer(blob[0:0])
   609  	if err := iterator.EncodeNextHop(w); err != nil {
   610  		return blob, err
   611  	}
   612  
   613  	return blob, nil
   614  
   615  }
   616  
   617  // threeHopNetwork is used for managing the created cluster of 3 hops.
   618  type threeHopNetwork struct {
   619  	aliceServer       *mockServer
   620  	aliceChannelLink  *channelLink
   621  	aliceOnionDecoder *mockIteratorDecoder
   622  
   623  	bobServer            *mockServer
   624  	firstBobChannelLink  *channelLink
   625  	secondBobChannelLink *channelLink
   626  	bobOnionDecoder      *mockIteratorDecoder
   627  
   628  	carolServer       *mockServer
   629  	carolChannelLink  *channelLink
   630  	carolOnionDecoder *mockIteratorDecoder
   631  
   632  	hopNetwork
   633  }
   634  
   635  // generateHops creates the per hop payload, the total amount to be sent, and
   636  // also the time lock value needed to route an HTLC with the target amount over
   637  // the specified path.
   638  func generateHops(payAmt lnwire.MilliAtom, startingHeight uint32,
   639  	path ...*channelLink) (lnwire.MilliAtom, uint32, []*hop.Payload) {
   640  
   641  	totalTimelock := startingHeight
   642  	runningAmt := payAmt
   643  
   644  	hops := make([]*hop.Payload, len(path))
   645  	for i := len(path) - 1; i >= 0; i-- {
   646  		// If this is the last hop, then the next hop is the special
   647  		// "exit node". Otherwise, we look to the "prior" hop.
   648  		nextHop := hop.Exit
   649  		if i != len(path)-1 {
   650  			nextHop = path[i+1].channel.ShortChanID()
   651  		}
   652  
   653  		var timeLock uint32
   654  		// If this is the last, hop, then the time lock will be their
   655  		// specified delta policy plus our starting height.
   656  		if i == len(path)-1 {
   657  			totalTimelock += testInvoiceCltvExpiry
   658  			timeLock = totalTimelock
   659  		} else {
   660  			// Otherwise, the outgoing time lock should be the
   661  			// incoming timelock minus their specified delta.
   662  			delta := path[i+1].cfg.FwrdingPolicy.TimeLockDelta
   663  			totalTimelock += delta
   664  			timeLock = totalTimelock - delta
   665  		}
   666  
   667  		// Finally, we'll need to calculate the amount to forward. For
   668  		// the last hop, it's just the payment amount.
   669  		amount := payAmt
   670  		if i != len(path)-1 {
   671  			prevHop := hops[i+1]
   672  			prevAmount := prevHop.ForwardingInfo().AmountToForward
   673  
   674  			fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount)
   675  			runningAmt += fee
   676  
   677  			// Otherwise, for a node to forward an HTLC, then
   678  			// following inequality most hold true:
   679  			//     * amt_in - fee >= amt_to_forward
   680  			amount = runningAmt - fee
   681  		}
   682  
   683  		var nextHopBytes [8]byte
   684  		binary.BigEndian.PutUint64(nextHopBytes[:], nextHop.ToUint64())
   685  
   686  		hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
   687  			Realm:         [1]byte{2}, // hop.coinNetwork
   688  			NextAddress:   nextHopBytes,
   689  			ForwardAmount: uint64(amount),
   690  			OutgoingCltv:  timeLock,
   691  		})
   692  	}
   693  
   694  	return runningAmt, totalTimelock, hops
   695  }
   696  
   697  type paymentResponse struct {
   698  	rhash lntypes.Hash
   699  	err   chan error
   700  }
   701  
   702  func (r *paymentResponse) Wait(d time.Duration) (lntypes.Hash, error) {
   703  	return r.rhash, waitForPaymentResult(r.err, d)
   704  }
   705  
   706  // waitForPaymentResult waits for either an error to be received on c or a
   707  // timeout.
   708  func waitForPaymentResult(c chan error, d time.Duration) error {
   709  	select {
   710  	case err := <-c:
   711  		close(c)
   712  		return err
   713  	case <-time.After(d):
   714  		return errors.New("htlc was not settled in time")
   715  	}
   716  }
   717  
   718  // waitForPayFuncResult executes the given function and waits for a result with
   719  // a timeout.
   720  func waitForPayFuncResult(payFunc func() error, d time.Duration) error {
   721  	errChan := make(chan error)
   722  	go func() {
   723  		errChan <- payFunc()
   724  	}()
   725  
   726  	return waitForPaymentResult(errChan, d)
   727  }
   728  
   729  // makePayment takes the destination node and amount as input, sends the
   730  // payment and returns the error channel to wait for error to be received and
   731  // invoice in order to check its status after the payment finished.
   732  //
   733  // With this function you can send payments:
   734  // * from Alice to Bob
   735  // * from Alice to Carol through the Bob
   736  // * from Alice to some another peer through the Bob
   737  func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
   738  	firstHop lnwire.ShortChannelID, hops []*hop.Payload,
   739  	invoiceAmt, htlcAmt lnwire.MilliAtom,
   740  	timelock uint32) *paymentResponse {
   741  
   742  	paymentErr := make(chan error, 1)
   743  	var rhash lntypes.Hash
   744  
   745  	invoice, payFunc, err := preparePayment(sendingPeer, receivingPeer,
   746  		firstHop, hops, invoiceAmt, htlcAmt, timelock,
   747  	)
   748  	if err != nil {
   749  		paymentErr <- err
   750  		return &paymentResponse{
   751  			rhash: rhash,
   752  			err:   paymentErr,
   753  		}
   754  	}
   755  
   756  	rhash = invoice.Terms.PaymentPreimage.Hash()
   757  
   758  	// Send payment and expose err channel.
   759  	go func() {
   760  		paymentErr <- payFunc()
   761  	}()
   762  
   763  	return &paymentResponse{
   764  		rhash: rhash,
   765  		err:   paymentErr,
   766  	}
   767  }
   768  
   769  // preparePayment creates an invoice at the receivingPeer and returns a function
   770  // that, when called, launches the payment from the sendingPeer.
   771  func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
   772  	firstHop lnwire.ShortChannelID, hops []*hop.Payload,
   773  	invoiceAmt, htlcAmt lnwire.MilliAtom,
   774  	timelock uint32) (*channeldb.Invoice, func() error, error) {
   775  
   776  	sender := sendingPeer.(*mockServer)
   777  	receiver := receivingPeer.(*mockServer)
   778  
   779  	// Generate route convert it to blob, and return next destination for
   780  	// htlc add request.
   781  	blob, err := generateRoute(hops...)
   782  	if err != nil {
   783  		return nil, nil, err
   784  	}
   785  
   786  	// Generate payment: invoice and htlc.
   787  	invoice, htlc, pid, err := generatePayment(
   788  		invoiceAmt, htlcAmt, timelock, blob,
   789  	)
   790  	if err != nil {
   791  		return nil, nil, err
   792  	}
   793  
   794  	// Check who is last in the route and add invoice to server registry.
   795  	hash := invoice.Terms.PaymentPreimage.Hash()
   796  	if err := receiver.registry.AddInvoice(*invoice, hash); err != nil {
   797  		return nil, nil, err
   798  	}
   799  
   800  	// Send payment and expose err channel.
   801  	return invoice, func() error {
   802  		err := sender.htlcSwitch.SendHTLC(
   803  			firstHop, pid, htlc,
   804  		)
   805  		if err != nil {
   806  			return err
   807  		}
   808  		resultChan, err := sender.htlcSwitch.GetPaymentResult(
   809  			pid, hash, newMockDeobfuscator(),
   810  		)
   811  		if err != nil {
   812  			return err
   813  		}
   814  
   815  		result, ok := <-resultChan
   816  		if !ok {
   817  			return fmt.Errorf("shutting down")
   818  		}
   819  
   820  		if result.Error != nil {
   821  			return result.Error
   822  		}
   823  
   824  		return nil
   825  	}, nil
   826  }
   827  
   828  // start starts the three hop network alice,bob,carol servers.
   829  func (n *threeHopNetwork) start() error {
   830  	if err := n.aliceServer.Start(); err != nil {
   831  		return err
   832  	}
   833  	if err := n.bobServer.Start(); err != nil {
   834  		return err
   835  	}
   836  	if err := n.carolServer.Start(); err != nil {
   837  		return err
   838  	}
   839  
   840  	return waitLinksEligible(map[string]*channelLink{
   841  		"alice":      n.aliceChannelLink,
   842  		"bob first":  n.firstBobChannelLink,
   843  		"bob second": n.secondBobChannelLink,
   844  		"carol":      n.carolChannelLink,
   845  	})
   846  }
   847  
   848  // stop stops nodes and cleanup its databases.
   849  func (n *threeHopNetwork) stop() {
   850  	done := make(chan struct{})
   851  	go func() {
   852  		n.aliceServer.Stop()
   853  		done <- struct{}{}
   854  	}()
   855  
   856  	go func() {
   857  		n.bobServer.Stop()
   858  		done <- struct{}{}
   859  	}()
   860  
   861  	go func() {
   862  		n.carolServer.Stop()
   863  		done <- struct{}{}
   864  	}()
   865  
   866  	for i := 0; i < 3; i++ {
   867  		<-done
   868  	}
   869  }
   870  
   871  type clusterChannels struct {
   872  	aliceToBob *lnwallet.LightningChannel
   873  	bobToAlice *lnwallet.LightningChannel
   874  	bobToCarol *lnwallet.LightningChannel
   875  	carolToBob *lnwallet.LightningChannel
   876  }
   877  
   878  // createClusterChannels creates lightning channels which are needed for
   879  // network cluster to be initialized.
   880  func createClusterChannels(aliceToBob, bobToCarol dcrutil.Amount) (
   881  	*clusterChannels, func(), func() (*clusterChannels, error), error) {
   882  
   883  	_, _, firstChanID, secondChanID := genIDs()
   884  
   885  	// Create lightning channels between Alice<->Bob and Bob<->Carol
   886  	aliceChannel, firstBobChannel, cleanAliceBob, err :=
   887  		createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
   888  			aliceToBob, 0, 0, firstChanID)
   889  	if err != nil {
   890  		return nil, nil, nil, errors.Errorf("unable to create "+
   891  			"alice<->bob channel: %v", err)
   892  	}
   893  
   894  	secondBobChannel, carolChannel, cleanBobCarol, err :=
   895  		createTestChannel(bobPrivKey, carolPrivKey, bobToCarol,
   896  			bobToCarol, 0, 0, secondChanID)
   897  	if err != nil {
   898  		cleanAliceBob()
   899  		return nil, nil, nil, errors.Errorf("unable to create "+
   900  			"bob<->carol channel: %v", err)
   901  	}
   902  
   903  	cleanUp := func() {
   904  		cleanAliceBob()
   905  		cleanBobCarol()
   906  	}
   907  
   908  	restoreFromDb := func() (*clusterChannels, error) {
   909  
   910  		a2b, err := aliceChannel.restore()
   911  		if err != nil {
   912  			return nil, err
   913  		}
   914  
   915  		b2a, err := firstBobChannel.restore()
   916  		if err != nil {
   917  			return nil, err
   918  		}
   919  
   920  		b2c, err := secondBobChannel.restore()
   921  		if err != nil {
   922  			return nil, err
   923  		}
   924  
   925  		c2b, err := carolChannel.restore()
   926  		if err != nil {
   927  			return nil, err
   928  		}
   929  
   930  		return &clusterChannels{
   931  			aliceToBob: a2b,
   932  			bobToAlice: b2a,
   933  			bobToCarol: b2c,
   934  			carolToBob: c2b,
   935  		}, nil
   936  	}
   937  
   938  	return &clusterChannels{
   939  		aliceToBob: aliceChannel.channel,
   940  		bobToAlice: firstBobChannel.channel,
   941  		bobToCarol: secondBobChannel.channel,
   942  		carolToBob: carolChannel.channel,
   943  	}, cleanUp, restoreFromDb, nil
   944  }
   945  
   946  // newThreeHopNetwork function creates the following topology and returns the
   947  // control object to manage this cluster:
   948  //
   949  //		alice			   bob				   carol
   950  //		server - <-connection-> - server - - <-connection-> - - - server
   951  //		 |		   	  |				   |
   952  //	  alice htlc			bob htlc		    carol htlc
   953  //	    switch			switch	\		    switch
   954  //		|			 |       \			|
   955  //		|			 |        \			|
   956  //
   957  // alice                   first bob    second bob              carol
   958  // channel link	    	  channel link   channel link		channel link
   959  //
   960  // This function takes server options which can be used to apply custom
   961  // settings to alice, bob and carol.
   962  func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
   963  	secondBobChannel, carolChannel *lnwallet.LightningChannel,
   964  	startingHeight uint32, opts ...serverOption) *threeHopNetwork {
   965  
   966  	aliceDb := aliceChannel.State().Db.GetParentDB()
   967  	bobDb := firstBobChannel.State().Db.GetParentDB()
   968  	carolDb := carolChannel.State().Db.GetParentDB()
   969  
   970  	hopNetwork := newHopNetwork()
   971  
   972  	// Create three peers/servers.
   973  	aliceServer, err := newMockServer(
   974  		t, "alice", startingHeight, aliceDb, hopNetwork.defaultDelta,
   975  	)
   976  	if err != nil {
   977  		t.Fatalf("unable to create alice server: %v", err)
   978  	}
   979  	bobServer, err := newMockServer(
   980  		t, "bob", startingHeight, bobDb, hopNetwork.defaultDelta,
   981  	)
   982  	if err != nil {
   983  		t.Fatalf("unable to create bob server: %v", err)
   984  	}
   985  	carolServer, err := newMockServer(
   986  		t, "carol", startingHeight, carolDb, hopNetwork.defaultDelta,
   987  	)
   988  	if err != nil {
   989  		t.Fatalf("unable to create carol server: %v", err)
   990  	}
   991  
   992  	// Apply all additional functional options to the servers before
   993  	// creating any links.
   994  	for _, option := range opts {
   995  		option(aliceServer, bobServer, carolServer)
   996  	}
   997  
   998  	// Create mock decoder instead of sphinx one in order to mock the route
   999  	// which htlc should follow.
  1000  	aliceDecoder := newMockIteratorDecoder()
  1001  	bobDecoder := newMockIteratorDecoder()
  1002  	carolDecoder := newMockIteratorDecoder()
  1003  
  1004  	aliceChannelLink, err := hopNetwork.createChannelLink(aliceServer,
  1005  		bobServer, aliceChannel, aliceDecoder,
  1006  	)
  1007  	if err != nil {
  1008  		t.Fatal(err)
  1009  	}
  1010  
  1011  	firstBobChannelLink, err := hopNetwork.createChannelLink(bobServer,
  1012  		aliceServer, firstBobChannel, bobDecoder)
  1013  	if err != nil {
  1014  		t.Fatal(err)
  1015  	}
  1016  
  1017  	secondBobChannelLink, err := hopNetwork.createChannelLink(bobServer,
  1018  		carolServer, secondBobChannel, bobDecoder)
  1019  	if err != nil {
  1020  		t.Fatal(err)
  1021  	}
  1022  
  1023  	carolChannelLink, err := hopNetwork.createChannelLink(carolServer,
  1024  		bobServer, carolChannel, carolDecoder)
  1025  	if err != nil {
  1026  		t.Fatal(err)
  1027  	}
  1028  
  1029  	return &threeHopNetwork{
  1030  		aliceServer:       aliceServer,
  1031  		aliceChannelLink:  aliceChannelLink.(*channelLink),
  1032  		aliceOnionDecoder: aliceDecoder,
  1033  
  1034  		bobServer:            bobServer,
  1035  		firstBobChannelLink:  firstBobChannelLink.(*channelLink),
  1036  		secondBobChannelLink: secondBobChannelLink.(*channelLink),
  1037  		bobOnionDecoder:      bobDecoder,
  1038  
  1039  		carolServer:       carolServer,
  1040  		carolChannelLink:  carolChannelLink.(*channelLink),
  1041  		carolOnionDecoder: carolDecoder,
  1042  
  1043  		hopNetwork: *hopNetwork,
  1044  	}
  1045  }
  1046  
  1047  // serverOption is a function which alters the three servers created for
  1048  // a three hop network to allow custom settings on each server.
  1049  type serverOption func(aliceServer, bobServer, carolServer *mockServer)
  1050  
  1051  // serverOptionWithHtlcNotifier is a functional option for the creation of
  1052  // three hop network servers which allows setting of htlc notifiers.
  1053  // Note that these notifiers should be started and stopped by the calling
  1054  // function.
  1055  func serverOptionWithHtlcNotifier(alice, bob,
  1056  	carol *HtlcNotifier) serverOption {
  1057  
  1058  	return func(aliceServer, bobServer, carolServer *mockServer) {
  1059  		aliceServer.htlcSwitch.cfg.HtlcNotifier = alice
  1060  		bobServer.htlcSwitch.cfg.HtlcNotifier = bob
  1061  		carolServer.htlcSwitch.cfg.HtlcNotifier = carol
  1062  	}
  1063  }
  1064  
  1065  // serverOptionRejectHtlc is the functional option for setting the reject
  1066  // htlc config option in each server's switch.
  1067  func serverOptionRejectHtlc(alice, bob, carol bool) serverOption {
  1068  	return func(aliceServer, bobServer, carolServer *mockServer) {
  1069  		aliceServer.htlcSwitch.cfg.RejectHTLC = alice
  1070  		bobServer.htlcSwitch.cfg.RejectHTLC = bob
  1071  		carolServer.htlcSwitch.cfg.RejectHTLC = carol
  1072  	}
  1073  }
  1074  
  1075  // createTwoClusterChannels creates lightning channels which are needed for
  1076  // a 2 hop network cluster to be initialized.
  1077  func createTwoClusterChannels(aliceToBob, bobToCarol dcrutil.Amount) (
  1078  	*testLightningChannel, *testLightningChannel,
  1079  	func(), error) {
  1080  
  1081  	_, _, firstChanID, _ := genIDs()
  1082  
  1083  	// Create lightning channels between Alice<->Bob and Bob<->Carol
  1084  	alice, bob, cleanAliceBob, err :=
  1085  		createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
  1086  			aliceToBob, 0, 0, firstChanID)
  1087  	if err != nil {
  1088  		return nil, nil, nil, errors.Errorf("unable to create "+
  1089  			"alice<->bob channel: %v", err)
  1090  	}
  1091  
  1092  	return alice, bob, cleanAliceBob, nil
  1093  }
  1094  
  1095  // hopNetwork is the base struct for two and three hop networks
  1096  type hopNetwork struct {
  1097  	feeEstimator *mockFeeEstimator
  1098  	globalPolicy ForwardingPolicy
  1099  	obfuscator   hop.ErrorEncrypter
  1100  
  1101  	defaultDelta uint32
  1102  }
  1103  
  1104  func newHopNetwork() *hopNetwork {
  1105  	defaultDelta := uint32(6)
  1106  
  1107  	globalPolicy := ForwardingPolicy{
  1108  		MinHTLCOut:    lnwire.NewMAtomsFromAtoms(5),
  1109  		BaseFee:       lnwire.NewMAtomsFromAtoms(1),
  1110  		TimeLockDelta: defaultDelta,
  1111  	}
  1112  	obfuscator := NewMockObfuscator()
  1113  
  1114  	return &hopNetwork{
  1115  		feeEstimator: newMockFeeEstimator(),
  1116  		globalPolicy: globalPolicy,
  1117  		obfuscator:   obfuscator,
  1118  		defaultDelta: defaultDelta,
  1119  	}
  1120  }
  1121  
  1122  func (h *hopNetwork) createChannelLink(server, peer *mockServer,
  1123  	channel *lnwallet.LightningChannel,
  1124  	decoder *mockIteratorDecoder) (ChannelLink, error) {
  1125  
  1126  	const (
  1127  		fwdPkgTimeout       = 15 * time.Second
  1128  		minFeeUpdateTimeout = 30 * time.Minute
  1129  		maxFeeUpdateTimeout = 40 * time.Minute
  1130  	)
  1131  
  1132  	link := NewChannelLink(
  1133  		ChannelLinkConfig{
  1134  			Switch:             server.htlcSwitch,
  1135  			BestHeight:         server.htlcSwitch.BestHeight,
  1136  			FwrdingPolicy:      h.globalPolicy,
  1137  			Peer:               peer,
  1138  			Circuits:           server.htlcSwitch.CircuitModifier(),
  1139  			ForwardPackets:     server.htlcSwitch.ForwardPackets,
  1140  			DecodeHopIterators: decoder.DecodeHopIterators,
  1141  			ExtractErrorEncrypter: func(*secp256k1.PublicKey) (
  1142  				hop.ErrorEncrypter, lnwire.FailCode) {
  1143  				return h.obfuscator, lnwire.CodeNone
  1144  			},
  1145  			FetchLastChannelUpdate: mockGetChanUpdateMessage,
  1146  			Registry:               server.registry,
  1147  			FeeEstimator:           h.feeEstimator,
  1148  			PreimageCache:          server.pCache,
  1149  			UpdateContractSignals: func(*contractcourt.ContractSignals) error {
  1150  				return nil
  1151  			},
  1152  			ChainEvents:             &contractcourt.ChainEventSubscription{},
  1153  			SyncStates:              true,
  1154  			BatchSize:               10,
  1155  			BatchTicker:             ticker.NewForce(testBatchTimeout),
  1156  			FwdPkgGCTicker:          ticker.NewForce(fwdPkgTimeout),
  1157  			PendingCommitTicker:     ticker.NewForce(2 * time.Minute),
  1158  			MinFeeUpdateTimeout:     minFeeUpdateTimeout,
  1159  			MaxFeeUpdateTimeout:     maxFeeUpdateTimeout,
  1160  			OnChannelFailure:        func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
  1161  			OutgoingCltvRejectDelta: 3,
  1162  			MaxOutgoingCltvExpiry:   DefaultMaxOutgoingCltvExpiry,
  1163  			MaxFeeAllocation:        DefaultMaxLinkFeeAllocation,
  1164  			MaxAnchorsCommitFeeRate: chainfee.AtomPerKByte(10 * 1000),
  1165  			NotifyActiveLink:        func(wire.OutPoint) {},
  1166  			NotifyActiveChannel:     func(wire.OutPoint) {},
  1167  			NotifyInactiveChannel:   func(wire.OutPoint) {},
  1168  			HtlcNotifier:            server.htlcSwitch.cfg.HtlcNotifier,
  1169  
  1170  			ResetChanReestablishWaitTime: func(chanID lnwire.ShortChannelID) error { return nil },
  1171  			AddToChanReestablishWaitTime: func(chanID lnwire.ShortChannelID, waitTime time.Duration) error { return nil },
  1172  		},
  1173  		channel,
  1174  	)
  1175  	if err := server.htlcSwitch.AddLink(link); err != nil {
  1176  		return nil, fmt.Errorf("unable to add channel link: %v", err)
  1177  	}
  1178  
  1179  	go func() {
  1180  		for {
  1181  			select {
  1182  			case <-link.(*channelLink).htlcUpdates:
  1183  			case <-link.(*channelLink).quit:
  1184  				return
  1185  			}
  1186  		}
  1187  	}()
  1188  
  1189  	return link, nil
  1190  }
  1191  
  1192  // twoHopNetwork is used for managing the created cluster of 2 hops.
  1193  type twoHopNetwork struct {
  1194  	hopNetwork
  1195  
  1196  	aliceServer      *mockServer
  1197  	aliceChannelLink *channelLink
  1198  
  1199  	bobServer      *mockServer
  1200  	bobChannelLink *channelLink
  1201  }
  1202  
  1203  // newTwoHopNetwork function creates the following topology and returns the
  1204  // control object to manage this cluster:
  1205  //
  1206  //		alice			   bob
  1207  //		server - <-connection-> - server
  1208  //		 |		   	    |
  1209  //	  alice htlc		  	 bob htlc
  1210  //	    switch			 switch
  1211  //		|			    |
  1212  //		|			    |
  1213  //
  1214  // alice                           bob
  1215  // channel link	    	       channel link
  1216  func newTwoHopNetwork(t testing.TB,
  1217  	aliceChannel, bobChannel *lnwallet.LightningChannel,
  1218  	startingHeight uint32) *twoHopNetwork {
  1219  
  1220  	aliceDb := aliceChannel.State().Db.GetParentDB()
  1221  	bobDb := bobChannel.State().Db.GetParentDB()
  1222  
  1223  	hopNetwork := newHopNetwork()
  1224  
  1225  	// Create two peers/servers.
  1226  	aliceServer, err := newMockServer(
  1227  		t, "alice", startingHeight, aliceDb, hopNetwork.defaultDelta,
  1228  	)
  1229  	if err != nil {
  1230  		t.Fatalf("unable to create alice server: %v", err)
  1231  	}
  1232  	bobServer, err := newMockServer(
  1233  		t, "bob", startingHeight, bobDb, hopNetwork.defaultDelta,
  1234  	)
  1235  	if err != nil {
  1236  		t.Fatalf("unable to create bob server: %v", err)
  1237  	}
  1238  
  1239  	// Create mock decoder instead of sphinx one in order to mock the route
  1240  	// which htlc should follow.
  1241  	aliceDecoder := newMockIteratorDecoder()
  1242  	bobDecoder := newMockIteratorDecoder()
  1243  
  1244  	aliceChannelLink, err := hopNetwork.createChannelLink(
  1245  		aliceServer, bobServer, aliceChannel, aliceDecoder,
  1246  	)
  1247  	if err != nil {
  1248  		t.Fatal(err)
  1249  	}
  1250  
  1251  	bobChannelLink, err := hopNetwork.createChannelLink(
  1252  		bobServer, aliceServer, bobChannel, bobDecoder,
  1253  	)
  1254  	if err != nil {
  1255  		t.Fatal(err)
  1256  	}
  1257  
  1258  	return &twoHopNetwork{
  1259  		aliceServer:      aliceServer,
  1260  		aliceChannelLink: aliceChannelLink.(*channelLink),
  1261  
  1262  		bobServer:      bobServer,
  1263  		bobChannelLink: bobChannelLink.(*channelLink),
  1264  
  1265  		hopNetwork: *hopNetwork,
  1266  	}
  1267  }
  1268  
  1269  // start starts the two hop network alice,bob servers.
  1270  func (n *twoHopNetwork) start() error {
  1271  	if err := n.aliceServer.Start(); err != nil {
  1272  		return err
  1273  	}
  1274  	if err := n.bobServer.Start(); err != nil {
  1275  		n.aliceServer.Stop()
  1276  		return err
  1277  	}
  1278  
  1279  	return waitLinksEligible(map[string]*channelLink{
  1280  		"alice": n.aliceChannelLink,
  1281  		"bob":   n.bobChannelLink,
  1282  	})
  1283  }
  1284  
  1285  // stop stops nodes and cleanup its databases.
  1286  func (n *twoHopNetwork) stop() {
  1287  	done := make(chan struct{})
  1288  	go func() {
  1289  		n.aliceServer.Stop()
  1290  		done <- struct{}{}
  1291  	}()
  1292  
  1293  	go func() {
  1294  		n.bobServer.Stop()
  1295  		done <- struct{}{}
  1296  	}()
  1297  
  1298  	for i := 0; i < 2; i++ {
  1299  		<-done
  1300  	}
  1301  }
  1302  
  1303  func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
  1304  	firstHop lnwire.ShortChannelID, hops []*hop.Payload,
  1305  	invoiceAmt, htlcAmt lnwire.MilliAtom,
  1306  	timelock uint32, preimage lntypes.Preimage) chan error {
  1307  
  1308  	paymentErr := make(chan error, 1)
  1309  
  1310  	sender := sendingPeer.(*mockServer)
  1311  	receiver := receivingPeer.(*mockServer)
  1312  
  1313  	// Generate route convert it to blob, and return next destination for
  1314  	// htlc add request.
  1315  	blob, err := generateRoute(hops...)
  1316  	if err != nil {
  1317  		paymentErr <- err
  1318  		return paymentErr
  1319  	}
  1320  
  1321  	rhash := preimage.Hash()
  1322  
  1323  	var payAddr [32]byte
  1324  	if _, err := crand.Read(payAddr[:]); err != nil {
  1325  		panic(err)
  1326  	}
  1327  
  1328  	// Generate payment: invoice and htlc.
  1329  	invoice, htlc, pid, err := generatePaymentWithPreimage(
  1330  		invoiceAmt, htlcAmt, timelock, blob,
  1331  		nil, rhash, payAddr,
  1332  	)
  1333  	if err != nil {
  1334  		paymentErr <- err
  1335  		return paymentErr
  1336  	}
  1337  
  1338  	// Check who is last in the route and add invoice to server registry.
  1339  	if err := receiver.registry.AddInvoice(*invoice, rhash); err != nil {
  1340  		paymentErr <- err
  1341  		return paymentErr
  1342  	}
  1343  
  1344  	// Send payment and expose err channel.
  1345  	err = sender.htlcSwitch.SendHTLC(firstHop, pid, htlc)
  1346  	if err != nil {
  1347  		paymentErr <- err
  1348  		return paymentErr
  1349  	}
  1350  
  1351  	go func() {
  1352  		resultChan, err := sender.htlcSwitch.GetPaymentResult(
  1353  			pid, rhash, newMockDeobfuscator(),
  1354  		)
  1355  		if err != nil {
  1356  			paymentErr <- err
  1357  			return
  1358  		}
  1359  
  1360  		result, ok := <-resultChan
  1361  		if !ok {
  1362  			paymentErr <- fmt.Errorf("shutting down")
  1363  			return
  1364  		}
  1365  
  1366  		if result.Error != nil {
  1367  			paymentErr <- result.Error
  1368  			return
  1369  		}
  1370  		paymentErr <- nil
  1371  	}()
  1372  
  1373  	return paymentErr
  1374  }
  1375  
  1376  // waitLinksEligible blocks until all links the provided name-to-link map are
  1377  // eligible to forward HTLCs.
  1378  func waitLinksEligible(links map[string]*channelLink) error {
  1379  	return wait.NoError(func() error {
  1380  		for name, link := range links {
  1381  			if link.EligibleToForward() {
  1382  				continue
  1383  			}
  1384  			return fmt.Errorf("%s channel link not eligible", name)
  1385  		}
  1386  		return nil
  1387  	}, 3*time.Second)
  1388  }
  1389  
  1390  // timeout implements a test level timeout.
  1391  func timeout(t *testing.T) func() {
  1392  	done := make(chan struct{})
  1393  	go func() {
  1394  		select {
  1395  		case <-time.After(20 * time.Second):
  1396  			pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
  1397  
  1398  			panic("test timeout")
  1399  		case <-done:
  1400  		}
  1401  	}()
  1402  
  1403  	return func() {
  1404  		close(done)
  1405  	}
  1406  }