github.com/decred/dcrlnd@v0.7.6/routing/payment_session_test.go (about)

     1  package routing
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/decred/dcrlnd/channeldb"
     8  	"github.com/decred/dcrlnd/lntypes"
     9  	"github.com/decred/dcrlnd/lnwire"
    10  	"github.com/decred/dcrlnd/routing/route"
    11  	"github.com/decred/dcrlnd/zpay32"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestValidateCLTVLimit(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	testCases := []struct {
    19  		name           string
    20  		cltvLimit      uint32
    21  		finalCltvDelta uint16
    22  		includePadding bool
    23  		expectError    bool
    24  	}{
    25  		{
    26  			name:           "bad limit with padding",
    27  			cltvLimit:      uint32(103),
    28  			finalCltvDelta: uint16(100),
    29  			includePadding: true,
    30  			expectError:    true,
    31  		},
    32  		{
    33  			name:           "good limit with padding",
    34  			cltvLimit:      uint32(104),
    35  			finalCltvDelta: uint16(100),
    36  			includePadding: true,
    37  			expectError:    false,
    38  		},
    39  		{
    40  			name:           "bad limit no padding",
    41  			cltvLimit:      uint32(100),
    42  			finalCltvDelta: uint16(100),
    43  			includePadding: false,
    44  			expectError:    true,
    45  		},
    46  		{
    47  			name:           "good limit no padding",
    48  			cltvLimit:      uint32(101),
    49  			finalCltvDelta: uint16(100),
    50  			includePadding: false,
    51  			expectError:    false,
    52  		},
    53  	}
    54  
    55  	for _, testCase := range testCases {
    56  		testCase := testCase
    57  
    58  		success := t.Run(testCase.name, func(t *testing.T) {
    59  			err := ValidateCLTVLimit(
    60  				testCase.cltvLimit, testCase.finalCltvDelta,
    61  				testCase.includePadding,
    62  			)
    63  
    64  			if testCase.expectError {
    65  				require.NotEmpty(t, err)
    66  			} else {
    67  				require.NoError(t, err)
    68  			}
    69  		})
    70  		if !success {
    71  			break
    72  		}
    73  	}
    74  }
    75  
    76  // TestUpdateAdditionalEdge checks that we can update the additional edges as
    77  // expected.
    78  func TestUpdateAdditionalEdge(t *testing.T) {
    79  
    80  	var (
    81  		testChannelID    = uint64(12345)
    82  		oldFeeBaseMAtoms = uint32(1000)
    83  		newFeeBaseMAtoms = uint32(1100)
    84  		oldExpiryDelta   = uint16(100)
    85  		newExpiryDelta   = uint16(120)
    86  
    87  		payHash lntypes.Hash
    88  	)
    89  
    90  	// Create a minimal test node using the private key priv1.
    91  	pub := priv1.PubKey().SerializeCompressed()
    92  	testNode := &channeldb.LightningNode{}
    93  	copy(testNode.PubKeyBytes[:], pub)
    94  
    95  	nodeID, err := testNode.PubKey()
    96  	require.NoError(t, err, "failed to get node id")
    97  
    98  	// Create a payment with a route hint.
    99  	payment := &LightningPayment{
   100  		Target: testNode.PubKeyBytes,
   101  		Amount: 1000,
   102  		RouteHints: [][]zpay32.HopHint{{
   103  			zpay32.HopHint{
   104  				// The nodeID is actually the target itself. It
   105  				// doesn't matter as we are not doing routing
   106  				// in this test.
   107  				NodeID:          nodeID,
   108  				ChannelID:       testChannelID,
   109  				FeeBaseMAtoms:   oldFeeBaseMAtoms,
   110  				CLTVExpiryDelta: oldExpiryDelta,
   111  			},
   112  		}},
   113  		paymentHash: &payHash,
   114  	}
   115  
   116  	// Create the paymentsession.
   117  	session, err := newPaymentSession(
   118  		payment,
   119  		func(routingGraph) (bandwidthHints, error) {
   120  			return &mockBandwidthHints{}, nil
   121  		},
   122  		func() (routingGraph, func(), error) {
   123  			return &sessionGraph{}, func() {}, nil
   124  		},
   125  		&MissionControl{},
   126  		PathFindingConfig{},
   127  	)
   128  	require.NoError(t, err, "failed to create payment session")
   129  
   130  	// We should have 1 additional edge.
   131  	require.Equal(t, 1, len(session.additionalEdges))
   132  
   133  	// The edge should use nodeID as key, and its value should have 1 edge
   134  	// policy.
   135  	vertex := route.NewVertex(nodeID)
   136  	policies, ok := session.additionalEdges[vertex]
   137  	require.True(t, ok, "cannot find policy")
   138  	require.Equal(t, 1, len(policies), "should have 1 edge policy")
   139  
   140  	// Check that the policy has been created as expected.
   141  	policy := policies[0]
   142  	require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
   143  	require.Equal(t,
   144  		oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
   145  	)
   146  	require.Equal(t,
   147  		lnwire.MilliAtom(oldFeeBaseMAtoms),
   148  		policy.FeeBaseMAtoms, "fee base msat mismatch",
   149  	)
   150  
   151  	// Create the channel update message and sign.
   152  	msg := &lnwire.ChannelUpdate{
   153  		ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID),
   154  		Timestamp:      uint32(time.Now().Unix()),
   155  		BaseFee:        newFeeBaseMAtoms,
   156  		TimeLockDelta:  newExpiryDelta,
   157  	}
   158  	signErrChanUpdate(t, priv1, msg)
   159  
   160  	// Apply the update.
   161  	require.True(t,
   162  		session.UpdateAdditionalEdge(msg, nodeID, policy),
   163  		"failed to update additional edge",
   164  	)
   165  
   166  	// Check that the policy has been updated as expected.
   167  	require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch")
   168  	require.Equal(t,
   169  		newExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch",
   170  	)
   171  	require.Equal(t,
   172  		lnwire.MilliAtom(newFeeBaseMAtoms),
   173  		policy.FeeBaseMAtoms, "fee base msat mismatch",
   174  	)
   175  }
   176  
   177  func TestRequestRoute(t *testing.T) {
   178  	const (
   179  		height = 10
   180  	)
   181  
   182  	cltvLimit := uint32(30)
   183  	finalCltvDelta := uint16(8)
   184  
   185  	payment := &LightningPayment{
   186  		CltvLimit:      cltvLimit,
   187  		FinalCLTVDelta: finalCltvDelta,
   188  		Amount:         1000,
   189  		FeeLimit:       1000,
   190  	}
   191  
   192  	var paymentHash [32]byte
   193  	if err := payment.SetPaymentHash(paymentHash); err != nil {
   194  		t.Fatal(err)
   195  	}
   196  
   197  	session, err := newPaymentSession(
   198  		payment,
   199  		func(routingGraph) (bandwidthHints, error) {
   200  			return &mockBandwidthHints{}, nil
   201  		},
   202  		func() (routingGraph, func(), error) {
   203  			return &sessionGraph{}, func() {}, nil
   204  		},
   205  		&MissionControl{},
   206  		PathFindingConfig{},
   207  	)
   208  	if err != nil {
   209  		t.Fatal(err)
   210  	}
   211  
   212  	// Override pathfinder with a mock.
   213  	session.pathFinder = func(
   214  		g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
   215  		source, target route.Vertex, amt lnwire.MilliAtom,
   216  		finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
   217  
   218  		// We expect find path to receive a cltv limit excluding the
   219  		// final cltv delta (including the block padding).
   220  		if r.CltvLimit != 22-uint32(BlockPadding) {
   221  			t.Fatal("wrong cltv limit")
   222  		}
   223  
   224  		path := []*channeldb.CachedEdgePolicy{
   225  			{
   226  				ToNodePubKey: func() route.Vertex {
   227  					return route.Vertex{}
   228  				},
   229  				ToNodeFeatures: lnwire.NewFeatureVector(
   230  					nil, nil,
   231  				),
   232  			},
   233  		}
   234  
   235  		return path, nil
   236  	}
   237  
   238  	route, err := session.RequestRoute(
   239  		payment.Amount, payment.FeeLimit, 0, height,
   240  	)
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  
   245  	// We expect an absolute route lock value of height + finalCltvDelta
   246  	// + BlockPadding.
   247  	if route.TotalTimeLock != 18+uint32(BlockPadding) {
   248  		t.Fatalf("unexpected total time lock of %v",
   249  			route.TotalTimeLock)
   250  	}
   251  }
   252  
   253  type sessionGraph struct {
   254  	routingGraph
   255  }
   256  
   257  func (g *sessionGraph) sourceNode() route.Vertex {
   258  	return route.Vertex{}
   259  }