github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/context.go (about) 1 package runtime 2 3 import ( 4 "context" 5 "encoding/base64" 6 "fmt" 7 "net" 8 "net/http" 9 "net/textproto" 10 "strconv" 11 "strings" 12 "sync" 13 "time" 14 15 "google.golang.org/grpc/codes" 16 "google.golang.org/grpc/grpclog" 17 "google.golang.org/grpc/metadata" 18 "google.golang.org/grpc/status" 19 ) 20 21 // MetadataHeaderPrefix is the http prefix that represents custom metadata 22 // parameters to or from a gRPC call. 23 const MetadataHeaderPrefix = "Grpc-Metadata-" 24 25 // MetadataPrefix is prepended to permanent HTTP header keys (as specified 26 // by the IANA) when added to the gRPC context. 27 const MetadataPrefix = "grpcgateway-" 28 29 // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to 30 // HTTP headers in a response handled by grpc-gateway 31 const MetadataTrailerPrefix = "Grpc-Trailer-" 32 33 const metadataGrpcTimeout = "Grpc-Timeout" 34 const metadataHeaderBinarySuffix = "-Bin" 35 36 const xForwardedFor = "X-Forwarded-For" 37 const xForwardedHost = "X-Forwarded-Host" 38 39 // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound 40 // header isn't present. If the value is 0 the sent `context` will not have a timeout. 41 var DefaultContextTimeout = 0 * time.Second 42 43 // malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed. 44 // See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context. 45 var malformedHTTPHeaders = map[string]struct{}{ 46 "connection": {}, 47 } 48 49 type ( 50 rpcMethodKey struct{} 51 httpPathPatternKey struct{} 52 53 AnnotateContextOption func(ctx context.Context) context.Context 54 ) 55 56 func WithHTTPPathPattern(pattern string) AnnotateContextOption { 57 return func(ctx context.Context) context.Context { 58 return withHTTPPathPattern(ctx, pattern) 59 } 60 } 61 62 func decodeBinHeader(v string) ([]byte, error) { 63 if len(v)%4 == 0 { 64 // Input was padded, or padding was not necessary. 65 return base64.StdEncoding.DecodeString(v) 66 } 67 return base64.RawStdEncoding.DecodeString(v) 68 } 69 70 /* 71 AnnotateContext adds context information such as metadata from the request. 72 73 At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For", 74 except that the forwarded destination is not another HTTP service but rather 75 a gRPC service. 76 */ 77 func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) { 78 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...) 79 if err != nil { 80 return nil, err 81 } 82 if md == nil { 83 return ctx, nil 84 } 85 86 return metadata.NewOutgoingContext(ctx, md), nil 87 } 88 89 // AnnotateIncomingContext adds context information such as metadata from the request. 90 // Attach metadata as incoming context. 91 func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) { 92 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...) 93 if err != nil { 94 return nil, err 95 } 96 if md == nil { 97 return ctx, nil 98 } 99 100 return metadata.NewIncomingContext(ctx, md), nil 101 } 102 103 func isValidGRPCMetadataKey(key string) bool { 104 // Must be a valid gRPC "Header-Name" as defined here: 105 // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md 106 // This means 0-9 a-z _ - . 107 // Only lowercase letters are valid in the wire protocol, but the client library will normalize 108 // uppercase ASCII to lowercase, so uppercase ASCII is also acceptable. 109 bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode. 110 for _, ch := range bytes { 111 validLowercaseLetter := ch >= 'a' && ch <= 'z' 112 validUppercaseLetter := ch >= 'A' && ch <= 'Z' 113 validDigit := ch >= '0' && ch <= '9' 114 validOther := ch == '.' || ch == '-' || ch == '_' 115 if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther { 116 return false 117 } 118 } 119 return true 120 } 121 122 func isValidGRPCMetadataTextValue(textValue string) bool { 123 // Must be a valid gRPC "ASCII-Value" as defined here: 124 // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md 125 // This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive. 126 bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode. 127 for _, ch := range bytes { 128 if ch < 0x20 || ch > 0x7E { 129 return false 130 } 131 } 132 return true 133 } 134 135 func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) { 136 ctx = withRPCMethod(ctx, rpcMethodName) 137 for _, o := range options { 138 ctx = o(ctx) 139 } 140 timeout := DefaultContextTimeout 141 if tm := req.Header.Get(metadataGrpcTimeout); tm != "" { 142 var err error 143 timeout, err = timeoutDecode(tm) 144 if err != nil { 145 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm) 146 } 147 } 148 var pairs []string 149 for key, vals := range req.Header { 150 key = textproto.CanonicalMIMEHeaderKey(key) 151 for _, val := range vals { 152 // For backwards-compatibility, pass through 'authorization' header with no prefix. 153 if key == "Authorization" { 154 pairs = append(pairs, "authorization", val) 155 } 156 if h, ok := mux.incomingHeaderMatcher(key); ok { 157 if !isValidGRPCMetadataKey(h) { 158 grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h) 159 continue 160 } 161 // Handles "-bin" metadata in grpc, since grpc will do another base64 162 // encode before sending to server, we need to decode it first. 163 if strings.HasSuffix(key, metadataHeaderBinarySuffix) { 164 b, err := decodeBinHeader(val) 165 if err != nil { 166 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err) 167 } 168 169 val = string(b) 170 } else if !isValidGRPCMetadataTextValue(val) { 171 grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h) 172 continue 173 } 174 pairs = append(pairs, h, val) 175 } 176 } 177 } 178 if host := req.Header.Get(xForwardedHost); host != "" { 179 pairs = append(pairs, strings.ToLower(xForwardedHost), host) 180 } else if req.Host != "" { 181 pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host) 182 } 183 184 if addr := req.RemoteAddr; addr != "" { 185 if remoteIP, _, err := net.SplitHostPort(addr); err == nil { 186 if fwd := req.Header.Get(xForwardedFor); fwd == "" { 187 pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP) 188 } else { 189 pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP)) 190 } 191 } 192 } 193 194 if timeout != 0 { 195 //nolint:govet // The context outlives this function 196 ctx, _ = context.WithTimeout(ctx, timeout) 197 } 198 if len(pairs) == 0 { 199 return ctx, nil, nil 200 } 201 md := metadata.Pairs(pairs...) 202 for _, mda := range mux.metadataAnnotators { 203 md = metadata.Join(md, mda(ctx, req)) 204 } 205 return ctx, md, nil 206 } 207 208 // ServerMetadata consists of metadata sent from gRPC server. 209 type ServerMetadata struct { 210 HeaderMD metadata.MD 211 TrailerMD metadata.MD 212 } 213 214 type serverMetadataKey struct{} 215 216 // NewServerMetadataContext creates a new context with ServerMetadata 217 func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context { 218 if ctx == nil { 219 ctx = context.Background() 220 } 221 return context.WithValue(ctx, serverMetadataKey{}, md) 222 } 223 224 // ServerMetadataFromContext returns the ServerMetadata in ctx 225 func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) { 226 if ctx == nil { 227 return md, false 228 } 229 md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata) 230 return 231 } 232 233 // ServerTransportStream implements grpc.ServerTransportStream. 234 // It should only be used by the generated files to support grpc.SendHeader 235 // outside of gRPC server use. 236 type ServerTransportStream struct { 237 mu sync.Mutex 238 header metadata.MD 239 trailer metadata.MD 240 } 241 242 // Method returns the method for the stream. 243 func (s *ServerTransportStream) Method() string { 244 return "" 245 } 246 247 // Header returns the header metadata of the stream. 248 func (s *ServerTransportStream) Header() metadata.MD { 249 s.mu.Lock() 250 defer s.mu.Unlock() 251 return s.header.Copy() 252 } 253 254 // SetHeader sets the header metadata. 255 func (s *ServerTransportStream) SetHeader(md metadata.MD) error { 256 if md.Len() == 0 { 257 return nil 258 } 259 260 s.mu.Lock() 261 s.header = metadata.Join(s.header, md) 262 s.mu.Unlock() 263 return nil 264 } 265 266 // SendHeader sets the header metadata. 267 func (s *ServerTransportStream) SendHeader(md metadata.MD) error { 268 return s.SetHeader(md) 269 } 270 271 // Trailer returns the cached trailer metadata. 272 func (s *ServerTransportStream) Trailer() metadata.MD { 273 s.mu.Lock() 274 defer s.mu.Unlock() 275 return s.trailer.Copy() 276 } 277 278 // SetTrailer sets the trailer metadata. 279 func (s *ServerTransportStream) SetTrailer(md metadata.MD) error { 280 if md.Len() == 0 { 281 return nil 282 } 283 284 s.mu.Lock() 285 s.trailer = metadata.Join(s.trailer, md) 286 s.mu.Unlock() 287 return nil 288 } 289 290 func timeoutDecode(s string) (time.Duration, error) { 291 size := len(s) 292 if size < 2 { 293 return 0, fmt.Errorf("timeout string is too short: %q", s) 294 } 295 d, ok := timeoutUnitToDuration(s[size-1]) 296 if !ok { 297 return 0, fmt.Errorf("timeout unit is not recognized: %q", s) 298 } 299 t, err := strconv.ParseInt(s[:size-1], 10, 64) 300 if err != nil { 301 return 0, err 302 } 303 return d * time.Duration(t), nil 304 } 305 306 func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) { 307 switch u { 308 case 'H': 309 return time.Hour, true 310 case 'M': 311 return time.Minute, true 312 case 'S': 313 return time.Second, true 314 case 'm': 315 return time.Millisecond, true 316 case 'u': 317 return time.Microsecond, true 318 case 'n': 319 return time.Nanosecond, true 320 default: 321 return 322 } 323 } 324 325 // isPermanentHTTPHeader checks whether hdr belongs to the list of 326 // permanent request headers maintained by IANA. 327 // http://www.iana.org/assignments/message-headers/message-headers.xml 328 func isPermanentHTTPHeader(hdr string) bool { 329 switch hdr { 330 case 331 "Accept", 332 "Accept-Charset", 333 "Accept-Language", 334 "Accept-Ranges", 335 "Authorization", 336 "Cache-Control", 337 "Content-Type", 338 "Cookie", 339 "Date", 340 "Expect", 341 "From", 342 "Host", 343 "If-Match", 344 "If-Modified-Since", 345 "If-None-Match", 346 "If-Schedule-Tag-Match", 347 "If-Unmodified-Since", 348 "Max-Forwards", 349 "Origin", 350 "Pragma", 351 "Referer", 352 "User-Agent", 353 "Via", 354 "Warning": 355 return true 356 } 357 return false 358 } 359 360 // isMalformedHTTPHeader checks whether header belongs to the list of 361 // "malformed headers" and would be rejected by the gRPC server. 362 func isMalformedHTTPHeader(header string) bool { 363 _, isMalformed := malformedHTTPHeaders[strings.ToLower(header)] 364 return isMalformed 365 } 366 367 // RPCMethod returns the method string for the server context. The returned 368 // string is in the format of "/package.service/method". 369 func RPCMethod(ctx context.Context) (string, bool) { 370 m := ctx.Value(rpcMethodKey{}) 371 if m == nil { 372 return "", false 373 } 374 ms, ok := m.(string) 375 if !ok { 376 return "", false 377 } 378 return ms, true 379 } 380 381 func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context { 382 return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName) 383 } 384 385 // HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists. 386 // The format of the returned string is defined by the google.api.http path template type. 387 func HTTPPathPattern(ctx context.Context) (string, bool) { 388 m := ctx.Value(httpPathPatternKey{}) 389 if m == nil { 390 return "", false 391 } 392 ms, ok := m.(string) 393 if !ok { 394 return "", false 395 } 396 return ms, true 397 } 398 399 func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context { 400 return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern) 401 }