github.com/decred/dcrlnd@v0.7.6/invoices/test_utils_test.go (about) 1 package invoices 2 3 import ( 4 "crypto/rand" 5 "encoding/binary" 6 "encoding/hex" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "runtime/pprof" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/decred/dcrd/chaincfg/chainhash" 16 "github.com/decred/dcrd/chaincfg/v3" 17 "github.com/decred/dcrd/dcrec/secp256k1/v4" 18 "github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa" 19 "github.com/decred/dcrlnd/chainntnfs" 20 "github.com/decred/dcrlnd/channeldb" 21 "github.com/decred/dcrlnd/clock" 22 "github.com/decred/dcrlnd/lntypes" 23 "github.com/decred/dcrlnd/lnwire" 24 "github.com/decred/dcrlnd/record" 25 "github.com/decred/dcrlnd/zpay32" 26 "github.com/stretchr/testify/require" 27 ) 28 29 type mockPayload struct { 30 mpp *record.MPP 31 amp *record.AMP 32 customRecords record.CustomSet 33 } 34 35 func (p *mockPayload) MultiPath() *record.MPP { 36 return p.mpp 37 } 38 39 func (p *mockPayload) AMPRecord() *record.AMP { 40 return p.amp 41 } 42 43 func (p *mockPayload) CustomRecords() record.CustomSet { 44 // This function should always return a map instance, but for mock 45 // configuration we do accept nil. 46 if p.customRecords == nil { 47 return make(record.CustomSet) 48 } 49 50 return p.customRecords 51 } 52 53 const ( 54 testHtlcExpiry = uint32(5) 55 56 testInvoiceCltvDelta = uint32(4) 57 58 testFinalCltvRejectDelta = int32(4) 59 60 testCurrentHeight = int32(1) 61 ) 62 63 var ( 64 testTimeout = 5 * time.Second 65 66 testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC) 67 68 testInvoicePreimage = lntypes.Preimage{1} 69 70 testInvoicePaymentHash = testInvoicePreimage.Hash() 71 72 testPrivKeyBytes, _ = hex.DecodeString( 73 "e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734") 74 75 testPrivKey = secp256k1.PrivKeyFromBytes( 76 testPrivKeyBytes) 77 78 testInvoiceDescription = "coffee" 79 80 testInvoiceAmount = lnwire.MilliAtom(100000) 81 82 testNetParams = chaincfg.MainNetParams() 83 84 testMessageSigner = zpay32.MessageSigner{ 85 SignCompact: func(msg []byte) ([]byte, error) { 86 hash := chainhash.HashB(msg) 87 sig := ecdsa.SignCompact(testPrivKey, hash, true) 88 return sig, nil 89 }, 90 } 91 92 testFeatures = lnwire.NewFeatureVector( 93 nil, lnwire.Features, 94 ) 95 96 testPayload = &mockPayload{} 97 98 testInvoiceCreationDate = testTime 99 ) 100 101 var ( 102 testInvoiceAmt = lnwire.MilliAtom(100000) 103 testInvoice = &channeldb.Invoice{ 104 Terms: channeldb.ContractTerm{ 105 PaymentPreimage: &testInvoicePreimage, 106 Value: testInvoiceAmt, 107 Expiry: time.Hour, 108 Features: testFeatures, 109 }, 110 CreationDate: testInvoiceCreationDate, 111 } 112 113 testPayAddrReqInvoice = &channeldb.Invoice{ 114 Terms: channeldb.ContractTerm{ 115 PaymentPreimage: &testInvoicePreimage, 116 Value: testInvoiceAmt, 117 Expiry: time.Hour, 118 Features: lnwire.NewFeatureVector( 119 lnwire.NewRawFeatureVector( 120 lnwire.TLVOnionPayloadOptional, 121 lnwire.PaymentAddrRequired, 122 ), 123 lnwire.Features, 124 ), 125 }, 126 CreationDate: testInvoiceCreationDate, 127 } 128 129 testPayAddrOptionalInvoice = &channeldb.Invoice{ 130 Terms: channeldb.ContractTerm{ 131 PaymentPreimage: &testInvoicePreimage, 132 Value: testInvoiceAmt, 133 Expiry: time.Hour, 134 Features: lnwire.NewFeatureVector( 135 lnwire.NewRawFeatureVector( 136 lnwire.TLVOnionPayloadOptional, 137 lnwire.PaymentAddrOptional, 138 ), 139 lnwire.Features, 140 ), 141 }, 142 CreationDate: testInvoiceCreationDate, 143 } 144 145 testHodlInvoice = &channeldb.Invoice{ 146 Terms: channeldb.ContractTerm{ 147 Value: testInvoiceAmt, 148 Expiry: time.Hour, 149 Features: testFeatures, 150 }, 151 CreationDate: testInvoiceCreationDate, 152 HodlInvoice: true, 153 } 154 ) 155 156 func newTestChannelDB(clock clock.Clock) (*channeldb.DB, func(), error) { 157 // First, create a temporary directory to be used for the duration of 158 // this test. 159 tempDirName, err := ioutil.TempDir("", "channeldb") 160 if err != nil { 161 return nil, nil, err 162 } 163 164 // Next, create channeldb for the first time. 165 cdb, err := channeldb.Open( 166 tempDirName, channeldb.OptionClock(clock), 167 ) 168 if err != nil { 169 os.RemoveAll(tempDirName) 170 return nil, nil, err 171 } 172 173 cleanUp := func() { 174 cdb.Close() 175 os.RemoveAll(tempDirName) 176 } 177 178 return cdb, cleanUp, nil 179 } 180 181 type testContext struct { 182 cdb *channeldb.DB 183 registry *InvoiceRegistry 184 notifier *mockChainNotifier 185 clock *clock.TestClock 186 187 cleanup func() 188 t *testing.T 189 } 190 191 func newTestContext(t *testing.T) *testContext { 192 clock := clock.NewTestClock(testTime) 193 194 cdb, cleanup, err := newTestChannelDB(clock) 195 if err != nil { 196 t.Fatal(err) 197 } 198 199 notifier := newMockNotifier() 200 201 expiryWatcher := NewInvoiceExpiryWatcher( 202 clock, 0, uint32(testCurrentHeight), nil, notifier, 203 ) 204 205 // Instantiate and start the invoice ctx.registry. 206 cfg := RegistryConfig{ 207 FinalCltvRejectDelta: testFinalCltvRejectDelta, 208 HtlcHoldDuration: 30 * time.Second, 209 Clock: clock, 210 } 211 registry := NewRegistry(cdb, expiryWatcher, &cfg) 212 213 err = registry.Start() 214 if err != nil { 215 cleanup() 216 t.Fatal(err) 217 } 218 219 ctx := testContext{ 220 cdb: cdb, 221 registry: registry, 222 notifier: notifier, 223 clock: clock, 224 t: t, 225 cleanup: func() { 226 if err = registry.Stop(); err != nil { 227 t.Fatalf("failed to stop invoice registry: %v", err) 228 } 229 cleanup() 230 }, 231 } 232 233 return &ctx 234 } 235 236 func getCircuitKey(htlcID uint64) channeldb.CircuitKey { 237 return channeldb.CircuitKey{ 238 ChanID: lnwire.ShortChannelID{ 239 BlockHeight: 1, TxIndex: 2, TxPosition: 3, 240 }, 241 HtlcID: htlcID, 242 } 243 } 244 245 func newTestInvoice(t *testing.T, preimage lntypes.Preimage, 246 timestamp time.Time, expiry time.Duration) *channeldb.Invoice { 247 248 if expiry == 0 { 249 expiry = time.Hour 250 } 251 252 var payAddr [32]byte 253 if _, err := rand.Read(payAddr[:]); err != nil { 254 t.Fatalf("unable to generate payment addr: %v", err) 255 } 256 257 rawInvoice, err := zpay32.NewInvoice( 258 testNetParams, 259 preimage.Hash(), 260 timestamp, 261 zpay32.Amount(testInvoiceAmount), 262 zpay32.Description(testInvoiceDescription), 263 zpay32.Expiry(expiry), 264 zpay32.PaymentAddr(payAddr), 265 ) 266 if err != nil { 267 t.Fatalf("Error while creating new invoice: %v", err) 268 } 269 270 paymentRequest, err := rawInvoice.Encode(testMessageSigner) 271 272 if err != nil { 273 t.Fatalf("Error while encoding payment request: %v", err) 274 } 275 276 return &channeldb.Invoice{ 277 Terms: channeldb.ContractTerm{ 278 PaymentPreimage: &preimage, 279 PaymentAddr: payAddr, 280 Value: testInvoiceAmount, 281 Expiry: expiry, 282 Features: testFeatures, 283 }, 284 PaymentRequest: []byte(paymentRequest), 285 CreationDate: timestamp, 286 } 287 } 288 289 // timeout implements a test level timeout. 290 func timeout() func() { 291 done := make(chan struct{}) 292 293 go func() { 294 select { 295 case <-time.After(5 * time.Second): 296 err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) 297 if err != nil { 298 panic(fmt.Sprintf("error writing to std out after timeout: %v", err)) 299 } 300 panic("timeout") 301 case <-done: 302 } 303 }() 304 305 return func() { 306 close(done) 307 } 308 } 309 310 // invoiceExpiryTestData simply holds generated expired and pending invoices. 311 type invoiceExpiryTestData struct { 312 expiredInvoices map[lntypes.Hash]*channeldb.Invoice 313 pendingInvoices map[lntypes.Hash]*channeldb.Invoice 314 } 315 316 // generateInvoiceExpiryTestData generates the specified number of fake expired 317 // and pending invoices anchored to the passed now timestamp. 318 func generateInvoiceExpiryTestData( 319 t *testing.T, now time.Time, 320 offset, numExpired, numPending int) invoiceExpiryTestData { 321 322 var testData invoiceExpiryTestData 323 324 testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice) 325 testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice) 326 327 expiredCreationDate := now.Add(-24 * time.Hour) 328 329 for i := 1; i <= numExpired; i++ { 330 var preimage lntypes.Preimage 331 binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i)) 332 expiry := time.Duration((i+offset)%24) * time.Hour 333 invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry) 334 testData.expiredInvoices[preimage.Hash()] = invoice 335 } 336 337 for i := 1; i <= numPending; i++ { 338 var preimage lntypes.Preimage 339 binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i)) 340 expiry := time.Duration((i+offset)%24) * time.Hour 341 invoice := newTestInvoice(t, preimage, now, expiry) 342 testData.pendingInvoices[preimage.Hash()] = invoice 343 } 344 345 return testData 346 } 347 348 // checkSettleResolution asserts the resolution is a settle with the correct 349 // preimage. If successful, the HtlcSettleResolution is returned in case further 350 // checks are desired. 351 func checkSettleResolution(t *testing.T, res HtlcResolution, 352 expPreimage lntypes.Preimage) *HtlcSettleResolution { 353 354 t.Helper() 355 356 settleResolution, ok := res.(*HtlcSettleResolution) 357 require.True(t, ok) 358 require.Equal(t, expPreimage, settleResolution.Preimage) 359 360 return settleResolution 361 } 362 363 // checkFailResolution asserts the resolution is a fail with the correct reason. 364 // If successful, the HtlcFailResolutionis returned in case further checks are 365 // desired. 366 func checkFailResolution(t *testing.T, res HtlcResolution, 367 expOutcome FailResolutionResult) *HtlcFailResolution { 368 369 t.Helper() 370 failResolution, ok := res.(*HtlcFailResolution) 371 require.True(t, ok) 372 require.Equal(t, expOutcome, failResolution.Outcome) 373 374 return failResolution 375 } 376 377 type hodlExpiryTest struct { 378 hash lntypes.Hash 379 state channeldb.ContractState 380 stateLock sync.Mutex 381 mockNotifier *mockChainNotifier 382 mockClock *clock.TestClock 383 cancelChan chan lntypes.Hash 384 watcher *InvoiceExpiryWatcher 385 } 386 387 func (h *hodlExpiryTest) setState(state channeldb.ContractState) { 388 h.stateLock.Lock() 389 defer h.stateLock.Unlock() 390 391 h.state = state 392 } 393 394 func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) { 395 select { 396 case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{ 397 Height: int32(height), 398 }: 399 400 case <-time.After(testTimeout): 401 t.Fatalf("block %v not consumed", height) 402 } 403 } 404 405 func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) { 406 select { 407 case actual := <-h.cancelChan: 408 require.Equal(t, expected, actual) 409 410 case <-time.After(testTimeout): 411 t.Fatalf("invoice: %v not canceled", h.hash) 412 } 413 } 414 415 // setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an 416 // arbitrary update function which advances the invoices's state. 417 func setupHodlExpiry(t *testing.T, creationDate time.Time, 418 expiry time.Duration, heightDelta uint32, 419 startState channeldb.ContractState, 420 startHtlcs []*channeldb.InvoiceHTLC) *hodlExpiryTest { 421 422 mockNotifier := newMockNotifier() 423 mockClock := clock.NewTestClock(testTime) 424 425 test := &hodlExpiryTest{ 426 state: startState, 427 watcher: NewInvoiceExpiryWatcher( 428 mockClock, heightDelta, uint32(testCurrentHeight), nil, 429 mockNotifier, 430 ), 431 cancelChan: make(chan lntypes.Hash), 432 mockNotifier: mockNotifier, 433 mockClock: mockClock, 434 } 435 436 // Use an unbuffered channel to block on cancel calls so that the test 437 // does not exit before we've processed all the invoices we expect. 438 cancelImpl := func(paymentHash lntypes.Hash, force bool) error { 439 test.stateLock.Lock() 440 currentState := test.state 441 test.stateLock.Unlock() 442 443 if currentState != channeldb.ContractOpen && !force { 444 return nil 445 } 446 447 select { 448 case test.cancelChan <- paymentHash: 449 case <-time.After(testTimeout): 450 } 451 452 return nil 453 } 454 455 require.NoError(t, test.watcher.Start(cancelImpl)) 456 457 // We set preimage and hash so that we can use our existing test 458 // helpers. In practice we would only have the hash, but this does not 459 // affect what we're testing at all. 460 preimage := lntypes.Preimage{1} 461 test.hash = preimage.Hash() 462 463 invoice := newTestInvoice(t, preimage, creationDate, expiry) 464 invoice.State = startState 465 invoice.HodlInvoice = true 466 invoice.Htlcs = make(map[channeldb.CircuitKey]*channeldb.InvoiceHTLC) 467 468 // If we have any htlcs, add them with unique circult keys. 469 for i, htlc := range startHtlcs { 470 key := channeldb.CircuitKey{ 471 HtlcID: uint64(i), 472 } 473 474 invoice.Htlcs[key] = htlc 475 } 476 477 // Create an expiry entry for our invoice in its starting state. This 478 // mimics adding invoices to the watcher on start. 479 entry := makeInvoiceExpiry(test.hash, invoice) 480 test.watcher.AddInvoices(entry) 481 482 return test 483 }