github.com/decred/dcrlnd@v0.7.6/lntest/itest/lnd_mpp_test.go (about)

     1  package itest
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/decred/dcrd/dcrutil/v4"
    10  	"github.com/decred/dcrd/wire"
    11  	"github.com/decred/dcrlnd/chainreg"
    12  	"github.com/decred/dcrlnd/lnrpc"
    13  	"github.com/decred/dcrlnd/lnrpc/routerrpc"
    14  	"github.com/decred/dcrlnd/lntest"
    15  	"github.com/decred/dcrlnd/routing/route"
    16  )
    17  
    18  // testSendToRouteMultiPath tests that we are able to successfully route a
    19  // payment using multiple shards across different paths, by using SendToRoute.
    20  func testSendToRouteMultiPath(net *lntest.NetworkHarness, t *harnessTest) {
    21  	ctxb := context.Background()
    22  
    23  	ctx := newMppTestContext(t, net)
    24  	defer ctx.shutdownNodes()
    25  
    26  	// To ensure the payment goes through separate paths, we'll set a
    27  	// channel size that can only carry one shard at a time. We'll divide
    28  	// the payment into 3 shards.
    29  	const (
    30  		paymentAmt = dcrutil.Amount(300000)
    31  		shardAmt   = paymentAmt / 3
    32  		chanAmt    = shardAmt * 3 / 2
    33  	)
    34  
    35  	// Set up a network with three different paths Alice <-> Bob.
    36  	//              _ Eve _
    37  	//             /       \
    38  	// Alice -- Carol ---- Bob
    39  	//      \              /
    40  	//       \__ Dave ____/
    41  	//
    42  	ctx.openChannel(ctx.carol, ctx.bob, chanAmt)
    43  	ctx.openChannel(ctx.dave, ctx.bob, chanAmt)
    44  	ctx.openChannel(ctx.alice, ctx.dave, chanAmt)
    45  	ctx.openChannel(ctx.eve, ctx.bob, chanAmt)
    46  	ctx.openChannel(ctx.carol, ctx.eve, chanAmt)
    47  
    48  	// Since the channel Alice-> Carol will have to carry two
    49  	// shards, we make it larger.
    50  	ctx.openChannel(ctx.alice, ctx.carol, chanAmt+shardAmt)
    51  
    52  	defer ctx.closeChannels()
    53  
    54  	ctx.waitForChannels()
    55  
    56  	// Make Bob create an invoice for Alice to pay.
    57  	payReqs, rHashes, invoices, err := createPayReqs(
    58  		ctx.bob, paymentAmt, 1,
    59  	)
    60  	if err != nil {
    61  		t.Fatalf("unable to create pay reqs: %v", err)
    62  	}
    63  
    64  	rHash := rHashes[0]
    65  	payReq := payReqs[0]
    66  
    67  	ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
    68  	decodeResp, err := ctx.bob.DecodePayReq(
    69  		ctxt, &lnrpc.PayReqString{PayReq: payReq},
    70  	)
    71  	if err != nil {
    72  		t.Fatalf("decode pay req: %v", err)
    73  	}
    74  
    75  	payAddr := decodeResp.PaymentAddr
    76  
    77  	// We'll send shards along three routes from Alice.
    78  	sendRoutes := [][]*lntest.HarnessNode{
    79  		{ctx.carol, ctx.bob},
    80  		{ctx.dave, ctx.bob},
    81  		{ctx.carol, ctx.eve, ctx.bob},
    82  	}
    83  
    84  	responses := make(chan *lnrpc.HTLCAttempt, len(sendRoutes))
    85  	for _, hops := range sendRoutes {
    86  		// Build a route for the specified hops.
    87  		r, err := ctx.buildRoute(ctxb, shardAmt, ctx.alice, hops)
    88  		if err != nil {
    89  			t.Fatalf("unable to build route: %v", err)
    90  		}
    91  
    92  		// Set the MPP records to indicate this is a payment shard.
    93  		hop := r.Hops[len(r.Hops)-1]
    94  		hop.TlvPayload = true
    95  		hop.MppRecord = &lnrpc.MPPRecord{
    96  			PaymentAddr:    payAddr,
    97  			TotalAmtMAtoms: int64(paymentAmt * 1000),
    98  		}
    99  
   100  		// Send the shard.
   101  		sendReq := &routerrpc.SendToRouteRequest{
   102  			PaymentHash: rHash,
   103  			Route:       r,
   104  		}
   105  
   106  		// We'll send all shards in their own goroutine, since SendToRoute will
   107  		// block as long as the payment is in flight.
   108  		go func() {
   109  			ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
   110  			resp, err := ctx.alice.RouterClient.SendToRouteV2(ctxt, sendReq)
   111  			if err != nil {
   112  				t.Fatalf("unable to send payment: %v", err)
   113  			}
   114  
   115  			responses <- resp
   116  		}()
   117  	}
   118  
   119  	// Wait for all responses to be back, and check that they all
   120  	// succeeded.
   121  	for range sendRoutes {
   122  		var resp *lnrpc.HTLCAttempt
   123  		select {
   124  		case resp = <-responses:
   125  		case <-time.After(defaultTimeout):
   126  			t.Fatalf("response not received")
   127  		}
   128  
   129  		if resp.Failure != nil {
   130  			t.Fatalf("received payment failure : %v", resp.Failure)
   131  		}
   132  
   133  		// All shards should come back with the preimage.
   134  		if !bytes.Equal(resp.Preimage, invoices[0].RPreimage) {
   135  			t.Fatalf("preimage doesn't match")
   136  		}
   137  	}
   138  
   139  	// assertNumHtlcs is a helper that checks the node's latest payment,
   140  	// and asserts it was split into num shards.
   141  	assertNumHtlcs := func(node *lntest.HarnessNode, num int) {
   142  		req := &lnrpc.ListPaymentsRequest{
   143  			IncludeIncomplete: true,
   144  		}
   145  		ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
   146  		paymentsResp, err := node.ListPayments(ctxt, req)
   147  		if err != nil {
   148  			t.Fatalf("error when obtaining payments: %v",
   149  				err)
   150  		}
   151  
   152  		payments := paymentsResp.Payments
   153  		if len(payments) == 0 {
   154  			t.Fatalf("no payments found")
   155  		}
   156  
   157  		payment := payments[len(payments)-1]
   158  		htlcs := payment.Htlcs
   159  		if len(htlcs) == 0 {
   160  			t.Fatalf("no htlcs")
   161  		}
   162  
   163  		succeeded := 0
   164  		for _, htlc := range htlcs {
   165  			if htlc.Status == lnrpc.HTLCAttempt_SUCCEEDED {
   166  				succeeded++
   167  			}
   168  		}
   169  
   170  		if succeeded != num {
   171  			t.Fatalf("expected %v succussful HTLCs, got %v", num,
   172  				succeeded)
   173  		}
   174  	}
   175  
   176  	// assertSettledInvoice checks that the invoice for the given payment
   177  	// hash is settled, and has been paid using num HTLCs.
   178  	assertSettledInvoice := func(node *lntest.HarnessNode, rhash []byte,
   179  		num int) {
   180  
   181  		found := false
   182  		offset := uint64(0)
   183  		for !found {
   184  			ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
   185  			invoicesResp, err := node.ListInvoices(
   186  				ctxt, &lnrpc.ListInvoiceRequest{
   187  					IndexOffset: offset,
   188  				},
   189  			)
   190  			if err != nil {
   191  				t.Fatalf("error when obtaining payments: %v",
   192  					err)
   193  			}
   194  
   195  			if len(invoicesResp.Invoices) == 0 {
   196  				break
   197  			}
   198  
   199  			for _, inv := range invoicesResp.Invoices {
   200  				if !bytes.Equal(inv.RHash, rhash) {
   201  					continue
   202  				}
   203  
   204  				// Assert that the amount paid to the invoice is
   205  				// correct.
   206  				if inv.AmtPaidAtoms != int64(paymentAmt) {
   207  					t.Fatalf("incorrect payment amt for "+
   208  						"invoicewant: %d, got %d",
   209  						paymentAmt, inv.AmtPaidAtoms)
   210  				}
   211  
   212  				if inv.State != lnrpc.Invoice_SETTLED {
   213  					t.Fatalf("Invoice not settled: %v",
   214  						inv.State)
   215  				}
   216  
   217  				if len(inv.Htlcs) != num {
   218  					t.Fatalf("expected invoice to be "+
   219  						"settled with %v HTLCs, had %v",
   220  						num, len(inv.Htlcs))
   221  				}
   222  
   223  				found = true
   224  				break
   225  			}
   226  
   227  			offset = invoicesResp.LastIndexOffset
   228  		}
   229  
   230  		if !found {
   231  			t.Fatalf("invoice not found")
   232  		}
   233  	}
   234  
   235  	// Finally check that the payment shows up with three settled HTLCs in
   236  	// Alice's list of payments...
   237  	assertNumHtlcs(ctx.alice, 3)
   238  
   239  	// ...and in Bob's list of paid invoices.
   240  	assertSettledInvoice(ctx.bob, rHash, 3)
   241  }
   242  
   243  type mppTestContext struct {
   244  	t   *harnessTest
   245  	net *lntest.NetworkHarness
   246  
   247  	// Keep a list of all our active channels.
   248  	networkChans      []*lnrpc.ChannelPoint
   249  	closeChannelFuncs []func()
   250  
   251  	alice, bob, carol, dave, eve *lntest.HarnessNode
   252  	nodes                        []*lntest.HarnessNode
   253  }
   254  
   255  func newMppTestContext(t *harnessTest,
   256  	net *lntest.NetworkHarness) *mppTestContext {
   257  
   258  	alice := net.NewNode(t.t, "alice", nil)
   259  	bob := net.NewNode(t.t, "bob", []string{"--accept-amp"})
   260  
   261  	// Create a five-node context consisting of Alice, Bob and three new
   262  	// nodes.
   263  	carol := net.NewNode(t.t, "carol", nil)
   264  	dave := net.NewNode(t.t, "dave", nil)
   265  	eve := net.NewNode(t.t, "eve", nil)
   266  
   267  	// Connect nodes to ensure propagation of channels.
   268  	nodes := []*lntest.HarnessNode{alice, bob, carol, dave, eve}
   269  	for i := 0; i < len(nodes); i++ {
   270  		for j := i + 1; j < len(nodes); j++ {
   271  			net.EnsureConnected(t.t, nodes[i], nodes[j])
   272  		}
   273  	}
   274  
   275  	ctx := mppTestContext{
   276  		t:     t,
   277  		net:   net,
   278  		alice: alice,
   279  		bob:   bob,
   280  		carol: carol,
   281  		dave:  dave,
   282  		eve:   eve,
   283  		nodes: nodes,
   284  	}
   285  
   286  	return &ctx
   287  }
   288  
   289  // openChannel is a helper to open a channel from->to.
   290  func (c *mppTestContext) openChannel(from, to *lntest.HarnessNode,
   291  	chanSize dcrutil.Amount) {
   292  
   293  	c.net.SendCoins(c.t.t, dcrutil.AtomsPerCoin, from)
   294  
   295  	chanPoint := openChannelAndAssert(
   296  		c.t, c.net, from, to,
   297  		lntest.OpenChannelParams{Amt: chanSize},
   298  	)
   299  
   300  	c.closeChannelFuncs = append(c.closeChannelFuncs, func() {
   301  		closeChannelAndAssert(c.t, c.net, from, chanPoint, false)
   302  	})
   303  
   304  	c.networkChans = append(c.networkChans, chanPoint)
   305  }
   306  
   307  func (c *mppTestContext) closeChannels() {
   308  	for _, f := range c.closeChannelFuncs {
   309  		f()
   310  	}
   311  }
   312  
   313  func (c *mppTestContext) shutdownNodes() {
   314  	shutdownAndAssert(c.net, c.t, c.alice)
   315  	shutdownAndAssert(c.net, c.t, c.bob)
   316  	shutdownAndAssert(c.net, c.t, c.carol)
   317  	shutdownAndAssert(c.net, c.t, c.dave)
   318  	shutdownAndAssert(c.net, c.t, c.eve)
   319  }
   320  
   321  func (c *mppTestContext) waitForChannels() {
   322  	// Wait for all nodes to have seen all channels.
   323  	for _, chanPoint := range c.networkChans {
   324  		for _, node := range c.nodes {
   325  			txid, err := lnrpc.GetChanPointFundingTxid(chanPoint)
   326  			if err != nil {
   327  				c.t.Fatalf("unable to get txid: %v", err)
   328  			}
   329  			point := wire.OutPoint{
   330  				Hash:  *txid,
   331  				Index: chanPoint.OutputIndex,
   332  			}
   333  
   334  			err = node.WaitForNetworkChannelOpen(chanPoint)
   335  			if err != nil {
   336  				c.t.Fatalf("(%v:%d): timeout waiting for "+
   337  					"channel(%s) open: %v",
   338  					node.Cfg.Name, node.NodeID, point, err)
   339  			}
   340  		}
   341  	}
   342  }
   343  
   344  // Helper function for Alice to build a route from pubkeys.
   345  func (c *mppTestContext) buildRoute(ctxb context.Context, amt dcrutil.Amount,
   346  	sender *lntest.HarnessNode, hops []*lntest.HarnessNode) (*lnrpc.Route,
   347  	error) {
   348  
   349  	rpcHops := make([][]byte, 0, len(hops))
   350  	for _, hop := range hops {
   351  		k := hop.PubKeyStr
   352  		pubkey, err := route.NewVertexFromStr(k)
   353  		if err != nil {
   354  			return nil, fmt.Errorf("error parsing %v: %v",
   355  				k, err)
   356  		}
   357  		rpcHops = append(rpcHops, pubkey[:])
   358  	}
   359  
   360  	req := &routerrpc.BuildRouteRequest{
   361  		AmtMAtoms:      int64(amt * 1000),
   362  		FinalCltvDelta: chainreg.DefaultDecredTimeLockDelta,
   363  		HopPubkeys:     rpcHops,
   364  	}
   365  
   366  	ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
   367  	routeResp, err := sender.RouterClient.BuildRoute(ctxt, req)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	return routeResp.Route, nil
   373  }