github.com/decred/dcrlnd@v0.7.6/routing/integrated_routing_context_test.go (about)

     1  package routing
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"math"
     7  	"os"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/decred/dcrlnd/kvdb"
    12  	"github.com/decred/dcrlnd/lnwire"
    13  	"github.com/decred/dcrlnd/routing/route"
    14  )
    15  
    16  const (
    17  	sourceNodeID = 1
    18  	targetNodeID = 2
    19  )
    20  
    21  type mockBandwidthHints struct {
    22  	hints map[uint64]lnwire.MilliAtom
    23  }
    24  
    25  func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64,
    26  	_ lnwire.MilliAtom) (lnwire.MilliAtom, bool) {
    27  
    28  	if m.hints == nil {
    29  		return 0, false
    30  	}
    31  
    32  	balance, ok := m.hints[channelID]
    33  	return balance, ok
    34  }
    35  
    36  // integratedRoutingContext defines the context in which integrated routing
    37  // tests run.
    38  type integratedRoutingContext struct {
    39  	graph *mockGraph
    40  	t     *testing.T
    41  
    42  	source *mockNode
    43  	target *mockNode
    44  
    45  	amt         lnwire.MilliAtom
    46  	maxShardAmt *lnwire.MilliAtom
    47  	finalExpiry int32
    48  
    49  	mcCfg          MissionControlConfig
    50  	pathFindingCfg PathFindingConfig
    51  }
    52  
    53  // newIntegratedRoutingContext instantiates a new integrated routing test
    54  // context with a source and a target node.
    55  func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext {
    56  	// Instantiate a mock graph.
    57  	source := newMockNode(sourceNodeID)
    58  	target := newMockNode(targetNodeID)
    59  
    60  	graph := newMockGraph(t)
    61  	graph.addNode(source)
    62  	graph.addNode(target)
    63  	graph.source = source
    64  
    65  	// Initiate the test context with a set of default configuration values.
    66  	// We don't use the lnd defaults here, because otherwise changing the
    67  	// defaults would break the unit tests. The actual values picked aren't
    68  	// critical to excite certain behavior, but do need to be aligned with
    69  	// the test case assertions.
    70  	ctx := integratedRoutingContext{
    71  		t:           t,
    72  		graph:       graph,
    73  		amt:         100000,
    74  		finalExpiry: 40,
    75  
    76  		mcCfg: MissionControlConfig{
    77  			ProbabilityEstimatorCfg: ProbabilityEstimatorCfg{
    78  				PenaltyHalfLife:       30 * time.Minute,
    79  				AprioriHopProbability: 0.6,
    80  				AprioriWeight:         0.5,
    81  			},
    82  		},
    83  
    84  		pathFindingCfg: PathFindingConfig{
    85  			AttemptCost:    1000,
    86  			MinProbability: 0.01,
    87  		},
    88  
    89  		source: source,
    90  		target: target,
    91  	}
    92  
    93  	return &ctx
    94  }
    95  
    96  // htlcAttempt records the route and outcome of an attempted htlc.
    97  type htlcAttempt struct {
    98  	route   *route.Route
    99  	success bool
   100  }
   101  
   102  func (h htlcAttempt) String() string {
   103  	return fmt.Sprintf("success=%v, route=%v", h.success, h.route)
   104  }
   105  
   106  // testPayment launches a test payment and asserts that it is completed after
   107  // the expected number of attempts.
   108  func (c *integratedRoutingContext) testPayment(maxParts uint32,
   109  	destFeatureBits ...lnwire.FeatureBit) ([]htlcAttempt, error) {
   110  
   111  	// We start out with the base set of MPP feature bits. If the caller
   112  	// overrides this set of bits, then we'll use their feature bits
   113  	// entirely.
   114  	baseFeatureBits := mppFeatures
   115  	if len(destFeatureBits) != 0 {
   116  		baseFeatureBits = lnwire.NewRawFeatureVector(destFeatureBits...)
   117  	}
   118  
   119  	var (
   120  		nextPid  uint64
   121  		attempts []htlcAttempt
   122  	)
   123  
   124  	// Create temporary database for mission control.
   125  	file, err := ioutil.TempFile("", "*.db")
   126  	if err != nil {
   127  		c.t.Fatal(err)
   128  	}
   129  
   130  	dbPath := file.Name()
   131  	defer os.Remove(dbPath)
   132  
   133  	db, err := kvdb.Open(
   134  		kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout,
   135  	)
   136  	if err != nil {
   137  		c.t.Fatal(err)
   138  	}
   139  	defer db.Close()
   140  
   141  	// Instantiate a new mission control with the current configuration
   142  	// values.
   143  	mc, err := NewMissionControl(db, c.source.pubkey, &c.mcCfg)
   144  	if err != nil {
   145  		c.t.Fatal(err)
   146  	}
   147  
   148  	getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) {
   149  		// Create bandwidth hints based on local channel balances.
   150  		bandwidthHints := map[uint64]lnwire.MilliAtom{}
   151  		for _, ch := range c.graph.nodes[c.source.pubkey].channels {
   152  			bandwidthHints[ch.id] = ch.balance
   153  		}
   154  
   155  		return &mockBandwidthHints{
   156  			hints: bandwidthHints,
   157  		}, nil
   158  	}
   159  
   160  	var paymentAddr [32]byte
   161  	payment := LightningPayment{
   162  		FinalCLTVDelta: uint16(c.finalExpiry),
   163  		FeeLimit:       lnwire.MaxMilliAtom,
   164  		Target:         c.target.pubkey,
   165  		PaymentAddr:    &paymentAddr,
   166  		DestFeatures:   lnwire.NewFeatureVector(baseFeatureBits, nil),
   167  		Amount:         c.amt,
   168  		CltvLimit:      math.MaxUint32,
   169  		MaxParts:       maxParts,
   170  	}
   171  
   172  	var paymentHash [32]byte
   173  	if err := payment.SetPaymentHash(paymentHash); err != nil {
   174  		return nil, err
   175  	}
   176  
   177  	if c.maxShardAmt != nil {
   178  		payment.MaxShardAmt = c.maxShardAmt
   179  	}
   180  
   181  	session, err := newPaymentSession(
   182  		&payment, getBandwidthHints,
   183  		func() (routingGraph, func(), error) {
   184  			return c.graph, func() {}, nil
   185  		},
   186  		mc, c.pathFindingCfg,
   187  	)
   188  	if err != nil {
   189  		c.t.Fatal(err)
   190  	}
   191  
   192  	// Override default minimum shard amount.
   193  	session.minShardAmt = lnwire.NewMAtomsFromAtoms(5000)
   194  
   195  	// Now the payment control loop starts. It will keep trying routes until
   196  	// the payment succeeds.
   197  	var (
   198  		amtRemaining  = payment.Amount
   199  		inFlightHtlcs uint32
   200  	)
   201  	for {
   202  		// Create bandwidth hints based on local channel balances.
   203  		bandwidthHints := map[uint64]lnwire.MilliAtom{}
   204  		for _, ch := range c.graph.nodes[c.source.pubkey].channels {
   205  			bandwidthHints[ch.id] = ch.balance
   206  		}
   207  
   208  		// Find a route.
   209  		route, err := session.RequestRoute(
   210  			amtRemaining, lnwire.MaxMilliAtom, inFlightHtlcs, 0,
   211  		)
   212  		if err != nil {
   213  			return attempts, err
   214  		}
   215  
   216  		// Send out the htlc on the mock graph.
   217  		pid := nextPid
   218  		nextPid++
   219  		htlcResult, err := c.graph.sendHtlc(route)
   220  		if err != nil {
   221  			c.t.Fatal(err)
   222  		}
   223  
   224  		success := htlcResult.failure == nil
   225  		attempts = append(attempts, htlcAttempt{
   226  			route:   route,
   227  			success: success,
   228  		})
   229  
   230  		// Process the result. In normal Lightning operations, the
   231  		// sender doesn't get an acknowledgement from the recipient that
   232  		// the htlc arrived. In integrated routing tests, this
   233  		// acknowledgement is available. It is a simplification of
   234  		// reality that still allows certain classes of tests to be
   235  		// performed.
   236  		if success {
   237  			inFlightHtlcs++
   238  
   239  			err := mc.ReportPaymentSuccess(pid, route)
   240  			if err != nil {
   241  				c.t.Fatal(err)
   242  			}
   243  
   244  			amtRemaining -= route.ReceiverAmt()
   245  
   246  			// If the full amount has been paid, the payment is
   247  			// successful and the control loop can be terminated.
   248  			if amtRemaining == 0 {
   249  				break
   250  			}
   251  
   252  			// Otherwise try to send the remaining amount.
   253  			continue
   254  		}
   255  
   256  		// Failure, update mission control and retry.
   257  		finalResult, err := mc.ReportPaymentFail(
   258  			pid, route,
   259  			getNodeIndex(route, htlcResult.failureSource),
   260  			htlcResult.failure,
   261  		)
   262  		if err != nil {
   263  			c.t.Fatal(err)
   264  		}
   265  
   266  		if finalResult != nil {
   267  			break
   268  		}
   269  	}
   270  
   271  	return attempts, nil
   272  }
   273  
   274  // getNodeIndex returns the zero-based index of the given node in the route.
   275  func getNodeIndex(route *route.Route, failureSource route.Vertex) *int {
   276  	if failureSource == route.SourcePubKey {
   277  		idx := 0
   278  		return &idx
   279  	}
   280  
   281  	for i, h := range route.Hops {
   282  		if h.PubKeyBytes == failureSource {
   283  			idx := i + 1
   284  			return &idx
   285  		}
   286  	}
   287  	return nil
   288  }