github.com/decred/dcrlnd@v0.7.6/rpcperms/middleware_handler.go (about) 1 package rpcperms 2 3 import ( 4 "context" 5 "encoding/hex" 6 "errors" 7 "fmt" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/decred/dcrd/chaincfg/v3" 13 "github.com/decred/dcrlnd/lnrpc" 14 "github.com/decred/dcrlnd/macaroons" 15 "google.golang.org/protobuf/proto" 16 "google.golang.org/protobuf/reflect/protoreflect" 17 "google.golang.org/protobuf/reflect/protoregistry" 18 "gopkg.in/macaroon.v2" 19 ) 20 21 var ( 22 // ErrShuttingDown is the error that's returned when the server is 23 // shutting down and a request cannot be served anymore. 24 ErrShuttingDown = errors.New("server shutting down") 25 26 // ErrTimeoutReached is the error that's returned if any of the 27 // middleware's tasks is not completed in the given time. 28 ErrTimeoutReached = errors.New("intercept timeout reached") 29 30 // errClientQuit is the error that's returned if the client closes the 31 // middleware communication stream before a request was fully handled. 32 errClientQuit = errors.New("interceptor RPC client quit") 33 ) 34 35 // MiddlewareHandler is a type that communicates with a middleware over the 36 // established bi-directional RPC stream. It sends messages to the middleware 37 // whenever the custom business logic implemented there should give feedback to 38 // a request or response that's happening on the main gRPC server. 39 type MiddlewareHandler struct { 40 // lastMsgID is the ID of the last intercept message that was forwarded 41 // to the middleware. 42 // 43 // NOTE: Must be used atomically! 44 lastMsgID uint64 45 46 middlewareName string 47 48 readOnly bool 49 50 customCaveatName string 51 52 receive func() (*lnrpc.RPCMiddlewareResponse, error) 53 54 send func(request *lnrpc.RPCMiddlewareRequest) error 55 56 interceptRequests chan *interceptRequest 57 58 timeout time.Duration 59 60 // params are our current chain params. 61 params *chaincfg.Params 62 63 // done is closed when the rpc client terminates. 64 done chan struct{} 65 66 // quit is closed when lnd is shutting down. 67 quit chan struct{} 68 69 wg sync.WaitGroup 70 } 71 72 // NewMiddlewareHandler creates a new handler for the middleware with the given 73 // name and custom caveat name. 74 func NewMiddlewareHandler(name, customCaveatName string, readOnly bool, 75 receive func() (*lnrpc.RPCMiddlewareResponse, error), 76 send func(request *lnrpc.RPCMiddlewareRequest) error, 77 timeout time.Duration, params *chaincfg.Params, 78 quit chan struct{}) *MiddlewareHandler { 79 80 // We explicitly want to log this as a warning since intercepting any 81 // gRPC messages can also be used for malicious purposes and the user 82 // should be made aware of the risks. 83 log.Warnf("A new gRPC middleware with the name '%s' was registered "+ 84 " with custom_macaroon_caveat='%s', read_only=%v. Make sure "+ 85 "you trust the middleware author since that code will be able "+ 86 "to intercept and possibly modify and gRPC messages sent/"+ 87 "received to/from a client that has a macaroon with that "+ 88 "custom caveat.", name, customCaveatName, readOnly) 89 90 return &MiddlewareHandler{ 91 middlewareName: name, 92 customCaveatName: customCaveatName, 93 readOnly: readOnly, 94 receive: receive, 95 send: send, 96 interceptRequests: make(chan *interceptRequest), 97 timeout: timeout, 98 params: params, 99 done: make(chan struct{}), 100 quit: quit, 101 } 102 } 103 104 // intercept handles the full interception lifecycle of a single middleware 105 // event (stream authentication, request interception or response interception). 106 // The lifecycle consists of sending a message to the middleware, receiving a 107 // feedback on it and sending the feedback to the appropriate channel. All steps 108 // are guarded by the configured timeout to make sure a middleware cannot slow 109 // down requests too much. 110 func (h *MiddlewareHandler) intercept(requestID uint64, 111 req *InterceptionRequest) (*interceptResponse, error) { 112 113 respChan := make(chan *interceptResponse, 1) 114 115 newRequest := &interceptRequest{ 116 requestID: requestID, 117 request: req, 118 response: respChan, 119 } 120 121 // timeout is the time after which intercept requests expire. 122 timeout := time.After(h.timeout) 123 124 // Send the request to the interceptRequests channel for the main 125 // goroutine to be picked up. 126 select { 127 case h.interceptRequests <- newRequest: 128 129 case <-timeout: 130 log.Errorf("MiddlewareHandler returned error - reached "+ 131 "timeout of %v for request interception", h.timeout) 132 133 return nil, ErrTimeoutReached 134 135 case <-h.done: 136 return nil, errClientQuit 137 138 case <-h.quit: 139 return nil, ErrShuttingDown 140 } 141 142 // Receive the response and return it. If no response has been received 143 // in AcceptorTimeout, then return false. 144 select { 145 case resp := <-respChan: 146 return resp, nil 147 148 case <-timeout: 149 log.Errorf("MiddlewareHandler returned error - reached "+ 150 "timeout of %v for response interception", h.timeout) 151 return nil, ErrTimeoutReached 152 153 case <-h.done: 154 return nil, errClientQuit 155 156 case <-h.quit: 157 return nil, ErrShuttingDown 158 } 159 } 160 161 // Run is the main loop for the middleware handler. This function will block 162 // until it receives the signal that lnd is shutting down, or the rpc stream is 163 // cancelled by the client. 164 func (h *MiddlewareHandler) Run() error { 165 // Wait for our goroutines to exit before we return. 166 defer h.wg.Wait() 167 defer log.Debugf("Exiting middleware run loop for %s", h.middlewareName) 168 169 // Create a channel that responses from middlewares are sent into. 170 responses := make(chan *lnrpc.RPCMiddlewareResponse) 171 172 // errChan is used by the receive loop to signal any errors that occur 173 // during reading from the stream. This is primarily used to shutdown 174 // the send loop in the case of an RPC client disconnecting. 175 errChan := make(chan error, 1) 176 177 // Start a goroutine to receive responses from the interceptor. We 178 // expect the receive function to block, so it must be run in a 179 // goroutine (otherwise we could not send more than one intercept 180 // request to the client). 181 h.wg.Add(1) 182 go func() { 183 h.receiveResponses(errChan, responses) 184 h.wg.Done() 185 }() 186 187 return h.sendInterceptRequests(errChan, responses) 188 } 189 190 // receiveResponses receives responses for our intercept requests and dispatches 191 // them into the responses channel provided, sending any errors that occur into 192 // the error channel provided. 193 func (h *MiddlewareHandler) receiveResponses(errChan chan error, 194 responses chan *lnrpc.RPCMiddlewareResponse) { 195 196 for { 197 resp, err := h.receive() 198 if err != nil { 199 errChan <- err 200 return 201 } 202 203 select { 204 case responses <- resp: 205 206 case <-h.done: 207 return 208 209 case <-h.quit: 210 return 211 } 212 } 213 } 214 215 // sendInterceptRequests handles intercept requests sent to us by our Accept() 216 // function, dispatching them to our acceptor stream and coordinating return of 217 // responses to their callers. 218 func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, 219 responses chan *lnrpc.RPCMiddlewareResponse) error { 220 221 // Close the done channel to indicate that the interceptor is no longer 222 // listening and any in-progress requests should be terminated. 223 defer close(h.done) 224 225 interceptRequests := make(map[uint64]*interceptRequest) 226 227 for { 228 select { 229 // Consume requests passed to us from our Accept() function and 230 // send them into our stream. 231 case newRequest := <-h.interceptRequests: 232 msgID := atomic.AddUint64(&h.lastMsgID, 1) 233 234 req := newRequest.request 235 interceptRequests[msgID] = newRequest 236 237 interceptReq, err := req.ToRPC( 238 newRequest.requestID, msgID, 239 ) 240 if err != nil { 241 return err 242 } 243 244 if err := h.send(interceptReq); err != nil { 245 return err 246 } 247 248 // Process newly received responses from our interceptor, 249 // looking the original request up in our map of requests and 250 // dispatching the response. 251 case resp := <-responses: 252 requestInfo, ok := interceptRequests[resp.RefMsgId] 253 if !ok { 254 continue 255 } 256 257 response := &interceptResponse{} 258 switch msg := resp.GetMiddlewareMessage().(type) { 259 case *lnrpc.RPCMiddlewareResponse_Feedback: 260 t := msg.Feedback 261 if t.Error != "" { 262 response.err = fmt.Errorf("%s", t.Error) 263 break 264 } 265 266 // For intercepted responses we also allow the 267 // content itself to be overwritten. 268 if requestInfo.request.Type == TypeResponse && 269 t.ReplaceResponse { 270 271 response.replace = true 272 protoMsg, err := parseProto( 273 requestInfo.request.ProtoTypeName, 274 t.ReplacementSerialized, 275 ) 276 277 if err != nil { 278 response.err = err 279 280 break 281 } 282 283 response.replacement = protoMsg 284 } 285 286 default: 287 return fmt.Errorf("unknown middleware "+ 288 "message: %v", msg) 289 } 290 291 select { 292 case requestInfo.response <- response: 293 case <-h.quit: 294 } 295 296 delete(interceptRequests, resp.RefMsgId) 297 298 // If we failed to receive from our middleware, we exit. 299 case err := <-errChan: 300 log.Errorf("Received an error: %v, shutting down", err) 301 return err 302 303 // Exit if we are shutting down. 304 case <-h.quit: 305 return ErrShuttingDown 306 } 307 } 308 } 309 310 // InterceptType defines the different types of intercept messages a middleware 311 // can receive. 312 type InterceptType uint8 313 314 const ( 315 // TypeStreamAuth is the type of intercept message that is sent when a 316 // client or streaming RPC is initialized. A message with this type will 317 // be sent out during stream initialization so a middleware can 318 // accept/deny the whole stream instead of only single messages on the 319 // stream. 320 TypeStreamAuth InterceptType = 1 321 322 // TypeRequest is the type of intercept message that is sent when an RPC 323 // request message is sent to lnd. For client-streaming RPCs a new 324 // message of this type is sent for each individual RPC request sent to 325 // the stream. 326 TypeRequest InterceptType = 2 327 328 // TypeResponse is the type of intercept message that is sent when an 329 // RPC response message is sent from lnd to a client. For 330 // server-streaming RPCs a new message of this type is sent for each 331 // individual RPC response sent to the stream. Middleware has the option 332 // to modify a response message before it is sent out to the client. 333 TypeResponse InterceptType = 3 334 ) 335 336 // InterceptionRequest is a struct holding all information that is sent to a 337 // middleware whenever there is something to intercept (auth, request, 338 // response). 339 type InterceptionRequest struct { 340 // Type is the type of the interception message. 341 Type InterceptType 342 343 // StreamRPC is set to true if the invoked RPC method is client or 344 // server streaming. 345 StreamRPC bool 346 347 // Macaroon holds the macaroon that the client sent to lnd. 348 Macaroon *macaroon.Macaroon 349 350 // RawMacaroon holds the raw binary serialized macaroon that the client 351 // sent to lnd. 352 RawMacaroon []byte 353 354 // CustomCaveatName is the name of the custom caveat that the middleware 355 // was intercepting for. 356 CustomCaveatName string 357 358 // CustomCaveatCondition is the condition of the custom caveat that the 359 // middleware was intercepting for. This can be empty for custom caveats 360 // that only have a name (marker caveats). 361 CustomCaveatCondition string 362 363 // FullURI is the full RPC method URI that was invoked. 364 FullURI string 365 366 // ProtoSerialized is the full request or response object in the 367 // protobuf binary serialization format. 368 ProtoSerialized []byte 369 370 // ProtoTypeName is the fully qualified name of the protobuf type of the 371 // request or response message that is serialized in the field above. 372 ProtoTypeName string 373 } 374 375 // NewMessageInterceptionRequest creates a new interception request for either 376 // a request or response message. 377 func NewMessageInterceptionRequest(ctx context.Context, 378 authType InterceptType, isStream bool, fullMethod string, 379 m interface{}) (*InterceptionRequest, error) { 380 381 mac, rawMacaroon, err := macaroonFromContext(ctx) 382 if err != nil { 383 return nil, err 384 } 385 386 rpcReq, ok := m.(proto.Message) 387 if !ok { 388 return nil, fmt.Errorf("msg is not proto message: %v", m) 389 } 390 rawRequest, err := proto.Marshal(rpcReq) 391 if err != nil { 392 return nil, fmt.Errorf("cannot marshal proto msg: %v", err) 393 } 394 395 return &InterceptionRequest{ 396 Type: authType, 397 StreamRPC: isStream, 398 Macaroon: mac, 399 RawMacaroon: rawMacaroon, 400 FullURI: fullMethod, 401 ProtoSerialized: rawRequest, 402 ProtoTypeName: string(proto.MessageName(rpcReq)), 403 }, nil 404 } 405 406 // NewStreamAuthInterceptionRequest creates a new interception request for a 407 // stream authentication message. 408 func NewStreamAuthInterceptionRequest(ctx context.Context, 409 fullMethod string) (*InterceptionRequest, error) { 410 411 mac, rawMacaroon, err := macaroonFromContext(ctx) 412 if err != nil { 413 return nil, err 414 } 415 416 return &InterceptionRequest{ 417 Type: TypeStreamAuth, 418 StreamRPC: true, 419 Macaroon: mac, 420 RawMacaroon: rawMacaroon, 421 FullURI: fullMethod, 422 }, nil 423 } 424 425 // macaroonFromContext tries to extract the macaroon from the incoming context. 426 // If there is no macaroon, a nil error is returned since some RPCs might not 427 // require a macaroon. But in case there is something in the macaroon header 428 // field that cannot be parsed, a non-nil error is returned. 429 func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte, 430 error) { 431 432 macHex, err := macaroons.RawMacaroonFromContext(ctx) 433 if err != nil { 434 // If there is no macaroon, we continue anyway as it might be an 435 // RPC that doesn't require a macaroon. 436 return nil, nil, nil 437 } 438 439 macBytes, err := hex.DecodeString(macHex) 440 if err != nil { 441 return nil, nil, err 442 } 443 444 mac := &macaroon.Macaroon{} 445 if err := mac.UnmarshalBinary(macBytes); err != nil { 446 return nil, nil, err 447 } 448 449 return mac, macBytes, nil 450 } 451 452 // ToRPC converts the interception request to its RPC counterpart. 453 func (r *InterceptionRequest) ToRPC(requestID, 454 msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) { 455 456 rpcRequest := &lnrpc.RPCMiddlewareRequest{ 457 RequestId: requestID, 458 MsgId: msgID, 459 RawMacaroon: r.RawMacaroon, 460 CustomCaveatCondition: r.CustomCaveatCondition, 461 } 462 463 switch r.Type { 464 case TypeStreamAuth: 465 rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_StreamAuth{ 466 StreamAuth: &lnrpc.StreamAuth{ 467 MethodFullUri: r.FullURI, 468 }, 469 } 470 471 case TypeRequest: 472 rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Request{ 473 Request: &lnrpc.RPCMessage{ 474 MethodFullUri: r.FullURI, 475 StreamRpc: r.StreamRPC, 476 TypeName: r.ProtoTypeName, 477 Serialized: r.ProtoSerialized, 478 }, 479 } 480 481 case TypeResponse: 482 rpcRequest.InterceptType = &lnrpc.RPCMiddlewareRequest_Response{ 483 Response: &lnrpc.RPCMessage{ 484 MethodFullUri: r.FullURI, 485 StreamRpc: r.StreamRPC, 486 TypeName: r.ProtoTypeName, 487 Serialized: r.ProtoSerialized, 488 }, 489 } 490 491 default: 492 return nil, fmt.Errorf("unknown intercept type %v", r.Type) 493 } 494 495 return rpcRequest, nil 496 } 497 498 // interceptRequest is a struct that keeps track of an interception request sent 499 // out to a middleware and the response that is eventually sent back by the 500 // middleware. 501 type interceptRequest struct { 502 requestID uint64 503 request *InterceptionRequest 504 response chan *interceptResponse 505 } 506 507 // interceptResponse is the response a middleware sends back for each 508 // intercepted message. 509 type interceptResponse struct { 510 err error 511 replace bool 512 replacement interface{} 513 } 514 515 // parseProto parses a proto serialized message of the given type into its 516 // native version. 517 func parseProto(typeName string, serialized []byte) (proto.Message, error) { 518 messageType, err := protoregistry.GlobalTypes.FindMessageByName( 519 protoreflect.FullName(typeName), 520 ) 521 if err != nil { 522 return nil, err 523 } 524 msg := messageType.New() 525 err = proto.Unmarshal(serialized, msg.Interface()) 526 if err != nil { 527 return nil, err 528 } 529 530 return msg.Interface(), nil 531 }