github.com/decred/dcrlnd@v0.7.6/amp/shard_tracker.go (about)

     1  package amp
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"github.com/decred/dcrlnd/lntypes"
    10  	"github.com/decred/dcrlnd/lnwire"
    11  	"github.com/decred/dcrlnd/record"
    12  	"github.com/decred/dcrlnd/routing/shards"
    13  )
    14  
    15  // Shard is an implementation of the shards.PaymentShards interface specific
    16  // to AMP payments.
    17  type Shard struct {
    18  	child *Child
    19  	mpp   *record.MPP
    20  	amp   *record.AMP
    21  }
    22  
    23  // A compile time check to ensure Shard implements the shards.PaymentShard
    24  // interface.
    25  var _ shards.PaymentShard = (*Shard)(nil)
    26  
    27  // Hash returns the hash used for the HTLC representing this AMP shard.
    28  func (s *Shard) Hash() lntypes.Hash {
    29  	return s.child.Hash
    30  }
    31  
    32  // MPP returns any extra MPP records that should be set for the final hop on
    33  // the route used by this shard.
    34  func (s *Shard) MPP() *record.MPP {
    35  	return s.mpp
    36  }
    37  
    38  // AMP returns any extra AMP records that should be set for the final hop on
    39  // the route used by this shard.
    40  func (s *Shard) AMP() *record.AMP {
    41  	return s.amp
    42  }
    43  
    44  // ShardTracker is an implementation of the shards.ShardTracker interface
    45  // that is able to generate payment shards according to the AMP splitting
    46  // algorithm. It can be used to generate new hashes to use for HTLCs, and also
    47  // cancel shares used for failed payment shards.
    48  type ShardTracker struct {
    49  	setID       [32]byte
    50  	paymentAddr [32]byte
    51  	totalAmt    lnwire.MilliAtom
    52  
    53  	sharer Sharer
    54  
    55  	shards map[uint64]*Child
    56  	sync.Mutex
    57  }
    58  
    59  // A compile time check to ensure ShardTracker implements the
    60  // shards.ShardTracker interface.
    61  var _ shards.ShardTracker = (*ShardTracker)(nil)
    62  
    63  // NewShardTracker creates a new shard tracker to use for AMP payments. The
    64  // root shard, setID, payment address and total amount must be correctly set in
    65  // order for the TLV options to include with each shard to be created
    66  // correctly.
    67  func NewShardTracker(root, setID, payAddr [32]byte,
    68  	totalAmt lnwire.MilliAtom) *ShardTracker {
    69  
    70  	// Create a new seed sharer from this root.
    71  	rootShare := Share(root)
    72  	rootSharer := SeedSharerFromRoot(&rootShare)
    73  
    74  	return &ShardTracker{
    75  		setID:       setID,
    76  		paymentAddr: payAddr,
    77  		totalAmt:    totalAmt,
    78  		sharer:      rootSharer,
    79  		shards:      make(map[uint64]*Child),
    80  	}
    81  }
    82  
    83  // NewShard registers a new attempt with the ShardTracker and returns a
    84  // new shard representing this attempt. This attempt's shard should be canceled
    85  // if it ends up not being used by the overall payment, i.e. if the attempt
    86  // fails.
    87  func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
    88  	error) {
    89  
    90  	s.Lock()
    91  	defer s.Unlock()
    92  
    93  	// Use a random child index.
    94  	var childIndex [4]byte
    95  	if _, err := rand.Read(childIndex[:]); err != nil {
    96  		return nil, err
    97  	}
    98  	idx := binary.BigEndian.Uint32(childIndex[:])
    99  
   100  	// Depending on whether we are requesting the last shard or not, either
   101  	// split the current share into two, or get a Child directly from the
   102  	// current sharer.
   103  	var child *Child
   104  	if last {
   105  		child = s.sharer.Child(idx)
   106  
   107  		// If this was the last shard, set the current share to the
   108  		// zero share to indicate we cannot split it further.
   109  		s.sharer = s.sharer.Zero()
   110  	} else {
   111  		left, sharer, err := s.sharer.Split()
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  
   116  		s.sharer = sharer
   117  		child = left.Child(idx)
   118  	}
   119  
   120  	// Track the new child and return the shard.
   121  	s.shards[pid] = child
   122  
   123  	mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
   124  	amp := record.NewAMP(
   125  		child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
   126  	)
   127  
   128  	return &Shard{
   129  		child: child,
   130  		mpp:   mpp,
   131  		amp:   amp,
   132  	}, nil
   133  }
   134  
   135  // CancelShard cancel's the shard corresponding to the given attempt ID.
   136  func (s *ShardTracker) CancelShard(pid uint64) error {
   137  	s.Lock()
   138  	defer s.Unlock()
   139  
   140  	c, ok := s.shards[pid]
   141  	if !ok {
   142  		return fmt.Errorf("pid not found")
   143  	}
   144  	delete(s.shards, pid)
   145  
   146  	// Now that we are canceling this shard, we XOR the share back into our
   147  	// current share.
   148  	s.sharer = s.sharer.Merge(c)
   149  	return nil
   150  }
   151  
   152  // GetHash retrieves the hash used by the shard of the given attempt ID. This
   153  // will return an error if the attempt ID is unknown.
   154  func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
   155  	s.Lock()
   156  	defer s.Unlock()
   157  
   158  	c, ok := s.shards[pid]
   159  	if !ok {
   160  		return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
   161  			"not found", pid)
   162  	}
   163  
   164  	return c.Hash, nil
   165  }