github.com/decred/dcrlnd@v0.7.6/htlcswitch/mock.go (about) 1 package htlcswitch 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "encoding/binary" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net" 11 "os" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 17 "github.com/decred/dcrd/dcrec/secp256k1/v4" 18 "github.com/decred/dcrd/dcrutil/v4" 19 "github.com/decred/dcrd/wire" 20 "github.com/decred/dcrlnd/chainntnfs" 21 "github.com/decred/dcrlnd/channeldb" 22 "github.com/decred/dcrlnd/clock" 23 "github.com/decred/dcrlnd/contractcourt" 24 "github.com/decred/dcrlnd/htlcswitch/hop" 25 "github.com/decred/dcrlnd/invoices" 26 "github.com/decred/dcrlnd/lnpeer" 27 "github.com/decred/dcrlnd/lntest/mock" 28 "github.com/decred/dcrlnd/lntypes" 29 "github.com/decred/dcrlnd/lnwallet/chainfee" 30 "github.com/decred/dcrlnd/lnwire" 31 "github.com/decred/dcrlnd/ticker" 32 sphinx "github.com/decred/lightning-onion/v4" 33 "github.com/go-errors/errors" 34 ) 35 36 type mockPreimageCache struct { 37 sync.Mutex 38 preimageMap map[lntypes.Hash]lntypes.Preimage 39 } 40 41 func newMockPreimageCache() *mockPreimageCache { 42 return &mockPreimageCache{ 43 preimageMap: make(map[lntypes.Hash]lntypes.Preimage), 44 } 45 } 46 47 func (m *mockPreimageCache) LookupPreimage( 48 hash lntypes.Hash) (lntypes.Preimage, bool) { 49 50 m.Lock() 51 defer m.Unlock() 52 53 p, ok := m.preimageMap[hash] 54 return p, ok 55 } 56 57 func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error { 58 m.Lock() 59 defer m.Unlock() 60 61 for _, preimage := range preimages { 62 m.preimageMap[preimage.Hash()] = preimage 63 } 64 65 return nil 66 } 67 68 func (m *mockPreimageCache) SubscribeUpdates() *contractcourt.WitnessSubscription { 69 return nil 70 } 71 72 type mockFeeEstimator struct { 73 byteFeeIn chan chainfee.AtomPerKByte 74 relayFee chan chainfee.AtomPerKByte 75 76 quit chan struct{} 77 } 78 79 func newMockFeeEstimator() *mockFeeEstimator { 80 return &mockFeeEstimator{ 81 byteFeeIn: make(chan chainfee.AtomPerKByte), 82 relayFee: make(chan chainfee.AtomPerKByte), 83 quit: make(chan struct{}), 84 } 85 } 86 87 func (m *mockFeeEstimator) EstimateFeePerKB( 88 numBlocks uint32) (chainfee.AtomPerKByte, error) { 89 90 select { 91 case feeRate := <-m.byteFeeIn: 92 return feeRate, nil 93 case <-m.quit: 94 return 0, fmt.Errorf("exiting") 95 } 96 } 97 98 func (m *mockFeeEstimator) RelayFeePerKB() chainfee.AtomPerKByte { 99 select { 100 case feeRate := <-m.relayFee: 101 return feeRate 102 case <-m.quit: 103 return 0 104 } 105 } 106 107 func (m *mockFeeEstimator) Start() error { 108 return nil 109 } 110 func (m *mockFeeEstimator) Stop() error { 111 close(m.quit) 112 return nil 113 } 114 115 var _ chainfee.Estimator = (*mockFeeEstimator)(nil) 116 117 type mockForwardingLog struct { 118 sync.Mutex 119 120 events map[time.Time]channeldb.ForwardingEvent 121 } 122 123 func (m *mockForwardingLog) AddForwardingEvents(events []channeldb.ForwardingEvent) error { 124 m.Lock() 125 defer m.Unlock() 126 127 for _, event := range events { 128 m.events[event.Timestamp] = event 129 } 130 131 return nil 132 } 133 134 type mockServer struct { 135 started int32 // To be used atomically. 136 shutdown int32 // To be used atomically. 137 wg sync.WaitGroup 138 quit chan struct{} 139 140 t testing.TB 141 142 name string 143 messages chan lnwire.Message 144 145 id [33]byte 146 htlcSwitch *Switch 147 148 registry *mockInvoiceRegistry 149 pCache *mockPreimageCache 150 interceptorFuncs []messageInterceptor 151 } 152 153 var _ lnpeer.Peer = (*mockServer)(nil) 154 155 func initDB() (*channeldb.DB, error) { 156 tempPath, err := ioutil.TempDir("", "switchdb") 157 if err != nil { 158 return nil, err 159 } 160 161 db, err := channeldb.Open(tempPath) 162 if err != nil { 163 return nil, err 164 } 165 166 return db, err 167 } 168 169 func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { 170 var err error 171 172 if db == nil { 173 db, err = initDB() 174 if err != nil { 175 return nil, err 176 } 177 } 178 179 cfg := Config{ 180 DB: db, 181 FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels, 182 FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels, 183 SwitchPackager: channeldb.NewSwitchPackager(), 184 FwdingLog: &mockForwardingLog{ 185 events: make(map[time.Time]channeldb.ForwardingEvent), 186 }, 187 FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { 188 return nil, nil 189 }, 190 Notifier: &mock.ChainNotifier{ 191 SpendChan: make(chan *chainntnfs.SpendDetail), 192 EpochChan: make(chan *chainntnfs.BlockEpoch), 193 ConfChan: make(chan *chainntnfs.TxConfirmation), 194 }, 195 FwdEventTicker: ticker.NewForce(DefaultFwdEventInterval), 196 LogEventTicker: ticker.NewForce(DefaultLogInterval), 197 AckEventTicker: ticker.NewForce(DefaultAckInterval), 198 HtlcNotifier: &mockHTLCNotifier{}, 199 Clock: clock.NewDefaultClock(), 200 HTLCExpiry: time.Hour, 201 DustThreshold: DefaultDustThreshold, 202 } 203 204 return New(cfg, startingHeight) 205 } 206 207 func newMockServer(t testing.TB, name string, startingHeight uint32, 208 db *channeldb.DB, defaultDelta uint32) (*mockServer, error) { 209 210 var id [33]byte 211 h := sha256.Sum256([]byte(name)) 212 copy(id[:], h[:]) 213 214 pCache := newMockPreimageCache() 215 216 htlcSwitch, err := initSwitchWithDB(startingHeight, db) 217 if err != nil { 218 return nil, err 219 } 220 221 registry := newMockRegistry(defaultDelta) 222 223 return &mockServer{ 224 t: t, 225 id: id, 226 name: name, 227 messages: make(chan lnwire.Message, 3000), 228 quit: make(chan struct{}), 229 registry: registry, 230 htlcSwitch: htlcSwitch, 231 pCache: pCache, 232 interceptorFuncs: make([]messageInterceptor, 0), 233 }, nil 234 } 235 236 func (s *mockServer) Start() error { 237 if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { 238 return errors.New("mock server already started") 239 } 240 241 if err := s.htlcSwitch.Start(); err != nil { 242 return err 243 } 244 245 s.wg.Add(1) 246 go func() { 247 defer s.wg.Done() 248 249 defer func() { 250 s.htlcSwitch.Stop() 251 }() 252 253 for { 254 select { 255 case msg := <-s.messages: 256 var shouldSkip bool 257 258 for _, interceptor := range s.interceptorFuncs { 259 skip, err := interceptor(msg) 260 if err != nil { 261 s.t.Fatalf("%v: error in the "+ 262 "interceptor: %v", s.name, err) 263 return 264 } 265 shouldSkip = shouldSkip || skip 266 } 267 268 if shouldSkip { 269 continue 270 } 271 272 if err := s.readHandler(msg); err != nil { 273 s.t.Fatal(err) 274 return 275 } 276 case <-s.quit: 277 return 278 } 279 } 280 }() 281 282 return nil 283 } 284 285 func (s *mockServer) QuitSignal() <-chan struct{} { 286 return s.quit 287 } 288 289 // mockHopIterator represents the test version of hop iterator which instead 290 // of encrypting the path in onion blob just stores the path as a list of hops. 291 type mockHopIterator struct { 292 hops []*hop.Payload 293 } 294 295 func newMockHopIterator(hops ...*hop.Payload) hop.Iterator { 296 return &mockHopIterator{hops: hops} 297 } 298 299 func (r *mockHopIterator) HopPayload() (*hop.Payload, error) { 300 h := r.hops[0] 301 r.hops = r.hops[1:] 302 return h, nil 303 } 304 305 func (r *mockHopIterator) ExtraOnionBlob() []byte { 306 return nil 307 } 308 309 func (r *mockHopIterator) ExtractErrorEncrypter( 310 extracter hop.ErrorEncrypterExtracter) (hop.ErrorEncrypter, 311 lnwire.FailCode) { 312 313 return extracter(nil) 314 } 315 316 func (r *mockHopIterator) EncodeNextHop(w io.Writer) error { 317 var hopLength [4]byte 318 binary.BigEndian.PutUint32(hopLength[:], uint32(len(r.hops))) 319 320 if _, err := w.Write(hopLength[:]); err != nil { 321 return err 322 } 323 324 for _, hop := range r.hops { 325 fwdInfo := hop.ForwardingInfo() 326 if err := encodeFwdInfo(w, &fwdInfo); err != nil { 327 return err 328 } 329 } 330 331 return nil 332 } 333 334 func encodeFwdInfo(w io.Writer, f *hop.ForwardingInfo) error { 335 if _, err := w.Write([]byte{byte(f.Network)}); err != nil { 336 return err 337 } 338 339 if err := binary.Write(w, binary.BigEndian, f.NextHop); err != nil { 340 return err 341 } 342 343 if err := binary.Write(w, binary.BigEndian, f.AmountToForward); err != nil { 344 return err 345 } 346 347 if err := binary.Write(w, binary.BigEndian, f.OutgoingCTLV); err != nil { 348 return err 349 } 350 351 return nil 352 } 353 354 var _ hop.Iterator = (*mockHopIterator)(nil) 355 356 // mockObfuscator mock implementation of the failure obfuscator which only 357 // encodes the failure and do not makes any onion obfuscation. 358 type mockObfuscator struct { 359 ogPacket *sphinx.OnionPacket 360 failure lnwire.FailureMessage 361 } 362 363 // NewMockObfuscator initializes a dummy mockObfuscator used for testing. 364 func NewMockObfuscator() hop.ErrorEncrypter { 365 return &mockObfuscator{} 366 } 367 368 func (o *mockObfuscator) OnionPacket() *sphinx.OnionPacket { 369 return o.ogPacket 370 } 371 372 func (o *mockObfuscator) Type() hop.EncrypterType { 373 return hop.EncrypterTypeMock 374 } 375 376 func (o *mockObfuscator) Encode(w io.Writer) error { 377 return nil 378 } 379 380 func (o *mockObfuscator) Decode(r io.Reader) error { 381 return nil 382 } 383 384 func (o *mockObfuscator) Reextract( 385 extracter hop.ErrorEncrypterExtracter) error { 386 387 return nil 388 } 389 390 func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( 391 lnwire.OpaqueReason, error) { 392 393 o.failure = failure 394 395 var b bytes.Buffer 396 if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { 397 return nil, err 398 } 399 return b.Bytes(), nil 400 } 401 402 func (o *mockObfuscator) IntermediateEncrypt(reason lnwire.OpaqueReason) lnwire.OpaqueReason { 403 return reason 404 } 405 406 func (o *mockObfuscator) EncryptMalformedError(reason lnwire.OpaqueReason) lnwire.OpaqueReason { 407 return reason 408 } 409 410 // mockDeobfuscator mock implementation of the failure deobfuscator which 411 // only decodes the failure do not makes any onion obfuscation. 412 type mockDeobfuscator struct{} 413 414 func newMockDeobfuscator() ErrorDecrypter { 415 return &mockDeobfuscator{} 416 } 417 418 func (o *mockDeobfuscator) DecryptError(reason lnwire.OpaqueReason) (*ForwardingError, error) { 419 420 r := bytes.NewReader(reason) 421 failure, err := lnwire.DecodeFailure(r, 0) 422 if err != nil { 423 return nil, err 424 } 425 426 return NewForwardingError(failure, 1), nil 427 } 428 429 var _ ErrorDecrypter = (*mockDeobfuscator)(nil) 430 431 // mockIteratorDecoder test version of hop iterator decoder which decodes the 432 // encoded array of hops. 433 type mockIteratorDecoder struct { 434 mu sync.RWMutex 435 436 responses map[[32]byte][]hop.DecodeHopIteratorResponse 437 438 decodeFail bool 439 } 440 441 func newMockIteratorDecoder() *mockIteratorDecoder { 442 return &mockIteratorDecoder{ 443 responses: make(map[[32]byte][]hop.DecodeHopIteratorResponse), 444 } 445 } 446 447 func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte, 448 cltv uint32) (hop.Iterator, lnwire.FailCode) { 449 450 var b [4]byte 451 _, err := r.Read(b[:]) 452 if err != nil { 453 return nil, lnwire.CodeTemporaryChannelFailure 454 } 455 hopLength := binary.BigEndian.Uint32(b[:]) 456 457 hops := make([]*hop.Payload, hopLength) 458 for i := uint32(0); i < hopLength; i++ { 459 var f hop.ForwardingInfo 460 if err := decodeFwdInfo(r, &f); err != nil { 461 return nil, lnwire.CodeTemporaryChannelFailure 462 } 463 464 var nextHopBytes [8]byte 465 binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64()) 466 467 hops[i] = hop.NewLegacyPayload(&sphinx.HopData{ 468 Realm: [1]byte{}, // hop.BitcoinNetwork 469 NextAddress: nextHopBytes, 470 ForwardAmount: uint64(f.AmountToForward), 471 OutgoingCltv: f.OutgoingCTLV, 472 }) 473 } 474 475 return newMockHopIterator(hops...), lnwire.CodeNone 476 } 477 478 func (p *mockIteratorDecoder) DecodeHopIterators(id []byte, 479 reqs []hop.DecodeHopIteratorRequest) ( 480 []hop.DecodeHopIteratorResponse, error) { 481 482 idHash := sha256.Sum256(id) 483 484 p.mu.RLock() 485 if resps, ok := p.responses[idHash]; ok { 486 p.mu.RUnlock() 487 return resps, nil 488 } 489 p.mu.RUnlock() 490 491 batchSize := len(reqs) 492 493 resps := make([]hop.DecodeHopIteratorResponse, 0, batchSize) 494 for _, req := range reqs { 495 iterator, failcode := p.DecodeHopIterator( 496 req.OnionReader, req.RHash, req.IncomingCltv, 497 ) 498 499 if p.decodeFail { 500 failcode = lnwire.CodeTemporaryChannelFailure 501 } 502 503 resp := hop.DecodeHopIteratorResponse{ 504 HopIterator: iterator, 505 FailCode: failcode, 506 } 507 resps = append(resps, resp) 508 } 509 510 p.mu.Lock() 511 p.responses[idHash] = resps 512 p.mu.Unlock() 513 514 return resps, nil 515 } 516 517 func decodeFwdInfo(r io.Reader, f *hop.ForwardingInfo) error { 518 var net [1]byte 519 if _, err := r.Read(net[:]); err != nil { 520 return err 521 } 522 f.Network = hop.Network(net[0]) 523 524 if err := binary.Read(r, binary.BigEndian, &f.NextHop); err != nil { 525 return err 526 } 527 528 if err := binary.Read(r, binary.BigEndian, &f.AmountToForward); err != nil { 529 return err 530 } 531 532 if err := binary.Read(r, binary.BigEndian, &f.OutgoingCTLV); err != nil { 533 return err 534 } 535 536 return nil 537 } 538 539 // messageInterceptor is function that handles the incoming peer messages and 540 // may decide should the peer skip the message or not. 541 type messageInterceptor func(m lnwire.Message) (bool, error) 542 543 // Record is used to set the function which will be triggered when new 544 // lnwire message was received. 545 func (s *mockServer) intersect(f messageInterceptor) { 546 s.interceptorFuncs = append(s.interceptorFuncs, f) 547 } 548 549 func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error { 550 551 for _, msg := range msgs { 552 select { 553 case s.messages <- msg: 554 case <-s.quit: 555 return errors.New("server is stopped") 556 } 557 } 558 559 return nil 560 } 561 562 func (s *mockServer) SendMessageLazy(sync bool, msgs ...lnwire.Message) error { 563 panic("not implemented") 564 } 565 566 func (s *mockServer) readHandler(message lnwire.Message) error { 567 var targetChan lnwire.ChannelID 568 569 switch msg := message.(type) { 570 case *lnwire.UpdateAddHTLC: 571 targetChan = msg.ChanID 572 case *lnwire.UpdateFulfillHTLC: 573 targetChan = msg.ChanID 574 case *lnwire.UpdateFailHTLC: 575 targetChan = msg.ChanID 576 case *lnwire.UpdateFailMalformedHTLC: 577 targetChan = msg.ChanID 578 case *lnwire.RevokeAndAck: 579 targetChan = msg.ChanID 580 case *lnwire.CommitSig: 581 targetChan = msg.ChanID 582 case *lnwire.FundingLocked: 583 // Ignore 584 return nil 585 case *lnwire.ChannelReestablish: 586 targetChan = msg.ChanID 587 case *lnwire.UpdateFee: 588 targetChan = msg.ChanID 589 default: 590 return fmt.Errorf("unknown message type: %T", msg) 591 } 592 593 // Dispatch the commitment update message to the proper channel link 594 // dedicated to this channel. If the link is not found, we will discard 595 // the message. 596 link, err := s.htlcSwitch.GetLink(targetChan) 597 if err != nil { 598 return nil 599 } 600 601 // Create goroutine for this, in order to be able to properly stop 602 // the server when handler stacked (server unavailable) 603 link.HandleChannelUpdate(message) 604 605 return nil 606 } 607 608 func (s *mockServer) PubKey() [33]byte { 609 return s.id 610 } 611 612 func (s *mockServer) IdentityKey() *secp256k1.PublicKey { 613 pubkey, _ := secp256k1.ParsePubKey(s.id[:]) 614 return pubkey 615 } 616 617 func (s *mockServer) Address() net.Addr { 618 return nil 619 } 620 621 func (s *mockServer) Inbound() bool { 622 return false 623 } 624 625 func (s *mockServer) AddNewChannel(channel *channeldb.OpenChannel, 626 cancel <-chan struct{}) error { 627 628 return nil 629 } 630 631 func (s *mockServer) WipeChannel(*wire.OutPoint) {} 632 633 func (s *mockServer) LocalFeatures() *lnwire.FeatureVector { 634 return nil 635 } 636 637 func (s *mockServer) RemoteFeatures() *lnwire.FeatureVector { 638 return nil 639 } 640 641 func (s *mockServer) Stop() error { 642 if !atomic.CompareAndSwapInt32(&s.shutdown, 0, 1) { 643 return nil 644 } 645 646 close(s.quit) 647 s.wg.Wait() 648 649 return nil 650 } 651 652 func (s *mockServer) String() string { 653 return s.name 654 } 655 656 type mockChannelLink struct { 657 htlcSwitch *Switch 658 659 shortChanID lnwire.ShortChannelID 660 661 chanID lnwire.ChannelID 662 663 peer lnpeer.Peer 664 665 mailBox MailBox 666 667 packets chan *htlcPacket 668 669 eligible bool 670 671 htlcID uint64 672 673 checkHtlcTransitResult *LinkError 674 675 checkHtlcForwardResult *LinkError 676 } 677 678 // completeCircuit is a helper method for adding the finalized payment circuit 679 // to the switch's circuit map. In testing, this should be executed after 680 // receiving an htlc from the downstream packets channel. 681 func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error { 682 switch htlc := pkt.htlc.(type) { 683 case *lnwire.UpdateAddHTLC: 684 pkt.outgoingChanID = f.shortChanID 685 pkt.outgoingHTLCID = f.htlcID 686 htlc.ID = f.htlcID 687 688 keystone := Keystone{pkt.inKey(), pkt.outKey()} 689 err := f.htlcSwitch.circuits.OpenCircuits(keystone) 690 if err != nil { 691 return err 692 } 693 694 f.htlcID++ 695 696 case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: 697 err := f.htlcSwitch.teardownCircuit(pkt) 698 if err != nil { 699 return err 700 } 701 } 702 703 f.mailBox.AckPacket(pkt.inKey()) 704 705 return nil 706 } 707 708 func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error { 709 return f.htlcSwitch.circuits.DeleteCircuits(pkt.inKey()) 710 } 711 712 func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, 713 shortChanID lnwire.ShortChannelID, peer lnpeer.Peer, eligible bool, 714 ) *mockChannelLink { 715 716 return &mockChannelLink{ 717 htlcSwitch: htlcSwitch, 718 chanID: chanID, 719 shortChanID: shortChanID, 720 peer: peer, 721 eligible: eligible, 722 } 723 } 724 725 func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error { 726 f.mailBox.AddPacket(pkt) 727 return nil 728 } 729 730 func (f *mockChannelLink) handleLocalAddPacket(pkt *htlcPacket) error { 731 _ = f.mailBox.AddPacket(pkt) 732 return nil 733 } 734 735 func (f *mockChannelLink) getDustSum(remote bool) lnwire.MilliAtom { 736 return 0 737 } 738 739 func (f *mockChannelLink) getFeeRate() chainfee.AtomPerKByte { 740 return 0 741 } 742 743 func (f *mockChannelLink) getDustClosure() dustClosure { 744 dustLimit := dcrutil.Amount(6030) 745 return dustHelper( 746 channeldb.SingleFunderTweaklessBit, dustLimit, dustLimit, 747 ) 748 } 749 750 func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) { 751 } 752 753 func (f *mockChannelLink) UpdateForwardingPolicy(_ ForwardingPolicy) { 754 } 755 func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliAtom, 756 lnwire.MilliAtom, uint32, uint32, uint32) *LinkError { 757 758 return f.checkHtlcForwardResult 759 } 760 761 func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, 762 amt lnwire.MilliAtom, timeout uint32, 763 heightNow uint32) *LinkError { 764 765 return f.checkHtlcTransitResult 766 } 767 768 func (f *mockChannelLink) Stats() (uint64, lnwire.MilliAtom, lnwire.MilliAtom) { 769 return 0, 0, 0 770 } 771 772 func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { 773 f.mailBox = mailBox 774 f.packets = mailBox.PacketOutBox() 775 mailBox.SetDustClosure(f.getDustClosure()) 776 } 777 778 func (f *mockChannelLink) Start() error { 779 f.mailBox.ResetMessages() 780 f.mailBox.ResetPackets() 781 return nil 782 } 783 784 func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID } 785 func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID } 786 func (f *mockChannelLink) Bandwidth() lnwire.MilliAtom { return 99999999 } 787 func (f *mockChannelLink) Peer() lnpeer.Peer { return f.peer } 788 func (f *mockChannelLink) ChannelPoint() *wire.OutPoint { return &wire.OutPoint{} } 789 func (f *mockChannelLink) Stop() {} 790 func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } 791 func (f *mockChannelLink) MayAddOutgoingHtlc(lnwire.MilliAtom) error { return nil } 792 func (f *mockChannelLink) ShutdownIfChannelClean() error { return nil } 793 func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } 794 func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { 795 f.eligible = true 796 return f.shortChanID, nil 797 } 798 799 var _ ChannelLink = (*mockChannelLink)(nil) 800 801 func newDB() (*channeldb.DB, func(), error) { 802 // First, create a temporary directory to be used for the duration of 803 // this test. 804 tempDirName, err := ioutil.TempDir("", "channeldb") 805 if err != nil { 806 return nil, nil, err 807 } 808 809 // Next, create channeldb for the first time. 810 cdb, err := channeldb.Open(tempDirName) 811 if err != nil { 812 os.RemoveAll(tempDirName) 813 return nil, nil, err 814 } 815 816 cleanUp := func() { 817 cdb.Close() 818 os.RemoveAll(tempDirName) 819 } 820 821 return cdb, cleanUp, nil 822 } 823 824 const testInvoiceCltvExpiry = 6 825 826 type mockInvoiceRegistry struct { 827 settleChan chan lntypes.Hash 828 829 registry *invoices.InvoiceRegistry 830 831 cleanup func() 832 } 833 834 type mockChainNotifier struct { 835 chainntnfs.ChainNotifier 836 } 837 838 // RegisterBlockEpochNtfn mocks a successful call to register block 839 // notifications. 840 func (m *mockChainNotifier) RegisterBlockEpochNtfn(*chainntnfs.BlockEpoch) ( 841 *chainntnfs.BlockEpochEvent, error) { 842 843 return &chainntnfs.BlockEpochEvent{ 844 Cancel: func() {}, 845 }, nil 846 } 847 848 func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { 849 cdb, cleanup, err := newDB() 850 if err != nil { 851 panic(err) 852 } 853 854 registry := invoices.NewRegistry( 855 cdb, 856 invoices.NewInvoiceExpiryWatcher( 857 clock.NewDefaultClock(), 0, 0, nil, 858 &mockChainNotifier{}, 859 ), 860 &invoices.RegistryConfig{ 861 FinalCltvRejectDelta: 5, 862 }, 863 ) 864 registry.Start() 865 866 return &mockInvoiceRegistry{ 867 registry: registry, 868 cleanup: cleanup, 869 } 870 } 871 872 func (i *mockInvoiceRegistry) LookupInvoice(rHash lntypes.Hash) ( 873 channeldb.Invoice, error) { 874 875 return i.registry.LookupInvoice(rHash) 876 } 877 878 func (i *mockInvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error { 879 return i.registry.SettleHodlInvoice(preimage) 880 } 881 882 func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash, 883 amt lnwire.MilliAtom, expiry uint32, currentHeight int32, 884 circuitKey channeldb.CircuitKey, hodlChan chan<- interface{}, 885 payload invoices.Payload) (invoices.HtlcResolution, error) { 886 887 event, err := i.registry.NotifyExitHopHtlc( 888 rhash, amt, expiry, currentHeight, circuitKey, hodlChan, 889 payload, 890 ) 891 if err != nil { 892 return nil, err 893 } 894 if i.settleChan != nil { 895 i.settleChan <- rhash 896 } 897 898 return event, nil 899 } 900 901 func (i *mockInvoiceRegistry) CancelInvoice(payHash lntypes.Hash) error { 902 return i.registry.CancelInvoice(payHash) 903 } 904 905 func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice, 906 paymentHash lntypes.Hash) error { 907 908 _, err := i.registry.AddInvoice(&invoice, paymentHash) 909 return err 910 } 911 912 func (i *mockInvoiceRegistry) HodlUnsubscribeAll(subscriber chan<- interface{}) { 913 i.registry.HodlUnsubscribeAll(subscriber) 914 } 915 916 var _ InvoiceDatabase = (*mockInvoiceRegistry)(nil) 917 918 type mockCircuitMap struct { 919 lookup chan *PaymentCircuit 920 } 921 922 var _ CircuitMap = (*mockCircuitMap)(nil) 923 924 func (m *mockCircuitMap) OpenCircuits(...Keystone) error { 925 return nil 926 } 927 928 func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID, 929 start uint64) error { 930 return nil 931 } 932 933 func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error { 934 return nil 935 } 936 937 func (m *mockCircuitMap) CommitCircuits( 938 circuit ...*PaymentCircuit) (*CircuitFwdActions, error) { 939 940 return nil, nil 941 } 942 943 func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit, 944 error) { 945 return nil, nil 946 } 947 948 func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit, 949 error) { 950 return nil, nil 951 } 952 953 func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit { 954 return <-m.lookup 955 } 956 957 func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit { 958 return nil 959 } 960 961 func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit { 962 return nil 963 } 964 965 func (m *mockCircuitMap) NumPending() int { 966 return 0 967 } 968 969 func (m *mockCircuitMap) NumOpen() int { 970 return 0 971 } 972 973 type mockOnionErrorDecryptor struct { 974 sourceIdx int 975 message []byte 976 err error 977 } 978 979 func (m *mockOnionErrorDecryptor) DecryptError(encryptedData []byte) ( 980 *sphinx.DecryptedError, error) { 981 982 return &sphinx.DecryptedError{ 983 SenderIdx: m.sourceIdx, 984 Message: m.message, 985 }, m.err 986 } 987 988 var _ htlcNotifier = (*mockHTLCNotifier)(nil) 989 990 type mockHTLCNotifier struct{} 991 992 func (h *mockHTLCNotifier) NotifyForwardingEvent(key HtlcKey, info HtlcInfo, 993 eventType HtlcEventType) { 994 } 995 996 func (h *mockHTLCNotifier) NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, 997 eventType HtlcEventType, linkErr *LinkError, incoming bool) { 998 } 999 1000 func (h *mockHTLCNotifier) NotifyForwardingFailEvent(key HtlcKey, 1001 eventType HtlcEventType) { 1002 } 1003 1004 func (h *mockHTLCNotifier) NotifySettleEvent(key HtlcKey, 1005 preimage lntypes.Preimage, eventType HtlcEventType) { 1006 }