github.com/decred/dcrlnd@v0.7.6/routing/shards/shard_tracker.go (about) 1 package shards 2 3 import ( 4 "fmt" 5 "sync" 6 7 "github.com/decred/dcrlnd/lntypes" 8 "github.com/decred/dcrlnd/record" 9 ) 10 11 // PaymentShard is an interface representing a shard tracked by the 12 // ShardTracker. It contains options that are specific to the given shard that 13 // might differ from the overall payment. 14 type PaymentShard interface { 15 // Hash returns the hash used for the HTLC representing this shard. 16 Hash() lntypes.Hash 17 18 // MPP returns any extra MPP records that should be set for the final 19 // hop on the route used by this shard. 20 MPP() *record.MPP 21 22 // AMP returns any extra AMP records that should be set for the final 23 // hop on the route used by this shard. 24 AMP() *record.AMP 25 } 26 27 // ShardTracker is an interfae representing a tracker that keeps track of the 28 // inflight shards of a payment, and is able to assign new shards the correct 29 // options such as hash and extra records. 30 type ShardTracker interface { 31 // NewShard registers a new attempt with the ShardTracker and returns a 32 // new shard representing this attempt. This attempt's shard should be 33 // canceled if it ends up not being used by the overall payment, i.e. 34 // if the attempt fails. 35 NewShard(uint64, bool) (PaymentShard, error) 36 37 // CancelShard cancel's the shard corresponding to the given attempt 38 // ID. This lets the ShardTracker free up any slots used by this shard, 39 // and in case of AMP payments return the share used by this shard to 40 // the root share. 41 CancelShard(uint64) error 42 43 // GetHash retrieves the hash used by the shard of the given attempt 44 // ID. This wil return an error if the attempt ID is unknown. 45 GetHash(uint64) (lntypes.Hash, error) 46 } 47 48 // Shard is a struct used for simple shards where we obly need to keep map it 49 // to a single hash. 50 type Shard struct { 51 hash lntypes.Hash 52 } 53 54 // Hash returns the hash used for the HTLC representing this shard. 55 func (s *Shard) Hash() lntypes.Hash { 56 return s.hash 57 } 58 59 // MPP returns any extra MPP records that should be set for the final hop on 60 // the route used by this shard. 61 func (s *Shard) MPP() *record.MPP { 62 return nil 63 } 64 65 // AMP returns any extra AMP records that should be set for the final hop on 66 // the route used by this shard. 67 func (s *Shard) AMP() *record.AMP { 68 return nil 69 } 70 71 // SimpleShardTracker is an implementation of the ShardTracker interface that 72 // simply maps attempt IDs to hashes. New shards will be given a static payment 73 // hash. This should be used for regular and MPP payments, in addition to 74 // resumed payments where all the attempt's hashes have already been created. 75 type SimpleShardTracker struct { 76 hash lntypes.Hash 77 shards map[uint64]lntypes.Hash 78 sync.Mutex 79 } 80 81 // A compile time check to ensure SimpleShardTracker implements the 82 // ShardTracker interface. 83 var _ ShardTracker = (*SimpleShardTracker)(nil) 84 85 // NewSimpleShardTracker creates a new intance of the SimpleShardTracker with 86 // the given payment hash and existing attempts. 87 func NewSimpleShardTracker(paymentHash lntypes.Hash, 88 shards map[uint64]lntypes.Hash) ShardTracker { 89 90 if shards == nil { 91 shards = make(map[uint64]lntypes.Hash) 92 } 93 94 return &SimpleShardTracker{ 95 hash: paymentHash, 96 shards: shards, 97 } 98 } 99 100 // NewShard registers a new attempt with the ShardTracker and returns a 101 // new shard representing this attempt. This attempt's shard should be canceled 102 // if it ends up not being used by the overall payment, i.e. if the attempt 103 // fails. 104 func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) { 105 m.Lock() 106 m.shards[id] = m.hash 107 m.Unlock() 108 109 return &Shard{ 110 hash: m.hash, 111 }, nil 112 } 113 114 // CancelShard cancel's the shard corresponding to the given attempt ID. 115 func (m *SimpleShardTracker) CancelShard(id uint64) error { 116 m.Lock() 117 delete(m.shards, id) 118 m.Unlock() 119 120 return nil 121 } 122 123 // GetHash retrieves the hash used by the shard of the given attempt ID. This 124 // will return an error if the attempt ID is unknown. 125 func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) { 126 m.Lock() 127 hash, ok := m.shards[id] 128 m.Unlock() 129 if !ok { 130 return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+ 131 "not found", id) 132 } 133 134 return hash, nil 135 }