github.com/tumi8/quic-go@v0.37.4-tum/http3/client.go (about) 1 package http3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strconv" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/tumi8/quic-go" 17 "github.com/tumi8/quic-go/noninternal/protocol" 18 "github.com/tumi8/quic-go/noninternal/utils" 19 "github.com/tumi8/quic-go/quicvarint" 20 21 "github.com/quic-go/qpack" 22 ) 23 24 // MethodGet0RTT allows a GET request to be sent using 0-RTT. 25 // Note that 0-RTT data doesn't provide replay protection. 26 const MethodGet0RTT = "GET_0RTT" 27 28 const ( 29 defaultUserAgent = "quic-go HTTP/3" 30 defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB 31 ) 32 33 var defaultQuicConfig = &quic.Config{ 34 MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams 35 KeepAlivePeriod: 10 * time.Second, 36 } 37 38 type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) 39 40 var dialAddr dialFunc = quic.DialAddrEarly 41 42 type roundTripperOpts struct { 43 DisableCompression bool 44 EnableDatagram bool 45 MaxHeaderBytes int64 46 AdditionalSettings map[uint64]uint64 47 StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) 48 UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) 49 } 50 51 // client is a HTTP3 client doing requests 52 type client struct { 53 tlsConf *tls.Config 54 config *quic.Config 55 opts *roundTripperOpts 56 57 dialOnce sync.Once 58 dialer dialFunc 59 handshakeErr error 60 61 requestWriter *requestWriter 62 63 decoder *qpack.Decoder 64 65 hostname string 66 conn atomic.Pointer[quic.EarlyConnection] 67 68 logger utils.Logger 69 } 70 71 var _ roundTripCloser = &client{} 72 73 func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { 74 if conf == nil { 75 conf = defaultQuicConfig.Clone() 76 } 77 if len(conf.Versions) == 0 { 78 conf = conf.Clone() 79 conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]} 80 } 81 if len(conf.Versions) != 1 { 82 return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") 83 } 84 if conf.MaxIncomingStreams == 0 { 85 conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams 86 } 87 conf.EnableDatagrams = opts.EnableDatagram 88 logger := utils.DefaultLogger.WithPrefix("h3 client") 89 90 if tlsConf == nil { 91 tlsConf = &tls.Config{} 92 } else { 93 tlsConf = tlsConf.Clone() 94 } 95 if tlsConf.ServerName == "" { 96 sni, _, err := net.SplitHostPort(hostname) 97 if err != nil { 98 // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. 99 sni = hostname 100 } 101 tlsConf.ServerName = sni 102 } 103 // Replace existing ALPNs by H3 104 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} 105 106 return &client{ 107 hostname: authorityAddr("https", hostname), 108 tlsConf: tlsConf, 109 requestWriter: newRequestWriter(logger), 110 decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), 111 config: conf, 112 opts: opts, 113 dialer: dialer, 114 logger: logger, 115 }, nil 116 } 117 118 func (c *client) dial(ctx context.Context) error { 119 var err error 120 var conn quic.EarlyConnection 121 if c.dialer != nil { 122 conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) 123 } else { 124 conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) 125 } 126 if err != nil { 127 return err 128 } 129 c.conn.Store(&conn) 130 131 // send the SETTINGs frame, using 0-RTT data, if possible 132 go func() { 133 if err := c.setupConn(conn); err != nil { 134 c.logger.Debugf("Setting up connection failed: %s", err) 135 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") 136 } 137 }() 138 139 if c.opts.StreamHijacker != nil { 140 go c.handleBidirectionalStreams(conn) 141 } 142 go c.handleUnidirectionalStreams(conn) 143 return nil 144 } 145 146 func (c *client) setupConn(conn quic.EarlyConnection) error { 147 // open the control stream 148 str, err := conn.OpenUniStream() 149 if err != nil { 150 return err 151 } 152 b := make([]byte, 0, 64) 153 b = quicvarint.Append(b, streamTypeControlStream) 154 // send the SETTINGS frame 155 b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) 156 _, err = str.Write(b) 157 return err 158 } 159 160 func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { 161 for { 162 str, err := conn.AcceptStream(context.Background()) 163 if err != nil { 164 c.logger.Debugf("accepting bidirectional stream failed: %s", err) 165 return 166 } 167 go func(str quic.Stream) { 168 _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { 169 return c.opts.StreamHijacker(ft, conn, str, e) 170 }) 171 if err == errHijacked { 172 return 173 } 174 if err != nil { 175 c.logger.Debugf("error handling stream: %s", err) 176 } 177 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") 178 }(str) 179 } 180 } 181 182 func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { 183 for { 184 str, err := conn.AcceptUniStream(context.Background()) 185 if err != nil { 186 c.logger.Debugf("accepting unidirectional stream failed: %s", err) 187 return 188 } 189 190 go func(str quic.ReceiveStream) { 191 streamType, err := quicvarint.Read(quicvarint.NewReader(str)) 192 if err != nil { 193 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { 194 return 195 } 196 c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) 197 return 198 } 199 // We're only interested in the control stream here. 200 switch streamType { 201 case streamTypeControlStream: 202 case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: 203 // Our QPACK implementation doesn't use the dynamic table yet. 204 // TODO: check that only one stream of each type is opened. 205 return 206 case streamTypePushStream: 207 // We never increased the Push ID, so we don't expect any push streams. 208 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") 209 return 210 default: 211 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { 212 return 213 } 214 str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 215 return 216 } 217 f, err := parseNextFrame(str, nil) 218 if err != nil { 219 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") 220 return 221 } 222 sf, ok := f.(*settingsFrame) 223 if !ok { 224 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") 225 return 226 } 227 if !sf.Datagram { 228 return 229 } 230 // If datagram support was enabled on our side as well as on the server side, 231 // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. 232 // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). 233 if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { 234 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") 235 } 236 }(str) 237 } 238 } 239 240 func (c *client) Close() error { 241 conn := c.conn.Load() 242 if conn == nil { 243 return nil 244 } 245 return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") 246 } 247 248 func (c *client) maxHeaderBytes() uint64 { 249 if c.opts.MaxHeaderBytes <= 0 { 250 return defaultMaxResponseHeaderBytes 251 } 252 return uint64(c.opts.MaxHeaderBytes) 253 } 254 255 // RoundTripOpt executes a request and returns a response 256 func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 257 if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { 258 return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) 259 } 260 261 c.dialOnce.Do(func() { 262 c.handshakeErr = c.dial(req.Context()) 263 }) 264 if c.handshakeErr != nil { 265 return nil, c.handshakeErr 266 } 267 268 // At this point, c.conn is guaranteed to be set. 269 conn := *c.conn.Load() 270 271 // Immediately send out this request, if this is a 0-RTT request. 272 if req.Method == MethodGet0RTT { 273 req.Method = http.MethodGet 274 } else { 275 // wait for the handshake to complete 276 select { 277 case <-conn.HandshakeComplete(): 278 case <-req.Context().Done(): 279 return nil, req.Context().Err() 280 } 281 } 282 283 str, err := conn.OpenStreamSync(req.Context()) 284 if err != nil { 285 return nil, err 286 } 287 288 // Request Cancellation: 289 // This go routine keeps running even after RoundTripOpt() returns. 290 // It is shut down when the application is done processing the body. 291 reqDone := make(chan struct{}) 292 done := make(chan struct{}) 293 go func() { 294 defer close(done) 295 select { 296 case <-req.Context().Done(): 297 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 298 str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) 299 case <-reqDone: 300 } 301 }() 302 303 doneChan := reqDone 304 if opt.DontCloseRequestStream { 305 doneChan = nil 306 } 307 rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) 308 if rerr.err != nil { // if any error occurred 309 close(reqDone) 310 <-done 311 if rerr.streamErr != 0 { // if it was a stream error 312 str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) 313 } 314 if rerr.connErr != 0 { // if it was a connection error 315 var reason string 316 if rerr.err != nil { 317 reason = rerr.err.Error() 318 } 319 conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) 320 } 321 return nil, rerr.err 322 } 323 if opt.DontCloseRequestStream { 324 close(reqDone) 325 <-done 326 } 327 return rsp, rerr.err 328 } 329 330 // cancelingReader reads from the io.Reader. 331 // It cancels writing on the stream if any error other than io.EOF occurs. 332 type cancelingReader struct { 333 r io.Reader 334 str Stream 335 } 336 337 func (r *cancelingReader) Read(b []byte) (int, error) { 338 n, err := r.r.Read(b) 339 if err != nil && err != io.EOF { 340 r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 341 } 342 return n, err 343 } 344 345 func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { 346 defer body.Close() 347 buf := make([]byte, bodyCopyBufferSize) 348 sr := &cancelingReader{str: str, r: body} 349 if contentLength == -1 { 350 _, err := io.CopyBuffer(str, sr, buf) 351 return err 352 } 353 354 // make sure we don't send more bytes than the content length 355 n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf) 356 if err != nil { 357 return err 358 } 359 var extra int64 360 extra, err = io.CopyBuffer(io.Discard, sr, buf) 361 n += extra 362 if n > contentLength { 363 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 364 return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n) 365 } 366 return err 367 } 368 369 func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { 370 var requestGzip bool 371 if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { 372 requestGzip = true 373 } 374 if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { 375 return nil, newStreamError(ErrCodeInternalError, err) 376 } 377 378 if req.Body == nil && !opt.DontCloseRequestStream { 379 str.Close() 380 } 381 382 hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) 383 if req.Body != nil { 384 // send the request body asynchronously 385 go func() { 386 contentLength := int64(-1) 387 // According to the documentation for http.Request.ContentLength, 388 // a value of 0 with a non-nil Body is also treated as unknown content length. 389 if req.ContentLength > 0 { 390 contentLength = req.ContentLength 391 } 392 if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil { 393 c.logger.Errorf("Error writing request: %s", err) 394 } 395 if !opt.DontCloseRequestStream { 396 hstr.Close() 397 } 398 }() 399 } 400 401 frame, err := parseNextFrame(str, nil) 402 if err != nil { 403 return nil, newStreamError(ErrCodeFrameError, err) 404 } 405 hf, ok := frame.(*headersFrame) 406 if !ok { 407 return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) 408 } 409 if hf.Length > c.maxHeaderBytes() { 410 return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) 411 } 412 headerBlock := make([]byte, hf.Length) 413 if _, err := io.ReadFull(str, headerBlock); err != nil { 414 return nil, newStreamError(ErrCodeRequestIncomplete, err) 415 } 416 hfs, err := c.decoder.DecodeFull(headerBlock) 417 if err != nil { 418 // TODO: use the right error code 419 return nil, newConnError(ErrCodeGeneralProtocolError, err) 420 } 421 422 res, err := responseFromHeaders(hfs) 423 if err != nil { 424 return nil, newStreamError(ErrCodeMessageError, err) 425 } 426 connState := conn.ConnectionState().TLS 427 res.TLS = &connState 428 res.Request = req 429 // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). 430 // See section 4.1.2 of RFC 9114. 431 var httpStr Stream 432 if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 { 433 httpStr = newLengthLimitedStream(hstr, res.ContentLength) 434 } else { 435 httpStr = hstr 436 } 437 respBody := newResponseBody(httpStr, conn, reqDone) 438 439 // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. 440 _, hasTransferEncoding := res.Header["Transfer-Encoding"] 441 isInformational := res.StatusCode >= 100 && res.StatusCode < 200 442 isNoContent := res.StatusCode == http.StatusNoContent 443 isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 444 if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { 445 res.ContentLength = -1 446 if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { 447 if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { 448 res.ContentLength = clen64 449 } 450 } 451 } 452 453 if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { 454 res.Header.Del("Content-Encoding") 455 res.Header.Del("Content-Length") 456 res.ContentLength = -1 457 res.Body = newGzipReader(respBody) 458 res.Uncompressed = true 459 } else { 460 res.Body = respBody 461 } 462 463 return res, requestError{} 464 } 465 466 func (c *client) HandshakeComplete() bool { 467 conn := c.conn.Load() 468 if conn == nil { 469 return false 470 } 471 select { 472 case <-(*conn).HandshakeComplete(): 473 return true 474 default: 475 return false 476 } 477 }