github.com/decred/dcrlnd@v0.7.6/lnrpc/routerrpc/router_backend_test.go (about)

     1  package routerrpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/hex"
     7  	"testing"
     8  
     9  	"github.com/decred/dcrd/dcrutil/v4"
    10  	"github.com/decred/dcrlnd/channeldb"
    11  	"github.com/decred/dcrlnd/lnwire"
    12  	"github.com/decred/dcrlnd/record"
    13  	"github.com/decred/dcrlnd/routing"
    14  	"github.com/decred/dcrlnd/routing/route"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/decred/dcrlnd/lnrpc"
    18  )
    19  
    20  const (
    21  	destKey       = "0286098b97bc843372b4426d4b276cea9aa2f48f0428d6f5b66ae101befc14f8b4"
    22  	ignoreNodeKey = "02f274f48f3c0d590449a6776e3ce8825076ac376e470e992246eebc565ef8bb2a"
    23  	hintNodeKey   = "0274e7fb33eafd74fe1acb6db7680bb4aa78e9c839a6e954e38abfad680f645ef7"
    24  
    25  	testMissionControlProb = 0.5
    26  )
    27  
    28  var (
    29  	sourceKey = route.Vertex{1, 2, 3}
    30  
    31  	node1 = route.Vertex{10}
    32  
    33  	node2 = route.Vertex{11}
    34  )
    35  
    36  // TestQueryRoutes asserts that query routes rpc parameters are properly parsed
    37  // and passed onto path finding.
    38  func TestQueryRoutes(t *testing.T) {
    39  	t.Run("no mission control", func(t *testing.T) {
    40  		testQueryRoutes(t, false, false, true)
    41  	})
    42  	t.Run("no mission control and msat", func(t *testing.T) {
    43  		testQueryRoutes(t, false, true, true)
    44  	})
    45  	t.Run("with mission control", func(t *testing.T) {
    46  		testQueryRoutes(t, true, false, true)
    47  	})
    48  	t.Run("no mission control bad cltv limit", func(t *testing.T) {
    49  		testQueryRoutes(t, false, false, false)
    50  	})
    51  }
    52  
    53  func testQueryRoutes(t *testing.T, useMissionControl bool, useMAtoms bool,
    54  	setTimelock bool) {
    55  
    56  	ignoreNodeBytes, err := hex.DecodeString(ignoreNodeKey)
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  
    61  	var ignoreNodeVertex route.Vertex
    62  	copy(ignoreNodeVertex[:], ignoreNodeBytes)
    63  
    64  	destNodeBytes, err := hex.DecodeString(destKey)
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  
    69  	var (
    70  		lastHop      = route.Vertex{64}
    71  		outgoingChan = uint64(383322)
    72  	)
    73  
    74  	hintNode, err := route.NewVertexFromStr(hintNodeKey)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  
    79  	rpcRouteHints := []*lnrpc.RouteHint{
    80  		{
    81  			HopHints: []*lnrpc.HopHint{
    82  				{
    83  					ChanId: 38484,
    84  					NodeId: hintNodeKey,
    85  				},
    86  			},
    87  		},
    88  	}
    89  
    90  	request := &lnrpc.QueryRoutesRequest{
    91  		PubKey:         destKey,
    92  		FinalCltvDelta: 100,
    93  		IgnoredNodes:   [][]byte{ignoreNodeBytes},
    94  		IgnoredEdges: []*lnrpc.EdgeLocator{{
    95  			ChannelId:        555,
    96  			DirectionReverse: true,
    97  		}},
    98  		IgnoredPairs: []*lnrpc.NodePair{{
    99  			From: node1[:],
   100  			To:   node2[:],
   101  		}},
   102  		UseMissionControl: useMissionControl,
   103  		LastHopPubkey:     lastHop[:],
   104  		OutgoingChanId:    outgoingChan,
   105  		DestFeatures:      []lnrpc.FeatureBit{lnrpc.FeatureBit_MPP_OPT},
   106  		RouteHints:        rpcRouteHints,
   107  	}
   108  
   109  	amtAtoms := int64(100000)
   110  	if useMAtoms {
   111  		request.AmtMAtoms = amtAtoms * 1000
   112  		request.FeeLimit = &lnrpc.FeeLimit{
   113  			Limit: &lnrpc.FeeLimit_FixedMAtoms{
   114  				FixedMAtoms: 250000,
   115  			},
   116  		}
   117  	} else {
   118  		request.Amt = amtAtoms
   119  		request.FeeLimit = &lnrpc.FeeLimit{
   120  			Limit: &lnrpc.FeeLimit_Fixed{
   121  				Fixed: 250,
   122  			},
   123  		}
   124  	}
   125  
   126  	findRoute := func(source, target route.Vertex,
   127  		amt lnwire.MilliAtom, restrictions *routing.RestrictParams,
   128  		_ record.CustomSet,
   129  		routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
   130  		finalExpiry uint16) (*route.Route, error) {
   131  
   132  		if int64(amt) != amtAtoms*1000 {
   133  			t.Fatal("unexpected amount")
   134  		}
   135  
   136  		if source != sourceKey {
   137  			t.Fatal("unexpected source key")
   138  		}
   139  
   140  		if !bytes.Equal(target[:], destNodeBytes) {
   141  			t.Fatal("unexpected target key")
   142  		}
   143  
   144  		if restrictions.FeeLimit != 250*1000 {
   145  			t.Fatal("unexpected fee limit")
   146  		}
   147  
   148  		if restrictions.ProbabilitySource(route.Vertex{2},
   149  			route.Vertex{1}, 0,
   150  		) != 0 {
   151  			t.Fatal("expecting 0% probability for ignored edge")
   152  		}
   153  
   154  		if restrictions.ProbabilitySource(ignoreNodeVertex,
   155  			route.Vertex{6}, 0,
   156  		) != 0 {
   157  			t.Fatal("expecting 0% probability for ignored node")
   158  		}
   159  
   160  		if restrictions.ProbabilitySource(node1, node2, 0) != 0 {
   161  			t.Fatal("expecting 0% probability for ignored pair")
   162  		}
   163  
   164  		if *restrictions.LastHop != lastHop {
   165  			t.Fatal("unexpected last hop")
   166  		}
   167  
   168  		if restrictions.OutgoingChannelIDs[0] != outgoingChan {
   169  			t.Fatal("unexpected outgoing channel id")
   170  		}
   171  
   172  		if !restrictions.DestFeatures.HasFeature(lnwire.MPPOptional) {
   173  			t.Fatal("unexpected dest features")
   174  		}
   175  
   176  		if _, ok := routeHints[hintNode]; !ok {
   177  			t.Fatal("expected route hint")
   178  		}
   179  
   180  		expectedProb := 1.0
   181  		if useMissionControl {
   182  			expectedProb = testMissionControlProb
   183  		}
   184  		if restrictions.ProbabilitySource(route.Vertex{4},
   185  			route.Vertex{5}, 0,
   186  		) != expectedProb {
   187  			t.Fatal("expecting 100% probability")
   188  		}
   189  
   190  		hops := []*route.Hop{{}}
   191  		return route.NewRouteFromHops(amt, 144, source, hops)
   192  	}
   193  
   194  	backend := &RouterBackend{
   195  		MaxPaymentMAtoms: lnwire.NewMAtomsFromAtoms(1000000),
   196  		FindRoute:        findRoute,
   197  		SelfNode:         route.Vertex{1, 2, 3},
   198  		FetchChannelCapacity: func(chanID uint64) (
   199  			dcrutil.Amount, error) {
   200  
   201  			return 1, nil
   202  		},
   203  		MissionControl: &mockMissionControl{},
   204  		FetchChannelEndpoints: func(chanID uint64) (route.Vertex,
   205  			route.Vertex, error) {
   206  
   207  			if chanID != 555 {
   208  				t.Fatalf("expected endpoints to be fetched for "+
   209  					"channel 555, but got %v instead",
   210  					chanID)
   211  			}
   212  			return route.Vertex{1}, route.Vertex{2}, nil
   213  		},
   214  	}
   215  
   216  	// If this is set, we'll populate MaxTotalTimelock. If this is not set,
   217  	// the test will fail as CltvLimit will be 0.
   218  	if setTimelock {
   219  		backend.MaxTotalTimelock = 1000
   220  	}
   221  
   222  	resp, err := backend.QueryRoutes(context.Background(), request)
   223  
   224  	// If no MaxTotalTimelock was set for the QueryRoutes request, make
   225  	// sure an error was returned.
   226  	if !setTimelock {
   227  		require.NotEmpty(t, err)
   228  		return
   229  	}
   230  
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	if len(resp.Routes) != 1 {
   235  		t.Fatal("expected a single route response")
   236  	}
   237  }
   238  
   239  type mockMissionControl struct {
   240  	MissionControl
   241  }
   242  
   243  func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
   244  	amt lnwire.MilliAtom) float64 {
   245  
   246  	return testMissionControlProb
   247  }
   248  
   249  func (m *mockMissionControl) ResetHistory() error {
   250  	return nil
   251  }
   252  
   253  func (m *mockMissionControl) GetHistorySnapshot() *routing.MissionControlSnapshot {
   254  	return nil
   255  }
   256  
   257  func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
   258  	toNode route.Vertex) routing.TimedPairResult {
   259  
   260  	return routing.TimedPairResult{}
   261  }
   262  
   263  type recordParseOutcome byte
   264  
   265  const (
   266  	valid recordParseOutcome = iota
   267  	invalid
   268  	norecord
   269  )
   270  
   271  type unmarshalMPPTest struct {
   272  	name    string
   273  	mpp     *lnrpc.MPPRecord
   274  	outcome recordParseOutcome
   275  }
   276  
   277  // TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
   278  // assert that an MPP record is only returned when both fields are properly
   279  // specified. It also asserts that zero-values for both inputs is also valid,
   280  // but returns a nil record.
   281  func TestUnmarshalMPP(t *testing.T) {
   282  	tests := []unmarshalMPPTest{
   283  		{
   284  			name:    "nil record",
   285  			mpp:     nil,
   286  			outcome: norecord,
   287  		},
   288  		{
   289  			name: "invalid total or addr",
   290  			mpp: &lnrpc.MPPRecord{
   291  				PaymentAddr:    nil,
   292  				TotalAmtMAtoms: 0,
   293  			},
   294  			outcome: invalid,
   295  		},
   296  		{
   297  			name: "valid total only",
   298  			mpp: &lnrpc.MPPRecord{
   299  				PaymentAddr:    nil,
   300  				TotalAmtMAtoms: 8,
   301  			},
   302  			outcome: invalid,
   303  		},
   304  		{
   305  			name: "valid addr only",
   306  			mpp: &lnrpc.MPPRecord{
   307  				PaymentAddr:    bytes.Repeat([]byte{0x02}, 32),
   308  				TotalAmtMAtoms: 0,
   309  			},
   310  			outcome: invalid,
   311  		},
   312  		{
   313  			name: "valid total and invalid addr",
   314  			mpp: &lnrpc.MPPRecord{
   315  				PaymentAddr:    []byte{0x02},
   316  				TotalAmtMAtoms: 8,
   317  			},
   318  			outcome: invalid,
   319  		},
   320  		{
   321  			name: "valid total and valid addr",
   322  			mpp: &lnrpc.MPPRecord{
   323  				PaymentAddr:    bytes.Repeat([]byte{0x02}, 32),
   324  				TotalAmtMAtoms: 8,
   325  			},
   326  			outcome: valid,
   327  		},
   328  	}
   329  
   330  	for _, test := range tests {
   331  		test := test
   332  		t.Run(test.name, func(t *testing.T) {
   333  			testUnmarshalMPP(t, test)
   334  		})
   335  	}
   336  }
   337  
   338  func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
   339  	mpp, err := UnmarshalMPP(test.mpp)
   340  	switch test.outcome {
   341  
   342  	// Valid arguments should result in no error, a non-nil MPP record, and
   343  	// the fields should be set correctly.
   344  	case valid:
   345  		if err != nil {
   346  			t.Fatalf("unable to parse mpp record: %v", err)
   347  		}
   348  		if mpp == nil {
   349  			t.Fatalf("mpp payload should be non-nil")
   350  		}
   351  		if int64(mpp.TotalMAtoms()) != test.mpp.TotalAmtMAtoms {
   352  			t.Fatalf("incorrect total msat")
   353  		}
   354  		addr := mpp.PaymentAddr()
   355  		if !bytes.Equal(addr[:], test.mpp.PaymentAddr) {
   356  			t.Fatalf("incorrect payment addr")
   357  		}
   358  
   359  	// Invalid arguments should produce a failure and nil MPP record.
   360  	case invalid:
   361  		if err == nil {
   362  			t.Fatalf("expected failure for invalid mpp")
   363  		}
   364  		if mpp != nil {
   365  			t.Fatalf("mpp payload should be nil for failure")
   366  		}
   367  
   368  	// Arguments that produce no MPP field should return no error and no MPP
   369  	// record.
   370  	case norecord:
   371  		if err != nil {
   372  			t.Fatalf("failure for args resulting for no-mpp")
   373  		}
   374  		if mpp != nil {
   375  			t.Fatalf("mpp payload should be nil for no-mpp")
   376  		}
   377  
   378  	default:
   379  		t.Fatalf("test case has non-standard outcome")
   380  	}
   381  }
   382  
   383  type unmarshalAMPTest struct {
   384  	name    string
   385  	amp     *lnrpc.AMPRecord
   386  	outcome recordParseOutcome
   387  }
   388  
   389  // TestUnmarshalAMP asserts the behavior of decoding an RPC AMPRecord.
   390  func TestUnmarshalAMP(t *testing.T) {
   391  	rootShare := bytes.Repeat([]byte{0x01}, 32)
   392  	setID := bytes.Repeat([]byte{0x02}, 32)
   393  
   394  	// All child indexes are valid.
   395  	childIndex := uint32(3)
   396  
   397  	tests := []unmarshalAMPTest{
   398  		{
   399  			name:    "nil record",
   400  			amp:     nil,
   401  			outcome: norecord,
   402  		},
   403  		{
   404  			name: "invalid root share invalid set id",
   405  			amp: &lnrpc.AMPRecord{
   406  				RootShare:  []byte{0x01},
   407  				SetId:      []byte{0x02},
   408  				ChildIndex: childIndex,
   409  			},
   410  			outcome: invalid,
   411  		},
   412  		{
   413  			name: "valid root share invalid set id",
   414  			amp: &lnrpc.AMPRecord{
   415  				RootShare:  rootShare,
   416  				SetId:      []byte{0x02},
   417  				ChildIndex: childIndex,
   418  			},
   419  			outcome: invalid,
   420  		},
   421  		{
   422  			name: "invalid root share valid set id",
   423  			amp: &lnrpc.AMPRecord{
   424  				RootShare:  []byte{0x01},
   425  				SetId:      setID,
   426  				ChildIndex: childIndex,
   427  			},
   428  			outcome: invalid,
   429  		},
   430  		{
   431  			name: "valid root share valid set id",
   432  			amp: &lnrpc.AMPRecord{
   433  				RootShare:  rootShare,
   434  				SetId:      setID,
   435  				ChildIndex: childIndex,
   436  			},
   437  			outcome: valid,
   438  		},
   439  	}
   440  
   441  	for _, test := range tests {
   442  		test := test
   443  		t.Run(test.name, func(t *testing.T) {
   444  			testUnmarshalAMP(t, test)
   445  		})
   446  	}
   447  }
   448  
   449  func testUnmarshalAMP(t *testing.T, test unmarshalAMPTest) {
   450  	amp, err := UnmarshalAMP(test.amp)
   451  	switch test.outcome {
   452  	case valid:
   453  		require.NoError(t, err)
   454  		require.NotNil(t, amp)
   455  
   456  		rootShare := amp.RootShare()
   457  		setID := amp.SetID()
   458  		require.Equal(t, test.amp.RootShare, rootShare[:])
   459  		require.Equal(t, test.amp.SetId, setID[:])
   460  		require.Equal(t, test.amp.ChildIndex, amp.ChildIndex())
   461  
   462  	case invalid:
   463  		require.Error(t, err)
   464  		require.Nil(t, amp)
   465  
   466  	case norecord:
   467  		require.NoError(t, err)
   468  		require.Nil(t, amp)
   469  
   470  	default:
   471  		t.Fatalf("test case has non-standard outcome")
   472  	}
   473  }