github.com/decred/dcrlnd@v0.7.6/rpcperms/middleware_handler.go (about)

     1  package rpcperms
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"errors"
     7  	"fmt"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/decred/dcrd/chaincfg/v3"
    13  	"github.com/decred/dcrlnd/lnrpc"
    14  	"github.com/decred/dcrlnd/macaroons"
    15  	"google.golang.org/protobuf/proto"
    16  	"google.golang.org/protobuf/reflect/protoreflect"
    17  	"google.golang.org/protobuf/reflect/protoregistry"
    18  	"gopkg.in/macaroon.v2"
    19  )
    20  
    21  var (
    22  	// ErrShuttingDown is the error that's returned when the server is
    23  	// shutting down and a request cannot be served anymore.
    24  	ErrShuttingDown = errors.New("server shutting down")
    25  
    26  	// ErrTimeoutReached is the error that's returned if any of the
    27  	// middleware's tasks is not completed in the given time.
    28  	ErrTimeoutReached = errors.New("intercept timeout reached")
    29  
    30  	// errClientQuit is the error that's returned if the client closes the
    31  	// middleware communication stream before a request was fully handled.
    32  	errClientQuit = errors.New("interceptor RPC client quit")
    33  )
    34  
    35  // MiddlewareHandler is a type that communicates with a middleware over the
    36  // established bi-directional RPC stream. It sends messages to the middleware
    37  // whenever the custom business logic implemented there should give feedback to
    38  // a request or response that's happening on the main gRPC server.
    39  type MiddlewareHandler struct {
    40  	// lastMsgID is the ID of the last intercept message that was forwarded
    41  	// to the middleware.
    42  	//
    43  	// NOTE: Must be used atomically!
    44  	lastMsgID uint64
    45  
    46  	middlewareName string
    47  
    48  	readOnly bool
    49  
    50  	customCaveatName string
    51  
    52  	receive func() (*lnrpc.RPCMiddlewareResponse, error)
    53  
    54  	send func(request *lnrpc.RPCMiddlewareRequest) error
    55  
    56  	interceptRequests chan *interceptRequest
    57  
    58  	timeout time.Duration
    59  
    60  	// params are our current chain params.
    61  	params *chaincfg.Params
    62  
    63  	// done is closed when the rpc client terminates.
    64  	done chan struct{}
    65  
    66  	// quit is closed when lnd is shutting down.
    67  	quit chan struct{}
    68  
    69  	wg sync.WaitGroup
    70  }
    71  
    72  // NewMiddlewareHandler creates a new handler for the middleware with the given
    73  // name and custom caveat name.
    74  func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
    75  	receive func() (*lnrpc.RPCMiddlewareResponse, error),
    76  	send func(request *lnrpc.RPCMiddlewareRequest) error,
    77  	timeout time.Duration, params *chaincfg.Params,
    78  	quit chan struct{}) *MiddlewareHandler {
    79  
    80  	// We explicitly want to log this as a warning since intercepting any
    81  	// gRPC messages can also be used for malicious purposes and the user
    82  	// should be made aware of the risks.
    83  	log.Warnf("A new gRPC middleware with the name '%s' was registered "+
    84  		" with custom_macaroon_caveat='%s', read_only=%v. Make sure "+
    85  		"you trust the middleware author since that code will be able "+
    86  		"to intercept and possibly modify and gRPC messages sent/"+
    87  		"received to/from a client that has a macaroon with that "+
    88  		"custom caveat.", name, customCaveatName, readOnly)
    89  
    90  	return &MiddlewareHandler{
    91  		middlewareName:    name,
    92  		customCaveatName:  customCaveatName,
    93  		readOnly:          readOnly,
    94  		receive:           receive,
    95  		send:              send,
    96  		interceptRequests: make(chan *interceptRequest),
    97  		timeout:           timeout,
    98  		params:            params,
    99  		done:              make(chan struct{}),
   100  		quit:              quit,
   101  	}
   102  }
   103  
   104  // intercept handles the full interception lifecycle of a single middleware
   105  // event (stream authentication, request interception or response interception).
   106  // The lifecycle consists of sending a message to the middleware, receiving a
   107  // feedback on it and sending the feedback to the appropriate channel. All steps
   108  // are guarded by the configured timeout to make sure a middleware cannot slow
   109  // down requests too much.
   110  func (h *MiddlewareHandler) intercept(requestID uint64,
   111  	req *InterceptionRequest) (*interceptResponse, error) {
   112  
   113  	respChan := make(chan *interceptResponse, 1)
   114  
   115  	newRequest := &interceptRequest{
   116  		requestID: requestID,
   117  		request:   req,
   118  		response:  respChan,
   119  	}
   120  
   121  	// timeout is the time after which intercept requests expire.
   122  	timeout := time.After(h.timeout)
   123  
   124  	// Send the request to the interceptRequests channel for the main
   125  	// goroutine to be picked up.
   126  	select {
   127  	case h.interceptRequests <- newRequest:
   128  
   129  	case <-timeout:
   130  		log.Errorf("MiddlewareHandler returned error - reached "+
   131  			"timeout of %v for request interception", h.timeout)
   132  
   133  		return nil, ErrTimeoutReached
   134  
   135  	case <-h.done:
   136  		return nil, errClientQuit
   137  
   138  	case <-h.quit:
   139  		return nil, ErrShuttingDown
   140  	}
   141  
   142  	// Receive the response and return it. If no response has been received
   143  	// in AcceptorTimeout, then return false.
   144  	select {
   145  	case resp := <-respChan:
   146  		return resp, nil
   147  
   148  	case <-timeout:
   149  		log.Errorf("MiddlewareHandler returned error - reached "+
   150  			"timeout of %v for response interception", h.timeout)
   151  		return nil, ErrTimeoutReached
   152  
   153  	case <-h.done:
   154  		return nil, errClientQuit
   155  
   156  	case <-h.quit:
   157  		return nil, ErrShuttingDown
   158  	}
   159  }
   160  
   161  // Run is the main loop for the middleware handler. This function will block
   162  // until it receives the signal that lnd is shutting down, or the rpc stream is
   163  // cancelled by the client.
   164  func (h *MiddlewareHandler) Run() error {
   165  	// Wait for our goroutines to exit before we return.
   166  	defer h.wg.Wait()
   167  	defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName)
   168  
   169  	// Create a channel that responses from middlewares are sent into.
   170  	responses := make(chan *lnrpc.RPCMiddlewareResponse)
   171  
   172  	// errChan is used by the receive loop to signal any errors that occur
   173  	// during reading from the stream. This is primarily used to shutdown
   174  	// the send loop in the case of an RPC client disconnecting.
   175  	errChan := make(chan error, 1)
   176  
   177  	// Start a goroutine to receive responses from the interceptor. We
   178  	// expect the receive function to block, so it must be run in a
   179  	// goroutine (otherwise we could not send more than one intercept
   180  	// request to the client).
   181  	h.wg.Add(1)
   182  	go func() {
   183  		h.receiveResponses(errChan, responses)
   184  		h.wg.Done()
   185  	}()
   186  
   187  	return h.sendInterceptRequests(errChan, responses)
   188  }
   189  
   190  // receiveResponses receives responses for our intercept requests and dispatches
   191  // them into the responses channel provided, sending any errors that occur into
   192  // the error channel provided.
   193  func (h *MiddlewareHandler) receiveResponses(errChan chan error,
   194  	responses chan *lnrpc.RPCMiddlewareResponse) {
   195  
   196  	for {
   197  		resp, err := h.receive()
   198  		if err != nil {
   199  			errChan <- err
   200  			return
   201  		}
   202  
   203  		select {
   204  		case responses <- resp:
   205  
   206  		case <-h.done:
   207  			return
   208  
   209  		case <-h.quit:
   210  			return
   211  		}
   212  	}
   213  }
   214  
   215  // sendInterceptRequests handles intercept requests sent to us by our Accept()
   216  // function, dispatching them to our acceptor stream and coordinating return of
   217  // responses to their callers.
   218  func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
   219  	responses chan *lnrpc.RPCMiddlewareResponse) error {
   220  
   221  	// Close the done channel to indicate that the interceptor is no longer
   222  	// listening and any in-progress requests should be terminated.
   223  	defer close(h.done)
   224  
   225  	interceptRequests := make(map[uint64]*interceptRequest)
   226  
   227  	for {
   228  		select {
   229  		// Consume requests passed to us from our Accept() function and
   230  		// send them into our stream.
   231  		case newRequest := <-h.interceptRequests:
   232  			msgID := atomic.AddUint64(&h.lastMsgID, 1)
   233  
   234  			req := newRequest.request
   235  			interceptRequests[msgID] = newRequest
   236  
   237  			interceptReq, err := req.ToRPC(
   238  				newRequest.requestID, msgID,
   239  			)
   240  			if err != nil {
   241  				return err
   242  			}
   243  
   244  			if err := h.send(interceptReq); err != nil {
   245  				return err
   246  			}
   247  
   248  		// Process newly received responses from our interceptor,
   249  		// looking the original request up in our map of requests and
   250  		// dispatching the response.
   251  		case resp := <-responses:
   252  			requestInfo, ok := interceptRequests[resp.RefMsgId]
   253  			if !ok {
   254  				continue
   255  			}
   256  
   257  			response := &interceptResponse{}
   258  			switch msg := resp.GetMiddlewareMessage().(type) {
   259  			case *lnrpc.RPCMiddlewareResponse_Feedback:
   260  				t := msg.Feedback
   261  				if t.Error != "" {
   262  					response.err = fmt.Errorf("%s", t.Error)
   263  					break
   264  				}
   265  
   266  				// For intercepted responses we also allow the
   267  				// content itself to be overwritten.
   268  				if requestInfo.request.Type == TypeResponse &&
   269  					t.ReplaceResponse {
   270  
   271  					response.replace = true
   272  					protoMsg, err := parseProto(
   273  						requestInfo.request.ProtoTypeName,
   274  						t.ReplacementSerialized,
   275  					)
   276  
   277  					if err != nil {
   278  						response.err = err
   279  
   280  						break
   281  					}
   282  
   283  					response.replacement = protoMsg
   284  				}
   285  
   286  			default:
   287  				return fmt.Errorf("unknown middleware "+
   288  					"message: %v", msg)
   289  			}
   290  
   291  			select {
   292  			case requestInfo.response <- response:
   293  			case <-h.quit:
   294  			}
   295  
   296  			delete(interceptRequests, resp.RefMsgId)
   297  
   298  		// If we failed to receive from our middleware, we exit.
   299  		case err := <-errChan:
   300  			log.Errorf("Received an error: %v, shutting down", err)
   301  			return err
   302  
   303  		// Exit if we are shutting down.
   304  		case <-h.quit:
   305  			return ErrShuttingDown
   306  		}
   307  	}
   308  }
   309  
   310  // InterceptType defines the different types of intercept messages a middleware
   311  // can receive.
   312  type InterceptType uint8
   313  
   314  const (
   315  	// TypeStreamAuth is the type of intercept message that is sent when a
   316  	// client or streaming RPC is initialized. A message with this type will
   317  	// be sent out during stream initialization so a middleware can
   318  	// accept/deny the whole stream instead of only single messages on the
   319  	// stream.
   320  	TypeStreamAuth InterceptType = 1
   321  
   322  	// TypeRequest is the type of intercept message that is sent when an RPC
   323  	// request message is sent to lnd. For client-streaming RPCs a new
   324  	// message of this type is sent for each individual RPC request sent to
   325  	// the stream.
   326  	TypeRequest InterceptType = 2
   327  
   328  	// TypeResponse is the type of intercept message that is sent when an
   329  	// RPC response message is sent from lnd to a client. For
   330  	// server-streaming RPCs a new message of this type is sent for each
   331  	// individual RPC response sent to the stream. Middleware has the option
   332  	// to modify a response message before it is sent out to the client.
   333  	TypeResponse InterceptType = 3
   334  )
   335  
   336  // InterceptionRequest is a struct holding all information that is sent to a
   337  // middleware whenever there is something to intercept (auth, request,
   338  // response).
   339  type InterceptionRequest struct {
   340  	// Type is the type of the interception message.
   341  	Type InterceptType
   342  
   343  	// StreamRPC is set to true if the invoked RPC method is client or
   344  	// server streaming.
   345  	StreamRPC bool
   346  
   347  	// Macaroon holds the macaroon that the client sent to lnd.
   348  	Macaroon *macaroon.Macaroon
   349  
   350  	// RawMacaroon holds the raw binary serialized macaroon that the client
   351  	// sent to lnd.
   352  	RawMacaroon []byte
   353  
   354  	// CustomCaveatName is the name of the custom caveat that the middleware
   355  	// was intercepting for.
   356  	CustomCaveatName string
   357  
   358  	// CustomCaveatCondition is the condition of the custom caveat that the
   359  	// middleware was intercepting for. This can be empty for custom caveats
   360  	// that only have a name (marker caveats).
   361  	CustomCaveatCondition string
   362  
   363  	// FullURI is the full RPC method URI that was invoked.
   364  	FullURI string
   365  
   366  	// ProtoSerialized is the full request or response object in the
   367  	// protobuf binary serialization format.
   368  	ProtoSerialized []byte
   369  
   370  	// ProtoTypeName is the fully qualified name of the protobuf type of the
   371  	// request or response message that is serialized in the field above.
   372  	ProtoTypeName string
   373  }
   374  
   375  // NewMessageInterceptionRequest creates a new interception request for either
   376  // a request or response message.
   377  func NewMessageInterceptionRequest(ctx context.Context,
   378  	authType InterceptType, isStream bool, fullMethod string,
   379  	m interface{}) (*InterceptionRequest, error) {
   380  
   381  	mac, rawMacaroon, err := macaroonFromContext(ctx)
   382  	if err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	rpcReq, ok := m.(proto.Message)
   387  	if !ok {
   388  		return nil, fmt.Errorf("msg is not proto message: %v", m)
   389  	}
   390  	rawRequest, err := proto.Marshal(rpcReq)
   391  	if err != nil {
   392  		return nil, fmt.Errorf("cannot marshal proto msg: %v", err)
   393  	}
   394  
   395  	return &InterceptionRequest{
   396  		Type:            authType,
   397  		StreamRPC:       isStream,
   398  		Macaroon:        mac,
   399  		RawMacaroon:     rawMacaroon,
   400  		FullURI:         fullMethod,
   401  		ProtoSerialized: rawRequest,
   402  		ProtoTypeName:   string(proto.MessageName(rpcReq)),
   403  	}, nil
   404  }
   405  
   406  // NewStreamAuthInterceptionRequest creates a new interception request for a
   407  // stream authentication message.
   408  func NewStreamAuthInterceptionRequest(ctx context.Context,
   409  	fullMethod string) (*InterceptionRequest, error) {
   410  
   411  	mac, rawMacaroon, err := macaroonFromContext(ctx)
   412  	if err != nil {
   413  		return nil, err
   414  	}
   415  
   416  	return &InterceptionRequest{
   417  		Type:        TypeStreamAuth,
   418  		StreamRPC:   true,
   419  		Macaroon:    mac,
   420  		RawMacaroon: rawMacaroon,
   421  		FullURI:     fullMethod,
   422  	}, nil
   423  }
   424  
   425  // macaroonFromContext tries to extract the macaroon from the incoming context.
   426  // If there is no macaroon, a nil error is returned since some RPCs might not
   427  // require a macaroon. But in case there is something in the macaroon header
   428  // field that cannot be parsed, a non-nil error is returned.
   429  func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
   430  	error) {
   431  
   432  	macHex, err := macaroons.RawMacaroonFromContext(ctx)
   433  	if err != nil {
   434  		// If there is no macaroon, we continue anyway as it might be an
   435  		// RPC that doesn't require a macaroon.
   436  		return nil, nil, nil
   437  	}
   438  
   439  	macBytes, err := hex.DecodeString(macHex)
   440  	if err != nil {
   441  		return nil, nil, err
   442  	}
   443  
   444  	mac := &macaroon.Macaroon{}
   445  	if err := mac.UnmarshalBinary(macBytes); err != nil {
   446  		return nil, nil, err
   447  	}
   448  
   449  	return mac, macBytes, nil
   450  }
   451  
   452  // ToRPC converts the interception request to its RPC counterpart.
   453  func (r *InterceptionRequest) ToRPC(requestID,
   454  	msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
   455  
   456  	rpcRequest := &lnrpc.RPCMiddlewareRequest{
   457  		RequestId:             requestID,
   458  		MsgId:                 msgID,
   459  		RawMacaroon:           r.RawMacaroon,
   460  		CustomCaveatCondition: r.CustomCaveatCondition,
   461  	}
   462  
   463  	switch r.Type {
   464  	case TypeStreamAuth:
   465  		rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{
   466  			StreamAuth: &lnrpc.StreamAuth{
   467  				MethodFullUri: r.FullURI,
   468  			},
   469  		}
   470  
   471  	case TypeRequest:
   472  		rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{
   473  			Request: &lnrpc.RPCMessage{
   474  				MethodFullUri: r.FullURI,
   475  				StreamRpc:     r.StreamRPC,
   476  				TypeName:      r.ProtoTypeName,
   477  				Serialized:    r.ProtoSerialized,
   478  			},
   479  		}
   480  
   481  	case TypeResponse:
   482  		rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{
   483  			Response: &lnrpc.RPCMessage{
   484  				MethodFullUri: r.FullURI,
   485  				StreamRpc:     r.StreamRPC,
   486  				TypeName:      r.ProtoTypeName,
   487  				Serialized:    r.ProtoSerialized,
   488  			},
   489  		}
   490  
   491  	default:
   492  		return nil, fmt.Errorf("unknown intercept type %v", r.Type)
   493  	}
   494  
   495  	return rpcRequest, nil
   496  }
   497  
   498  // interceptRequest is a struct that keeps track of an interception request sent
   499  // out to a middleware and the response that is eventually sent back by the
   500  // middleware.
   501  type interceptRequest struct {
   502  	requestID uint64
   503  	request   *InterceptionRequest
   504  	response  chan *interceptResponse
   505  }
   506  
   507  // interceptResponse is the response a middleware sends back for each
   508  // intercepted message.
   509  type interceptResponse struct {
   510  	err         error
   511  	replace     bool
   512  	replacement interface{}
   513  }
   514  
   515  // parseProto parses a proto serialized message of the given type into its
   516  // native version.
   517  func parseProto(typeName string, serialized []byte) (proto.Message, error) {
   518  	messageType, err := protoregistry.GlobalTypes.FindMessageByName(
   519  		protoreflect.FullName(typeName),
   520  	)
   521  	if err != nil {
   522  		return nil, err
   523  	}
   524  	msg := messageType.New()
   525  	err = proto.Unmarshal(serialized, msg.Interface())
   526  	if err != nil {
   527  		return nil, err
   528  	}
   529  
   530  	return msg.Interface(), nil
   531  }