github.com/decred/dcrlnd@v0.7.6/amp/shard_tracker.go (about) 1 package amp 2 3 import ( 4 "crypto/rand" 5 "encoding/binary" 6 "fmt" 7 "sync" 8 9 "github.com/decred/dcrlnd/lntypes" 10 "github.com/decred/dcrlnd/lnwire" 11 "github.com/decred/dcrlnd/record" 12 "github.com/decred/dcrlnd/routing/shards" 13 ) 14 15 // Shard is an implementation of the shards.PaymentShards interface specific 16 // to AMP payments. 17 type Shard struct { 18 child *Child 19 mpp *record.MPP 20 amp *record.AMP 21 } 22 23 // A compile time check to ensure Shard implements the shards.PaymentShard 24 // interface. 25 var _ shards.PaymentShard = (*Shard)(nil) 26 27 // Hash returns the hash used for the HTLC representing this AMP shard. 28 func (s *Shard) Hash() lntypes.Hash { 29 return s.child.Hash 30 } 31 32 // MPP returns any extra MPP records that should be set for the final hop on 33 // the route used by this shard. 34 func (s *Shard) MPP() *record.MPP { 35 return s.mpp 36 } 37 38 // AMP returns any extra AMP records that should be set for the final hop on 39 // the route used by this shard. 40 func (s *Shard) AMP() *record.AMP { 41 return s.amp 42 } 43 44 // ShardTracker is an implementation of the shards.ShardTracker interface 45 // that is able to generate payment shards according to the AMP splitting 46 // algorithm. It can be used to generate new hashes to use for HTLCs, and also 47 // cancel shares used for failed payment shards. 48 type ShardTracker struct { 49 setID [32]byte 50 paymentAddr [32]byte 51 totalAmt lnwire.MilliAtom 52 53 sharer Sharer 54 55 shards map[uint64]*Child 56 sync.Mutex 57 } 58 59 // A compile time check to ensure ShardTracker implements the 60 // shards.ShardTracker interface. 61 var _ shards.ShardTracker = (*ShardTracker)(nil) 62 63 // NewShardTracker creates a new shard tracker to use for AMP payments. The 64 // root shard, setID, payment address and total amount must be correctly set in 65 // order for the TLV options to include with each shard to be created 66 // correctly. 67 func NewShardTracker(root, setID, payAddr [32]byte, 68 totalAmt lnwire.MilliAtom) *ShardTracker { 69 70 // Create a new seed sharer from this root. 71 rootShare := Share(root) 72 rootSharer := SeedSharerFromRoot(&rootShare) 73 74 return &ShardTracker{ 75 setID: setID, 76 paymentAddr: payAddr, 77 totalAmt: totalAmt, 78 sharer: rootSharer, 79 shards: make(map[uint64]*Child), 80 } 81 } 82 83 // NewShard registers a new attempt with the ShardTracker and returns a 84 // new shard representing this attempt. This attempt's shard should be canceled 85 // if it ends up not being used by the overall payment, i.e. if the attempt 86 // fails. 87 func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, 88 error) { 89 90 s.Lock() 91 defer s.Unlock() 92 93 // Use a random child index. 94 var childIndex [4]byte 95 if _, err := rand.Read(childIndex[:]); err != nil { 96 return nil, err 97 } 98 idx := binary.BigEndian.Uint32(childIndex[:]) 99 100 // Depending on whether we are requesting the last shard or not, either 101 // split the current share into two, or get a Child directly from the 102 // current sharer. 103 var child *Child 104 if last { 105 child = s.sharer.Child(idx) 106 107 // If this was the last shard, set the current share to the 108 // zero share to indicate we cannot split it further. 109 s.sharer = s.sharer.Zero() 110 } else { 111 left, sharer, err := s.sharer.Split() 112 if err != nil { 113 return nil, err 114 } 115 116 s.sharer = sharer 117 child = left.Child(idx) 118 } 119 120 // Track the new child and return the shard. 121 s.shards[pid] = child 122 123 mpp := record.NewMPP(s.totalAmt, s.paymentAddr) 124 amp := record.NewAMP( 125 child.ChildDesc.Share, s.setID, child.ChildDesc.Index, 126 ) 127 128 return &Shard{ 129 child: child, 130 mpp: mpp, 131 amp: amp, 132 }, nil 133 } 134 135 // CancelShard cancel's the shard corresponding to the given attempt ID. 136 func (s *ShardTracker) CancelShard(pid uint64) error { 137 s.Lock() 138 defer s.Unlock() 139 140 c, ok := s.shards[pid] 141 if !ok { 142 return fmt.Errorf("pid not found") 143 } 144 delete(s.shards, pid) 145 146 // Now that we are canceling this shard, we XOR the share back into our 147 // current share. 148 s.sharer = s.sharer.Merge(c) 149 return nil 150 } 151 152 // GetHash retrieves the hash used by the shard of the given attempt ID. This 153 // will return an error if the attempt ID is unknown. 154 func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { 155 s.Lock() 156 defer s.Unlock() 157 158 c, ok := s.shards[pid] 159 if !ok { 160 return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ 161 "not found", pid) 162 } 163 164 return c.Hash, nil 165 }