github.com/decred/dcrlnd@v0.7.6/routing/shards/shard_tracker.go (about)

     1  package shards
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/decred/dcrlnd/lntypes"
     8  	"github.com/decred/dcrlnd/record"
     9  )
    10  
    11  // PaymentShard is an interface representing a shard tracked by the
    12  // ShardTracker. It contains options that are specific to the given shard that
    13  // might differ from the overall payment.
    14  type PaymentShard interface {
    15  	// Hash returns the hash used for the HTLC representing this shard.
    16  	Hash() lntypes.Hash
    17  
    18  	// MPP returns any extra MPP records that should be set for the final
    19  	// hop on the route used by this shard.
    20  	MPP() *record.MPP
    21  
    22  	// AMP returns any extra AMP records that should be set for the final
    23  	// hop on the route used by this shard.
    24  	AMP() *record.AMP
    25  }
    26  
    27  // ShardTracker is an interfae representing a tracker that keeps track of the
    28  // inflight shards of a payment, and is able to assign new shards the correct
    29  // options such as hash and extra records.
    30  type ShardTracker interface {
    31  	// NewShard registers a new attempt with the ShardTracker and returns a
    32  	// new shard representing this attempt. This attempt's shard should be
    33  	// canceled if it ends up not being used by the overall payment, i.e.
    34  	// if the attempt fails.
    35  	NewShard(uint64, bool) (PaymentShard, error)
    36  
    37  	// CancelShard cancel's the shard corresponding to the given attempt
    38  	// ID. This lets the ShardTracker free up any slots used by this shard,
    39  	// and in case of AMP payments return the share used by this shard to
    40  	// the root share.
    41  	CancelShard(uint64) error
    42  
    43  	// GetHash retrieves the hash used by the shard of the given attempt
    44  	// ID. This wil return an error if the attempt ID is unknown.
    45  	GetHash(uint64) (lntypes.Hash, error)
    46  }
    47  
    48  // Shard is a struct used for simple shards where we obly need to keep map it
    49  // to a single hash.
    50  type Shard struct {
    51  	hash lntypes.Hash
    52  }
    53  
    54  // Hash returns the hash used for the HTLC representing this shard.
    55  func (s *Shard) Hash() lntypes.Hash {
    56  	return s.hash
    57  }
    58  
    59  // MPP returns any extra MPP records that should be set for the final hop on
    60  // the route used by this shard.
    61  func (s *Shard) MPP() *record.MPP {
    62  	return nil
    63  }
    64  
    65  // AMP returns any extra AMP records that should be set for the final hop on
    66  // the route used by this shard.
    67  func (s *Shard) AMP() *record.AMP {
    68  	return nil
    69  }
    70  
    71  // SimpleShardTracker is an implementation of the ShardTracker interface that
    72  // simply maps attempt IDs to hashes. New shards will be given a static payment
    73  // hash. This should be used for regular and MPP payments, in addition to
    74  // resumed payments where all the attempt's hashes have already been created.
    75  type SimpleShardTracker struct {
    76  	hash   lntypes.Hash
    77  	shards map[uint64]lntypes.Hash
    78  	sync.Mutex
    79  }
    80  
    81  // A compile time check to ensure SimpleShardTracker implements the
    82  // ShardTracker interface.
    83  var _ ShardTracker = (*SimpleShardTracker)(nil)
    84  
    85  // NewSimpleShardTracker creates a new intance of the SimpleShardTracker with
    86  // the given payment hash and existing attempts.
    87  func NewSimpleShardTracker(paymentHash lntypes.Hash,
    88  	shards map[uint64]lntypes.Hash) ShardTracker {
    89  
    90  	if shards == nil {
    91  		shards = make(map[uint64]lntypes.Hash)
    92  	}
    93  
    94  	return &SimpleShardTracker{
    95  		hash:   paymentHash,
    96  		shards: shards,
    97  	}
    98  }
    99  
   100  // NewShard registers a new attempt with the ShardTracker and returns a
   101  // new shard representing this attempt. This attempt's shard should be canceled
   102  // if it ends up not being used by the overall payment, i.e. if the attempt
   103  // fails.
   104  func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
   105  	m.Lock()
   106  	m.shards[id] = m.hash
   107  	m.Unlock()
   108  
   109  	return &Shard{
   110  		hash: m.hash,
   111  	}, nil
   112  }
   113  
   114  // CancelShard cancel's the shard corresponding to the given attempt ID.
   115  func (m *SimpleShardTracker) CancelShard(id uint64) error {
   116  	m.Lock()
   117  	delete(m.shards, id)
   118  	m.Unlock()
   119  
   120  	return nil
   121  }
   122  
   123  // GetHash retrieves the hash used by the shard of the given attempt ID. This
   124  // will return an error if the attempt ID is unknown.
   125  func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
   126  	m.Lock()
   127  	hash, ok := m.shards[id]
   128  	m.Unlock()
   129  	if !ok {
   130  		return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
   131  			"not found", id)
   132  	}
   133  
   134  	return hash, nil
   135  }