github.com/decred/dcrlnd@v0.7.6/lnrpc/routerrpc/forward_interceptor.go (about)

     1  package routerrpc
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/decred/dcrlnd/channeldb"
     9  	"github.com/decred/dcrlnd/htlcswitch"
    10  	"github.com/decred/dcrlnd/lntypes"
    11  	"github.com/decred/dcrlnd/lnwire"
    12  )
    13  
    14  var (
    15  	// ErrFwdNotExists is an error returned when the caller tries to resolve
    16  	// a forward that doesn't exist anymore.
    17  	ErrFwdNotExists = errors.New("forward does not exist")
    18  
    19  	// ErrMissingPreimage is an error returned when the caller tries to settle
    20  	// a forward and doesn't provide a preimage.
    21  	ErrMissingPreimage = errors.New("missing preimage")
    22  )
    23  
    24  // forwardInterceptor is a helper struct that handles the lifecycle of an rpc
    25  // interceptor streaming session.
    26  // It is created when the stream opens and disconnects when the stream closes.
    27  type forwardInterceptor struct {
    28  	// server is the Server reference
    29  	server *Server
    30  
    31  	// holdForwards is a map of current hold forwards and their corresponding
    32  	// ForwardResolver.
    33  	holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward
    34  
    35  	// stream is the bidirectional RPC stream
    36  	stream Router_HtlcInterceptorServer
    37  
    38  	// quit is a channel that is closed when this forwardInterceptor is shutting
    39  	// down.
    40  	quit chan struct{}
    41  
    42  	// intercepted is where we stream all intercepted packets coming from
    43  	// the switch.
    44  	intercepted chan htlcswitch.InterceptedForward
    45  
    46  	wg sync.WaitGroup
    47  }
    48  
    49  // newForwardInterceptor creates a new forwardInterceptor.
    50  func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor {
    51  	return &forwardInterceptor{
    52  		server: server,
    53  		stream: stream,
    54  		holdForwards: make(
    55  			map[channeldb.CircuitKey]htlcswitch.InterceptedForward),
    56  		quit:        make(chan struct{}),
    57  		intercepted: make(chan htlcswitch.InterceptedForward),
    58  	}
    59  }
    60  
    61  // run sends the intercepted packets to the client and receives the
    62  // corersponding responses. On one hand it regsitered itself as an interceptor
    63  // that receives the switch packets and on the other hand launches a go routine
    64  // to read from the client stream.
    65  // To coordinate all this and make sure it is safe for concurrent access all
    66  // packets are sent to the main where they are handled.
    67  func (r *forwardInterceptor) run() error {
    68  	// make sure we disconnect and resolves all remaining packets if any.
    69  	defer r.onDisconnect()
    70  
    71  	// Register our interceptor so we receive all forwarded packets.
    72  	interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder
    73  	interceptableForwarder.SetInterceptor(r.onIntercept)
    74  	defer interceptableForwarder.SetInterceptor(nil)
    75  
    76  	// start a go routine that reads client resolutions.
    77  	errChan := make(chan error)
    78  	resolutionRequests := make(chan *ForwardHtlcInterceptResponse)
    79  	r.wg.Add(1)
    80  	go r.readClientResponses(resolutionRequests, errChan)
    81  
    82  	// run the main loop that synchronizes both sides input into one go routine.
    83  	for {
    84  		select {
    85  		case intercepted := <-r.intercepted:
    86  			log.Tracef("sending intercepted packet to client %v", intercepted)
    87  			// in case we couldn't forward we exit the loop and drain the
    88  			// current interceptor as this indicates on a connection problem.
    89  			if err := r.holdAndForwardToClient(intercepted); err != nil {
    90  				return err
    91  			}
    92  		case resolution := <-resolutionRequests:
    93  			log.Tracef("resolving intercepted packet %v", resolution)
    94  			// in case we couldn't resolve we just add a log line since this
    95  			// does not indicate on any connection problem.
    96  			if err := r.resolveFromClient(resolution); err != nil {
    97  				log.Warnf("client resolution of intercepted "+
    98  					"packet failed %v", err)
    99  			}
   100  		case err := <-errChan:
   101  			return err
   102  		case <-r.server.quit:
   103  			return nil
   104  		}
   105  	}
   106  }
   107  
   108  // onIntercept is the function that is called by the switch for every forwarded
   109  // packet. Our interceptor makes sure we hold the packet and then signal to the
   110  // main loop to handle the packet. We only return true if we were able
   111  // to deliver the packet to the main loop.
   112  func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool {
   113  	select {
   114  	case r.intercepted <- p:
   115  		return true
   116  	case <-r.quit:
   117  		return false
   118  	case <-r.server.quit:
   119  		return false
   120  	}
   121  }
   122  
   123  func (r *forwardInterceptor) readClientResponses(
   124  	resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) {
   125  
   126  	defer r.wg.Done()
   127  	for {
   128  		resp, err := r.stream.Recv()
   129  		if err != nil {
   130  			errChan <- err
   131  			return
   132  		}
   133  
   134  		// Now that we have the response from the RPC client, send it to
   135  		// the responses chan.
   136  		select {
   137  		case resolutionChan <- resp:
   138  		case <-r.quit:
   139  			return
   140  		case <-r.server.quit:
   141  			return
   142  		}
   143  	}
   144  }
   145  
   146  // holdAndForwardToClient forwards the intercepted htlc to the client.
   147  func (r *forwardInterceptor) holdAndForwardToClient(
   148  	forward htlcswitch.InterceptedForward) error {
   149  
   150  	htlc := forward.Packet()
   151  	inKey := htlc.IncomingCircuit
   152  
   153  	// Ignore already held htlcs.
   154  	if _, ok := r.holdForwards[inKey]; ok {
   155  		return nil
   156  	}
   157  
   158  	// First hold the forward, then send to client.
   159  	r.holdForwards[inKey] = forward
   160  	interceptionRequest := &ForwardHtlcInterceptRequest{
   161  		IncomingCircuitKey: &CircuitKey{
   162  			ChanId: inKey.ChanID.ToUint64(),
   163  			HtlcId: inKey.HtlcID,
   164  		},
   165  		OutgoingRequestedChanId: htlc.OutgoingChanID.ToUint64(),
   166  		PaymentHash:             htlc.Hash[:],
   167  		OutgoingAmountMAtoms:    uint64(htlc.OutgoingAmount),
   168  		OutgoingExpiry:          htlc.OutgoingExpiry,
   169  		IncomingAmountMAtoms:    uint64(htlc.IncomingAmount),
   170  		IncomingExpiry:          htlc.IncomingExpiry,
   171  		CustomRecords:           htlc.CustomRecords,
   172  		OnionBlob:               htlc.OnionBlob[:],
   173  	}
   174  
   175  	return r.stream.Send(interceptionRequest)
   176  }
   177  
   178  // resolveFromClient handles a resolution arrived from the client.
   179  func (r *forwardInterceptor) resolveFromClient(
   180  	in *ForwardHtlcInterceptResponse) error {
   181  
   182  	circuitKey := channeldb.CircuitKey{
   183  		ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
   184  		HtlcID: in.IncomingCircuitKey.HtlcId,
   185  	}
   186  	var interceptedForward htlcswitch.InterceptedForward
   187  	interceptedForward, ok := r.holdForwards[circuitKey]
   188  	if !ok {
   189  		return ErrFwdNotExists
   190  	}
   191  	delete(r.holdForwards, circuitKey)
   192  
   193  	switch in.Action {
   194  	case ResolveHoldForwardAction_RESUME:
   195  		return interceptedForward.Resume()
   196  	case ResolveHoldForwardAction_FAIL:
   197  		return interceptedForward.Fail()
   198  	case ResolveHoldForwardAction_SETTLE:
   199  		if in.Preimage == nil {
   200  			return ErrMissingPreimage
   201  		}
   202  		preimage, err := lntypes.MakePreimage(in.Preimage)
   203  		if err != nil {
   204  			return err
   205  		}
   206  		return interceptedForward.Settle(preimage)
   207  	default:
   208  		return fmt.Errorf("unrecognized resolve action %v", in.Action)
   209  	}
   210  }
   211  
   212  // onDisconnect removes all previousely held forwards from
   213  // the store. Before they are removed it ensure to resume as the default
   214  // behavior.
   215  func (r *forwardInterceptor) onDisconnect() {
   216  	// Then close the channel so all go routine will exit.
   217  	close(r.quit)
   218  
   219  	log.Infof("RPC interceptor disconnected, resolving held packets")
   220  	for key, forward := range r.holdForwards {
   221  		if err := forward.Resume(); err != nil {
   222  			log.Errorf("failed to resume hold forward %v", err)
   223  		}
   224  		delete(r.holdForwards, key)
   225  	}
   226  	r.wg.Wait()
   227  }