github.com/decred/dcrlnd@v0.7.6/invoices/test_utils_test.go (about)

     1  package invoices
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"os"
    10  	"runtime/pprof"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/decred/dcrd/chaincfg/chainhash"
    16  	"github.com/decred/dcrd/chaincfg/v3"
    17  	"github.com/decred/dcrd/dcrec/secp256k1/v4"
    18  	"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
    19  	"github.com/decred/dcrlnd/chainntnfs"
    20  	"github.com/decred/dcrlnd/channeldb"
    21  	"github.com/decred/dcrlnd/clock"
    22  	"github.com/decred/dcrlnd/lntypes"
    23  	"github.com/decred/dcrlnd/lnwire"
    24  	"github.com/decred/dcrlnd/record"
    25  	"github.com/decred/dcrlnd/zpay32"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  type mockPayload struct {
    30  	mpp           *record.MPP
    31  	amp           *record.AMP
    32  	customRecords record.CustomSet
    33  }
    34  
    35  func (p *mockPayload) MultiPath() *record.MPP {
    36  	return p.mpp
    37  }
    38  
    39  func (p *mockPayload) AMPRecord() *record.AMP {
    40  	return p.amp
    41  }
    42  
    43  func (p *mockPayload) CustomRecords() record.CustomSet {
    44  	// This function should always return a map instance, but for mock
    45  	// configuration we do accept nil.
    46  	if p.customRecords == nil {
    47  		return make(record.CustomSet)
    48  	}
    49  
    50  	return p.customRecords
    51  }
    52  
    53  const (
    54  	testHtlcExpiry = uint32(5)
    55  
    56  	testInvoiceCltvDelta = uint32(4)
    57  
    58  	testFinalCltvRejectDelta = int32(4)
    59  
    60  	testCurrentHeight = int32(1)
    61  )
    62  
    63  var (
    64  	testTimeout = 5 * time.Second
    65  
    66  	testTime = time.Date(2018, time.February, 2, 14, 0, 0, 0, time.UTC)
    67  
    68  	testInvoicePreimage = lntypes.Preimage{1}
    69  
    70  	testInvoicePaymentHash = testInvoicePreimage.Hash()
    71  
    72  	testPrivKeyBytes, _ = hex.DecodeString(
    73  		"e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734")
    74  
    75  	testPrivKey = secp256k1.PrivKeyFromBytes(
    76  		testPrivKeyBytes)
    77  
    78  	testInvoiceDescription = "coffee"
    79  
    80  	testInvoiceAmount = lnwire.MilliAtom(100000)
    81  
    82  	testNetParams = chaincfg.MainNetParams()
    83  
    84  	testMessageSigner = zpay32.MessageSigner{
    85  		SignCompact: func(msg []byte) ([]byte, error) {
    86  			hash := chainhash.HashB(msg)
    87  			sig := ecdsa.SignCompact(testPrivKey, hash, true)
    88  			return sig, nil
    89  		},
    90  	}
    91  
    92  	testFeatures = lnwire.NewFeatureVector(
    93  		nil, lnwire.Features,
    94  	)
    95  
    96  	testPayload = &mockPayload{}
    97  
    98  	testInvoiceCreationDate = testTime
    99  )
   100  
   101  var (
   102  	testInvoiceAmt = lnwire.MilliAtom(100000)
   103  	testInvoice    = &channeldb.Invoice{
   104  		Terms: channeldb.ContractTerm{
   105  			PaymentPreimage: &testInvoicePreimage,
   106  			Value:           testInvoiceAmt,
   107  			Expiry:          time.Hour,
   108  			Features:        testFeatures,
   109  		},
   110  		CreationDate: testInvoiceCreationDate,
   111  	}
   112  
   113  	testPayAddrReqInvoice = &channeldb.Invoice{
   114  		Terms: channeldb.ContractTerm{
   115  			PaymentPreimage: &testInvoicePreimage,
   116  			Value:           testInvoiceAmt,
   117  			Expiry:          time.Hour,
   118  			Features: lnwire.NewFeatureVector(
   119  				lnwire.NewRawFeatureVector(
   120  					lnwire.TLVOnionPayloadOptional,
   121  					lnwire.PaymentAddrRequired,
   122  				),
   123  				lnwire.Features,
   124  			),
   125  		},
   126  		CreationDate: testInvoiceCreationDate,
   127  	}
   128  
   129  	testPayAddrOptionalInvoice = &channeldb.Invoice{
   130  		Terms: channeldb.ContractTerm{
   131  			PaymentPreimage: &testInvoicePreimage,
   132  			Value:           testInvoiceAmt,
   133  			Expiry:          time.Hour,
   134  			Features: lnwire.NewFeatureVector(
   135  				lnwire.NewRawFeatureVector(
   136  					lnwire.TLVOnionPayloadOptional,
   137  					lnwire.PaymentAddrOptional,
   138  				),
   139  				lnwire.Features,
   140  			),
   141  		},
   142  		CreationDate: testInvoiceCreationDate,
   143  	}
   144  
   145  	testHodlInvoice = &channeldb.Invoice{
   146  		Terms: channeldb.ContractTerm{
   147  			Value:    testInvoiceAmt,
   148  			Expiry:   time.Hour,
   149  			Features: testFeatures,
   150  		},
   151  		CreationDate: testInvoiceCreationDate,
   152  		HodlInvoice:  true,
   153  	}
   154  )
   155  
   156  func newTestChannelDB(clock clock.Clock) (*channeldb.DB, func(), error) {
   157  	// First, create a temporary directory to be used for the duration of
   158  	// this test.
   159  	tempDirName, err := ioutil.TempDir("", "channeldb")
   160  	if err != nil {
   161  		return nil, nil, err
   162  	}
   163  
   164  	// Next, create channeldb for the first time.
   165  	cdb, err := channeldb.Open(
   166  		tempDirName, channeldb.OptionClock(clock),
   167  	)
   168  	if err != nil {
   169  		os.RemoveAll(tempDirName)
   170  		return nil, nil, err
   171  	}
   172  
   173  	cleanUp := func() {
   174  		cdb.Close()
   175  		os.RemoveAll(tempDirName)
   176  	}
   177  
   178  	return cdb, cleanUp, nil
   179  }
   180  
   181  type testContext struct {
   182  	cdb      *channeldb.DB
   183  	registry *InvoiceRegistry
   184  	notifier *mockChainNotifier
   185  	clock    *clock.TestClock
   186  
   187  	cleanup func()
   188  	t       *testing.T
   189  }
   190  
   191  func newTestContext(t *testing.T) *testContext {
   192  	clock := clock.NewTestClock(testTime)
   193  
   194  	cdb, cleanup, err := newTestChannelDB(clock)
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	notifier := newMockNotifier()
   200  
   201  	expiryWatcher := NewInvoiceExpiryWatcher(
   202  		clock, 0, uint32(testCurrentHeight), nil, notifier,
   203  	)
   204  
   205  	// Instantiate and start the invoice ctx.registry.
   206  	cfg := RegistryConfig{
   207  		FinalCltvRejectDelta: testFinalCltvRejectDelta,
   208  		HtlcHoldDuration:     30 * time.Second,
   209  		Clock:                clock,
   210  	}
   211  	registry := NewRegistry(cdb, expiryWatcher, &cfg)
   212  
   213  	err = registry.Start()
   214  	if err != nil {
   215  		cleanup()
   216  		t.Fatal(err)
   217  	}
   218  
   219  	ctx := testContext{
   220  		cdb:      cdb,
   221  		registry: registry,
   222  		notifier: notifier,
   223  		clock:    clock,
   224  		t:        t,
   225  		cleanup: func() {
   226  			if err = registry.Stop(); err != nil {
   227  				t.Fatalf("failed to stop invoice registry: %v", err)
   228  			}
   229  			cleanup()
   230  		},
   231  	}
   232  
   233  	return &ctx
   234  }
   235  
   236  func getCircuitKey(htlcID uint64) channeldb.CircuitKey {
   237  	return channeldb.CircuitKey{
   238  		ChanID: lnwire.ShortChannelID{
   239  			BlockHeight: 1, TxIndex: 2, TxPosition: 3,
   240  		},
   241  		HtlcID: htlcID,
   242  	}
   243  }
   244  
   245  func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
   246  	timestamp time.Time, expiry time.Duration) *channeldb.Invoice {
   247  
   248  	if expiry == 0 {
   249  		expiry = time.Hour
   250  	}
   251  
   252  	var payAddr [32]byte
   253  	if _, err := rand.Read(payAddr[:]); err != nil {
   254  		t.Fatalf("unable to generate payment addr: %v", err)
   255  	}
   256  
   257  	rawInvoice, err := zpay32.NewInvoice(
   258  		testNetParams,
   259  		preimage.Hash(),
   260  		timestamp,
   261  		zpay32.Amount(testInvoiceAmount),
   262  		zpay32.Description(testInvoiceDescription),
   263  		zpay32.Expiry(expiry),
   264  		zpay32.PaymentAddr(payAddr),
   265  	)
   266  	if err != nil {
   267  		t.Fatalf("Error while creating new invoice: %v", err)
   268  	}
   269  
   270  	paymentRequest, err := rawInvoice.Encode(testMessageSigner)
   271  
   272  	if err != nil {
   273  		t.Fatalf("Error while encoding payment request: %v", err)
   274  	}
   275  
   276  	return &channeldb.Invoice{
   277  		Terms: channeldb.ContractTerm{
   278  			PaymentPreimage: &preimage,
   279  			PaymentAddr:     payAddr,
   280  			Value:           testInvoiceAmount,
   281  			Expiry:          expiry,
   282  			Features:        testFeatures,
   283  		},
   284  		PaymentRequest: []byte(paymentRequest),
   285  		CreationDate:   timestamp,
   286  	}
   287  }
   288  
   289  // timeout implements a test level timeout.
   290  func timeout() func() {
   291  	done := make(chan struct{})
   292  
   293  	go func() {
   294  		select {
   295  		case <-time.After(5 * time.Second):
   296  			err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
   297  			if err != nil {
   298  				panic(fmt.Sprintf("error writing to std out after timeout: %v", err))
   299  			}
   300  			panic("timeout")
   301  		case <-done:
   302  		}
   303  	}()
   304  
   305  	return func() {
   306  		close(done)
   307  	}
   308  }
   309  
   310  // invoiceExpiryTestData simply holds generated expired and pending invoices.
   311  type invoiceExpiryTestData struct {
   312  	expiredInvoices map[lntypes.Hash]*channeldb.Invoice
   313  	pendingInvoices map[lntypes.Hash]*channeldb.Invoice
   314  }
   315  
   316  // generateInvoiceExpiryTestData generates the specified number of fake expired
   317  // and pending invoices anchored to the passed now timestamp.
   318  func generateInvoiceExpiryTestData(
   319  	t *testing.T, now time.Time,
   320  	offset, numExpired, numPending int) invoiceExpiryTestData {
   321  
   322  	var testData invoiceExpiryTestData
   323  
   324  	testData.expiredInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
   325  	testData.pendingInvoices = make(map[lntypes.Hash]*channeldb.Invoice)
   326  
   327  	expiredCreationDate := now.Add(-24 * time.Hour)
   328  
   329  	for i := 1; i <= numExpired; i++ {
   330  		var preimage lntypes.Preimage
   331  		binary.BigEndian.PutUint32(preimage[:4], uint32(offset+i))
   332  		expiry := time.Duration((i+offset)%24) * time.Hour
   333  		invoice := newTestInvoice(t, preimage, expiredCreationDate, expiry)
   334  		testData.expiredInvoices[preimage.Hash()] = invoice
   335  	}
   336  
   337  	for i := 1; i <= numPending; i++ {
   338  		var preimage lntypes.Preimage
   339  		binary.BigEndian.PutUint32(preimage[4:], uint32(offset+i))
   340  		expiry := time.Duration((i+offset)%24) * time.Hour
   341  		invoice := newTestInvoice(t, preimage, now, expiry)
   342  		testData.pendingInvoices[preimage.Hash()] = invoice
   343  	}
   344  
   345  	return testData
   346  }
   347  
   348  // checkSettleResolution asserts the resolution is a settle with the correct
   349  // preimage. If successful, the HtlcSettleResolution is returned in case further
   350  // checks are desired.
   351  func checkSettleResolution(t *testing.T, res HtlcResolution,
   352  	expPreimage lntypes.Preimage) *HtlcSettleResolution {
   353  
   354  	t.Helper()
   355  
   356  	settleResolution, ok := res.(*HtlcSettleResolution)
   357  	require.True(t, ok)
   358  	require.Equal(t, expPreimage, settleResolution.Preimage)
   359  
   360  	return settleResolution
   361  }
   362  
   363  // checkFailResolution asserts the resolution is a fail with the correct reason.
   364  // If successful, the HtlcFailResolutionis returned in case further checks are
   365  // desired.
   366  func checkFailResolution(t *testing.T, res HtlcResolution,
   367  	expOutcome FailResolutionResult) *HtlcFailResolution {
   368  
   369  	t.Helper()
   370  	failResolution, ok := res.(*HtlcFailResolution)
   371  	require.True(t, ok)
   372  	require.Equal(t, expOutcome, failResolution.Outcome)
   373  
   374  	return failResolution
   375  }
   376  
   377  type hodlExpiryTest struct {
   378  	hash         lntypes.Hash
   379  	state        channeldb.ContractState
   380  	stateLock    sync.Mutex
   381  	mockNotifier *mockChainNotifier
   382  	mockClock    *clock.TestClock
   383  	cancelChan   chan lntypes.Hash
   384  	watcher      *InvoiceExpiryWatcher
   385  }
   386  
   387  func (h *hodlExpiryTest) setState(state channeldb.ContractState) {
   388  	h.stateLock.Lock()
   389  	defer h.stateLock.Unlock()
   390  
   391  	h.state = state
   392  }
   393  
   394  func (h *hodlExpiryTest) announceBlock(t *testing.T, height uint32) {
   395  	select {
   396  	case h.mockNotifier.blockChan <- &chainntnfs.BlockEpoch{
   397  		Height: int32(height),
   398  	}:
   399  
   400  	case <-time.After(testTimeout):
   401  		t.Fatalf("block %v not consumed", height)
   402  	}
   403  }
   404  
   405  func (h *hodlExpiryTest) assertCanceled(t *testing.T, expected lntypes.Hash) {
   406  	select {
   407  	case actual := <-h.cancelChan:
   408  		require.Equal(t, expected, actual)
   409  
   410  	case <-time.After(testTimeout):
   411  		t.Fatalf("invoice: %v not canceled", h.hash)
   412  	}
   413  }
   414  
   415  // setupHodlExpiry creates a hodl invoice in our expiry watcher and runs an
   416  // arbitrary update function which advances the invoices's state.
   417  func setupHodlExpiry(t *testing.T, creationDate time.Time,
   418  	expiry time.Duration, heightDelta uint32,
   419  	startState channeldb.ContractState,
   420  	startHtlcs []*channeldb.InvoiceHTLC) *hodlExpiryTest {
   421  
   422  	mockNotifier := newMockNotifier()
   423  	mockClock := clock.NewTestClock(testTime)
   424  
   425  	test := &hodlExpiryTest{
   426  		state: startState,
   427  		watcher: NewInvoiceExpiryWatcher(
   428  			mockClock, heightDelta, uint32(testCurrentHeight), nil,
   429  			mockNotifier,
   430  		),
   431  		cancelChan:   make(chan lntypes.Hash),
   432  		mockNotifier: mockNotifier,
   433  		mockClock:    mockClock,
   434  	}
   435  
   436  	// Use an unbuffered channel to block on cancel calls so that the test
   437  	// does not exit before we've processed all the invoices we expect.
   438  	cancelImpl := func(paymentHash lntypes.Hash, force bool) error {
   439  		test.stateLock.Lock()
   440  		currentState := test.state
   441  		test.stateLock.Unlock()
   442  
   443  		if currentState != channeldb.ContractOpen && !force {
   444  			return nil
   445  		}
   446  
   447  		select {
   448  		case test.cancelChan <- paymentHash:
   449  		case <-time.After(testTimeout):
   450  		}
   451  
   452  		return nil
   453  	}
   454  
   455  	require.NoError(t, test.watcher.Start(cancelImpl))
   456  
   457  	// We set preimage and hash so that we can use our existing test
   458  	// helpers. In practice we would only have the hash, but this does not
   459  	// affect what we're testing at all.
   460  	preimage := lntypes.Preimage{1}
   461  	test.hash = preimage.Hash()
   462  
   463  	invoice := newTestInvoice(t, preimage, creationDate, expiry)
   464  	invoice.State = startState
   465  	invoice.HodlInvoice = true
   466  	invoice.Htlcs = make(map[channeldb.CircuitKey]*channeldb.InvoiceHTLC)
   467  
   468  	// If we have any htlcs, add them with unique circult keys.
   469  	for i, htlc := range startHtlcs {
   470  		key := channeldb.CircuitKey{
   471  			HtlcID: uint64(i),
   472  		}
   473  
   474  		invoice.Htlcs[key] = htlc
   475  	}
   476  
   477  	// Create an expiry entry for our invoice in its starting state. This
   478  	// mimics adding invoices to the watcher on start.
   479  	entry := makeInvoiceExpiry(test.hash, invoice)
   480  	test.watcher.AddInvoices(entry)
   481  
   482  	return test
   483  }