google.golang.org/grpc@v1.62.1/internal/transport/http_util.go (about) 1 /* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package transport 20 21 import ( 22 "bufio" 23 "encoding/base64" 24 "errors" 25 "fmt" 26 "io" 27 "math" 28 "net" 29 "net/http" 30 "net/url" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 "unicode/utf8" 36 37 "golang.org/x/net/http2" 38 "golang.org/x/net/http2/hpack" 39 "google.golang.org/grpc/codes" 40 ) 41 42 const ( 43 // http2MaxFrameLen specifies the max length of a HTTP2 frame. 44 http2MaxFrameLen = 16384 // 16KB frame 45 // https://httpwg.org/specs/rfc7540.html#SettingValues 46 http2InitHeaderTableSize = 4096 47 ) 48 49 var ( 50 clientPreface = []byte(http2.ClientPreface) 51 http2ErrConvTab = map[http2.ErrCode]codes.Code{ 52 http2.ErrCodeNo: codes.Internal, 53 http2.ErrCodeProtocol: codes.Internal, 54 http2.ErrCodeInternal: codes.Internal, 55 http2.ErrCodeFlowControl: codes.ResourceExhausted, 56 http2.ErrCodeSettingsTimeout: codes.Internal, 57 http2.ErrCodeStreamClosed: codes.Internal, 58 http2.ErrCodeFrameSize: codes.Internal, 59 http2.ErrCodeRefusedStream: codes.Unavailable, 60 http2.ErrCodeCancel: codes.Canceled, 61 http2.ErrCodeCompression: codes.Internal, 62 http2.ErrCodeConnect: codes.Internal, 63 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, 64 http2.ErrCodeInadequateSecurity: codes.PermissionDenied, 65 http2.ErrCodeHTTP11Required: codes.Internal, 66 } 67 // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table. 68 HTTPStatusConvTab = map[int]codes.Code{ 69 // 400 Bad Request - INTERNAL. 70 http.StatusBadRequest: codes.Internal, 71 // 401 Unauthorized - UNAUTHENTICATED. 72 http.StatusUnauthorized: codes.Unauthenticated, 73 // 403 Forbidden - PERMISSION_DENIED. 74 http.StatusForbidden: codes.PermissionDenied, 75 // 404 Not Found - UNIMPLEMENTED. 76 http.StatusNotFound: codes.Unimplemented, 77 // 429 Too Many Requests - UNAVAILABLE. 78 http.StatusTooManyRequests: codes.Unavailable, 79 // 502 Bad Gateway - UNAVAILABLE. 80 http.StatusBadGateway: codes.Unavailable, 81 // 503 Service Unavailable - UNAVAILABLE. 82 http.StatusServiceUnavailable: codes.Unavailable, 83 // 504 Gateway timeout - UNAVAILABLE. 84 http.StatusGatewayTimeout: codes.Unavailable, 85 } 86 ) 87 88 var grpcStatusDetailsBinHeader = "grpc-status-details-bin" 89 90 // isReservedHeader checks whether hdr belongs to HTTP2 headers 91 // reserved by gRPC protocol. Any other headers are classified as the 92 // user-specified metadata. 93 func isReservedHeader(hdr string) bool { 94 if hdr != "" && hdr[0] == ':' { 95 return true 96 } 97 switch hdr { 98 case "content-type", 99 "user-agent", 100 "grpc-message-type", 101 "grpc-encoding", 102 "grpc-message", 103 "grpc-status", 104 "grpc-timeout", 105 // Intentionally exclude grpc-previous-rpc-attempts and 106 // grpc-retry-pushback-ms, which are "reserved", but their API 107 // intentionally works via metadata. 108 "te": 109 return true 110 default: 111 return false 112 } 113 } 114 115 // isWhitelistedHeader checks whether hdr should be propagated into metadata 116 // visible to users, even though it is classified as "reserved", above. 117 func isWhitelistedHeader(hdr string) bool { 118 switch hdr { 119 case ":authority", "user-agent": 120 return true 121 default: 122 return false 123 } 124 } 125 126 const binHdrSuffix = "-bin" 127 128 func encodeBinHeader(v []byte) string { 129 return base64.RawStdEncoding.EncodeToString(v) 130 } 131 132 func decodeBinHeader(v string) ([]byte, error) { 133 if len(v)%4 == 0 { 134 // Input was padded, or padding was not necessary. 135 return base64.StdEncoding.DecodeString(v) 136 } 137 return base64.RawStdEncoding.DecodeString(v) 138 } 139 140 func encodeMetadataHeader(k, v string) string { 141 if strings.HasSuffix(k, binHdrSuffix) { 142 return encodeBinHeader(([]byte)(v)) 143 } 144 return v 145 } 146 147 func decodeMetadataHeader(k, v string) (string, error) { 148 if strings.HasSuffix(k, binHdrSuffix) { 149 b, err := decodeBinHeader(v) 150 return string(b), err 151 } 152 return v, nil 153 } 154 155 type timeoutUnit uint8 156 157 const ( 158 hour timeoutUnit = 'H' 159 minute timeoutUnit = 'M' 160 second timeoutUnit = 'S' 161 millisecond timeoutUnit = 'm' 162 microsecond timeoutUnit = 'u' 163 nanosecond timeoutUnit = 'n' 164 ) 165 166 func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { 167 switch u { 168 case hour: 169 return time.Hour, true 170 case minute: 171 return time.Minute, true 172 case second: 173 return time.Second, true 174 case millisecond: 175 return time.Millisecond, true 176 case microsecond: 177 return time.Microsecond, true 178 case nanosecond: 179 return time.Nanosecond, true 180 default: 181 } 182 return 183 } 184 185 func decodeTimeout(s string) (time.Duration, error) { 186 size := len(s) 187 if size < 2 { 188 return 0, fmt.Errorf("transport: timeout string is too short: %q", s) 189 } 190 if size > 9 { 191 // Spec allows for 8 digits plus the unit. 192 return 0, fmt.Errorf("transport: timeout string is too long: %q", s) 193 } 194 unit := timeoutUnit(s[size-1]) 195 d, ok := timeoutUnitToDuration(unit) 196 if !ok { 197 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) 198 } 199 t, err := strconv.ParseInt(s[:size-1], 10, 64) 200 if err != nil { 201 return 0, err 202 } 203 const maxHours = math.MaxInt64 / int64(time.Hour) 204 if d == time.Hour && t > maxHours { 205 // This timeout would overflow math.MaxInt64; clamp it. 206 return time.Duration(math.MaxInt64), nil 207 } 208 return d * time.Duration(t), nil 209 } 210 211 const ( 212 spaceByte = ' ' 213 tildeByte = '~' 214 percentByte = '%' 215 ) 216 217 // encodeGrpcMessage is used to encode status code in header field 218 // "grpc-message". It does percent encoding and also replaces invalid utf-8 219 // characters with Unicode replacement character. 220 // 221 // It checks to see if each individual byte in msg is an allowable byte, and 222 // then either percent encoding or passing it through. When percent encoding, 223 // the byte is converted into hexadecimal notation with a '%' prepended. 224 func encodeGrpcMessage(msg string) string { 225 if msg == "" { 226 return "" 227 } 228 lenMsg := len(msg) 229 for i := 0; i < lenMsg; i++ { 230 c := msg[i] 231 if !(c >= spaceByte && c <= tildeByte && c != percentByte) { 232 return encodeGrpcMessageUnchecked(msg) 233 } 234 } 235 return msg 236 } 237 238 func encodeGrpcMessageUnchecked(msg string) string { 239 var sb strings.Builder 240 for len(msg) > 0 { 241 r, size := utf8.DecodeRuneInString(msg) 242 for _, b := range []byte(string(r)) { 243 if size > 1 { 244 // If size > 1, r is not ascii. Always do percent encoding. 245 fmt.Fprintf(&sb, "%%%02X", b) 246 continue 247 } 248 249 // The for loop is necessary even if size == 1. r could be 250 // utf8.RuneError. 251 // 252 // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD". 253 if b >= spaceByte && b <= tildeByte && b != percentByte { 254 sb.WriteByte(b) 255 } else { 256 fmt.Fprintf(&sb, "%%%02X", b) 257 } 258 } 259 msg = msg[size:] 260 } 261 return sb.String() 262 } 263 264 // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage. 265 func decodeGrpcMessage(msg string) string { 266 if msg == "" { 267 return "" 268 } 269 lenMsg := len(msg) 270 for i := 0; i < lenMsg; i++ { 271 if msg[i] == percentByte && i+2 < lenMsg { 272 return decodeGrpcMessageUnchecked(msg) 273 } 274 } 275 return msg 276 } 277 278 func decodeGrpcMessageUnchecked(msg string) string { 279 var sb strings.Builder 280 lenMsg := len(msg) 281 for i := 0; i < lenMsg; i++ { 282 c := msg[i] 283 if c == percentByte && i+2 < lenMsg { 284 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8) 285 if err != nil { 286 sb.WriteByte(c) 287 } else { 288 sb.WriteByte(byte(parsed)) 289 i += 2 290 } 291 } else { 292 sb.WriteByte(c) 293 } 294 } 295 return sb.String() 296 } 297 298 type bufWriter struct { 299 pool *sync.Pool 300 buf []byte 301 offset int 302 batchSize int 303 conn net.Conn 304 err error 305 } 306 307 func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter { 308 w := &bufWriter{ 309 batchSize: batchSize, 310 conn: conn, 311 pool: pool, 312 } 313 // this indicates that we should use non shared buf 314 if pool == nil { 315 w.buf = make([]byte, batchSize) 316 } 317 return w 318 } 319 320 func (w *bufWriter) Write(b []byte) (n int, err error) { 321 if w.err != nil { 322 return 0, w.err 323 } 324 if w.batchSize == 0 { // Buffer has been disabled. 325 n, err = w.conn.Write(b) 326 return n, toIOError(err) 327 } 328 if w.buf == nil { 329 b := w.pool.Get().(*[]byte) 330 w.buf = *b 331 } 332 for len(b) > 0 { 333 nn := copy(w.buf[w.offset:], b) 334 b = b[nn:] 335 w.offset += nn 336 n += nn 337 if w.offset >= w.batchSize { 338 err = w.flushKeepBuffer() 339 } 340 } 341 return n, err 342 } 343 344 func (w *bufWriter) Flush() error { 345 err := w.flushKeepBuffer() 346 // Only release the buffer if we are in a "shared" mode 347 if w.buf != nil && w.pool != nil { 348 b := w.buf 349 w.pool.Put(&b) 350 w.buf = nil 351 } 352 return err 353 } 354 355 func (w *bufWriter) flushKeepBuffer() error { 356 if w.err != nil { 357 return w.err 358 } 359 if w.offset == 0 { 360 return nil 361 } 362 _, w.err = w.conn.Write(w.buf[:w.offset]) 363 w.err = toIOError(w.err) 364 w.offset = 0 365 return w.err 366 } 367 368 type ioError struct { 369 error 370 } 371 372 func (i ioError) Unwrap() error { 373 return i.error 374 } 375 376 func isIOError(err error) bool { 377 return errors.As(err, &ioError{}) 378 } 379 380 func toIOError(err error) error { 381 if err == nil { 382 return nil 383 } 384 return ioError{error: err} 385 } 386 387 type framer struct { 388 writer *bufWriter 389 fr *http2.Framer 390 } 391 392 var writeBufferPoolMap map[int]*sync.Pool = make(map[int]*sync.Pool) 393 var writeBufferMutex sync.Mutex 394 395 func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer { 396 if writeBufferSize < 0 { 397 writeBufferSize = 0 398 } 399 var r io.Reader = conn 400 if readBufferSize > 0 { 401 r = bufio.NewReaderSize(r, readBufferSize) 402 } 403 var pool *sync.Pool 404 if sharedWriteBuffer { 405 pool = getWriteBufferPool(writeBufferSize) 406 } 407 w := newBufWriter(conn, writeBufferSize, pool) 408 f := &framer{ 409 writer: w, 410 fr: http2.NewFramer(w, r), 411 } 412 f.fr.SetMaxReadFrameSize(http2MaxFrameLen) 413 // Opt-in to Frame reuse API on framer to reduce garbage. 414 // Frames aren't safe to read from after a subsequent call to ReadFrame. 415 f.fr.SetReuseFrames() 416 f.fr.MaxHeaderListSize = maxHeaderListSize 417 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) 418 return f 419 } 420 421 func getWriteBufferPool(writeBufferSize int) *sync.Pool { 422 writeBufferMutex.Lock() 423 defer writeBufferMutex.Unlock() 424 size := writeBufferSize * 2 425 pool, ok := writeBufferPoolMap[size] 426 if ok { 427 return pool 428 } 429 pool = &sync.Pool{ 430 New: func() any { 431 b := make([]byte, size) 432 return &b 433 }, 434 } 435 writeBufferPoolMap[size] = pool 436 return pool 437 } 438 439 // parseDialTarget returns the network and address to pass to dialer. 440 func parseDialTarget(target string) (string, string) { 441 net := "tcp" 442 m1 := strings.Index(target, ":") 443 m2 := strings.Index(target, ":/") 444 // handle unix:addr which will fail with url.Parse 445 if m1 >= 0 && m2 < 0 { 446 if n := target[0:m1]; n == "unix" { 447 return n, target[m1+1:] 448 } 449 } 450 if m2 >= 0 { 451 t, err := url.Parse(target) 452 if err != nil { 453 return net, target 454 } 455 scheme := t.Scheme 456 addr := t.Path 457 if scheme == "unix" { 458 if addr == "" { 459 addr = t.Host 460 } 461 return scheme, addr 462 } 463 } 464 return net, target 465 }