github.com/decred/dcrlnd@v0.7.6/routing/integrated_routing_context_test.go (about) 1 package routing 2 3 import ( 4 "fmt" 5 "io/ioutil" 6 "math" 7 "os" 8 "testing" 9 "time" 10 11 "github.com/decred/dcrlnd/kvdb" 12 "github.com/decred/dcrlnd/lnwire" 13 "github.com/decred/dcrlnd/routing/route" 14 ) 15 16 const ( 17 sourceNodeID = 1 18 targetNodeID = 2 19 ) 20 21 type mockBandwidthHints struct { 22 hints map[uint64]lnwire.MilliAtom 23 } 24 25 func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64, 26 _ lnwire.MilliAtom) (lnwire.MilliAtom, bool) { 27 28 if m.hints == nil { 29 return 0, false 30 } 31 32 balance, ok := m.hints[channelID] 33 return balance, ok 34 } 35 36 // integratedRoutingContext defines the context in which integrated routing 37 // tests run. 38 type integratedRoutingContext struct { 39 graph *mockGraph 40 t *testing.T 41 42 source *mockNode 43 target *mockNode 44 45 amt lnwire.MilliAtom 46 maxShardAmt *lnwire.MilliAtom 47 finalExpiry int32 48 49 mcCfg MissionControlConfig 50 pathFindingCfg PathFindingConfig 51 } 52 53 // newIntegratedRoutingContext instantiates a new integrated routing test 54 // context with a source and a target node. 55 func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext { 56 // Instantiate a mock graph. 57 source := newMockNode(sourceNodeID) 58 target := newMockNode(targetNodeID) 59 60 graph := newMockGraph(t) 61 graph.addNode(source) 62 graph.addNode(target) 63 graph.source = source 64 65 // Initiate the test context with a set of default configuration values. 66 // We don't use the lnd defaults here, because otherwise changing the 67 // defaults would break the unit tests. The actual values picked aren't 68 // critical to excite certain behavior, but do need to be aligned with 69 // the test case assertions. 70 ctx := integratedRoutingContext{ 71 t: t, 72 graph: graph, 73 amt: 100000, 74 finalExpiry: 40, 75 76 mcCfg: MissionControlConfig{ 77 ProbabilityEstimatorCfg: ProbabilityEstimatorCfg{ 78 PenaltyHalfLife: 30 * time.Minute, 79 AprioriHopProbability: 0.6, 80 AprioriWeight: 0.5, 81 }, 82 }, 83 84 pathFindingCfg: PathFindingConfig{ 85 AttemptCost: 1000, 86 MinProbability: 0.01, 87 }, 88 89 source: source, 90 target: target, 91 } 92 93 return &ctx 94 } 95 96 // htlcAttempt records the route and outcome of an attempted htlc. 97 type htlcAttempt struct { 98 route *route.Route 99 success bool 100 } 101 102 func (h htlcAttempt) String() string { 103 return fmt.Sprintf("success=%v, route=%v", h.success, h.route) 104 } 105 106 // testPayment launches a test payment and asserts that it is completed after 107 // the expected number of attempts. 108 func (c *integratedRoutingContext) testPayment(maxParts uint32, 109 destFeatureBits ...lnwire.FeatureBit) ([]htlcAttempt, error) { 110 111 // We start out with the base set of MPP feature bits. If the caller 112 // overrides this set of bits, then we'll use their feature bits 113 // entirely. 114 baseFeatureBits := mppFeatures 115 if len(destFeatureBits) != 0 { 116 baseFeatureBits = lnwire.NewRawFeatureVector(destFeatureBits...) 117 } 118 119 var ( 120 nextPid uint64 121 attempts []htlcAttempt 122 ) 123 124 // Create temporary database for mission control. 125 file, err := ioutil.TempFile("", "*.db") 126 if err != nil { 127 c.t.Fatal(err) 128 } 129 130 dbPath := file.Name() 131 defer os.Remove(dbPath) 132 133 db, err := kvdb.Open( 134 kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout, 135 ) 136 if err != nil { 137 c.t.Fatal(err) 138 } 139 defer db.Close() 140 141 // Instantiate a new mission control with the current configuration 142 // values. 143 mc, err := NewMissionControl(db, c.source.pubkey, &c.mcCfg) 144 if err != nil { 145 c.t.Fatal(err) 146 } 147 148 getBandwidthHints := func(_ routingGraph) (bandwidthHints, error) { 149 // Create bandwidth hints based on local channel balances. 150 bandwidthHints := map[uint64]lnwire.MilliAtom{} 151 for _, ch := range c.graph.nodes[c.source.pubkey].channels { 152 bandwidthHints[ch.id] = ch.balance 153 } 154 155 return &mockBandwidthHints{ 156 hints: bandwidthHints, 157 }, nil 158 } 159 160 var paymentAddr [32]byte 161 payment := LightningPayment{ 162 FinalCLTVDelta: uint16(c.finalExpiry), 163 FeeLimit: lnwire.MaxMilliAtom, 164 Target: c.target.pubkey, 165 PaymentAddr: &paymentAddr, 166 DestFeatures: lnwire.NewFeatureVector(baseFeatureBits, nil), 167 Amount: c.amt, 168 CltvLimit: math.MaxUint32, 169 MaxParts: maxParts, 170 } 171 172 var paymentHash [32]byte 173 if err := payment.SetPaymentHash(paymentHash); err != nil { 174 return nil, err 175 } 176 177 if c.maxShardAmt != nil { 178 payment.MaxShardAmt = c.maxShardAmt 179 } 180 181 session, err := newPaymentSession( 182 &payment, getBandwidthHints, 183 func() (routingGraph, func(), error) { 184 return c.graph, func() {}, nil 185 }, 186 mc, c.pathFindingCfg, 187 ) 188 if err != nil { 189 c.t.Fatal(err) 190 } 191 192 // Override default minimum shard amount. 193 session.minShardAmt = lnwire.NewMAtomsFromAtoms(5000) 194 195 // Now the payment control loop starts. It will keep trying routes until 196 // the payment succeeds. 197 var ( 198 amtRemaining = payment.Amount 199 inFlightHtlcs uint32 200 ) 201 for { 202 // Create bandwidth hints based on local channel balances. 203 bandwidthHints := map[uint64]lnwire.MilliAtom{} 204 for _, ch := range c.graph.nodes[c.source.pubkey].channels { 205 bandwidthHints[ch.id] = ch.balance 206 } 207 208 // Find a route. 209 route, err := session.RequestRoute( 210 amtRemaining, lnwire.MaxMilliAtom, inFlightHtlcs, 0, 211 ) 212 if err != nil { 213 return attempts, err 214 } 215 216 // Send out the htlc on the mock graph. 217 pid := nextPid 218 nextPid++ 219 htlcResult, err := c.graph.sendHtlc(route) 220 if err != nil { 221 c.t.Fatal(err) 222 } 223 224 success := htlcResult.failure == nil 225 attempts = append(attempts, htlcAttempt{ 226 route: route, 227 success: success, 228 }) 229 230 // Process the result. In normal Lightning operations, the 231 // sender doesn't get an acknowledgement from the recipient that 232 // the htlc arrived. In integrated routing tests, this 233 // acknowledgement is available. It is a simplification of 234 // reality that still allows certain classes of tests to be 235 // performed. 236 if success { 237 inFlightHtlcs++ 238 239 err := mc.ReportPaymentSuccess(pid, route) 240 if err != nil { 241 c.t.Fatal(err) 242 } 243 244 amtRemaining -= route.ReceiverAmt() 245 246 // If the full amount has been paid, the payment is 247 // successful and the control loop can be terminated. 248 if amtRemaining == 0 { 249 break 250 } 251 252 // Otherwise try to send the remaining amount. 253 continue 254 } 255 256 // Failure, update mission control and retry. 257 finalResult, err := mc.ReportPaymentFail( 258 pid, route, 259 getNodeIndex(route, htlcResult.failureSource), 260 htlcResult.failure, 261 ) 262 if err != nil { 263 c.t.Fatal(err) 264 } 265 266 if finalResult != nil { 267 break 268 } 269 } 270 271 return attempts, nil 272 } 273 274 // getNodeIndex returns the zero-based index of the given node in the route. 275 func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { 276 if failureSource == route.SourcePubKey { 277 idx := 0 278 return &idx 279 } 280 281 for i, h := range route.Hops { 282 if h.PubKeyBytes == failureSource { 283 idx := i + 1 284 return &idx 285 } 286 } 287 return nil 288 }