github.com/decred/dcrlnd@v0.7.6/lntest/itest/utils.go (about)

     1  package itest
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"time"
     9  
    10  	"github.com/decred/dcrd/chaincfg/chainhash"
    11  	"github.com/decred/dcrd/dcrutil/v4"
    12  	"github.com/decred/dcrd/rpcclient/v8"
    13  	"github.com/decred/dcrd/txscript/v4/stdscript"
    14  	"github.com/decred/dcrd/wire"
    15  	"github.com/decred/dcrlnd/input"
    16  	"github.com/decred/dcrlnd/lnrpc"
    17  	"github.com/decred/dcrlnd/lnrpc/routerrpc"
    18  	"github.com/decred/dcrlnd/lntest"
    19  	"github.com/decred/dcrlnd/lntest/wait"
    20  	"github.com/decred/dcrlnd/lnwallet"
    21  	"github.com/decred/dcrlnd/lnwallet/chainfee"
    22  	"github.com/decred/dcrlnd/lnwire"
    23  	"github.com/go-errors/errors"
    24  	"github.com/stretchr/testify/require"
    25  	"matheusd.com/testctx"
    26  )
    27  
    28  // completePaymentRequests sends payments from a lightning node to complete all
    29  // payment requests. If the awaitResponse parameter is true, this function
    30  // does not return until all payments successfully complete without errors.
    31  func completePaymentRequests(client lnrpc.LightningClient,
    32  	routerClient routerrpc.RouterClient, paymentRequests []string,
    33  	awaitResponse bool) error {
    34  
    35  	ctxb := context.Background()
    36  	ctx, cancel := context.WithTimeout(ctxb, defaultTimeout)
    37  	defer cancel()
    38  
    39  	// We start by getting the current state of the client's channels. This
    40  	// is needed to ensure the payments actually have been committed before
    41  	// we return.
    42  	req := &lnrpc.ListChannelsRequest{}
    43  	listResp, err := client.ListChannels(ctx, req)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	// send sends a payment and returns an error if it doesn't succeeded.
    49  	send := func(payReq string) error {
    50  		ctxc, cancel := context.WithCancel(ctx)
    51  		defer cancel()
    52  
    53  		payStream, err := routerClient.SendPaymentV2(
    54  			ctxc,
    55  			&routerrpc.SendPaymentRequest{
    56  				PaymentRequest: payReq,
    57  				TimeoutSeconds: 60,
    58  				FeeLimitMAtoms: noFeeLimitMAtoms,
    59  			},
    60  		)
    61  		if err != nil {
    62  			return err
    63  		}
    64  
    65  		resp, err := getPaymentResult(payStream)
    66  		if err != nil {
    67  			return err
    68  		}
    69  		if resp.Status != lnrpc.Payment_SUCCEEDED {
    70  			return errors.New(resp.FailureReason)
    71  		}
    72  
    73  		return nil
    74  	}
    75  
    76  	// Launch all payments simultaneously.
    77  	results := make(chan error)
    78  	for _, payReq := range paymentRequests {
    79  		payReqCopy := payReq
    80  		go func() {
    81  			err := send(payReqCopy)
    82  			if awaitResponse {
    83  				results <- err
    84  			}
    85  		}()
    86  	}
    87  
    88  	// If awaiting a response, verify that all payments succeeded.
    89  	if awaitResponse {
    90  		for range paymentRequests {
    91  			err := <-results
    92  			if err != nil {
    93  				return err
    94  			}
    95  		}
    96  		return nil
    97  	}
    98  
    99  	// We are not waiting for feedback in the form of a response, but we
   100  	// should still wait long enough for the server to receive and handle
   101  	// the send before cancelling the request. We wait for the number of
   102  	// updates to one of our channels has increased before we return.
   103  	err = wait.Predicate(func() bool {
   104  		newListResp, err := client.ListChannels(ctx, req)
   105  		if err != nil {
   106  			return false
   107  		}
   108  
   109  		// If the number of open channels is now lower than before
   110  		// attempting the payments, it means one of the payments
   111  		// triggered a force closure (for example, due to an incorrect
   112  		// preimage). Return early since it's clear the payment was
   113  		// attempted.
   114  		if len(newListResp.Channels) < len(listResp.Channels) {
   115  			return true
   116  		}
   117  
   118  		for _, c1 := range listResp.Channels {
   119  			for _, c2 := range newListResp.Channels {
   120  				if c1.ChannelPoint != c2.ChannelPoint {
   121  					continue
   122  				}
   123  
   124  				// If this channel has an increased numbr of
   125  				// updates, we assume the payments are
   126  				// committed, and we can return.
   127  				if c2.NumUpdates > c1.NumUpdates {
   128  					return true
   129  				}
   130  			}
   131  		}
   132  
   133  		return false
   134  	}, defaultTimeout)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  // makeFakePayHash creates random pre image hash
   143  func makeFakePayHash(t *harnessTest) []byte {
   144  	randBuf := make([]byte, 32)
   145  
   146  	if _, err := rand.Read(randBuf); err != nil {
   147  		t.Fatalf("internal error, cannot generate random string: %v", err)
   148  	}
   149  
   150  	return randBuf
   151  }
   152  
   153  // createPayReqs is a helper method that will create a slice of payment
   154  // requests for the given node.
   155  func createPayReqs(node *lntest.HarnessNode, paymentAmt dcrutil.Amount,
   156  	numInvoices int) ([]string, [][]byte, []*lnrpc.Invoice, error) {
   157  
   158  	payReqs := make([]string, numInvoices)
   159  	rHashes := make([][]byte, numInvoices)
   160  	invoices := make([]*lnrpc.Invoice, numInvoices)
   161  	for i := 0; i < numInvoices; i++ {
   162  		preimage := make([]byte, 32)
   163  		_, err := rand.Read(preimage)
   164  		if err != nil {
   165  			return nil, nil, nil, fmt.Errorf("unable to generate "+
   166  				"preimage: %v", err)
   167  		}
   168  		invoice := &lnrpc.Invoice{
   169  			Memo:      "testing",
   170  			RPreimage: preimage,
   171  			Value:     int64(paymentAmt),
   172  
   173  			// Historically, integration tests assumed this check never happened,
   174  			// so disable by default. There are specific tests for asserting the
   175  			// behavior when this flag is false.
   176  			IgnoreMaxInboundAmt: true,
   177  		}
   178  		ctxt, _ := context.WithTimeout(
   179  			context.Background(), defaultTimeout,
   180  		)
   181  		resp, err := node.AddInvoice(ctxt, invoice)
   182  		if err != nil {
   183  			return nil, nil, nil, fmt.Errorf("unable to add "+
   184  				"invoice: %v", err)
   185  		}
   186  
   187  		// Set the payment address in the invoice so the caller can
   188  		// properly use it.
   189  		invoice.PaymentAddr = resp.PaymentAddr
   190  
   191  		payReqs[i] = resp.PaymentRequest
   192  		rHashes[i] = resp.RHash
   193  		invoices[i] = invoice
   194  	}
   195  	return payReqs, rHashes, invoices, nil
   196  }
   197  
   198  // getChanInfo is a helper method for getting channel info for a node's sole
   199  // channel.
   200  func getChanInfo(node *lntest.HarnessNode) (*lnrpc.Channel, error) {
   201  
   202  	ctxb := context.Background()
   203  	ctx, cancel := context.WithTimeout(ctxb, defaultTimeout)
   204  	defer cancel()
   205  
   206  	req := &lnrpc.ListChannelsRequest{}
   207  	channelInfo, err := node.ListChannels(ctx, req)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  	if len(channelInfo.Channels) != 1 {
   212  		return nil, fmt.Errorf("node should only have a single "+
   213  			"channel, instead it has %v", len(channelInfo.Channels))
   214  	}
   215  
   216  	return channelInfo.Channels[0], nil
   217  }
   218  
   219  // commitTypeHasAnchors returns whether commitType uses anchor outputs.
   220  func commitTypeHasAnchors(commitType lnrpc.CommitmentType) bool {
   221  	switch commitType {
   222  	case lnrpc.CommitmentType_ANCHORS,
   223  		lnrpc.CommitmentType_SCRIPT_ENFORCED_LEASE:
   224  		return true
   225  	default:
   226  		return false
   227  	}
   228  }
   229  
   230  // nodeArgsForCommitType returns the command line flag to supply to enable this
   231  // commitment type.
   232  func nodeArgsForCommitType(commitType lnrpc.CommitmentType) []string {
   233  	switch commitType {
   234  	case lnrpc.CommitmentType_LEGACY:
   235  		return []string{"--protocol.legacy.committweak"}
   236  	case lnrpc.CommitmentType_STATIC_REMOTE_KEY:
   237  		return []string{}
   238  	case lnrpc.CommitmentType_ANCHORS:
   239  		return []string{"--protocol.anchors"}
   240  	case lnrpc.CommitmentType_SCRIPT_ENFORCED_LEASE:
   241  		return []string{
   242  			"--protocol.anchors",
   243  			"--protocol.script-enforced-lease",
   244  		}
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  // calcStaticFee calculates appropriate fees for commitment transactions.  This
   251  // function provides a simple way to allow test balance assertions to take fee
   252  // calculations into account.
   253  func calcStaticFee(c lnrpc.CommitmentType, numHTLCs int) dcrutil.Amount {
   254  	const htlcSize = input.HTLCOutputSize
   255  	var (
   256  		feePerKB   = chainfee.AtomPerKByte(1e4)
   257  		commitSize = input.CommitmentTxSize
   258  		anchors    = dcrutil.Amount(0)
   259  	)
   260  
   261  	// The anchor commitment type is slightly heavier, and we must also add
   262  	// the value of the two anchors to the resulting fee the initiator
   263  	// pays. In addition the fee rate is capped at 10 sat/vbyte for anchor
   264  	// channels.
   265  	if commitTypeHasAnchors(c) {
   266  		feePerKB = chainfee.AtomPerKByte(
   267  			lnwallet.DefaultAnchorsCommitMaxFeeRateAtomsPerByte * 1000,
   268  		)
   269  		commitSize = input.CommitmentWithAnchorsTxSize
   270  		anchors = 2 * anchorSize
   271  	}
   272  
   273  	return feePerKB.FeeForSize(commitSize+htlcSize*int64(numHTLCs)) +
   274  		anchors
   275  }
   276  
   277  // channelCommitType retrieves the active channel commitment type for the given
   278  // chan point.
   279  func channelCommitType(node *lntest.HarnessNode,
   280  	chanPoint *lnrpc.ChannelPoint) (lnrpc.CommitmentType, error) {
   281  
   282  	ctxb := context.Background()
   283  	ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
   284  
   285  	req := &lnrpc.ListChannelsRequest{}
   286  	channels, err := node.ListChannels(ctxt, req)
   287  	if err != nil {
   288  		return 0, fmt.Errorf("listchannels failed: %v", err)
   289  	}
   290  
   291  	for _, c := range channels.Channels {
   292  		if c.ChannelPoint == txStr(chanPoint) {
   293  			return c.CommitmentType, nil
   294  		}
   295  	}
   296  
   297  	return 0, fmt.Errorf("channel point %v not found", chanPoint)
   298  }
   299  
   300  // calculateMaxHtlc re-implements the RequiredRemoteChannelReserve of the
   301  // funding manager's config, which corresponds to the maximum MaxHTLC value we
   302  // allow users to set when updating a channel policy.
   303  func calculateMaxHtlc(chanCap dcrutil.Amount) uint64 {
   304  	reserve := lnwire.NewMAtomsFromAtoms(chanCap / 100)
   305  	max := lnwire.NewMAtomsFromAtoms(chanCap) - reserve
   306  	return uint64(max)
   307  }
   308  
   309  // waitForNodeBlockHeight queries the node for its current block height until
   310  // it reaches the passed height.
   311  func waitForNodeBlockHeight(node *lntest.HarnessNode, height int64) error {
   312  
   313  	ctxb := context.Background()
   314  	ctx, cancel := context.WithTimeout(ctxb, defaultTimeout)
   315  	defer cancel()
   316  
   317  	var predErr error
   318  	err := wait.Predicate(func() bool {
   319  		info, err := node.GetInfo(ctx, &lnrpc.GetInfoRequest{})
   320  		if err != nil {
   321  			predErr = err
   322  			return false
   323  		}
   324  
   325  		if int64(info.BlockHeight) != height {
   326  			predErr = fmt.Errorf("expected block height to "+
   327  				"be %v, was %v", height, info.BlockHeight)
   328  			return false
   329  		}
   330  		return true
   331  	}, defaultTimeout)
   332  	if err != nil {
   333  		return predErr
   334  	}
   335  	return nil
   336  }
   337  
   338  // getNTxsFromMempool polls until finding the desired number of transactions in
   339  // the provided miner's mempool and returns the full transactions to the caller.
   340  func getNTxsFromMempool(miner *rpcclient.Client, n int,
   341  	timeout time.Duration) ([]*wire.MsgTx, error) {
   342  
   343  	txids, err := waitForNTxsInMempool(miner, n, timeout)
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	var txes []*wire.MsgTx
   349  	for _, txid := range txids {
   350  		ctxt, cancel := context.WithTimeout(context.Background(), timeout)
   351  		defer cancel()
   352  		tx, err := miner.GetRawTransaction(ctxt, txid)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  		txes = append(txes, tx.MsgTx())
   357  	}
   358  	return txes, nil
   359  }
   360  
   361  // getTxFee retrieves parent transactions and reconstructs the fee paid.
   362  func getTxFee(miner *rpcclient.Client, tx *wire.MsgTx) (dcrutil.Amount, error) {
   363  	var balance dcrutil.Amount
   364  	for _, in := range tx.TxIn {
   365  		parentHash := in.PreviousOutPoint.Hash
   366  		ctxt, cancel := context.WithTimeout(context.Background(), time.Second)
   367  		defer cancel()
   368  		rawTx, err := miner.GetRawTransaction(ctxt, &parentHash)
   369  		if err != nil {
   370  			return 0, err
   371  		}
   372  		parent := rawTx.MsgTx()
   373  		balance += dcrutil.Amount(
   374  			parent.TxOut[in.PreviousOutPoint.Index].Value,
   375  		)
   376  	}
   377  
   378  	for _, out := range tx.TxOut {
   379  		balance -= dcrutil.Amount(out.Value)
   380  	}
   381  
   382  	return balance, nil
   383  }
   384  
   385  // channelSubscription houses the proxied update and error chans for a node's
   386  // channel subscriptions.
   387  type channelSubscription struct {
   388  	updateChan chan *lnrpc.ChannelEventUpdate
   389  	errChan    chan error
   390  	quit       chan struct{}
   391  }
   392  
   393  // subscribeChannelNotifications subscribes to channel updates and launches a
   394  // goroutine that forwards these to the returned channel.
   395  func subscribeChannelNotifications(ctxb context.Context, t *harnessTest,
   396  	node *lntest.HarnessNode) channelSubscription {
   397  
   398  	// We'll first start by establishing a notification client which will
   399  	// send us notifications upon channels becoming active, inactive or
   400  	// closed.
   401  	req := &lnrpc.ChannelEventSubscription{}
   402  	ctx, cancelFunc := context.WithCancel(ctxb)
   403  
   404  	chanUpdateClient, err := node.SubscribeChannelEvents(ctx, req)
   405  	if err != nil {
   406  		t.Fatalf("unable to create channel update client: %v", err)
   407  	}
   408  
   409  	// We'll launch a goroutine that will be responsible for proxying all
   410  	// notifications recv'd from the client into the channel below.
   411  	errChan := make(chan error, 1)
   412  	quit := make(chan struct{})
   413  	chanUpdates := make(chan *lnrpc.ChannelEventUpdate, 20)
   414  	go func() {
   415  		defer cancelFunc()
   416  		for {
   417  			select {
   418  			case <-quit:
   419  				return
   420  			default:
   421  				chanUpdate, err := chanUpdateClient.Recv()
   422  				select {
   423  				case <-quit:
   424  					return
   425  				default:
   426  				}
   427  
   428  				if err == io.EOF {
   429  					return
   430  				} else if err != nil {
   431  					select {
   432  					case errChan <- err:
   433  					case <-quit:
   434  					}
   435  					return
   436  				}
   437  
   438  				select {
   439  				case chanUpdates <- chanUpdate:
   440  				case <-quit:
   441  					return
   442  				}
   443  			}
   444  		}
   445  	}()
   446  
   447  	return channelSubscription{
   448  		updateChan: chanUpdates,
   449  		errChan:    errChan,
   450  		quit:       quit,
   451  	}
   452  }
   453  
   454  // findTxAtHeight gets all of the transactions that a node's wallet has a record
   455  // of at the target height, and finds and returns the tx with the target txid,
   456  // failing if it is not found.
   457  func findTxAtHeight(t *harnessTest, height int64,
   458  	target string, node *lntest.HarnessNode) *lnrpc.Transaction {
   459  
   460  	ctxb := context.Background()
   461  	ctx, cancel := context.WithTimeout(ctxb, defaultTimeout)
   462  	defer cancel()
   463  
   464  	txns, err := node.LightningClient.GetTransactions(
   465  		ctx, &lnrpc.GetTransactionsRequest{
   466  			StartHeight: int32(height),
   467  			EndHeight:   int32(height),
   468  		},
   469  	)
   470  	require.NoError(t.t, err, "could not get transactions")
   471  
   472  	for _, tx := range txns.Transactions {
   473  		if tx.TxHash == target {
   474  			return tx
   475  		}
   476  	}
   477  
   478  	return nil
   479  }
   480  
   481  // getOutputIndex returns the output index of the given address in the given
   482  // transaction.
   483  func getOutputIndex(t *harnessTest, miner *lntest.HarnessMiner,
   484  	txid *chainhash.Hash, addr string) int {
   485  
   486  	t.t.Helper()
   487  
   488  	// We'll then extract the raw transaction from the mempool in order to
   489  	// determine the index of the p2tr output.
   490  	tx, err := miner.Client.GetRawTransaction(testctx.New(t), txid)
   491  	require.NoError(t.t, err)
   492  
   493  	p2trOutputIndex := -1
   494  	for i, txOut := range tx.MsgTx().TxOut {
   495  		_, addrs := stdscript.ExtractAddrs(
   496  			txOut.Version, txOut.PkScript, miner.ActiveNet,
   497  		)
   498  
   499  		if len(addrs) > 0 && addrs[0].String() == addr {
   500  			p2trOutputIndex = i
   501  		}
   502  	}
   503  	require.Greater(t.t, p2trOutputIndex, -1)
   504  
   505  	return p2trOutputIndex
   506  }