github.com/decred/dcrlnd@v0.7.6/routing/payment_session_test.go (about) 1 package routing 2 3 import ( 4 "testing" 5 "time" 6 7 "github.com/decred/dcrlnd/channeldb" 8 "github.com/decred/dcrlnd/lntypes" 9 "github.com/decred/dcrlnd/lnwire" 10 "github.com/decred/dcrlnd/routing/route" 11 "github.com/decred/dcrlnd/zpay32" 12 "github.com/stretchr/testify/require" 13 ) 14 15 func TestValidateCLTVLimit(t *testing.T) { 16 t.Parallel() 17 18 testCases := []struct { 19 name string 20 cltvLimit uint32 21 finalCltvDelta uint16 22 includePadding bool 23 expectError bool 24 }{ 25 { 26 name: "bad limit with padding", 27 cltvLimit: uint32(103), 28 finalCltvDelta: uint16(100), 29 includePadding: true, 30 expectError: true, 31 }, 32 { 33 name: "good limit with padding", 34 cltvLimit: uint32(104), 35 finalCltvDelta: uint16(100), 36 includePadding: true, 37 expectError: false, 38 }, 39 { 40 name: "bad limit no padding", 41 cltvLimit: uint32(100), 42 finalCltvDelta: uint16(100), 43 includePadding: false, 44 expectError: true, 45 }, 46 { 47 name: "good limit no padding", 48 cltvLimit: uint32(101), 49 finalCltvDelta: uint16(100), 50 includePadding: false, 51 expectError: false, 52 }, 53 } 54 55 for _, testCase := range testCases { 56 testCase := testCase 57 58 success := t.Run(testCase.name, func(t *testing.T) { 59 err := ValidateCLTVLimit( 60 testCase.cltvLimit, testCase.finalCltvDelta, 61 testCase.includePadding, 62 ) 63 64 if testCase.expectError { 65 require.NotEmpty(t, err) 66 } else { 67 require.NoError(t, err) 68 } 69 }) 70 if !success { 71 break 72 } 73 } 74 } 75 76 // TestUpdateAdditionalEdge checks that we can update the additional edges as 77 // expected. 78 func TestUpdateAdditionalEdge(t *testing.T) { 79 80 var ( 81 testChannelID = uint64(12345) 82 oldFeeBaseMAtoms = uint32(1000) 83 newFeeBaseMAtoms = uint32(1100) 84 oldExpiryDelta = uint16(100) 85 newExpiryDelta = uint16(120) 86 87 payHash lntypes.Hash 88 ) 89 90 // Create a minimal test node using the private key priv1. 91 pub := priv1.PubKey().SerializeCompressed() 92 testNode := &channeldb.LightningNode{} 93 copy(testNode.PubKeyBytes[:], pub) 94 95 nodeID, err := testNode.PubKey() 96 require.NoError(t, err, "failed to get node id") 97 98 // Create a payment with a route hint. 99 payment := &LightningPayment{ 100 Target: testNode.PubKeyBytes, 101 Amount: 1000, 102 RouteHints: [][]zpay32.HopHint{{ 103 zpay32.HopHint{ 104 // The nodeID is actually the target itself. It 105 // doesn't matter as we are not doing routing 106 // in this test. 107 NodeID: nodeID, 108 ChannelID: testChannelID, 109 FeeBaseMAtoms: oldFeeBaseMAtoms, 110 CLTVExpiryDelta: oldExpiryDelta, 111 }, 112 }}, 113 paymentHash: &payHash, 114 } 115 116 // Create the paymentsession. 117 session, err := newPaymentSession( 118 payment, 119 func(routingGraph) (bandwidthHints, error) { 120 return &mockBandwidthHints{}, nil 121 }, 122 func() (routingGraph, func(), error) { 123 return &sessionGraph{}, func() {}, nil 124 }, 125 &MissionControl{}, 126 PathFindingConfig{}, 127 ) 128 require.NoError(t, err, "failed to create payment session") 129 130 // We should have 1 additional edge. 131 require.Equal(t, 1, len(session.additionalEdges)) 132 133 // The edge should use nodeID as key, and its value should have 1 edge 134 // policy. 135 vertex := route.NewVertex(nodeID) 136 policies, ok := session.additionalEdges[vertex] 137 require.True(t, ok, "cannot find policy") 138 require.Equal(t, 1, len(policies), "should have 1 edge policy") 139 140 // Check that the policy has been created as expected. 141 policy := policies[0] 142 require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") 143 require.Equal(t, 144 oldExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", 145 ) 146 require.Equal(t, 147 lnwire.MilliAtom(oldFeeBaseMAtoms), 148 policy.FeeBaseMAtoms, "fee base msat mismatch", 149 ) 150 151 // Create the channel update message and sign. 152 msg := &lnwire.ChannelUpdate{ 153 ShortChannelID: lnwire.NewShortChanIDFromInt(testChannelID), 154 Timestamp: uint32(time.Now().Unix()), 155 BaseFee: newFeeBaseMAtoms, 156 TimeLockDelta: newExpiryDelta, 157 } 158 signErrChanUpdate(t, priv1, msg) 159 160 // Apply the update. 161 require.True(t, 162 session.UpdateAdditionalEdge(msg, nodeID, policy), 163 "failed to update additional edge", 164 ) 165 166 // Check that the policy has been updated as expected. 167 require.Equal(t, testChannelID, policy.ChannelID, "channel ID mismatch") 168 require.Equal(t, 169 newExpiryDelta, policy.TimeLockDelta, "timelock delta mismatch", 170 ) 171 require.Equal(t, 172 lnwire.MilliAtom(newFeeBaseMAtoms), 173 policy.FeeBaseMAtoms, "fee base msat mismatch", 174 ) 175 } 176 177 func TestRequestRoute(t *testing.T) { 178 const ( 179 height = 10 180 ) 181 182 cltvLimit := uint32(30) 183 finalCltvDelta := uint16(8) 184 185 payment := &LightningPayment{ 186 CltvLimit: cltvLimit, 187 FinalCLTVDelta: finalCltvDelta, 188 Amount: 1000, 189 FeeLimit: 1000, 190 } 191 192 var paymentHash [32]byte 193 if err := payment.SetPaymentHash(paymentHash); err != nil { 194 t.Fatal(err) 195 } 196 197 session, err := newPaymentSession( 198 payment, 199 func(routingGraph) (bandwidthHints, error) { 200 return &mockBandwidthHints{}, nil 201 }, 202 func() (routingGraph, func(), error) { 203 return &sessionGraph{}, func() {}, nil 204 }, 205 &MissionControl{}, 206 PathFindingConfig{}, 207 ) 208 if err != nil { 209 t.Fatal(err) 210 } 211 212 // Override pathfinder with a mock. 213 session.pathFinder = func( 214 g *graphParams, r *RestrictParams, cfg *PathFindingConfig, 215 source, target route.Vertex, amt lnwire.MilliAtom, 216 finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { 217 218 // We expect find path to receive a cltv limit excluding the 219 // final cltv delta (including the block padding). 220 if r.CltvLimit != 22-uint32(BlockPadding) { 221 t.Fatal("wrong cltv limit") 222 } 223 224 path := []*channeldb.CachedEdgePolicy{ 225 { 226 ToNodePubKey: func() route.Vertex { 227 return route.Vertex{} 228 }, 229 ToNodeFeatures: lnwire.NewFeatureVector( 230 nil, nil, 231 ), 232 }, 233 } 234 235 return path, nil 236 } 237 238 route, err := session.RequestRoute( 239 payment.Amount, payment.FeeLimit, 0, height, 240 ) 241 if err != nil { 242 t.Fatal(err) 243 } 244 245 // We expect an absolute route lock value of height + finalCltvDelta 246 // + BlockPadding. 247 if route.TotalTimeLock != 18+uint32(BlockPadding) { 248 t.Fatalf("unexpected total time lock of %v", 249 route.TotalTimeLock) 250 } 251 } 252 253 type sessionGraph struct { 254 routingGraph 255 } 256 257 func (g *sessionGraph) sourceNode() route.Vertex { 258 return route.Vertex{} 259 }