github.com/decred/dcrlnd@v0.7.6/invoices/invoice_expiry_watcher_test.go (about)

     1  package invoices
     2  
     3  import (
     4  	"sync"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/decred/dcrlnd/chainntnfs"
     9  	"github.com/decred/dcrlnd/channeldb"
    10  	"github.com/decred/dcrlnd/clock"
    11  	"github.com/decred/dcrlnd/lntypes"
    12  )
    13  
    14  // invoiceExpiryWatcherTest holds a test fixture and implements checks
    15  // for InvoiceExpiryWatcher tests.
    16  type invoiceExpiryWatcherTest struct {
    17  	t                *testing.T
    18  	wg               sync.WaitGroup
    19  	watcher          *InvoiceExpiryWatcher
    20  	testData         invoiceExpiryTestData
    21  	canceledInvoices []lntypes.Hash
    22  }
    23  
    24  type mockChainNotifier struct {
    25  	chainntnfs.ChainNotifier
    26  
    27  	blockChan chan *chainntnfs.BlockEpoch
    28  }
    29  
    30  func newMockNotifier() *mockChainNotifier {
    31  	return &mockChainNotifier{
    32  		blockChan: make(chan *chainntnfs.BlockEpoch),
    33  	}
    34  }
    35  
    36  // RegisterBlockEpochNtfn mocks a block epoch notification, using the mock's
    37  // block channel to deliver blocks to the client.
    38  func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) (
    39  	*chainntnfs.BlockEpochEvent, error) {
    40  
    41  	return &chainntnfs.BlockEpochEvent{
    42  		Epochs: m.blockChan,
    43  		Cancel: func() {},
    44  	}, nil
    45  }
    46  
    47  // newInvoiceExpiryWatcherTest creates a new InvoiceExpiryWatcher test fixture
    48  // and sets up the test environment.
    49  func newInvoiceExpiryWatcherTest(t *testing.T, now time.Time,
    50  	numExpiredInvoices, numPendingInvoices int) *invoiceExpiryWatcherTest {
    51  
    52  	mockNotifier := newMockNotifier()
    53  	test := &invoiceExpiryWatcherTest{
    54  		watcher: NewInvoiceExpiryWatcher(
    55  			clock.NewTestClock(testTime), 0,
    56  			uint32(testCurrentHeight), nil, mockNotifier,
    57  		),
    58  		testData: generateInvoiceExpiryTestData(
    59  			t, now, 0, numExpiredInvoices, numPendingInvoices,
    60  		),
    61  	}
    62  
    63  	test.wg.Add(numExpiredInvoices)
    64  
    65  	err := test.watcher.Start(func(paymentHash lntypes.Hash,
    66  		force bool) error {
    67  
    68  		test.canceledInvoices = append(
    69  			test.canceledInvoices, paymentHash,
    70  		)
    71  		test.wg.Done()
    72  		return nil
    73  	})
    74  
    75  	if err != nil {
    76  		t.Fatalf("cannot start InvoiceExpiryWatcher: %v", err)
    77  	}
    78  
    79  	return test
    80  }
    81  
    82  func (t *invoiceExpiryWatcherTest) waitForFinish(timeout time.Duration) {
    83  	done := make(chan struct{})
    84  
    85  	// Wait for all cancels.
    86  	go func() {
    87  		t.wg.Wait()
    88  		close(done)
    89  	}()
    90  
    91  	select {
    92  	case <-done:
    93  	case <-time.After(timeout):
    94  		t.t.Fatalf("test timeout")
    95  	}
    96  }
    97  
    98  func (t *invoiceExpiryWatcherTest) checkExpectations() {
    99  	// Check that invoices that got canceled during the test are the ones
   100  	// that expired.
   101  	if len(t.canceledInvoices) != len(t.testData.expiredInvoices) {
   102  		t.t.Fatalf("expected %v cancellations, got %v",
   103  			len(t.testData.expiredInvoices),
   104  			len(t.canceledInvoices))
   105  	}
   106  
   107  	for i := range t.canceledInvoices {
   108  		if _, ok := t.testData.expiredInvoices[t.canceledInvoices[i]]; !ok {
   109  			t.t.Fatalf("wrong invoice canceled")
   110  		}
   111  	}
   112  }
   113  
   114  // Tests that InvoiceExpiryWatcher can be started and stopped.
   115  func TestInvoiceExpiryWatcherStartStop(t *testing.T) {
   116  	watcher := NewInvoiceExpiryWatcher(
   117  		clock.NewTestClock(testTime), 0, uint32(testCurrentHeight), nil,
   118  		newMockNotifier(),
   119  	)
   120  	cancel := func(lntypes.Hash, bool) error {
   121  		t.Fatalf("unexpected call")
   122  		return nil
   123  	}
   124  
   125  	if err := watcher.Start(cancel); err != nil {
   126  		t.Fatalf("unexpected error upon start: %v", err)
   127  	}
   128  
   129  	if err := watcher.Start(cancel); err == nil {
   130  		t.Fatalf("expected error upon second start")
   131  	}
   132  
   133  	watcher.Stop()
   134  
   135  	if err := watcher.Start(cancel); err != nil {
   136  		t.Fatalf("unexpected error upon start: %v", err)
   137  	}
   138  }
   139  
   140  // Tests that no invoices will expire from an empty InvoiceExpiryWatcher.
   141  func TestInvoiceExpiryWithNoInvoices(t *testing.T) {
   142  	t.Parallel()
   143  
   144  	test := newInvoiceExpiryWatcherTest(t, testTime, 0, 0)
   145  
   146  	test.waitForFinish(testTimeout)
   147  	test.watcher.Stop()
   148  	test.checkExpectations()
   149  }
   150  
   151  // Tests that if all add invoices are expired, then all invoices
   152  // will be canceled.
   153  func TestInvoiceExpiryWithOnlyExpiredInvoices(t *testing.T) {
   154  	t.Parallel()
   155  
   156  	test := newInvoiceExpiryWatcherTest(t, testTime, 0, 5)
   157  
   158  	for paymentHash, invoice := range test.testData.pendingInvoices {
   159  		test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
   160  	}
   161  
   162  	test.waitForFinish(testTimeout)
   163  	test.watcher.Stop()
   164  	test.checkExpectations()
   165  }
   166  
   167  // Tests that if some invoices are expired, then those invoices
   168  // will be canceled.
   169  func TestInvoiceExpiryWithPendingAndExpiredInvoices(t *testing.T) {
   170  	t.Parallel()
   171  
   172  	test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
   173  
   174  	for paymentHash, invoice := range test.testData.expiredInvoices {
   175  		test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
   176  	}
   177  
   178  	for paymentHash, invoice := range test.testData.pendingInvoices {
   179  		test.watcher.AddInvoices(makeInvoiceExpiry(paymentHash, invoice))
   180  	}
   181  
   182  	test.waitForFinish(testTimeout)
   183  	test.watcher.Stop()
   184  	test.checkExpectations()
   185  }
   186  
   187  // Tests adding multiple invoices at once.
   188  func TestInvoiceExpiryWhenAddingMultipleInvoices(t *testing.T) {
   189  	t.Parallel()
   190  
   191  	test := newInvoiceExpiryWatcherTest(t, testTime, 5, 5)
   192  	var invoices []invoiceExpiry
   193  
   194  	for hash, invoice := range test.testData.expiredInvoices {
   195  		invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
   196  	}
   197  
   198  	for hash, invoice := range test.testData.pendingInvoices {
   199  		invoices = append(invoices, makeInvoiceExpiry(hash, invoice))
   200  	}
   201  
   202  	test.watcher.AddInvoices(invoices...)
   203  	test.waitForFinish(testTimeout)
   204  	test.watcher.Stop()
   205  	test.checkExpectations()
   206  }
   207  
   208  // TestExpiredHodlInv tests expiration of an already-expired hodl invoice
   209  // which has no htlcs.
   210  func TestExpiredHodlInv(t *testing.T) {
   211  	t.Parallel()
   212  
   213  	creationDate := testTime.Add(time.Hour * -24)
   214  	expiry := time.Hour
   215  
   216  	test := setupHodlExpiry(
   217  		t, creationDate, expiry, 0, channeldb.ContractOpen, nil,
   218  	)
   219  
   220  	test.assertCanceled(t, test.hash)
   221  	test.watcher.Stop()
   222  }
   223  
   224  // TestAcceptedHodlNotExpired tests that hodl invoices which are in an accepted
   225  // state are not expired once their time-based expiry elapses, using a regular
   226  // invoice that expires at the same time as a control to ensure that invoices
   227  // with that timestamp would otherwise be expired.
   228  func TestAcceptedHodlNotExpired(t *testing.T) {
   229  	t.Parallel()
   230  
   231  	creationDate := testTime
   232  	expiry := time.Hour
   233  
   234  	test := setupHodlExpiry(
   235  		t, creationDate, expiry, 0, channeldb.ContractAccepted, nil,
   236  	)
   237  	defer test.watcher.Stop()
   238  
   239  	// Add another invoice that will expire at our expiry time as a control
   240  	// value.
   241  	tsExpires := &invoiceExpiryTs{
   242  		PaymentHash: lntypes.Hash{1, 2, 3},
   243  		Expiry:      creationDate.Add(expiry),
   244  		Keysend:     true,
   245  	}
   246  	test.watcher.AddInvoices(tsExpires)
   247  
   248  	test.mockClock.SetTime(creationDate.Add(expiry + 1))
   249  
   250  	// Assert that only the ts expiry invoice is expired.
   251  	test.assertCanceled(t, tsExpires.PaymentHash)
   252  }
   253  
   254  // TestHeightAlreadyExpired tests the case where we add an invoice with htlcs
   255  // that have already expired to the expiry watcher.
   256  func TestHeightAlreadyExpired(t *testing.T) {
   257  	t.Parallel()
   258  
   259  	expiredHtlc := []*channeldb.InvoiceHTLC{
   260  		{
   261  			State:  channeldb.HtlcStateAccepted,
   262  			Expiry: uint32(testCurrentHeight),
   263  		},
   264  	}
   265  
   266  	test := setupHodlExpiry(
   267  		t, testTime, time.Hour, 0, channeldb.ContractAccepted,
   268  		expiredHtlc,
   269  	)
   270  	defer test.watcher.Stop()
   271  
   272  	test.assertCanceled(t, test.hash)
   273  }
   274  
   275  // TestExpiryHeightArrives tests the case where we add a hodl invoice to the
   276  // expiry watcher when it has no htlcs, htlcs are added and then they finally
   277  // expire. We use a non-zero delta for this test to check that we expire with
   278  // sufficient buffer.
   279  func TestExpiryHeightArrives(t *testing.T) {
   280  	var (
   281  		creationDate        = testTime
   282  		expiry              = time.Hour * 2
   283  		delta        uint32 = 1
   284  	)
   285  
   286  	// Start out with a hodl invoice that is open, and has no htlcs.
   287  	test := setupHodlExpiry(
   288  		t, creationDate, expiry, delta, channeldb.ContractOpen, nil,
   289  	)
   290  	defer test.watcher.Stop()
   291  
   292  	htlc1 := uint32(testCurrentHeight + 10)
   293  	expiry1 := makeHeightExpiry(test.hash, htlc1)
   294  
   295  	// Add htlcs to our invoice and progress its state to accepted.
   296  	test.watcher.AddInvoices(expiry1)
   297  	test.setState(channeldb.ContractAccepted)
   298  
   299  	// Progress time so that our expiry has elapsed. We no longer expect
   300  	// this invoice to be canceled because it has been accepted.
   301  	test.mockClock.SetTime(creationDate.Add(expiry))
   302  
   303  	// Tick our mock block subscription with the next block, we don't
   304  	// expect anything to happen.
   305  	currentHeight := uint32(testCurrentHeight + 1)
   306  	test.announceBlock(t, currentHeight)
   307  
   308  	// Now, we add another htlc to the invoice. This one has a lower expiry
   309  	// height than our current ones.
   310  	htlc2 := currentHeight + 5
   311  	expiry2 := makeHeightExpiry(test.hash, htlc2)
   312  	test.watcher.AddInvoices(expiry2)
   313  
   314  	// Announce our lowest htlc expiry block minus our delta, the invoice
   315  	// should be expired now.
   316  	test.announceBlock(t, htlc2-delta)
   317  	test.assertCanceled(t, test.hash)
   318  }