github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/client.go (about) 1 package http3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strconv" 12 "sync" 13 "sync/atomic" 14 "time" 15 16 "github.com/danielpfeifer02/quic-go-prio-packs" 17 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 18 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 19 "github.com/danielpfeifer02/quic-go-prio-packs/quicvarint" 20 21 "github.com/quic-go/qpack" 22 ) 23 24 // MethodGet0RTT allows a GET request to be sent using 0-RTT. 25 // Note that 0-RTT data doesn't provide replay protection. 26 const MethodGet0RTT = "GET_0RTT" 27 28 const ( 29 defaultUserAgent = "quic-go HTTP/3" 30 defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB 31 ) 32 33 var defaultQuicConfig = &quic.Config{ 34 MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams 35 KeepAlivePeriod: 10 * time.Second, 36 } 37 38 type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) 39 40 var dialAddr dialFunc = quic.DialAddrEarly 41 42 type roundTripperOpts struct { 43 DisableCompression bool 44 EnableDatagram bool 45 MaxHeaderBytes int64 46 AdditionalSettings map[uint64]uint64 47 StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) 48 UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) 49 } 50 51 // client is a HTTP3 client doing requests 52 type client struct { 53 tlsConf *tls.Config 54 config *quic.Config 55 opts *roundTripperOpts 56 57 dialOnce sync.Once 58 dialer dialFunc 59 handshakeErr error 60 61 requestWriter *requestWriter 62 63 decoder *qpack.Decoder 64 65 hostname string 66 conn atomic.Pointer[quic.EarlyConnection] 67 68 logger utils.Logger 69 } 70 71 var _ roundTripCloser = &client{} 72 73 func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { 74 if conf == nil { 75 conf = defaultQuicConfig.Clone() 76 conf.EnableDatagrams = opts.EnableDatagram 77 } 78 if opts.EnableDatagram && !conf.EnableDatagrams { 79 return nil, errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") 80 } 81 if len(conf.Versions) == 0 { 82 conf = conf.Clone() 83 conf.Versions = []quic.Version{protocol.SupportedVersions[0]} 84 } 85 if len(conf.Versions) != 1 { 86 return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") 87 } 88 if conf.MaxIncomingStreams == 0 { 89 conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams 90 } 91 logger := utils.DefaultLogger.WithPrefix("h3 client") 92 93 if tlsConf == nil { 94 tlsConf = &tls.Config{} 95 } else { 96 tlsConf = tlsConf.Clone() 97 } 98 if tlsConf.ServerName == "" { 99 sni, _, err := net.SplitHostPort(hostname) 100 if err != nil { 101 // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. 102 sni = hostname 103 } 104 tlsConf.ServerName = sni 105 } 106 // Replace existing ALPNs by H3 107 tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} 108 109 return &client{ 110 hostname: authorityAddr("https", hostname), 111 tlsConf: tlsConf, 112 requestWriter: newRequestWriter(logger), 113 decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), 114 config: conf, 115 opts: opts, 116 dialer: dialer, 117 logger: logger, 118 }, nil 119 } 120 121 func (c *client) dial(ctx context.Context) error { 122 var err error 123 var conn quic.EarlyConnection 124 if c.dialer != nil { 125 conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config) 126 } else { 127 conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) 128 } 129 if err != nil { 130 return err 131 } 132 c.conn.Store(&conn) 133 134 // send the SETTINGs frame, using 0-RTT data, if possible 135 go func() { 136 if err := c.setupConn(conn); err != nil { 137 c.logger.Debugf("Setting up connection failed: %s", err) 138 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "") 139 } 140 }() 141 142 if c.opts.StreamHijacker != nil { 143 go c.handleBidirectionalStreams(conn) 144 } 145 go c.handleUnidirectionalStreams(conn) 146 return nil 147 } 148 149 func (c *client) setupConn(conn quic.EarlyConnection) error { 150 // open the control stream 151 str, err := conn.OpenUniStream() 152 if err != nil { 153 return err 154 } 155 b := make([]byte, 0, 64) 156 b = quicvarint.Append(b, streamTypeControlStream) 157 // send the SETTINGS frame 158 b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b) 159 _, err = str.Write(b) 160 return err 161 } 162 163 func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) { 164 for { 165 str, err := conn.AcceptStream(context.Background()) 166 if err != nil { 167 c.logger.Debugf("accepting bidirectional stream failed: %s", err) 168 return 169 } 170 go func(str quic.Stream) { 171 _, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) { 172 return c.opts.StreamHijacker(ft, conn, str, e) 173 }) 174 if err == errHijacked { 175 return 176 } 177 if err != nil { 178 c.logger.Debugf("error handling stream: %s", err) 179 } 180 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream") 181 }(str) 182 } 183 } 184 185 func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) { 186 var rcvdControlStream atomic.Bool 187 188 for { 189 str, err := conn.AcceptUniStream(context.Background()) 190 if err != nil { 191 c.logger.Debugf("accepting unidirectional stream failed: %s", err) 192 return 193 } 194 195 go func(str quic.ReceiveStream) { 196 streamType, err := quicvarint.Read(quicvarint.NewReader(str)) 197 if err != nil { 198 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) { 199 return 200 } 201 c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) 202 return 203 } 204 // We're only interested in the control stream here. 205 switch streamType { 206 case streamTypeControlStream: 207 case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: 208 // Our QPACK implementation doesn't use the dynamic table yet. 209 // TODO: check that only one stream of each type is opened. 210 return 211 case streamTypePushStream: 212 // We never increased the Push ID, so we don't expect any push streams. 213 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "") 214 return 215 default: 216 if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) { 217 return 218 } 219 str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 220 return 221 } 222 // Only a single control stream is allowed. 223 if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { 224 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") 225 return 226 } 227 f, err := parseNextFrame(str, nil) 228 if err != nil { 229 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "") 230 return 231 } 232 sf, ok := f.(*settingsFrame) 233 if !ok { 234 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "") 235 return 236 } 237 if !sf.Datagram { 238 return 239 } 240 // If datagram support was enabled on our side as well as on the server side, 241 // we can expect it to have been negotiated both on the transport and on the HTTP/3 layer. 242 // Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT). 243 if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams { 244 conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support") 245 } 246 }(str) 247 } 248 } 249 250 func (c *client) Close() error { 251 conn := c.conn.Load() 252 if conn == nil { 253 return nil 254 } 255 return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "") 256 } 257 258 func (c *client) maxHeaderBytes() uint64 { 259 if c.opts.MaxHeaderBytes <= 0 { 260 return defaultMaxResponseHeaderBytes 261 } 262 return uint64(c.opts.MaxHeaderBytes) 263 } 264 265 // RoundTripOpt executes a request and returns a response 266 func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 267 rsp, err := c.roundTripOpt(req, opt) 268 if err != nil && req.Context().Err() != nil { 269 // if the context was canceled, return the context cancellation error 270 err = req.Context().Err() 271 } 272 return rsp, err 273 } 274 275 func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 276 if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { 277 return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host) 278 } 279 280 c.dialOnce.Do(func() { 281 c.handshakeErr = c.dial(req.Context()) 282 }) 283 if c.handshakeErr != nil { 284 return nil, c.handshakeErr 285 } 286 287 // At this point, c.conn is guaranteed to be set. 288 conn := *c.conn.Load() 289 290 // Immediately send out this request, if this is a 0-RTT request. 291 if req.Method == MethodGet0RTT { 292 req.Method = http.MethodGet 293 } else { 294 // wait for the handshake to complete 295 select { 296 case <-conn.HandshakeComplete(): 297 case <-req.Context().Done(): 298 return nil, req.Context().Err() 299 } 300 } 301 302 str, err := conn.OpenStreamSync(req.Context()) 303 if err != nil { 304 return nil, err 305 } 306 307 // Request Cancellation: 308 // This go routine keeps running even after RoundTripOpt() returns. 309 // It is shut down when the application is done processing the body. 310 reqDone := make(chan struct{}) 311 done := make(chan struct{}) 312 go func() { 313 defer close(done) 314 select { 315 case <-req.Context().Done(): 316 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 317 str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)) 318 case <-reqDone: 319 } 320 }() 321 322 doneChan := reqDone 323 if opt.DontCloseRequestStream { 324 doneChan = nil 325 } 326 rsp, rerr := c.doRequest(req, conn, str, opt, doneChan) 327 if rerr.err != nil { // if any error occurred 328 close(reqDone) 329 <-done 330 if rerr.streamErr != 0 { // if it was a stream error 331 str.CancelWrite(quic.StreamErrorCode(rerr.streamErr)) 332 } 333 if rerr.connErr != 0 { // if it was a connection error 334 var reason string 335 if rerr.err != nil { 336 reason = rerr.err.Error() 337 } 338 conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) 339 } 340 return nil, maybeReplaceError(rerr.err) 341 } 342 if opt.DontCloseRequestStream { 343 close(reqDone) 344 <-done 345 } 346 return rsp, maybeReplaceError(rerr.err) 347 } 348 349 // cancelingReader reads from the io.Reader. 350 // It cancels writing on the stream if any error other than io.EOF occurs. 351 type cancelingReader struct { 352 r io.Reader 353 str Stream 354 } 355 356 func (r *cancelingReader) Read(b []byte) (int, error) { 357 n, err := r.r.Read(b) 358 if err != nil && err != io.EOF { 359 r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 360 } 361 return n, err 362 } 363 364 func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error { 365 defer body.Close() 366 buf := make([]byte, bodyCopyBufferSize) 367 sr := &cancelingReader{str: str, r: body} 368 if contentLength == -1 { 369 _, err := io.CopyBuffer(str, sr, buf) 370 return err 371 } 372 373 // make sure we don't send more bytes than the content length 374 n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf) 375 if err != nil { 376 return err 377 } 378 var extra int64 379 extra, err = io.CopyBuffer(io.Discard, sr, buf) 380 n += extra 381 if n > contentLength { 382 str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 383 return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n) 384 } 385 return err 386 } 387 388 func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) { 389 var requestGzip bool 390 if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { 391 requestGzip = true 392 } 393 if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil { 394 return nil, newStreamError(ErrCodeInternalError, err) 395 } 396 397 if req.Body == nil && !opt.DontCloseRequestStream { 398 str.Close() 399 } 400 401 hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") }) 402 if req.Body != nil { 403 // send the request body asynchronously 404 go func() { 405 contentLength := int64(-1) 406 // According to the documentation for http.Request.ContentLength, 407 // a value of 0 with a non-nil Body is also treated as unknown content length. 408 if req.ContentLength > 0 { 409 contentLength = req.ContentLength 410 } 411 if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil { 412 c.logger.Errorf("Error writing request: %s", err) 413 } 414 if !opt.DontCloseRequestStream { 415 hstr.Close() 416 } 417 }() 418 } 419 420 frame, err := parseNextFrame(str, nil) 421 if err != nil { 422 return nil, newStreamError(ErrCodeFrameError, err) 423 } 424 hf, ok := frame.(*headersFrame) 425 if !ok { 426 return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame")) 427 } 428 if hf.Length > c.maxHeaderBytes() { 429 return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes())) 430 } 431 headerBlock := make([]byte, hf.Length) 432 if _, err := io.ReadFull(str, headerBlock); err != nil { 433 return nil, newStreamError(ErrCodeRequestIncomplete, err) 434 } 435 hfs, err := c.decoder.DecodeFull(headerBlock) 436 if err != nil { 437 // TODO: use the right error code 438 return nil, newConnError(ErrCodeGeneralProtocolError, err) 439 } 440 441 res, err := responseFromHeaders(hfs) 442 if err != nil { 443 return nil, newStreamError(ErrCodeMessageError, err) 444 } 445 connState := conn.ConnectionState().TLS 446 res.TLS = &connState 447 res.Request = req 448 // Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set). 449 // See section 4.1.2 of RFC 9114. 450 var httpStr Stream 451 if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 { 452 httpStr = newLengthLimitedStream(hstr, res.ContentLength) 453 } else { 454 httpStr = hstr 455 } 456 respBody := newResponseBody(httpStr, conn, reqDone) 457 458 // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. 459 _, hasTransferEncoding := res.Header["Transfer-Encoding"] 460 isInformational := res.StatusCode >= 100 && res.StatusCode < 200 461 isNoContent := res.StatusCode == http.StatusNoContent 462 isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 463 if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { 464 res.ContentLength = -1 465 if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { 466 if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { 467 res.ContentLength = clen64 468 } 469 } 470 } 471 472 if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { 473 res.Header.Del("Content-Encoding") 474 res.Header.Del("Content-Length") 475 res.ContentLength = -1 476 res.Body = newGzipReader(respBody) 477 res.Uncompressed = true 478 } else { 479 res.Body = respBody 480 } 481 482 return res, requestError{} 483 } 484 485 func (c *client) HandshakeComplete() bool { 486 conn := c.conn.Load() 487 if conn == nil { 488 return false 489 } 490 select { 491 case <-(*conn).HandshakeComplete(): 492 return true 493 default: 494 return false 495 } 496 }