github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/mux.go (about) 1 package runtime 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "net/textproto" 9 "regexp" 10 "strings" 11 12 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule" 13 "google.golang.org/grpc/codes" 14 "google.golang.org/grpc/grpclog" 15 "google.golang.org/grpc/health/grpc_health_v1" 16 "google.golang.org/grpc/metadata" 17 "google.golang.org/grpc/status" 18 "google.golang.org/protobuf/proto" 19 ) 20 21 // UnescapingMode defines the behavior of ServeMux when unescaping path parameters. 22 type UnescapingMode int 23 24 const ( 25 // UnescapingModeLegacy is the default V2 behavior, which escapes the entire 26 // path string before doing any routing. 27 UnescapingModeLegacy UnescapingMode = iota 28 29 // UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570 30 // reserved characters. 31 UnescapingModeAllExceptReserved 32 33 // UnescapingModeAllExceptSlash unescapes URL path parameters except path 34 // separators, which will be left as "%2F". 35 UnescapingModeAllExceptSlash 36 37 // UnescapingModeAllCharacters unescapes all URL path parameters. 38 UnescapingModeAllCharacters 39 40 // UnescapingModeDefault is the default escaping type. 41 // TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's 42 // reference implementation 43 UnescapingModeDefault = UnescapingModeLegacy 44 ) 45 46 var encodedPathSplitter = regexp.MustCompile("(/|%2F)") 47 48 // A HandlerFunc handles a specific pair of path pattern and HTTP method. 49 type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) 50 51 // ServeMux is a request multiplexer for grpc-gateway. 52 // It matches http requests to patterns and invokes the corresponding handler. 53 type ServeMux struct { 54 // handlers maps HTTP method to a list of handlers. 55 handlers map[string][]handler 56 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error 57 marshalers marshalerRegistry 58 incomingHeaderMatcher HeaderMatcherFunc 59 outgoingHeaderMatcher HeaderMatcherFunc 60 outgoingTrailerMatcher HeaderMatcherFunc 61 metadataAnnotators []func(context.Context, *http.Request) metadata.MD 62 errorHandler ErrorHandlerFunc 63 streamErrorHandler StreamErrorHandlerFunc 64 routingErrorHandler RoutingErrorHandlerFunc 65 disablePathLengthFallback bool 66 unescapingMode UnescapingMode 67 } 68 69 // ServeMuxOption is an option that can be given to a ServeMux on construction. 70 type ServeMuxOption func(*ServeMux) 71 72 // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. 73 // 74 // forwardResponseOption is an option that will be called on the relevant context.Context, 75 // http.ResponseWriter, and proto.Message before every forwarded response. 76 // 77 // The message may be nil in the case where just a header is being sent. 78 func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption { 79 return func(serveMux *ServeMux) { 80 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption) 81 } 82 } 83 84 // WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode 85 // for more information. 86 func WithUnescapingMode(mode UnescapingMode) ServeMuxOption { 87 return func(serveMux *ServeMux) { 88 serveMux.unescapingMode = mode 89 } 90 } 91 92 // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters. 93 // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be 94 // done with careful consideration. 95 func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption { 96 return func(serveMux *ServeMux) { 97 currentQueryParser = queryParameterParser 98 } 99 } 100 101 // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context. 102 type HeaderMatcherFunc func(string) (string, bool) 103 104 // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header 105 // keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function. 106 // HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'. 107 // Other headers are not added to the gRPC metadata. 108 func DefaultHeaderMatcher(key string) (string, bool) { 109 switch key = textproto.CanonicalMIMEHeaderKey(key); { 110 case isPermanentHTTPHeader(key): 111 return MetadataPrefix + key, true 112 case strings.HasPrefix(key, MetadataHeaderPrefix): 113 return key[len(MetadataHeaderPrefix):], true 114 } 115 return "", false 116 } 117 118 func defaultOutgoingHeaderMatcher(key string) (string, bool) { 119 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true 120 } 121 122 func defaultOutgoingTrailerMatcher(key string) (string, bool) { 123 return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true 124 } 125 126 // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway. 127 // 128 // This matcher will be called with each header in http.Request. If matcher returns true, that header will be 129 // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header. 130 func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { 131 for _, header := range fn.matchedMalformedHeaders() { 132 grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header) 133 } 134 135 return func(mux *ServeMux) { 136 mux.incomingHeaderMatcher = fn 137 } 138 } 139 140 // matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server. 141 func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string { 142 if fn == nil { 143 return nil 144 } 145 headers := make([]string, 0) 146 for header := range malformedHTTPHeaders { 147 out, accept := fn(header) 148 if accept && isMalformedHTTPHeader(out) { 149 headers = append(headers, out) 150 } 151 } 152 return headers 153 } 154 155 // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway. 156 // 157 // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be 158 // passed to http response returned from gateway. To transform the header before passing to response, 159 // matcher should return the modified header. 160 func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { 161 return func(mux *ServeMux) { 162 mux.outgoingHeaderMatcher = fn 163 } 164 } 165 166 // WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway. 167 // 168 // This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be 169 // passed to http response returned from gateway. To transform the header before passing to response, 170 // matcher should return the modified header. 171 func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption { 172 return func(mux *ServeMux) { 173 mux.outgoingTrailerMatcher = fn 174 } 175 } 176 177 // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context. 178 // 179 // This can be used by services that need to read from http.Request and modify gRPC context. A common use case 180 // is reading token from cookie and adding it in gRPC context. 181 func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption { 182 return func(serveMux *ServeMux) { 183 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator) 184 } 185 } 186 187 // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler. 188 // 189 // This can be used to configure a custom error response. 190 func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption { 191 return func(serveMux *ServeMux) { 192 serveMux.errorHandler = fn 193 } 194 } 195 196 // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream 197 // error handler, which allows for customizing the error trailer for server-streaming 198 // calls. 199 // 200 // For stream errors that occur before any response has been written, the mux's 201 // ErrorHandler will be invoked. However, once data has been written, the errors must 202 // be handled differently: they must be included in the response body. The response body's 203 // final message will include the error details returned by the stream error handler. 204 func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption { 205 return func(serveMux *ServeMux) { 206 serveMux.streamErrorHandler = fn 207 } 208 } 209 210 // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors. 211 // 212 // Method called for errors which can happen before gRPC route selected or executed. 213 // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest 214 func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption { 215 return func(serveMux *ServeMux) { 216 serveMux.routingErrorHandler = fn 217 } 218 } 219 220 // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback. 221 func WithDisablePathLengthFallback() ServeMuxOption { 222 return func(serveMux *ServeMux) { 223 serveMux.disablePathLengthFallback = true 224 } 225 } 226 227 // WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath. 228 // When called the handler will forward the request to the upstream grpc service health check (defined in the 229 // gRPC Health Checking Protocol). 230 // 231 // See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how 232 // to setup the protocol in the grpc server. 233 // 234 // If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest. 235 func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption { 236 return func(s *ServeMux) { 237 // error can be ignored since pattern is definitely valid 238 _ = s.HandlePath( 239 http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string, 240 ) { 241 _, outboundMarshaler := MarshalerForRequest(s, r) 242 243 resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{ 244 Service: r.URL.Query().Get("service"), 245 }) 246 if err != nil { 247 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err) 248 return 249 } 250 251 w.Header().Set("Content-Type", "application/json") 252 253 if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING { 254 switch resp.GetStatus() { 255 case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN: 256 err = status.Error(codes.Unavailable, resp.String()) 257 case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN: 258 err = status.Error(codes.NotFound, resp.String()) 259 } 260 261 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err) 262 return 263 } 264 265 _ = outboundMarshaler.NewEncoder(w).Encode(resp) 266 }) 267 } 268 } 269 270 // WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux. 271 // 272 // See WithHealthEndpointAt for the general implementation. 273 func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption { 274 return WithHealthEndpointAt(healthCheckClient, "/healthz") 275 } 276 277 // NewServeMux returns a new ServeMux whose internal mapping is empty. 278 func NewServeMux(opts ...ServeMuxOption) *ServeMux { 279 serveMux := &ServeMux{ 280 handlers: make(map[string][]handler), 281 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), 282 marshalers: makeMarshalerMIMERegistry(), 283 errorHandler: DefaultHTTPErrorHandler, 284 streamErrorHandler: DefaultStreamErrorHandler, 285 routingErrorHandler: DefaultRoutingErrorHandler, 286 unescapingMode: UnescapingModeDefault, 287 } 288 289 for _, opt := range opts { 290 opt(serveMux) 291 } 292 293 if serveMux.incomingHeaderMatcher == nil { 294 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher 295 } 296 if serveMux.outgoingHeaderMatcher == nil { 297 serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher 298 } 299 if serveMux.outgoingTrailerMatcher == nil { 300 serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher 301 } 302 303 return serveMux 304 } 305 306 // Handle associates "h" to the pair of HTTP method and path pattern. 307 func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { 308 s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...) 309 } 310 311 // HandlePath allows users to configure custom path handlers. 312 // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/ 313 func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error { 314 compiler, err := httprule.Parse(pathPattern) 315 if err != nil { 316 return fmt.Errorf("parsing path pattern: %w", err) 317 } 318 tp := compiler.Compile() 319 pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb) 320 if err != nil { 321 return fmt.Errorf("creating new pattern: %w", err) 322 } 323 s.Handle(meth, pattern, h) 324 return nil 325 } 326 327 // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path. 328 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 329 ctx := r.Context() 330 331 path := r.URL.Path 332 if !strings.HasPrefix(path, "/") { 333 _, outboundMarshaler := MarshalerForRequest(s, r) 334 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest) 335 return 336 } 337 338 // TODO(v3): remove UnescapingModeLegacy 339 if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" { 340 path = r.URL.RawPath 341 } 342 343 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { 344 if err := r.ParseForm(); err != nil { 345 _, outboundMarshaler := MarshalerForRequest(s, r) 346 sterr := status.Error(codes.InvalidArgument, err.Error()) 347 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) 348 return 349 } 350 r.Method = strings.ToUpper(override) 351 } 352 353 var pathComponents []string 354 // since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F" 355 // in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the 356 // path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default 357 // behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved 358 if s.unescapingMode == UnescapingModeAllCharacters { 359 pathComponents = encodedPathSplitter.Split(path[1:], -1) 360 } else { 361 pathComponents = strings.Split(path[1:], "/") 362 } 363 364 lastPathComponent := pathComponents[len(pathComponents)-1] 365 366 for _, h := range s.handlers[r.Method] { 367 // If the pattern has a verb, explicitly look for a suffix in the last 368 // component that matches a colon plus the verb. This allows us to 369 // handle some cases that otherwise can't be correctly handled by the 370 // former LastIndex case, such as when the verb literal itself contains 371 // a colon. This should work for all cases that have run through the 372 // parser because we know what verb we're looking for, however, there 373 // are still some cases that the parser itself cannot disambiguate. See 374 // the comment there if interested. 375 376 var verb string 377 patVerb := h.pat.Verb() 378 379 idx := -1 380 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) { 381 idx = len(lastPathComponent) - len(patVerb) - 1 382 } 383 if idx == 0 { 384 _, outboundMarshaler := MarshalerForRequest(s, r) 385 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) 386 return 387 } 388 389 comps := make([]string, len(pathComponents)) 390 copy(comps, pathComponents) 391 392 if idx > 0 { 393 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:] 394 } 395 396 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode) 397 if err != nil { 398 var mse MalformedSequenceError 399 if ok := errors.As(err, &mse); ok { 400 _, outboundMarshaler := MarshalerForRequest(s, r) 401 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ 402 HTTPStatus: http.StatusBadRequest, 403 Err: mse, 404 }) 405 } 406 continue 407 } 408 h.h(w, r, pathParams) 409 return 410 } 411 412 // if no handler has found for the request, lookup for other methods 413 // to handle POST -> GET fallback if the request is subject to path 414 // length fallback. 415 // Note we are not eagerly checking the request here as we want to return the 416 // right HTTP status code, and we need to process the fallback candidates in 417 // order to do that. 418 for m, handlers := range s.handlers { 419 if m == r.Method { 420 continue 421 } 422 for _, h := range handlers { 423 var verb string 424 patVerb := h.pat.Verb() 425 426 idx := -1 427 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) { 428 idx = len(lastPathComponent) - len(patVerb) - 1 429 } 430 431 comps := make([]string, len(pathComponents)) 432 copy(comps, pathComponents) 433 434 if idx > 0 { 435 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:] 436 } 437 438 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode) 439 if err != nil { 440 var mse MalformedSequenceError 441 if ok := errors.As(err, &mse); ok { 442 _, outboundMarshaler := MarshalerForRequest(s, r) 443 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ 444 HTTPStatus: http.StatusBadRequest, 445 Err: mse, 446 }) 447 } 448 continue 449 } 450 451 // X-HTTP-Method-Override is optional. Always allow fallback to POST. 452 // Also, only consider POST -> GET fallbacks, and avoid falling back to 453 // potentially dangerous operations like DELETE. 454 if s.isPathLengthFallback(r) && m == http.MethodGet { 455 if err := r.ParseForm(); err != nil { 456 _, outboundMarshaler := MarshalerForRequest(s, r) 457 sterr := status.Error(codes.InvalidArgument, err.Error()) 458 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) 459 return 460 } 461 h.h(w, r, pathParams) 462 return 463 } 464 _, outboundMarshaler := MarshalerForRequest(s, r) 465 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed) 466 return 467 } 468 } 469 470 _, outboundMarshaler := MarshalerForRequest(s, r) 471 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) 472 } 473 474 // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. 475 func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error { 476 return s.forwardResponseOptions 477 } 478 479 func (s *ServeMux) isPathLengthFallback(r *http.Request) bool { 480 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" 481 } 482 483 type handler struct { 484 pat Pattern 485 h HandlerFunc 486 }