github.com/decred/dcrlnd@v0.7.6/invoices/invoice_expiry_watcher.go (about)

     1  package invoices
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/decred/dcrd/chaincfg/chainhash"
     9  	"github.com/decred/dcrlnd/chainntnfs"
    10  	"github.com/decred/dcrlnd/channeldb"
    11  	"github.com/decred/dcrlnd/clock"
    12  	"github.com/decred/dcrlnd/lntypes"
    13  	"github.com/decred/dcrlnd/queue"
    14  	"github.com/decred/dcrlnd/zpay32"
    15  )
    16  
    17  // invoiceExpiry is a vanity interface for different invoice expiry types
    18  // which implement the priority queue item interface, used to improve code
    19  // readability.
    20  type invoiceExpiry queue.PriorityQueueItem
    21  
    22  // Compile time assertion that invoiceExpiryTs implements invoiceExpiry.
    23  var _ invoiceExpiry = (*invoiceExpiryTs)(nil)
    24  
    25  // invoiceExpiryTs holds and invoice's payment hash and its expiry. This
    26  // is used to order invoices by their expiry time for cancellation.
    27  type invoiceExpiryTs struct {
    28  	PaymentHash lntypes.Hash
    29  	Expiry      time.Time
    30  	Keysend     bool
    31  }
    32  
    33  // Less implements PriorityQueueItem.Less such that the top item in the
    34  // priorty queue will be the one that expires next.
    35  func (e invoiceExpiryTs) Less(other queue.PriorityQueueItem) bool {
    36  	return e.Expiry.Before(other.(*invoiceExpiryTs).Expiry)
    37  }
    38  
    39  // Compile time assertion that invoiceExpiryHeight implements invoiceExpiry.
    40  var _ invoiceExpiry = (*invoiceExpiryHeight)(nil)
    41  
    42  // invoiceExpiryHeight holds information about an invoice which can be used to
    43  // cancel it based on its expiry height.
    44  type invoiceExpiryHeight struct {
    45  	paymentHash  lntypes.Hash
    46  	expiryHeight uint32
    47  }
    48  
    49  // Less implements PriorityQueueItem.Less such that the top item in the
    50  // priority queue is the lowest block height.
    51  func (b invoiceExpiryHeight) Less(other queue.PriorityQueueItem) bool {
    52  	return b.expiryHeight < other.(*invoiceExpiryHeight).expiryHeight
    53  }
    54  
    55  // expired returns a boolean that indicates whether this entry has expired,
    56  // taking our expiry delta into account.
    57  func (b invoiceExpiryHeight) expired(currentHeight, delta uint32) bool {
    58  	return currentHeight+delta >= b.expiryHeight
    59  }
    60  
    61  // InvoiceExpiryWatcher handles automatic invoice cancellation of expried
    62  // invoices. Upon start InvoiceExpiryWatcher will retrieve all pending (not yet
    63  // settled or canceled) invoices invoices to its watcing queue. When a new
    64  // invoice is added to the InvoiceRegistry, it'll be forarded to the
    65  // InvoiceExpiryWatcher and will end up in the watching queue as well.
    66  // If any of the watched invoices expire, they'll be removed from the watching
    67  // queue and will be cancelled through InvoiceRegistry.CancelInvoice().
    68  type InvoiceExpiryWatcher struct {
    69  	sync.Mutex
    70  	started bool
    71  
    72  	// clock is the clock implementation that InvoiceExpiryWatcher uses.
    73  	// It is useful for testing.
    74  	clock clock.Clock
    75  
    76  	// notifier provides us with block height updates.
    77  	notifier chainntnfs.ChainNotifier
    78  
    79  	// blockExpiryDelta is the number of blocks before a htlc's expiry that
    80  	// we expire the invoice based on expiry height. We use a delta because
    81  	// we will go to some delta before our expiry, so we want to cancel
    82  	// before this to prevent force closes.
    83  	blockExpiryDelta uint32
    84  
    85  	// currentHeight is the current block height.
    86  	currentHeight uint32
    87  
    88  	// currentHash is the block hash for our current height.
    89  	currentHash *chainhash.Hash
    90  
    91  	// cancelInvoice is a template method that cancels an expired invoice.
    92  	cancelInvoice func(lntypes.Hash, bool) error
    93  
    94  	// timestampExpiryQueue holds invoiceExpiry items and is used to find
    95  	// the next invoice to expire.
    96  	timestampExpiryQueue queue.PriorityQueue
    97  
    98  	// blockExpiryQueue holds blockExpiry items and is used to find the
    99  	// next invoice to expire based on block height. Only hold invoices
   100  	// with active htlcs are added to this queue, because they require
   101  	// manual cancellation when the hltc is going to time out. Items in
   102  	// this queue may already be in the timestampExpiryQueue, this is ok
   103  	// because they will not be expired based on timestamp if they have
   104  	// active htlcs.
   105  	blockExpiryQueue queue.PriorityQueue
   106  
   107  	// newInvoices channel is used to wake up the main loop when a new
   108  	// invoices is added.
   109  	newInvoices chan []invoiceExpiry
   110  
   111  	wg sync.WaitGroup
   112  
   113  	// quit signals InvoiceExpiryWatcher to stop.
   114  	quit chan struct{}
   115  }
   116  
   117  // NewInvoiceExpiryWatcher creates a new InvoiceExpiryWatcher instance.
   118  func NewInvoiceExpiryWatcher(clock clock.Clock,
   119  	expiryDelta, startHeight uint32, startHash *chainhash.Hash,
   120  	notifier chainntnfs.ChainNotifier) *InvoiceExpiryWatcher {
   121  
   122  	return &InvoiceExpiryWatcher{
   123  		clock:            clock,
   124  		notifier:         notifier,
   125  		blockExpiryDelta: expiryDelta,
   126  		currentHeight:    startHeight,
   127  		currentHash:      startHash,
   128  		newInvoices:      make(chan []invoiceExpiry),
   129  		quit:             make(chan struct{}),
   130  	}
   131  }
   132  
   133  // Start starts the the subscription handler and the main loop. Start() will
   134  // return with error if InvoiceExpiryWatcher is already started. Start()
   135  // expects a cancellation function passed that will be use to cancel expired
   136  // invoices by their payment hash.
   137  func (ew *InvoiceExpiryWatcher) Start(
   138  	cancelInvoice func(lntypes.Hash, bool) error) error {
   139  
   140  	ew.Lock()
   141  	defer ew.Unlock()
   142  
   143  	if ew.started {
   144  		return fmt.Errorf("InvoiceExpiryWatcher already started")
   145  	}
   146  
   147  	ew.started = true
   148  	ew.cancelInvoice = cancelInvoice
   149  
   150  	ntfn, err := ew.notifier.RegisterBlockEpochNtfn(&chainntnfs.BlockEpoch{
   151  		Height: int32(ew.currentHeight),
   152  		Hash:   ew.currentHash,
   153  	})
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	ew.wg.Add(1)
   159  	go ew.mainLoop(ntfn)
   160  
   161  	return nil
   162  }
   163  
   164  // Stop quits the expiry handler loop and waits for InvoiceExpiryWatcher to
   165  // fully stop.
   166  func (ew *InvoiceExpiryWatcher) Stop() {
   167  	ew.Lock()
   168  	defer ew.Unlock()
   169  
   170  	if ew.started {
   171  		// Signal subscriptionHandler to quit and wait for it to return.
   172  		close(ew.quit)
   173  		ew.wg.Wait()
   174  		ew.started = false
   175  	}
   176  }
   177  
   178  // makeInvoiceExpiry checks if the passed invoice may be canceled and calculates
   179  // the expiry time and creates a slimmer invoiceExpiry implementation.
   180  func makeInvoiceExpiry(paymentHash lntypes.Hash,
   181  	invoice *channeldb.Invoice) invoiceExpiry {
   182  
   183  	switch invoice.State {
   184  	// If we have an open invoice with no htlcs, we want to expire the
   185  	// invoice based on timestamp
   186  	case channeldb.ContractOpen:
   187  		return makeTimestampExpiry(paymentHash, invoice)
   188  
   189  	// If an invoice has active htlcs, we want to expire it based on block
   190  	// height. We only do this for hodl invoices, since regular invoices
   191  	// should resolve themselves automatically.
   192  	case channeldb.ContractAccepted:
   193  		if !invoice.HodlInvoice {
   194  			log.Debugf("Invoice in accepted state not added to "+
   195  				"expiry watcher: %v", paymentHash)
   196  
   197  			return nil
   198  		}
   199  
   200  		var minHeight uint32
   201  		for _, htlc := range invoice.Htlcs {
   202  			// We only care about accepted htlcs, since they will
   203  			// trigger force-closes.
   204  			if htlc.State != channeldb.HtlcStateAccepted {
   205  				continue
   206  			}
   207  
   208  			if minHeight == 0 || htlc.Expiry < minHeight {
   209  				minHeight = htlc.Expiry
   210  			}
   211  		}
   212  
   213  		return makeHeightExpiry(paymentHash, minHeight)
   214  
   215  	default:
   216  		log.Debugf("Invoice not added to expiry watcher: %v",
   217  			paymentHash)
   218  
   219  		return nil
   220  	}
   221  }
   222  
   223  // makeTimestampExpiry creates a timestamp-based expiry entry.
   224  func makeTimestampExpiry(paymentHash lntypes.Hash,
   225  	invoice *channeldb.Invoice) *invoiceExpiryTs {
   226  
   227  	if invoice.State != channeldb.ContractOpen {
   228  		return nil
   229  	}
   230  
   231  	realExpiry := invoice.Terms.Expiry
   232  	if realExpiry == 0 {
   233  		realExpiry = zpay32.DefaultInvoiceExpiry
   234  	}
   235  
   236  	expiry := invoice.CreationDate.Add(realExpiry)
   237  	return &invoiceExpiryTs{
   238  		PaymentHash: paymentHash,
   239  		Expiry:      expiry,
   240  		Keysend:     len(invoice.PaymentRequest) == 0,
   241  	}
   242  }
   243  
   244  // makeHeightExpiry creates height-based expiry for an invoice based on its
   245  // lowest htlc expiry height.
   246  func makeHeightExpiry(paymentHash lntypes.Hash,
   247  	minHeight uint32) *invoiceExpiryHeight {
   248  
   249  	if minHeight == 0 {
   250  		log.Warnf("make height expiry called with 0 height")
   251  		return nil
   252  	}
   253  
   254  	return &invoiceExpiryHeight{
   255  		paymentHash:  paymentHash,
   256  		expiryHeight: minHeight,
   257  	}
   258  }
   259  
   260  // AddInvoices adds invoices to the InvoiceExpiryWatcher.
   261  func (ew *InvoiceExpiryWatcher) AddInvoices(invoices ...invoiceExpiry) {
   262  	if len(invoices) == 0 {
   263  		return
   264  	}
   265  
   266  	select {
   267  	case ew.newInvoices <- invoices:
   268  		log.Debugf("Added %d invoices to the expiry watcher",
   269  			len(invoices))
   270  
   271  	// Select on quit too so that callers won't get blocked in case
   272  	// of concurrent shutdown.
   273  	case <-ew.quit:
   274  	}
   275  }
   276  
   277  // nextTimestampExpiry returns a Time chan to wait on until the next invoice
   278  // expires. If there are no active invoices, then it'll simply wait
   279  // indefinitely.
   280  func (ew *InvoiceExpiryWatcher) nextTimestampExpiry() <-chan time.Time {
   281  	if !ew.timestampExpiryQueue.Empty() {
   282  		top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs)
   283  		return ew.clock.TickAfter(top.Expiry.Sub(ew.clock.Now()))
   284  	}
   285  
   286  	return nil
   287  }
   288  
   289  // nextHeightExpiry returns a channel that will immediately be read from if
   290  // the top item on our queue has expired.
   291  func (ew *InvoiceExpiryWatcher) nextHeightExpiry() <-chan uint32 {
   292  	if ew.blockExpiryQueue.Empty() {
   293  		return nil
   294  	}
   295  
   296  	top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight)
   297  	if !top.expired(ew.currentHeight, ew.blockExpiryDelta) {
   298  		return nil
   299  	}
   300  
   301  	blockChan := make(chan uint32, 1)
   302  	blockChan <- top.expiryHeight
   303  	return blockChan
   304  }
   305  
   306  // cancelNextExpiredInvoice will cancel the next expired invoice and removes
   307  // it from the expiry queue.
   308  func (ew *InvoiceExpiryWatcher) cancelNextExpiredInvoice() {
   309  	if !ew.timestampExpiryQueue.Empty() {
   310  		top := ew.timestampExpiryQueue.Top().(*invoiceExpiryTs)
   311  		if !top.Expiry.Before(ew.clock.Now()) {
   312  			return
   313  		}
   314  
   315  		// Don't force-cancel already accepted invoices. An exception to
   316  		// this are auto-generated keysend invoices. Because those move
   317  		// to the Accepted state directly after being opened, the expiry
   318  		// field would never be used. Enabling cancellation for accepted
   319  		// keysend invoices creates a safety mechanism that can prevents
   320  		// channel force-closes.
   321  		ew.expireInvoice(top.PaymentHash, top.Keysend)
   322  		ew.timestampExpiryQueue.Pop()
   323  	}
   324  }
   325  
   326  // cancelNextHeightExpiredInvoice looks at our height based queue and expires
   327  // the next invoice if we have reached its expiry block.
   328  func (ew *InvoiceExpiryWatcher) cancelNextHeightExpiredInvoice() {
   329  	if ew.blockExpiryQueue.Empty() {
   330  		return
   331  	}
   332  
   333  	top := ew.blockExpiryQueue.Top().(*invoiceExpiryHeight)
   334  	if !top.expired(ew.currentHeight, ew.blockExpiryDelta) {
   335  		return
   336  	}
   337  
   338  	// We always force-cancel block-based expiry so that we can
   339  	// cancel invoices that have been accepted but not yet resolved.
   340  	// This helps us avoid force closes.
   341  	ew.expireInvoice(top.paymentHash, true)
   342  	ew.blockExpiryQueue.Pop()
   343  }
   344  
   345  // expireInvoice attempts to expire an invoice and logs an error if we get an
   346  // unexpected error.
   347  func (ew *InvoiceExpiryWatcher) expireInvoice(hash lntypes.Hash, force bool) {
   348  	err := ew.cancelInvoice(hash, force)
   349  	switch err {
   350  	case nil:
   351  
   352  	case channeldb.ErrInvoiceAlreadyCanceled:
   353  
   354  	case channeldb.ErrInvoiceAlreadySettled:
   355  
   356  	case channeldb.ErrInvoiceNotFound:
   357  		// It's possible that the user has manually canceled the invoice
   358  		// which will then be deleted by the garbage collector resulting
   359  		// in an ErrInvoiceNotFound error.
   360  
   361  	default:
   362  		log.Errorf("Unable to cancel invoice: %v: %v", hash, err)
   363  	}
   364  }
   365  
   366  // pushInvoices adds invoices to be expired to their relevant queue.
   367  func (ew *InvoiceExpiryWatcher) pushInvoices(invoices []invoiceExpiry) {
   368  	for _, inv := range invoices {
   369  		// Switch on the type of entry we have. We need to check nil
   370  		// on the implementation of the interface because the interface
   371  		// itself is non-nil.
   372  		switch expiry := inv.(type) {
   373  		case *invoiceExpiryTs:
   374  			if expiry != nil {
   375  				ew.timestampExpiryQueue.Push(expiry)
   376  			}
   377  
   378  		case *invoiceExpiryHeight:
   379  			if expiry != nil {
   380  				ew.blockExpiryQueue.Push(expiry)
   381  			}
   382  
   383  		default:
   384  			log.Errorf("unexpected queue item: %T", inv)
   385  		}
   386  	}
   387  }
   388  
   389  // mainLoop is a goroutine that receives new invoices and handles cancellation
   390  // of expired invoices.
   391  func (ew *InvoiceExpiryWatcher) mainLoop(blockNtfns *chainntnfs.BlockEpochEvent) {
   392  	defer func() {
   393  		blockNtfns.Cancel()
   394  		ew.wg.Done()
   395  	}()
   396  
   397  	// We have two different queues, so we use a different cancel method
   398  	// depending on which expiry condition we have hit. Starting with time
   399  	// based expiry is an arbitrary choice to start off.
   400  	cancelNext := ew.cancelNextExpiredInvoice
   401  
   402  	for {
   403  		// Cancel any invoices that may have expired.
   404  		cancelNext()
   405  
   406  		select {
   407  
   408  		case newInvoices := <-ew.newInvoices:
   409  			// Take newly forwarded invoices with higher priority
   410  			// in order to not block the newInvoices channel.
   411  			ew.pushInvoices(newInvoices)
   412  			continue
   413  
   414  		default:
   415  			select {
   416  
   417  			// Wait until the next invoice expires.
   418  			case <-ew.nextTimestampExpiry():
   419  				cancelNext = ew.cancelNextExpiredInvoice
   420  				continue
   421  
   422  			case <-ew.nextHeightExpiry():
   423  				cancelNext = ew.cancelNextHeightExpiredInvoice
   424  				continue
   425  
   426  			case newInvoices := <-ew.newInvoices:
   427  				ew.pushInvoices(newInvoices)
   428  
   429  			// Consume new blocks.
   430  			case block, ok := <-blockNtfns.Epochs:
   431  				if !ok {
   432  					log.Debugf("block notifications " +
   433  						"canceled")
   434  					return
   435  				}
   436  
   437  				ew.currentHeight = uint32(block.Height)
   438  				ew.currentHash = block.Hash
   439  
   440  			case <-ew.quit:
   441  				return
   442  			}
   443  		}
   444  	}
   445  }