github.com/decred/dcrlnd@v0.7.6/routing/mock_test.go (about) 1 package routing 2 3 import ( 4 "fmt" 5 "sync" 6 7 "github.com/decred/dcrd/dcrec/secp256k1/v4" 8 "github.com/decred/dcrlnd/channeldb" 9 "github.com/decred/dcrlnd/htlcswitch" 10 "github.com/decred/dcrlnd/lntypes" 11 "github.com/decred/dcrlnd/lnwire" 12 "github.com/decred/dcrlnd/routing/route" 13 "github.com/go-errors/errors" 14 "github.com/stretchr/testify/mock" 15 ) 16 17 type mockPaymentAttemptDispatcherOld struct { 18 onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error) 19 results map[uint64]*htlcswitch.PaymentResult 20 21 sync.Mutex 22 } 23 24 var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcherOld)(nil) 25 26 func (m *mockPaymentAttemptDispatcherOld) SendHTLC( 27 firstHop lnwire.ShortChannelID, pid uint64, 28 _ *lnwire.UpdateAddHTLC) error { 29 30 if m.onPayment == nil { 31 return nil 32 } 33 34 var result *htlcswitch.PaymentResult 35 preimage, err := m.onPayment(firstHop) 36 if err != nil { 37 rtErr, ok := err.(htlcswitch.ClearTextError) 38 if !ok { 39 return err 40 } 41 result = &htlcswitch.PaymentResult{ 42 Error: rtErr, 43 } 44 } else { 45 result = &htlcswitch.PaymentResult{Preimage: preimage} 46 } 47 48 m.Lock() 49 if m.results == nil { 50 m.results = make(map[uint64]*htlcswitch.PaymentResult) 51 } 52 53 m.results[pid] = result 54 m.Unlock() 55 56 return nil 57 } 58 59 func (m *mockPaymentAttemptDispatcherOld) GetPaymentResult(paymentID uint64, 60 _ lntypes.Hash, _ htlcswitch.ErrorDecrypter) ( 61 <-chan *htlcswitch.PaymentResult, error) { 62 63 c := make(chan *htlcswitch.PaymentResult, 1) 64 65 m.Lock() 66 res, ok := m.results[paymentID] 67 m.Unlock() 68 69 if !ok { 70 return nil, htlcswitch.ErrPaymentIDNotFound 71 } 72 c <- res 73 74 return c, nil 75 76 } 77 func (m *mockPaymentAttemptDispatcherOld) CleanStore( 78 map[uint64]struct{}) error { 79 80 return nil 81 } 82 83 func (m *mockPaymentAttemptDispatcherOld) setPaymentResult( 84 f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) { 85 86 m.onPayment = f 87 } 88 89 type mockPaymentSessionSourceOld struct { 90 routes []*route.Route 91 routeRelease chan struct{} 92 } 93 94 var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) 95 96 func (m *mockPaymentSessionSourceOld) NewPaymentSession( 97 _ *LightningPayment) (PaymentSession, error) { 98 99 return &mockPaymentSessionOld{ 100 routes: m.routes, 101 release: m.routeRelease, 102 }, nil 103 } 104 105 func (m *mockPaymentSessionSourceOld) NewPaymentSessionForRoute( 106 preBuiltRoute *route.Route) PaymentSession { 107 return nil 108 } 109 110 func (m *mockPaymentSessionSourceOld) NewPaymentSessionEmpty() PaymentSession { 111 return &mockPaymentSessionOld{} 112 } 113 114 type mockMissionControlOld struct { 115 MissionControl 116 } 117 118 var _ MissionController = (*mockMissionControlOld)(nil) 119 120 func (m *mockMissionControlOld) ReportPaymentFail( 121 paymentID uint64, rt *route.Route, 122 failureSourceIdx *int, failure lnwire.FailureMessage) ( 123 *channeldb.FailureReason, error) { 124 125 // Report a permanent failure if this is an error caused 126 // by incorrect details. 127 if failure.Code() == lnwire.CodeIncorrectOrUnknownPaymentDetails { 128 reason := channeldb.FailureReasonPaymentDetails 129 return &reason, nil 130 } 131 132 return nil, nil 133 } 134 135 func (m *mockMissionControlOld) ReportPaymentSuccess(paymentID uint64, 136 rt *route.Route) error { 137 138 return nil 139 } 140 141 func (m *mockMissionControlOld) GetProbability(fromNode, toNode route.Vertex, 142 amt lnwire.MilliAtom) float64 { 143 144 return 0 145 } 146 147 type mockPaymentSessionOld struct { 148 routes []*route.Route 149 150 // release is a channel that optionally blocks requesting a route 151 // from our mock payment channel. If this value is nil, we will just 152 // release the route automatically. 153 release chan struct{} 154 } 155 156 var _ PaymentSession = (*mockPaymentSessionOld)(nil) 157 158 func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliAtom, 159 _, height uint32) (*route.Route, error) { 160 161 if m.release != nil { 162 m.release <- struct{}{} 163 } 164 165 if len(m.routes) == 0 { 166 return nil, errNoPathFound 167 } 168 169 r := m.routes[0] 170 m.routes = m.routes[1:] 171 172 return r, nil 173 } 174 175 func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, 176 _ *secp256k1.PublicKey, _ *channeldb.CachedEdgePolicy) bool { 177 178 return false 179 } 180 181 func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *secp256k1.PublicKey, 182 _ uint64) *channeldb.CachedEdgePolicy { 183 184 return nil 185 } 186 187 type mockPayerOld struct { 188 sendResult chan error 189 paymentResult chan *htlcswitch.PaymentResult 190 quit chan struct{} 191 } 192 193 var _ PaymentAttemptDispatcher = (*mockPayerOld)(nil) 194 195 func (m *mockPayerOld) SendHTLC(_ lnwire.ShortChannelID, 196 paymentID uint64, 197 _ *lnwire.UpdateAddHTLC) error { 198 199 select { 200 case res := <-m.sendResult: 201 return res 202 case <-m.quit: 203 return fmt.Errorf("test quitting") 204 } 205 206 } 207 208 func (m *mockPayerOld) GetPaymentResult(paymentID uint64, _ lntypes.Hash, 209 _ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) { 210 211 select { 212 case res, ok := <-m.paymentResult: 213 resChan := make(chan *htlcswitch.PaymentResult, 1) 214 if !ok { 215 close(resChan) 216 } else { 217 resChan <- res 218 } 219 220 return resChan, nil 221 222 case <-m.quit: 223 return nil, fmt.Errorf("test quitting") 224 } 225 } 226 227 func (m *mockPayerOld) CleanStore(pids map[uint64]struct{}) error { 228 return nil 229 } 230 231 type initArgs struct { 232 c *channeldb.PaymentCreationInfo 233 } 234 235 type registerAttemptArgs struct { 236 a *channeldb.HTLCAttemptInfo 237 } 238 239 type settleAttemptArgs struct { 240 preimg lntypes.Preimage 241 } 242 243 type failAttemptArgs struct { 244 reason *channeldb.HTLCFailInfo 245 } 246 247 type failPaymentArgs struct { 248 reason channeldb.FailureReason 249 } 250 251 type testPayment struct { 252 info channeldb.PaymentCreationInfo 253 attempts []channeldb.HTLCAttempt 254 } 255 256 type mockControlTowerOld struct { 257 payments map[lntypes.Hash]*testPayment 258 successful map[lntypes.Hash]struct{} 259 failed map[lntypes.Hash]channeldb.FailureReason 260 261 init chan initArgs 262 registerAttempt chan registerAttemptArgs 263 settleAttempt chan settleAttemptArgs 264 failAttempt chan failAttemptArgs 265 failPayment chan failPaymentArgs 266 fetchInFlight chan struct{} 267 268 sync.Mutex 269 } 270 271 var _ ControlTower = (*mockControlTowerOld)(nil) 272 273 func makeMockControlTower() *mockControlTowerOld { 274 return &mockControlTowerOld{ 275 payments: make(map[lntypes.Hash]*testPayment), 276 successful: make(map[lntypes.Hash]struct{}), 277 failed: make(map[lntypes.Hash]channeldb.FailureReason), 278 } 279 } 280 281 func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash, 282 c *channeldb.PaymentCreationInfo) error { 283 284 if m.init != nil { 285 m.init <- initArgs{c} 286 } 287 288 m.Lock() 289 defer m.Unlock() 290 291 // Don't allow re-init a successful payment. 292 if _, ok := m.successful[phash]; ok { 293 return channeldb.ErrAlreadyPaid 294 } 295 296 _, failed := m.failed[phash] 297 _, ok := m.payments[phash] 298 299 // If the payment is known, only allow re-init if failed. 300 if ok && !failed { 301 return channeldb.ErrPaymentInFlight 302 } 303 304 delete(m.failed, phash) 305 m.payments[phash] = &testPayment{ 306 info: *c, 307 } 308 309 return nil 310 } 311 312 func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, 313 a *channeldb.HTLCAttemptInfo) error { 314 315 if m.registerAttempt != nil { 316 m.registerAttempt <- registerAttemptArgs{a} 317 } 318 319 m.Lock() 320 defer m.Unlock() 321 322 // Lookup payment. 323 p, ok := m.payments[phash] 324 if !ok { 325 return channeldb.ErrPaymentNotInitiated 326 } 327 328 var inFlight bool 329 for _, a := range p.attempts { 330 if a.Settle != nil { 331 continue 332 } 333 334 if a.Failure != nil { 335 continue 336 } 337 338 inFlight = true 339 } 340 341 // Cannot register attempts for successful or failed payments. 342 _, settled := m.successful[phash] 343 _, failed := m.failed[phash] 344 345 if settled || failed { 346 return channeldb.ErrPaymentTerminal 347 } 348 349 if settled && !inFlight { 350 return channeldb.ErrPaymentAlreadySucceeded 351 } 352 353 if failed && !inFlight { 354 return channeldb.ErrPaymentAlreadyFailed 355 } 356 357 // Add attempt to payment. 358 p.attempts = append(p.attempts, channeldb.HTLCAttempt{ 359 HTLCAttemptInfo: *a, 360 }) 361 m.payments[phash] = p 362 363 return nil 364 } 365 366 func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, 367 pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( 368 *channeldb.HTLCAttempt, error) { 369 370 if m.settleAttempt != nil { 371 m.settleAttempt <- settleAttemptArgs{settleInfo.Preimage} 372 } 373 374 m.Lock() 375 defer m.Unlock() 376 377 // Only allow setting attempts if the payment is known. 378 p, ok := m.payments[phash] 379 if !ok { 380 return nil, channeldb.ErrPaymentNotInitiated 381 } 382 383 // Find the attempt with this pid, and set the settle info. 384 for i, a := range p.attempts { 385 if a.AttemptID != pid { 386 continue 387 } 388 389 if a.Settle != nil { 390 return nil, channeldb.ErrAttemptAlreadySettled 391 } 392 if a.Failure != nil { 393 return nil, channeldb.ErrAttemptAlreadyFailed 394 } 395 396 p.attempts[i].Settle = settleInfo 397 398 // Mark the payment successful on first settled attempt. 399 m.successful[phash] = struct{}{} 400 return &channeldb.HTLCAttempt{ 401 Settle: settleInfo, 402 }, nil 403 } 404 405 return nil, fmt.Errorf("pid not found") 406 } 407 408 func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, 409 failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { 410 411 if m.failAttempt != nil { 412 m.failAttempt <- failAttemptArgs{failInfo} 413 } 414 415 m.Lock() 416 defer m.Unlock() 417 418 // Only allow failing attempts if the payment is known. 419 p, ok := m.payments[phash] 420 if !ok { 421 return nil, channeldb.ErrPaymentNotInitiated 422 } 423 424 // Find the attempt with this pid, and set the failure info. 425 for i, a := range p.attempts { 426 if a.AttemptID != pid { 427 continue 428 } 429 430 if a.Settle != nil { 431 return nil, channeldb.ErrAttemptAlreadySettled 432 } 433 if a.Failure != nil { 434 return nil, channeldb.ErrAttemptAlreadyFailed 435 } 436 437 p.attempts[i].Failure = failInfo 438 return &channeldb.HTLCAttempt{ 439 Failure: failInfo, 440 }, nil 441 } 442 443 return nil, fmt.Errorf("pid not found") 444 } 445 446 func (m *mockControlTowerOld) Fail(phash lntypes.Hash, 447 reason channeldb.FailureReason) error { 448 449 m.Lock() 450 defer m.Unlock() 451 452 if m.failPayment != nil { 453 m.failPayment <- failPaymentArgs{reason} 454 } 455 456 // Payment must be known. 457 if _, ok := m.payments[phash]; !ok { 458 return channeldb.ErrPaymentNotInitiated 459 } 460 461 m.failed[phash] = reason 462 463 return nil 464 } 465 466 func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) ( 467 *channeldb.MPPayment, error) { 468 469 m.Lock() 470 defer m.Unlock() 471 472 return m.fetchPayment(phash) 473 } 474 475 func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) ( 476 *channeldb.MPPayment, error) { 477 478 p, ok := m.payments[phash] 479 if !ok { 480 return nil, channeldb.ErrPaymentNotInitiated 481 } 482 483 mp := &channeldb.MPPayment{ 484 Info: &p.info, 485 } 486 487 reason, ok := m.failed[phash] 488 if ok { 489 mp.FailureReason = &reason 490 } 491 492 // Return a copy of the current attempts. 493 mp.HTLCs = append(mp.HTLCs, p.attempts...) 494 return mp, nil 495 } 496 497 func (m *mockControlTowerOld) FetchInFlightPayments() ( 498 []*channeldb.MPPayment, error) { 499 500 if m.fetchInFlight != nil { 501 m.fetchInFlight <- struct{}{} 502 } 503 504 m.Lock() 505 defer m.Unlock() 506 507 // In flight are all payments not successful or failed. 508 var fl []*channeldb.MPPayment 509 for hash := range m.payments { 510 if _, ok := m.successful[hash]; ok { 511 continue 512 } 513 if _, ok := m.failed[hash]; ok { 514 continue 515 } 516 517 mp, err := m.fetchPayment(hash) 518 if err != nil { 519 return nil, err 520 } 521 522 fl = append(fl, mp) 523 } 524 525 return fl, nil 526 } 527 528 func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) ( 529 *ControlTowerSubscriber, error) { 530 531 return nil, errors.New("not implemented") 532 } 533 534 type mockPaymentAttemptDispatcher struct { 535 mock.Mock 536 537 resultChan chan *htlcswitch.PaymentResult 538 } 539 540 var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) 541 542 func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, 543 pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error { 544 545 args := m.Called(firstHop, pid, htlcAdd) 546 return args.Error(0) 547 } 548 549 func (m *mockPaymentAttemptDispatcher) GetPaymentResult(attemptID uint64, 550 paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) ( 551 <-chan *htlcswitch.PaymentResult, error) { 552 553 m.Called(attemptID, paymentHash, deobfuscator) 554 555 // Instead of returning the mocked returned values, we need to return 556 // the chan resultChan so it can be converted into a read-only chan. 557 return m.resultChan, nil 558 } 559 560 func (m *mockPaymentAttemptDispatcher) CleanStore( 561 keepPids map[uint64]struct{}) error { 562 563 args := m.Called(keepPids) 564 return args.Error(0) 565 } 566 567 type mockPaymentSessionSource struct { 568 mock.Mock 569 } 570 571 var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) 572 573 func (m *mockPaymentSessionSource) NewPaymentSession( 574 payment *LightningPayment) (PaymentSession, error) { 575 576 args := m.Called(payment) 577 return args.Get(0).(PaymentSession), args.Error(1) 578 } 579 580 func (m *mockPaymentSessionSource) NewPaymentSessionForRoute( 581 preBuiltRoute *route.Route) PaymentSession { 582 583 args := m.Called(preBuiltRoute) 584 return args.Get(0).(PaymentSession) 585 } 586 587 func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession { 588 args := m.Called() 589 return args.Get(0).(PaymentSession) 590 } 591 592 type mockMissionControl struct { 593 mock.Mock 594 } 595 596 var _ MissionController = (*mockMissionControl)(nil) 597 598 func (m *mockMissionControl) ReportPaymentFail( 599 paymentID uint64, rt *route.Route, 600 failureSourceIdx *int, failure lnwire.FailureMessage) ( 601 *channeldb.FailureReason, error) { 602 603 args := m.Called(paymentID, rt, failureSourceIdx, failure) 604 605 // Type assertion on nil will fail, so we check and return here. 606 if args.Get(0) == nil { 607 return nil, args.Error(1) 608 } 609 610 return args.Get(0).(*channeldb.FailureReason), args.Error(1) 611 } 612 613 func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64, 614 rt *route.Route) error { 615 616 args := m.Called(paymentID, rt) 617 return args.Error(0) 618 } 619 620 func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex, 621 amt lnwire.MilliAtom) float64 { 622 623 args := m.Called(fromNode, toNode, amt) 624 return args.Get(0).(float64) 625 } 626 627 type mockPaymentSession struct { 628 mock.Mock 629 } 630 631 var _ PaymentSession = (*mockPaymentSession)(nil) 632 633 func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliAtom, 634 activeShards, height uint32) (*route.Route, error) { 635 args := m.Called(maxAmt, feeLimit, activeShards, height) 636 return args.Get(0).(*route.Route), args.Error(1) 637 } 638 639 func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, 640 pubKey *secp256k1.PublicKey, policy *channeldb.CachedEdgePolicy) bool { 641 642 args := m.Called(msg, pubKey, policy) 643 return args.Bool(0) 644 } 645 646 func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *secp256k1.PublicKey, 647 channelID uint64) *channeldb.CachedEdgePolicy { 648 649 args := m.Called(pubKey, channelID) 650 return args.Get(0).(*channeldb.CachedEdgePolicy) 651 } 652 653 type mockControlTower struct { 654 mock.Mock 655 sync.Mutex 656 } 657 658 var _ ControlTower = (*mockControlTower)(nil) 659 660 func (m *mockControlTower) InitPayment(phash lntypes.Hash, 661 c *channeldb.PaymentCreationInfo) error { 662 663 args := m.Called(phash, c) 664 return args.Error(0) 665 } 666 667 func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, 668 a *channeldb.HTLCAttemptInfo) error { 669 670 m.Lock() 671 defer m.Unlock() 672 673 args := m.Called(phash, a) 674 return args.Error(0) 675 } 676 677 func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, 678 pid uint64, settleInfo *channeldb.HTLCSettleInfo) ( 679 *channeldb.HTLCAttempt, error) { 680 681 m.Lock() 682 defer m.Unlock() 683 684 args := m.Called(phash, pid, settleInfo) 685 return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) 686 } 687 688 func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, 689 failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) { 690 691 m.Lock() 692 defer m.Unlock() 693 694 args := m.Called(phash, pid, failInfo) 695 return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1) 696 } 697 698 func (m *mockControlTower) Fail(phash lntypes.Hash, 699 reason channeldb.FailureReason) error { 700 701 m.Lock() 702 defer m.Unlock() 703 704 args := m.Called(phash, reason) 705 return args.Error(0) 706 } 707 708 func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( 709 *channeldb.MPPayment, error) { 710 711 m.Lock() 712 defer m.Unlock() 713 args := m.Called(phash) 714 715 // Type assertion on nil will fail, so we check and return here. 716 if args.Get(0) == nil { 717 return nil, args.Error(1) 718 } 719 720 // Make a copy of the payment here to avoid data race. 721 p := args.Get(0).(*channeldb.MPPayment) 722 payment := &channeldb.MPPayment{ 723 FailureReason: p.FailureReason, 724 } 725 payment.HTLCs = make([]channeldb.HTLCAttempt, len(p.HTLCs)) 726 copy(payment.HTLCs, p.HTLCs) 727 728 return payment, args.Error(1) 729 } 730 731 func (m *mockControlTower) FetchInFlightPayments() ( 732 []*channeldb.MPPayment, error) { 733 734 args := m.Called() 735 return args.Get(0).([]*channeldb.MPPayment), args.Error(1) 736 } 737 738 func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( 739 *ControlTowerSubscriber, error) { 740 741 args := m.Called(paymentHash) 742 return args.Get(0).(*ControlTowerSubscriber), args.Error(1) 743 } 744 745 type mockLink struct { 746 htlcswitch.ChannelLink 747 bandwidth lnwire.MilliAtom 748 mayAddOutgoingErr error 749 ineligible bool 750 } 751 752 // Bandwidth returns the bandwidth the mock was configured with. 753 func (m *mockLink) Bandwidth() lnwire.MilliAtom { 754 return m.bandwidth 755 } 756 757 // EligibleToForward returns the mock's configured eligibility. 758 func (m *mockLink) EligibleToForward() bool { 759 return !m.ineligible 760 } 761 762 // MayAddOutgoingHtlc returns the error configured in our mock. 763 func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliAtom) error { 764 return m.mayAddOutgoingErr 765 }