github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/client.go (about) 1 package http3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strconv" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/daeuniverse/quic-go" 17 "github.com/daeuniverse/quic-go/internal/protocol" 18 "github.com/daeuniverse/quic-go/internal/utils" 19 "github.com/daeuniverse/quic-go/quicvarint" 20 21 "github.com/quic-go/qpack" 22 ) 23 24 // MethodGet0RTT allows a GET request to be sent using 0-RTT. 25 // Note that 0-RTT data doesn't provide replay protection. 26 const MethodGet0RTT = "GET_0RTT" 27 28 const ( 29 defaultUserAgent = "quic-go HTTP/3" 30 defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB 31 ) 32 33 var defaultQuicConfig = &quic.Config{ 34 MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams 35 KeepAlivePeriod: 10 * time.Second, 36 } 37 38 type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) 39 40 var dialAddr dialFunc = quic.DialAddrEarly 41 42 type roundTripperOpts struct { 43 DisableCompression bool 44 EnableDatagram bool 45 MaxHeaderBytes int64 46 AdditionalSettings map[uint64]uint64 47 StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) 48 UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) 49 } 50 51 // client is a HTTP3 client doing requests 52 type client struct { 53 tlsConf *tls.Config 54 config *quic.Config 55 opts *roundTripperOpts 56 57 dialOnce sync.Once 58 dialer dialFunc 59 handshakeErr error 60 61 receivedSettings chan struct{} // closed once the server's SETTINGS frame was processed 62 settings *Settings // set once receivedSettings is closed 63 64 requestWriter *requestWriter 65 66 decoder *qpack.Decoder 67 68 hostname string 69 conn atomic.Pointer[quic.EarlyConnection] 70 71 logger utils.Logger 72 } 73 74 var _ roundTripCloser = &client{} 75 76 func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { 77 if conf == nil { 78 conf = defaultQuicConfig.Clone() 79 conf.EnableDatagrams = opts.EnableDatagram 80 } 81 if opts.EnableDatagram && !conf.EnableDatagrams { 82 return nil, errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") 83 } 84 if len(conf.Versions) == 0 { 85 conf = conf.Clone() 86 conf.Versions = []quic.Version{protocol.SupportedVersions[0]} 87 } 88 if len(conf.Versions) != 1 { 89 return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") 90 } 91 if conf.MaxIncomingStreams == 0 { 92 conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams 93 } 94 logger := utils.DefaultLogger.WithPrefix("h3 client") 95 96 if tlsConf == nil { 97 tlsConf = &tls.Config{} 98 } else { 99 tlsConf = tlsConf.Clone() 100 } 101 if tlsConf.ServerName == "" { 102 sni, _, err := net.SplitHostPort(hostname) 103 if err != nil { 104 // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. 105 sni = hostname 106 } 107 tlsConf.ServerName = sni 108 } 109 // Replace existing ALPNs by H3 110 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} 111 112 return &client{ 113 hostname: authorityAddr("https", hostname), 114 tlsConf: tlsConf, 115 requestWriter: newRequestWriter(logger), 116 receivedSettings: make(chan struct{}), 117 decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), 118 config: conf, 119 opts: opts, 120 dialer: dialer, 121 logger: logger, 122 }, nil 123 } 124 125 func (c *client) dial(ctx context.Context) error { 126 var err error 127 var conn quic.EarlyConnection 128 if c.dialer != nil { 129 conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) 130 } else { 131 conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) 132 } 133 if err != nil { 134 return err 135 } 136 c.conn.Store(&conn) 137 138 // send the SETTINGs frame, using 0-RTT data, if possible 139 go func() { 140 if err := c.setupConn(conn); err != nil { 141 c.logger.Debugf("Setting up connection failed: %s", err) 142 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") 143 } 144 }() 145 146 if c.opts.StreamHijacker != nil { 147 go c.handleBidirectionalStreams(conn) 148 } 149 go c.handleUnidirectionalStreams(conn) 150 return nil 151 } 152 153 func (c *client) setupConn(conn quic.EarlyConnection) error { 154 // open the control stream 155 str, err := conn.OpenUniStream() 156 if err != nil { 157 return err 158 } 159 b := make([]byte, 0, 64) 160 b = quicvarint.Append(b, streamTypeControlStream) 161 // send the SETTINGS frame 162 b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) 163 _, err = str.Write(b) 164 return err 165 } 166 167 func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { 168 for { 169 str, err := conn.AcceptStream(context.Background()) 170 if err != nil { 171 c.logger.Debugf("accepting bidirectional stream failed: %s", err) 172 return 173 } 174 go func(str quic.Stream) { 175 _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { 176 return c.opts.StreamHijacker(ft, conn, str, e) 177 }) 178 if err == errHijacked { 179 return 180 } 181 if err != nil { 182 c.logger.Debugf("error handling stream: %s", err) 183 } 184 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") 185 }(str) 186 } 187 } 188 189 func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { 190 var rcvdControlStream atomic.Bool 191 192 for { 193 str, err := conn.AcceptUniStream(context.Background()) 194 if err != nil { 195 c.logger.Debugf("accepting unidirectional stream failed: %s", err) 196 return 197 } 198 199 go func(str quic.ReceiveStream) { 200 streamType, err := quicvarint.Read(quicvarint.NewReader(str)) 201 if err != nil { 202 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { 203 return 204 } 205 c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) 206 return 207 } 208 // We're only interested in the control stream here. 209 switch streamType { 210 case streamTypeControlStream: 211 case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: 212 // Our QPACK implementation doesn't use the dynamic table yet. 213 // TODO: check that only one stream of each type is opened. 214 return 215 case streamTypePushStream: 216 // We never increased the Push ID, so we don't expect any push streams. 217 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") 218 return 219 default: 220 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { 221 return 222 } 223 str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 224 return 225 } 226 // Only a single control stream is allowed. 227 if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { 228 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") 229 return 230 } 231 f, err := parseNextFrame(str, nil) 232 if err != nil { 233 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") 234 return 235 } 236 sf, ok := f.(*settingsFrame) 237 if !ok { 238 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") 239 return 240 } 241 c.settings = &Settings{ 242 EnableDatagram: sf.Datagram, 243 EnableExtendedConnect: sf.ExtendedConnect, 244 Other: sf.Other, 245 } 246 close(c.receivedSettings) 247 if !sf.Datagram { 248 return 249 } 250 // If datagram support was enabled on our side as well as on the server side, 251 // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. 252 // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). 253 if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { 254 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") 255 } 256 }(str) 257 } 258 } 259 260 func (c *client) Close() error { 261 conn := c.conn.Load() 262 if conn == nil { 263 return nil 264 } 265 return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") 266 } 267 268 func (c *client) maxHeaderBytes() uint64 { 269 if c.opts.MaxHeaderBytes <= 0 { 270 return defaultMaxResponseHeaderBytes 271 } 272 return uint64(c.opts.MaxHeaderBytes) 273 } 274 275 // RoundTripOpt executes a request and returns a response 276 func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 277 rsp, err := c.roundTripOpt(req, opt) 278 if err != nil && req.Context().Err() != nil { 279 // if the context was canceled, return the context cancellation error 280 err = req.Context().Err() 281 } 282 return rsp, err 283 } 284 285 func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 286 if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { 287 return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) 288 } 289 290 c.dialOnce.Do(func() { 291 c.handshakeErr = c.dial(req.Context()) 292 }) 293 if c.handshakeErr != nil { 294 return nil, c.handshakeErr 295 } 296 297 // At this point, c.conn is guaranteed to be set. 298 conn := *c.conn.Load() 299 300 // Immediately send out this request, if this is a 0-RTT request. 301 if req.Method == MethodGet0RTT { 302 req.Method = http.MethodGet 303 } else { 304 // wait for the handshake to complete 305 select { 306 case <-conn.HandshakeComplete(): 307 case <-req.Context().Done(): 308 return nil, req.Context().Err() 309 } 310 } 311 312 if opt.CheckSettings != nil { 313 // wait for the server's SETTINGS frame to arrive 314 select { 315 case <-c.receivedSettings: 316 case <-conn.Context().Done(): 317 return nil, context.Cause(conn.Context()) 318 } 319 if err := opt.CheckSettings(*c.settings); err != nil { 320 return nil, err 321 } 322 } 323 324 str, err := conn.OpenStreamSync(req.Context()) 325 if err != nil { 326 return nil, err 327 } 328 329 // Request Cancellation: 330 // This go routine keeps running even after RoundTripOpt() returns. 331 // It is shut down when the application is done processing the body. 332 reqDone := make(chan struct{}) 333 done := make(chan struct{}) 334 go func() { 335 defer close(done) 336 select { 337 case <-req.Context().Done(): 338 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 339 str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) 340 case <-reqDone: 341 } 342 }() 343 344 doneChan := reqDone 345 if opt.DontCloseRequestStream { 346 doneChan = nil 347 } 348 rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) 349 if rerr.err != nil { // if any error occurred 350 close(reqDone) 351 <-done 352 if rerr.streamErr != 0 { // if it was a stream error 353 str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) 354 } 355 if rerr.connErr != 0 { // if it was a connection error 356 var reason string 357 if rerr.err != nil { 358 reason = rerr.err.Error() 359 } 360 conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) 361 } 362 return nil, maybeReplaceError(rerr.err) 363 } 364 if opt.DontCloseRequestStream { 365 close(reqDone) 366 <-done 367 } 368 return rsp, maybeReplaceError(rerr.err) 369 } 370 371 // cancelingReader reads from the io.Reader. 372 // It cancels writing on the stream if any error other than io.EOF occurs. 373 type cancelingReader struct { 374 r io.Reader 375 str Stream 376 } 377 378 func (r *cancelingReader) Read(b []byte) (int, error) { 379 n, err := r.r.Read(b) 380 if err != nil && err != io.EOF { 381 r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 382 } 383 return n, err 384 } 385 386 func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { 387 defer body.Close() 388 buf := make([]byte, bodyCopyBufferSize) 389 sr := &cancelingReader{str: str, r: body} 390 if contentLength == -1 { 391 _, err := io.CopyBuffer(str, sr, buf) 392 return err 393 } 394 395 // make sure we don't send more bytes than the content length 396 n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf) 397 if err != nil { 398 return err 399 } 400 var extra int64 401 extra, err = io.CopyBuffer(io.Discard, sr, buf) 402 n += extra 403 if n > contentLength { 404 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 405 return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n) 406 } 407 return err 408 } 409 410 func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { 411 var requestGzip bool 412 if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { 413 requestGzip = true 414 } 415 if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { 416 return nil, newStreamError(ErrCodeInternalError, err) 417 } 418 419 if req.Body == nil && !opt.DontCloseRequestStream { 420 str.Close() 421 } 422 423 hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) 424 if req.Body != nil { 425 // send the request body asynchronously 426 go func() { 427 contentLength := int64(-1) 428 // According to the documentation for http.Request.ContentLength, 429 // a value of 0 with a non-nil Body is also treated as unknown content length. 430 if req.ContentLength > 0 { 431 contentLength = req.ContentLength 432 } 433 if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil { 434 c.logger.Errorf("Error writing request: %s", err) 435 } 436 if !opt.DontCloseRequestStream { 437 hstr.Close() 438 } 439 }() 440 } 441 442 frame, err := parseNextFrame(str, nil) 443 if err != nil { 444 return nil, newStreamError(ErrCodeFrameError, err) 445 } 446 hf, ok := frame.(*headersFrame) 447 if !ok { 448 return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) 449 } 450 if hf.Length > c.maxHeaderBytes() { 451 return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) 452 } 453 headerBlock := make([]byte, hf.Length) 454 if _, err := io.ReadFull(str, headerBlock); err != nil { 455 return nil, newStreamError(ErrCodeRequestIncomplete, err) 456 } 457 hfs, err := c.decoder.DecodeFull(headerBlock) 458 if err != nil { 459 // TODO: use the right error code 460 return nil, newConnError(ErrCodeGeneralProtocolError, err) 461 } 462 463 res, err := responseFromHeaders(hfs) 464 if err != nil { 465 return nil, newStreamError(ErrCodeMessageError, err) 466 } 467 connState := conn.ConnectionState().TLS 468 res.TLS = &connState 469 res.Request = req 470 // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). 471 // See section 4.1.2 of RFC 9114. 472 var httpStr Stream 473 if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 { 474 httpStr = newLengthLimitedStream(hstr, res.ContentLength) 475 } else { 476 httpStr = hstr 477 } 478 respBody := newResponseBody(httpStr, conn, reqDone) 479 480 // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. 481 _, hasTransferEncoding := res.Header["Transfer-Encoding"] 482 isInformational := res.StatusCode >= 100 && res.StatusCode < 200 483 isNoContent := res.StatusCode == http.StatusNoContent 484 isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 485 if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { 486 res.ContentLength = -1 487 if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { 488 if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { 489 res.ContentLength = clen64 490 } 491 } 492 } 493 494 if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { 495 res.Header.Del("Content-Encoding") 496 res.Header.Del("Content-Length") 497 res.ContentLength = -1 498 res.Body = newGzipReader(respBody) 499 res.Uncompressed = true 500 } else { 501 res.Body = respBody 502 } 503 504 return res, requestError{} 505 } 506 507 func (c *client) HandshakeComplete() bool { 508 conn := c.conn.Load() 509 if conn == nil { 510 return false 511 } 512 select { 513 case <-(*conn).HandshakeComplete(): 514 return true 515 default: 516 return false 517 } 518 }