github.com/decred/dcrlnd@v0.7.6/lnrpc/routerrpc/router_backend_test.go (about) 1 package routerrpc 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/hex" 7 "testing" 8 9 "github.com/decred/dcrd/dcrutil/v4" 10 "github.com/decred/dcrlnd/channeldb" 11 "github.com/decred/dcrlnd/lnwire" 12 "github.com/decred/dcrlnd/record" 13 "github.com/decred/dcrlnd/routing" 14 "github.com/decred/dcrlnd/routing/route" 15 "github.com/stretchr/testify/require" 16 17 "github.com/decred/dcrlnd/lnrpc" 18 ) 19 20 const ( 21 destKey = "0286098b97bc843372b4426d4b276cea9aa2f48f0428d6f5b66ae101befc14f8b4" 22 ignoreNodeKey = "02f274f48f3c0d590449a6776e3ce8825076ac376e470e992246eebc565ef8bb2a" 23 hintNodeKey = "0274e7fb33eafd74fe1acb6db7680bb4aa78e9c839a6e954e38abfad680f645ef7" 24 25 testMissionControlProb = 0.5 26 ) 27 28 var ( 29 sourceKey = route.Vertex{1, 2, 3} 30 31 node1 = route.Vertex{10} 32 33 node2 = route.Vertex{11} 34 ) 35 36 // TestQueryRoutes asserts that query routes rpc parameters are properly parsed 37 // and passed onto path finding. 38 func TestQueryRoutes(t *testing.T) { 39 t.Run("no mission control", func(t *testing.T) { 40 testQueryRoutes(t, false, false, true) 41 }) 42 t.Run("no mission control and msat", func(t *testing.T) { 43 testQueryRoutes(t, false, true, true) 44 }) 45 t.Run("with mission control", func(t *testing.T) { 46 testQueryRoutes(t, true, false, true) 47 }) 48 t.Run("no mission control bad cltv limit", func(t *testing.T) { 49 testQueryRoutes(t, false, false, false) 50 }) 51 } 52 53 func testQueryRoutes(t *testing.T, useMissionControl bool, useMAtoms bool, 54 setTimelock bool) { 55 56 ignoreNodeBytes, err := hex.DecodeString(ignoreNodeKey) 57 if err != nil { 58 t.Fatal(err) 59 } 60 61 var ignoreNodeVertex route.Vertex 62 copy(ignoreNodeVertex[:], ignoreNodeBytes) 63 64 destNodeBytes, err := hex.DecodeString(destKey) 65 if err != nil { 66 t.Fatal(err) 67 } 68 69 var ( 70 lastHop = route.Vertex{64} 71 outgoingChan = uint64(383322) 72 ) 73 74 hintNode, err := route.NewVertexFromStr(hintNodeKey) 75 if err != nil { 76 t.Fatal(err) 77 } 78 79 rpcRouteHints := []*lnrpc.RouteHint{ 80 { 81 HopHints: []*lnrpc.HopHint{ 82 { 83 ChanId: 38484, 84 NodeId: hintNodeKey, 85 }, 86 }, 87 }, 88 } 89 90 request := &lnrpc.QueryRoutesRequest{ 91 PubKey: destKey, 92 FinalCltvDelta: 100, 93 IgnoredNodes: [][]byte{ignoreNodeBytes}, 94 IgnoredEdges: []*lnrpc.EdgeLocator{{ 95 ChannelId: 555, 96 DirectionReverse: true, 97 }}, 98 IgnoredPairs: []*lnrpc.NodePair{{ 99 From: node1[:], 100 To: node2[:], 101 }}, 102 UseMissionControl: useMissionControl, 103 LastHopPubkey: lastHop[:], 104 OutgoingChanId: outgoingChan, 105 DestFeatures: []lnrpc.FeatureBit{lnrpc.FeatureBit_MPP_OPT}, 106 RouteHints: rpcRouteHints, 107 } 108 109 amtAtoms := int64(100000) 110 if useMAtoms { 111 request.AmtMAtoms = amtAtoms * 1000 112 request.FeeLimit = &lnrpc.FeeLimit{ 113 Limit: &lnrpc.FeeLimit_FixedMAtoms{ 114 FixedMAtoms: 250000, 115 }, 116 } 117 } else { 118 request.Amt = amtAtoms 119 request.FeeLimit = &lnrpc.FeeLimit{ 120 Limit: &lnrpc.FeeLimit_Fixed{ 121 Fixed: 250, 122 }, 123 } 124 } 125 126 findRoute := func(source, target route.Vertex, 127 amt lnwire.MilliAtom, restrictions *routing.RestrictParams, 128 _ record.CustomSet, 129 routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy, 130 finalExpiry uint16) (*route.Route, error) { 131 132 if int64(amt) != amtAtoms*1000 { 133 t.Fatal("unexpected amount") 134 } 135 136 if source != sourceKey { 137 t.Fatal("unexpected source key") 138 } 139 140 if !bytes.Equal(target[:], destNodeBytes) { 141 t.Fatal("unexpected target key") 142 } 143 144 if restrictions.FeeLimit != 250*1000 { 145 t.Fatal("unexpected fee limit") 146 } 147 148 if restrictions.ProbabilitySource(route.Vertex{2}, 149 route.Vertex{1}, 0, 150 ) != 0 { 151 t.Fatal("expecting 0% probability for ignored edge") 152 } 153 154 if restrictions.ProbabilitySource(ignoreNodeVertex, 155 route.Vertex{6}, 0, 156 ) != 0 { 157 t.Fatal("expecting 0% probability for ignored node") 158 } 159 160 if restrictions.ProbabilitySource(node1, node2, 0) != 0 { 161 t.Fatal("expecting 0% probability for ignored pair") 162 } 163 164 if *restrictions.LastHop != lastHop { 165 t.Fatal("unexpected last hop") 166 } 167 168 if restrictions.OutgoingChannelIDs[0] != outgoingChan { 169 t.Fatal("unexpected outgoing channel id") 170 } 171 172 if !restrictions.DestFeatures.HasFeature(lnwire.MPPOptional) { 173 t.Fatal("unexpected dest features") 174 } 175 176 if _, ok := routeHints[hintNode]; !ok { 177 t.Fatal("expected route hint") 178 } 179 180 expectedProb := 1.0 181 if useMissionControl { 182 expectedProb = testMissionControlProb 183 } 184 if restrictions.ProbabilitySource(route.Vertex{4}, 185 route.Vertex{5}, 0, 186 ) != expectedProb { 187 t.Fatal("expecting 100% probability") 188 } 189 190 hops := []*route.Hop{{}} 191 return route.NewRouteFromHops(amt, 144, source, hops) 192 } 193 194 backend := &RouterBackend{ 195 MaxPaymentMAtoms: lnwire.NewMAtomsFromAtoms(1000000), 196 FindRoute: findRoute, 197 SelfNode: route.Vertex{1, 2, 3}, 198 FetchChannelCapacity: func(chanID uint64) ( 199 dcrutil.Amount, error) { 200 201 return 1, nil 202 }, 203 MissionControl: &mockMissionControl{}, 204 FetchChannelEndpoints: func(chanID uint64) (route.Vertex, 205 route.Vertex, error) { 206 207 if chanID != 555 { 208 t.Fatalf("expected endpoints to be fetched for "+ 209 "channel 555, but got %v instead", 210 chanID) 211 } 212 return route.Vertex{1}, route.Vertex{2}, nil 213 }, 214 } 215 216 // If this is set, we'll populate MaxTotalTimelock. If this is not set, 217 // the test will fail as CltvLimit will be 0. 218 if setTimelock { 219 backend.MaxTotalTimelock = 1000 220 } 221 222 resp, err := backend.QueryRoutes(context.Background(), request) 223 224 // If no MaxTotalTimelock was set for the QueryRoutes request, make 225 // sure an error was returned. 226 if !setTimelock { 227 require.NotEmpty(t, err) 228 return 229 } 230 231 if err != nil { 232 t.Fatal(err) 233 } 234 if len(resp.Routes) != 1 { 235 t.Fatal("expected a single route response") 236 } 237 } 238 239 type mockMissionControl struct { 240 MissionControl 241 } 242 243 func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex, 244 amt lnwire.MilliAtom) float64 { 245 246 return testMissionControlProb 247 } 248 249 func (m *mockMissionControl) ResetHistory() error { 250 return nil 251 } 252 253 func (m *mockMissionControl) GetHistorySnapshot() *routing.MissionControlSnapshot { 254 return nil 255 } 256 257 func (m *mockMissionControl) GetPairHistorySnapshot(fromNode, 258 toNode route.Vertex) routing.TimedPairResult { 259 260 return routing.TimedPairResult{} 261 } 262 263 type recordParseOutcome byte 264 265 const ( 266 valid recordParseOutcome = iota 267 invalid 268 norecord 269 ) 270 271 type unmarshalMPPTest struct { 272 name string 273 mpp *lnrpc.MPPRecord 274 outcome recordParseOutcome 275 } 276 277 // TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to 278 // assert that an MPP record is only returned when both fields are properly 279 // specified. It also asserts that zero-values for both inputs is also valid, 280 // but returns a nil record. 281 func TestUnmarshalMPP(t *testing.T) { 282 tests := []unmarshalMPPTest{ 283 { 284 name: "nil record", 285 mpp: nil, 286 outcome: norecord, 287 }, 288 { 289 name: "invalid total or addr", 290 mpp: &lnrpc.MPPRecord{ 291 PaymentAddr: nil, 292 TotalAmtMAtoms: 0, 293 }, 294 outcome: invalid, 295 }, 296 { 297 name: "valid total only", 298 mpp: &lnrpc.MPPRecord{ 299 PaymentAddr: nil, 300 TotalAmtMAtoms: 8, 301 }, 302 outcome: invalid, 303 }, 304 { 305 name: "valid addr only", 306 mpp: &lnrpc.MPPRecord{ 307 PaymentAddr: bytes.Repeat([]byte{0x02}, 32), 308 TotalAmtMAtoms: 0, 309 }, 310 outcome: invalid, 311 }, 312 { 313 name: "valid total and invalid addr", 314 mpp: &lnrpc.MPPRecord{ 315 PaymentAddr: []byte{0x02}, 316 TotalAmtMAtoms: 8, 317 }, 318 outcome: invalid, 319 }, 320 { 321 name: "valid total and valid addr", 322 mpp: &lnrpc.MPPRecord{ 323 PaymentAddr: bytes.Repeat([]byte{0x02}, 32), 324 TotalAmtMAtoms: 8, 325 }, 326 outcome: valid, 327 }, 328 } 329 330 for _, test := range tests { 331 test := test 332 t.Run(test.name, func(t *testing.T) { 333 testUnmarshalMPP(t, test) 334 }) 335 } 336 } 337 338 func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) { 339 mpp, err := UnmarshalMPP(test.mpp) 340 switch test.outcome { 341 342 // Valid arguments should result in no error, a non-nil MPP record, and 343 // the fields should be set correctly. 344 case valid: 345 if err != nil { 346 t.Fatalf("unable to parse mpp record: %v", err) 347 } 348 if mpp == nil { 349 t.Fatalf("mpp payload should be non-nil") 350 } 351 if int64(mpp.TotalMAtoms()) != test.mpp.TotalAmtMAtoms { 352 t.Fatalf("incorrect total msat") 353 } 354 addr := mpp.PaymentAddr() 355 if !bytes.Equal(addr[:], test.mpp.PaymentAddr) { 356 t.Fatalf("incorrect payment addr") 357 } 358 359 // Invalid arguments should produce a failure and nil MPP record. 360 case invalid: 361 if err == nil { 362 t.Fatalf("expected failure for invalid mpp") 363 } 364 if mpp != nil { 365 t.Fatalf("mpp payload should be nil for failure") 366 } 367 368 // Arguments that produce no MPP field should return no error and no MPP 369 // record. 370 case norecord: 371 if err != nil { 372 t.Fatalf("failure for args resulting for no-mpp") 373 } 374 if mpp != nil { 375 t.Fatalf("mpp payload should be nil for no-mpp") 376 } 377 378 default: 379 t.Fatalf("test case has non-standard outcome") 380 } 381 } 382 383 type unmarshalAMPTest struct { 384 name string 385 amp *lnrpc.AMPRecord 386 outcome recordParseOutcome 387 } 388 389 // TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord. 390 func TestUnmarshalAMP(t *testing.T) { 391 rootShare := bytes.Repeat([]byte{0x01}, 32) 392 setID := bytes.Repeat([]byte{0x02}, 32) 393 394 // All child indexes are valid. 395 childIndex := uint32(3) 396 397 tests := []unmarshalAMPTest{ 398 { 399 name: "nil record", 400 amp: nil, 401 outcome: norecord, 402 }, 403 { 404 name: "invalid root share invalid set id", 405 amp: &lnrpc.AMPRecord{ 406 RootShare: []byte{0x01}, 407 SetId: []byte{0x02}, 408 ChildIndex: childIndex, 409 }, 410 outcome: invalid, 411 }, 412 { 413 name: "valid root share invalid set id", 414 amp: &lnrpc.AMPRecord{ 415 RootShare: rootShare, 416 SetId: []byte{0x02}, 417 ChildIndex: childIndex, 418 }, 419 outcome: invalid, 420 }, 421 { 422 name: "invalid root share valid set id", 423 amp: &lnrpc.AMPRecord{ 424 RootShare: []byte{0x01}, 425 SetId: setID, 426 ChildIndex: childIndex, 427 }, 428 outcome: invalid, 429 }, 430 { 431 name: "valid root share valid set id", 432 amp: &lnrpc.AMPRecord{ 433 RootShare: rootShare, 434 SetId: setID, 435 ChildIndex: childIndex, 436 }, 437 outcome: valid, 438 }, 439 } 440 441 for _, test := range tests { 442 test := test 443 t.Run(test.name, func(t *testing.T) { 444 testUnmarshalAMP(t, test) 445 }) 446 } 447 } 448 449 func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) { 450 amp, err := UnmarshalAMP(test.amp) 451 switch test.outcome { 452 case valid: 453 require.NoError(t, err) 454 require.NotNil(t, amp) 455 456 rootShare := amp.RootShare() 457 setID := amp.SetID() 458 require.Equal(t, test.amp.RootShare, rootShare[:]) 459 require.Equal(t, test.amp.SetId, setID[:]) 460 require.Equal(t, test.amp.ChildIndex, amp.ChildIndex()) 461 462 case invalid: 463 require.Error(t, err) 464 require.Nil(t, amp) 465 466 case norecord: 467 require.NoError(t, err) 468 require.Nil(t, amp) 469 470 default: 471 t.Fatalf("test case has non-standard outcome") 472 } 473 }