github.com/decred/dcrlnd@v0.7.6/amp/shard_tracker_test.go (about) 1 package amp_test 2 3 import ( 4 "crypto/rand" 5 "testing" 6 7 "github.com/decred/dcrlnd/amp" 8 "github.com/decred/dcrlnd/lnwire" 9 "github.com/decred/dcrlnd/routing/shards" 10 "github.com/stretchr/testify/require" 11 ) 12 13 // TestAMPShardTracker tests that we can derive and cancel shards at will using 14 // the AMP shard tracker. 15 func TestAMPShardTracker(t *testing.T) { 16 var root, setID, payAddr [32]byte 17 _, err := rand.Read(root[:]) 18 require.NoError(t, err) 19 20 _, err = rand.Read(setID[:]) 21 require.NoError(t, err) 22 23 _, err = rand.Read(payAddr[:]) 24 require.NoError(t, err) 25 26 var totalAmt lnwire.MilliAtom = 1000 27 28 // Create an AMP shard tracker using the random data we just generated. 29 tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) 30 31 // Trying to retrieve a hash for id 0 should result in an error. 32 _, err = tracker.GetHash(0) 33 require.Error(t, err) 34 35 // We start by creating 20 shards. 36 const numShards = 20 37 38 var shards []shards.PaymentShard 39 for i := uint64(0); i < numShards; i++ { 40 s, err := tracker.NewShard(i, i == numShards-1) 41 require.NoError(t, err) 42 43 // Check that the shards have their payloads set as expected. 44 require.Equal(t, setID, s.AMP().SetID()) 45 require.Equal(t, totalAmt, s.MPP().TotalMAtoms()) 46 require.Equal(t, payAddr, s.MPP().PaymentAddr()) 47 48 shards = append(shards, s) 49 } 50 51 // Make sure we can retrieve the hash for all of them. 52 for i := uint64(0); i < numShards; i++ { 53 hash, err := tracker.GetHash(i) 54 require.NoError(t, err) 55 require.Equal(t, shards[i].Hash(), hash) 56 } 57 58 // Now cancel half of the shards. 59 j := 0 60 for i := uint64(0); i < numShards; i++ { 61 if i%2 == 0 { 62 err := tracker.CancelShard(i) 63 require.NoError(t, err) 64 continue 65 } 66 67 // Keep shard. 68 shards[j] = shards[i] 69 j++ 70 } 71 shards = shards[:j] 72 73 // Get a new last shard. 74 s, err := tracker.NewShard(numShards, true) 75 require.NoError(t, err) 76 shards = append(shards, s) 77 78 // Finally make sure these shards together can be used to reconstruct 79 // the children. 80 childDescs := make([]amp.ChildDesc, len(shards)) 81 for i, s := range shards { 82 childDescs[i] = amp.ChildDesc{ 83 Share: s.AMP().RootShare(), 84 Index: s.AMP().ChildIndex(), 85 } 86 } 87 88 // Using the child descriptors, reconstruct the children. 89 children := amp.ReconstructChildren(childDescs...) 90 91 // Validate that the derived child preimages match the hash of each shard. 92 for i, child := range children { 93 require.Equal(t, shards[i].Hash(), child.Hash) 94 } 95 }