github.com/decred/dcrlnd@v0.7.6/lntest/itest/utils.go (about) 1 package itest 2 3 import ( 4 "context" 5 "crypto/rand" 6 "fmt" 7 "io" 8 "time" 9 10 "github.com/decred/dcrd/chaincfg/chainhash" 11 "github.com/decred/dcrd/dcrutil/v4" 12 "github.com/decred/dcrd/rpcclient/v8" 13 "github.com/decred/dcrd/txscript/v4/stdscript" 14 "github.com/decred/dcrd/wire" 15 "github.com/decred/dcrlnd/input" 16 "github.com/decred/dcrlnd/lnrpc" 17 "github.com/decred/dcrlnd/lnrpc/routerrpc" 18 "github.com/decred/dcrlnd/lntest" 19 "github.com/decred/dcrlnd/lntest/wait" 20 "github.com/decred/dcrlnd/lnwallet" 21 "github.com/decred/dcrlnd/lnwallet/chainfee" 22 "github.com/decred/dcrlnd/lnwire" 23 "github.com/go-errors/errors" 24 "github.com/stretchr/testify/require" 25 "matheusd.com/testctx" 26 ) 27 28 // completePaymentRequests sends payments from a lightning node to complete all 29 // payment requests. If the awaitResponse parameter is true, this function 30 // does not return until all payments successfully complete without errors. 31 func completePaymentRequests(client lnrpc.LightningClient, 32 routerClient routerrpc.RouterClient, paymentRequests []string, 33 awaitResponse bool) error { 34 35 ctxb := context.Background() 36 ctx, cancel := context.WithTimeout(ctxb, defaultTimeout) 37 defer cancel() 38 39 // We start by getting the current state of the client's channels. This 40 // is needed to ensure the payments actually have been committed before 41 // we return. 42 req := &lnrpc.ListChannelsRequest{} 43 listResp, err := client.ListChannels(ctx, req) 44 if err != nil { 45 return err 46 } 47 48 // send sends a payment and returns an error if it doesn't succeeded. 49 send := func(payReq string) error { 50 ctxc, cancel := context.WithCancel(ctx) 51 defer cancel() 52 53 payStream, err := routerClient.SendPaymentV2( 54 ctxc, 55 &routerrpc.SendPaymentRequest{ 56 PaymentRequest: payReq, 57 TimeoutSeconds: 60, 58 FeeLimitMAtoms: noFeeLimitMAtoms, 59 }, 60 ) 61 if err != nil { 62 return err 63 } 64 65 resp, err := getPaymentResult(payStream) 66 if err != nil { 67 return err 68 } 69 if resp.Status != lnrpc.Payment_SUCCEEDED { 70 return errors.New(resp.FailureReason) 71 } 72 73 return nil 74 } 75 76 // Launch all payments simultaneously. 77 results := make(chan error) 78 for _, payReq := range paymentRequests { 79 payReqCopy := payReq 80 go func() { 81 err := send(payReqCopy) 82 if awaitResponse { 83 results <- err 84 } 85 }() 86 } 87 88 // If awaiting a response, verify that all payments succeeded. 89 if awaitResponse { 90 for range paymentRequests { 91 err := <-results 92 if err != nil { 93 return err 94 } 95 } 96 return nil 97 } 98 99 // We are not waiting for feedback in the form of a response, but we 100 // should still wait long enough for the server to receive and handle 101 // the send before cancelling the request. We wait for the number of 102 // updates to one of our channels has increased before we return. 103 err = wait.Predicate(func() bool { 104 newListResp, err := client.ListChannels(ctx, req) 105 if err != nil { 106 return false 107 } 108 109 // If the number of open channels is now lower than before 110 // attempting the payments, it means one of the payments 111 // triggered a force closure (for example, due to an incorrect 112 // preimage). Return early since it's clear the payment was 113 // attempted. 114 if len(newListResp.Channels) < len(listResp.Channels) { 115 return true 116 } 117 118 for _, c1 := range listResp.Channels { 119 for _, c2 := range newListResp.Channels { 120 if c1.ChannelPoint != c2.ChannelPoint { 121 continue 122 } 123 124 // If this channel has an increased numbr of 125 // updates, we assume the payments are 126 // committed, and we can return. 127 if c2.NumUpdates > c1.NumUpdates { 128 return true 129 } 130 } 131 } 132 133 return false 134 }, defaultTimeout) 135 if err != nil { 136 return err 137 } 138 139 return nil 140 } 141 142 // makeFakePayHash creates random pre image hash 143 func makeFakePayHash(t *harnessTest) []byte { 144 randBuf := make([]byte, 32) 145 146 if _, err := rand.Read(randBuf); err != nil { 147 t.Fatalf("internal error, cannot generate random string: %v", err) 148 } 149 150 return randBuf 151 } 152 153 // createPayReqs is a helper method that will create a slice of payment 154 // requests for the given node. 155 func createPayReqs(node *lntest.HarnessNode, paymentAmt dcrutil.Amount, 156 numInvoices int) ([]string, [][]byte, []*lnrpc.Invoice, error) { 157 158 payReqs := make([]string, numInvoices) 159 rHashes := make([][]byte, numInvoices) 160 invoices := make([]*lnrpc.Invoice, numInvoices) 161 for i := 0; i < numInvoices; i++ { 162 preimage := make([]byte, 32) 163 _, err := rand.Read(preimage) 164 if err != nil { 165 return nil, nil, nil, fmt.Errorf("unable to generate "+ 166 "preimage: %v", err) 167 } 168 invoice := &lnrpc.Invoice{ 169 Memo: "testing", 170 RPreimage: preimage, 171 Value: int64(paymentAmt), 172 173 // Historically, integration tests assumed this check never happened, 174 // so disable by default. There are specific tests for asserting the 175 // behavior when this flag is false. 176 IgnoreMaxInboundAmt: true, 177 } 178 ctxt, _ := context.WithTimeout( 179 context.Background(), defaultTimeout, 180 ) 181 resp, err := node.AddInvoice(ctxt, invoice) 182 if err != nil { 183 return nil, nil, nil, fmt.Errorf("unable to add "+ 184 "invoice: %v", err) 185 } 186 187 // Set the payment address in the invoice so the caller can 188 // properly use it. 189 invoice.PaymentAddr = resp.PaymentAddr 190 191 payReqs[i] = resp.PaymentRequest 192 rHashes[i] = resp.RHash 193 invoices[i] = invoice 194 } 195 return payReqs, rHashes, invoices, nil 196 } 197 198 // getChanInfo is a helper method for getting channel info for a node's sole 199 // channel. 200 func getChanInfo(node *lntest.HarnessNode) (*lnrpc.Channel, error) { 201 202 ctxb := context.Background() 203 ctx, cancel := context.WithTimeout(ctxb, defaultTimeout) 204 defer cancel() 205 206 req := &lnrpc.ListChannelsRequest{} 207 channelInfo, err := node.ListChannels(ctx, req) 208 if err != nil { 209 return nil, err 210 } 211 if len(channelInfo.Channels) != 1 { 212 return nil, fmt.Errorf("node should only have a single "+ 213 "channel, instead it has %v", len(channelInfo.Channels)) 214 } 215 216 return channelInfo.Channels[0], nil 217 } 218 219 // commitTypeHasAnchors returns whether commitType uses anchor outputs. 220 func commitTypeHasAnchors(commitType lnrpc.CommitmentType) bool { 221 switch commitType { 222 case lnrpc.CommitmentType_ANCHORS, 223 lnrpc.CommitmentType_SCRIPT_ENFORCED_LEASE: 224 return true 225 default: 226 return false 227 } 228 } 229 230 // nodeArgsForCommitType returns the command line flag to supply to enable this 231 // commitment type. 232 func nodeArgsForCommitType(commitType lnrpc.CommitmentType) []string { 233 switch commitType { 234 case lnrpc.CommitmentType_LEGACY: 235 return []string{"--protocol.legacy.committweak"} 236 case lnrpc.CommitmentType_STATIC_REMOTE_KEY: 237 return []string{} 238 case lnrpc.CommitmentType_ANCHORS: 239 return []string{"--protocol.anchors"} 240 case lnrpc.CommitmentType_SCRIPT_ENFORCED_LEASE: 241 return []string{ 242 "--protocol.anchors", 243 "--protocol.script-enforced-lease", 244 } 245 } 246 247 return nil 248 } 249 250 // calcStaticFee calculates appropriate fees for commitment transactions. This 251 // function provides a simple way to allow test balance assertions to take fee 252 // calculations into account. 253 func calcStaticFee(c lnrpc.CommitmentType, numHTLCs int) dcrutil.Amount { 254 const htlcSize = input.HTLCOutputSize 255 var ( 256 feePerKB = chainfee.AtomPerKByte(1e4) 257 commitSize = input.CommitmentTxSize 258 anchors = dcrutil.Amount(0) 259 ) 260 261 // The anchor commitment type is slightly heavier, and we must also add 262 // the value of the two anchors to the resulting fee the initiator 263 // pays. In addition the fee rate is capped at 10 sat/vbyte for anchor 264 // channels. 265 if commitTypeHasAnchors(c) { 266 feePerKB = chainfee.AtomPerKByte( 267 lnwallet.DefaultAnchorsCommitMaxFeeRateAtomsPerByte * 1000, 268 ) 269 commitSize = input.CommitmentWithAnchorsTxSize 270 anchors = 2 * anchorSize 271 } 272 273 return feePerKB.FeeForSize(commitSize+htlcSize*int64(numHTLCs)) + 274 anchors 275 } 276 277 // channelCommitType retrieves the active channel commitment type for the given 278 // chan point. 279 func channelCommitType(node *lntest.HarnessNode, 280 chanPoint *lnrpc.ChannelPoint) (lnrpc.CommitmentType, error) { 281 282 ctxb := context.Background() 283 ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) 284 285 req := &lnrpc.ListChannelsRequest{} 286 channels, err := node.ListChannels(ctxt, req) 287 if err != nil { 288 return 0, fmt.Errorf("listchannels failed: %v", err) 289 } 290 291 for _, c := range channels.Channels { 292 if c.ChannelPoint == txStr(chanPoint) { 293 return c.CommitmentType, nil 294 } 295 } 296 297 return 0, fmt.Errorf("channel point %v not found", chanPoint) 298 } 299 300 // calculateMaxHtlc re-implements the RequiredRemoteChannelReserve of the 301 // funding manager's config, which corresponds to the maximum MaxHTLC value we 302 // allow users to set when updating a channel policy. 303 func calculateMaxHtlc(chanCap dcrutil.Amount) uint64 { 304 reserve := lnwire.NewMAtomsFromAtoms(chanCap / 100) 305 max := lnwire.NewMAtomsFromAtoms(chanCap) - reserve 306 return uint64(max) 307 } 308 309 // waitForNodeBlockHeight queries the node for its current block height until 310 // it reaches the passed height. 311 func waitForNodeBlockHeight(node *lntest.HarnessNode, height int64) error { 312 313 ctxb := context.Background() 314 ctx, cancel := context.WithTimeout(ctxb, defaultTimeout) 315 defer cancel() 316 317 var predErr error 318 err := wait.Predicate(func() bool { 319 info, err := node.GetInfo(ctx, &lnrpc.GetInfoRequest{}) 320 if err != nil { 321 predErr = err 322 return false 323 } 324 325 if int64(info.BlockHeight) != height { 326 predErr = fmt.Errorf("expected block height to "+ 327 "be %v, was %v", height, info.BlockHeight) 328 return false 329 } 330 return true 331 }, defaultTimeout) 332 if err != nil { 333 return predErr 334 } 335 return nil 336 } 337 338 // getNTxsFromMempool polls until finding the desired number of transactions in 339 // the provided miner's mempool and returns the full transactions to the caller. 340 func getNTxsFromMempool(miner *rpcclient.Client, n int, 341 timeout time.Duration) ([]*wire.MsgTx, error) { 342 343 txids, err := waitForNTxsInMempool(miner, n, timeout) 344 if err != nil { 345 return nil, err 346 } 347 348 var txes []*wire.MsgTx 349 for _, txid := range txids { 350 ctxt, cancel := context.WithTimeout(context.Background(), timeout) 351 defer cancel() 352 tx, err := miner.GetRawTransaction(ctxt, txid) 353 if err != nil { 354 return nil, err 355 } 356 txes = append(txes, tx.MsgTx()) 357 } 358 return txes, nil 359 } 360 361 // getTxFee retrieves parent transactions and reconstructs the fee paid. 362 func getTxFee(miner *rpcclient.Client, tx *wire.MsgTx) (dcrutil.Amount, error) { 363 var balance dcrutil.Amount 364 for _, in := range tx.TxIn { 365 parentHash := in.PreviousOutPoint.Hash 366 ctxt, cancel := context.WithTimeout(context.Background(), time.Second) 367 defer cancel() 368 rawTx, err := miner.GetRawTransaction(ctxt, &parentHash) 369 if err != nil { 370 return 0, err 371 } 372 parent := rawTx.MsgTx() 373 balance += dcrutil.Amount( 374 parent.TxOut[in.PreviousOutPoint.Index].Value, 375 ) 376 } 377 378 for _, out := range tx.TxOut { 379 balance -= dcrutil.Amount(out.Value) 380 } 381 382 return balance, nil 383 } 384 385 // channelSubscription houses the proxied update and error chans for a node's 386 // channel subscriptions. 387 type channelSubscription struct { 388 updateChan chan *lnrpc.ChannelEventUpdate 389 errChan chan error 390 quit chan struct{} 391 } 392 393 // subscribeChannelNotifications subscribes to channel updates and launches a 394 // goroutine that forwards these to the returned channel. 395 func subscribeChannelNotifications(ctxb context.Context, t *harnessTest, 396 node *lntest.HarnessNode) channelSubscription { 397 398 // We'll first start by establishing a notification client which will 399 // send us notifications upon channels becoming active, inactive or 400 // closed. 401 req := &lnrpc.ChannelEventSubscription{} 402 ctx, cancelFunc := context.WithCancel(ctxb) 403 404 chanUpdateClient, err := node.SubscribeChannelEvents(ctx, req) 405 if err != nil { 406 t.Fatalf("unable to create channel update client: %v", err) 407 } 408 409 // We'll launch a goroutine that will be responsible for proxying all 410 // notifications recv'd from the client into the channel below. 411 errChan := make(chan error, 1) 412 quit := make(chan struct{}) 413 chanUpdates := make(chan *lnrpc.ChannelEventUpdate, 20) 414 go func() { 415 defer cancelFunc() 416 for { 417 select { 418 case <-quit: 419 return 420 default: 421 chanUpdate, err := chanUpdateClient.Recv() 422 select { 423 case <-quit: 424 return 425 default: 426 } 427 428 if err == io.EOF { 429 return 430 } else if err != nil { 431 select { 432 case errChan <- err: 433 case <-quit: 434 } 435 return 436 } 437 438 select { 439 case chanUpdates <- chanUpdate: 440 case <-quit: 441 return 442 } 443 } 444 } 445 }() 446 447 return channelSubscription{ 448 updateChan: chanUpdates, 449 errChan: errChan, 450 quit: quit, 451 } 452 } 453 454 // findTxAtHeight gets all of the transactions that a node's wallet has a record 455 // of at the target height, and finds and returns the tx with the target txid, 456 // failing if it is not found. 457 func findTxAtHeight(t *harnessTest, height int64, 458 target string, node *lntest.HarnessNode) *lnrpc.Transaction { 459 460 ctxb := context.Background() 461 ctx, cancel := context.WithTimeout(ctxb, defaultTimeout) 462 defer cancel() 463 464 txns, err := node.LightningClient.GetTransactions( 465 ctx, &lnrpc.GetTransactionsRequest{ 466 StartHeight: int32(height), 467 EndHeight: int32(height), 468 }, 469 ) 470 require.NoError(t.t, err, "could not get transactions") 471 472 for _, tx := range txns.Transactions { 473 if tx.TxHash == target { 474 return tx 475 } 476 } 477 478 return nil 479 } 480 481 // getOutputIndex returns the output index of the given address in the given 482 // transaction. 483 func getOutputIndex(t *harnessTest, miner *lntest.HarnessMiner, 484 txid *chainhash.Hash, addr string) int { 485 486 t.t.Helper() 487 488 // We'll then extract the raw transaction from the mempool in order to 489 // determine the index of the p2tr output. 490 tx, err := miner.Client.GetRawTransaction(testctx.New(t), txid) 491 require.NoError(t.t, err) 492 493 p2trOutputIndex := -1 494 for i, txOut := range tx.MsgTx().TxOut { 495 _, addrs := stdscript.ExtractAddrs( 496 txOut.Version, txOut.PkScript, miner.ActiveNet, 497 ) 498 499 if len(addrs) > 0 && addrs[0].String() == addr { 500 p2trOutputIndex = i 501 } 502 } 503 require.Greater(t.t, p2trOutputIndex, -1) 504 505 return p2trOutputIndex 506 }