github.com/decred/dcrlnd@v0.7.6/amp/shard_tracker_test.go (about)

     1  package amp_test
     2  
     3  import (
     4  	"crypto/rand"
     5  	"testing"
     6  
     7  	"github.com/decred/dcrlnd/amp"
     8  	"github.com/decred/dcrlnd/lnwire"
     9  	"github.com/decred/dcrlnd/routing/shards"
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  // TestAMPShardTracker tests that we can derive and cancel shards at will using
    14  // the AMP shard tracker.
    15  func TestAMPShardTracker(t *testing.T) {
    16  	var root, setID, payAddr [32]byte
    17  	_, err := rand.Read(root[:])
    18  	require.NoError(t, err)
    19  
    20  	_, err = rand.Read(setID[:])
    21  	require.NoError(t, err)
    22  
    23  	_, err = rand.Read(payAddr[:])
    24  	require.NoError(t, err)
    25  
    26  	var totalAmt lnwire.MilliAtom = 1000
    27  
    28  	// Create an AMP shard tracker using the random data we just generated.
    29  	tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt)
    30  
    31  	// Trying to retrieve a hash for id 0 should result in an error.
    32  	_, err = tracker.GetHash(0)
    33  	require.Error(t, err)
    34  
    35  	// We start by creating 20 shards.
    36  	const numShards = 20
    37  
    38  	var shards []shards.PaymentShard
    39  	for i := uint64(0); i < numShards; i++ {
    40  		s, err := tracker.NewShard(i, i == numShards-1)
    41  		require.NoError(t, err)
    42  
    43  		// Check that the shards have their payloads set as expected.
    44  		require.Equal(t, setID, s.AMP().SetID())
    45  		require.Equal(t, totalAmt, s.MPP().TotalMAtoms())
    46  		require.Equal(t, payAddr, s.MPP().PaymentAddr())
    47  
    48  		shards = append(shards, s)
    49  	}
    50  
    51  	// Make sure we can retrieve the hash for all of them.
    52  	for i := uint64(0); i < numShards; i++ {
    53  		hash, err := tracker.GetHash(i)
    54  		require.NoError(t, err)
    55  		require.Equal(t, shards[i].Hash(), hash)
    56  	}
    57  
    58  	// Now cancel half of the shards.
    59  	j := 0
    60  	for i := uint64(0); i < numShards; i++ {
    61  		if i%2 == 0 {
    62  			err := tracker.CancelShard(i)
    63  			require.NoError(t, err)
    64  			continue
    65  		}
    66  
    67  		// Keep shard.
    68  		shards[j] = shards[i]
    69  		j++
    70  	}
    71  	shards = shards[:j]
    72  
    73  	// Get a new last shard.
    74  	s, err := tracker.NewShard(numShards, true)
    75  	require.NoError(t, err)
    76  	shards = append(shards, s)
    77  
    78  	// Finally make sure these shards together can be used to reconstruct
    79  	// the children.
    80  	childDescs := make([]amp.ChildDesc, len(shards))
    81  	for i, s := range shards {
    82  		childDescs[i] = amp.ChildDesc{
    83  			Share: s.AMP().RootShare(),
    84  			Index: s.AMP().ChildIndex(),
    85  		}
    86  	}
    87  
    88  	// Using the child descriptors, reconstruct the children.
    89  	children := amp.ReconstructChildren(childDescs...)
    90  
    91  	// Validate that the derived child preimages match the hash of each shard.
    92  	for i, child := range children {
    93  		require.Equal(t, shards[i].Hash(), child.Hash)
    94  	}
    95  }