github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/client.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strconv"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/daeuniverse/quic-go"
    17  	"github.com/daeuniverse/quic-go/internal/protocol"
    18  	"github.com/daeuniverse/quic-go/internal/utils"
    19  	"github.com/daeuniverse/quic-go/quicvarint"
    20  
    21  	"github.com/quic-go/qpack"
    22  )
    23  
    24  // MethodGet0RTT allows a GET request to be sent using 0-RTT.
    25  // Note that 0-RTT data doesn't provide replay protection.
    26  const MethodGet0RTT = "GET_0RTT"
    27  
    28  const (
    29  	defaultUserAgent              = "quic-go HTTP/3"
    30  	defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
    31  )
    32  
    33  var defaultQuicConfig = &quic.Config{
    34  	MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
    35  	KeepAlivePeriod:    10 * time.Second,
    36  }
    37  
    38  type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    39  
    40  var dialAddr dialFunc = quic.DialAddrEarly
    41  
    42  type roundTripperOpts struct {
    43  	DisableCompression bool
    44  	EnableDatagram     bool
    45  	MaxHeaderBytes     int64
    46  	AdditionalSettings map[uint64]uint64
    47  	StreamHijacker     func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
    48  	UniStreamHijacker  func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
    49  }
    50  
    51  // client is a HTTP3 client doing requests
    52  type client struct {
    53  	tlsConf *tls.Config
    54  	config  *quic.Config
    55  	opts    *roundTripperOpts
    56  
    57  	dialOnce     sync.Once
    58  	dialer       dialFunc
    59  	handshakeErr error
    60  
    61  	receivedSettings chan struct{} // closed once the server's SETTINGS frame was processed
    62  	settings         *Settings     // set once receivedSettings is closed
    63  
    64  	requestWriter *requestWriter
    65  
    66  	decoder *qpack.Decoder
    67  
    68  	hostname string
    69  	conn     atomic.Pointer[quic.EarlyConnection]
    70  
    71  	logger utils.Logger
    72  }
    73  
    74  var _ roundTripCloser = &client{}
    75  
    76  func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) {
    77  	if conf == nil {
    78  		conf = defaultQuicConfig.Clone()
    79  		conf.EnableDatagrams = opts.EnableDatagram
    80  	}
    81  	if opts.EnableDatagram && !conf.EnableDatagrams {
    82  		return nil, errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
    83  	}
    84  	if len(conf.Versions) == 0 {
    85  		conf = conf.Clone()
    86  		conf.Versions = []quic.Version{protocol.SupportedVersions[0]}
    87  	}
    88  	if len(conf.Versions) != 1 {
    89  		return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
    90  	}
    91  	if conf.MaxIncomingStreams == 0 {
    92  		conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
    93  	}
    94  	logger := utils.DefaultLogger.WithPrefix("h3 client")
    95  
    96  	if tlsConf == nil {
    97  		tlsConf = &tls.Config{}
    98  	} else {
    99  		tlsConf = tlsConf.Clone()
   100  	}
   101  	if tlsConf.ServerName == "" {
   102  		sni, _, err := net.SplitHostPort(hostname)
   103  		if err != nil {
   104  			// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
   105  			sni = hostname
   106  		}
   107  		tlsConf.ServerName = sni
   108  	}
   109  	// Replace existing ALPNs by H3
   110  	tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}
   111  
   112  	return &client{
   113  		hostname:         authorityAddr("https", hostname),
   114  		tlsConf:          tlsConf,
   115  		requestWriter:    newRequestWriter(logger),
   116  		receivedSettings: make(chan struct{}),
   117  		decoder:          qpack.NewDecoder(func(hf qpack.HeaderField) {}),
   118  		config:           conf,
   119  		opts:             opts,
   120  		dialer:           dialer,
   121  		logger:           logger,
   122  	}, nil
   123  }
   124  
   125  func (c *client) dial(ctx context.Context) error {
   126  	var err error
   127  	var conn quic.EarlyConnection
   128  	if c.dialer != nil {
   129  		conn, err = c.dialer(ctx, c.hostname, c.tlsConf, c.config)
   130  	} else {
   131  		conn, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
   132  	}
   133  	if err != nil {
   134  		return err
   135  	}
   136  	c.conn.Store(&conn)
   137  
   138  	// send the SETTINGs frame, using 0-RTT data, if possible
   139  	go func() {
   140  		if err := c.setupConn(conn); err != nil {
   141  			c.logger.Debugf("Setting up connection failed: %s", err)
   142  			conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
   143  		}
   144  	}()
   145  
   146  	if c.opts.StreamHijacker != nil {
   147  		go c.handleBidirectionalStreams(conn)
   148  	}
   149  	go c.handleUnidirectionalStreams(conn)
   150  	return nil
   151  }
   152  
   153  func (c *client) setupConn(conn quic.EarlyConnection) error {
   154  	// open the control stream
   155  	str, err := conn.OpenUniStream()
   156  	if err != nil {
   157  		return err
   158  	}
   159  	b := make([]byte, 0, 64)
   160  	b = quicvarint.Append(b, streamTypeControlStream)
   161  	// send the SETTINGS frame
   162  	b = (&settingsFrame{Datagram: c.opts.EnableDatagram, Other: c.opts.AdditionalSettings}).Append(b)
   163  	_, err = str.Write(b)
   164  	return err
   165  }
   166  
   167  func (c *client) handleBidirectionalStreams(conn quic.EarlyConnection) {
   168  	for {
   169  		str, err := conn.AcceptStream(context.Background())
   170  		if err != nil {
   171  			c.logger.Debugf("accepting bidirectional stream failed: %s", err)
   172  			return
   173  		}
   174  		go func(str quic.Stream) {
   175  			_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
   176  				return c.opts.StreamHijacker(ft, conn, str, e)
   177  			})
   178  			if err == errHijacked {
   179  				return
   180  			}
   181  			if err != nil {
   182  				c.logger.Debugf("error handling stream: %s", err)
   183  			}
   184  			conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
   185  		}(str)
   186  	}
   187  }
   188  
   189  func (c *client) handleUnidirectionalStreams(conn quic.EarlyConnection) {
   190  	var rcvdControlStream atomic.Bool
   191  
   192  	for {
   193  		str, err := conn.AcceptUniStream(context.Background())
   194  		if err != nil {
   195  			c.logger.Debugf("accepting unidirectional stream failed: %s", err)
   196  			return
   197  		}
   198  
   199  		go func(str quic.ReceiveStream) {
   200  			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
   201  			if err != nil {
   202  				if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, err) {
   203  					return
   204  				}
   205  				c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
   206  				return
   207  			}
   208  			// We're only interested in the control stream here.
   209  			switch streamType {
   210  			case streamTypeControlStream:
   211  			case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
   212  				// Our QPACK implementation doesn't use the dynamic table yet.
   213  				// TODO: check that only one stream of each type is opened.
   214  				return
   215  			case streamTypePushStream:
   216  				// We never increased the Push ID, so we don't expect any push streams.
   217  				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
   218  				return
   219  			default:
   220  				if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
   221  					return
   222  				}
   223  				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   224  				return
   225  			}
   226  			// Only a single control stream is allowed.
   227  			if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr {
   228  				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
   229  				return
   230  			}
   231  			f, err := parseNextFrame(str, nil)
   232  			if err != nil {
   233  				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
   234  				return
   235  			}
   236  			sf, ok := f.(*settingsFrame)
   237  			if !ok {
   238  				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
   239  				return
   240  			}
   241  			c.settings = &Settings{
   242  				EnableDatagram:        sf.Datagram,
   243  				EnableExtendedConnect: sf.ExtendedConnect,
   244  				Other:                 sf.Other,
   245  			}
   246  			close(c.receivedSettings)
   247  			if !sf.Datagram {
   248  				return
   249  			}
   250  			// If datagram support was enabled on our side as well as on the server side,
   251  			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
   252  			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
   253  			if c.opts.EnableDatagram && !conn.ConnectionState().SupportsDatagrams {
   254  				conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
   255  			}
   256  		}(str)
   257  	}
   258  }
   259  
   260  func (c *client) Close() error {
   261  	conn := c.conn.Load()
   262  	if conn == nil {
   263  		return nil
   264  	}
   265  	return (*conn).CloseWithError(quic.ApplicationErrorCode(ErrCodeNoError), "")
   266  }
   267  
   268  func (c *client) maxHeaderBytes() uint64 {
   269  	if c.opts.MaxHeaderBytes <= 0 {
   270  		return defaultMaxResponseHeaderBytes
   271  	}
   272  	return uint64(c.opts.MaxHeaderBytes)
   273  }
   274  
   275  // RoundTripOpt executes a request and returns a response
   276  func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   277  	rsp, err := c.roundTripOpt(req, opt)
   278  	if err != nil && req.Context().Err() != nil {
   279  		// if the context was canceled, return the context cancellation error
   280  		err = req.Context().Err()
   281  	}
   282  	return rsp, err
   283  }
   284  
   285  func (c *client) roundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   286  	if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
   287  		return nil, fmt.Errorf("http3 client BUG: RoundTripOpt called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
   288  	}
   289  
   290  	c.dialOnce.Do(func() {
   291  		c.handshakeErr = c.dial(req.Context())
   292  	})
   293  	if c.handshakeErr != nil {
   294  		return nil, c.handshakeErr
   295  	}
   296  
   297  	// At this point, c.conn is guaranteed to be set.
   298  	conn := *c.conn.Load()
   299  
   300  	// Immediately send out this request, if this is a 0-RTT request.
   301  	if req.Method == MethodGet0RTT {
   302  		req.Method = http.MethodGet
   303  	} else {
   304  		// wait for the handshake to complete
   305  		select {
   306  		case <-conn.HandshakeComplete():
   307  		case <-req.Context().Done():
   308  			return nil, req.Context().Err()
   309  		}
   310  	}
   311  
   312  	if opt.CheckSettings != nil {
   313  		// wait for the server's SETTINGS frame to arrive
   314  		select {
   315  		case <-c.receivedSettings:
   316  		case <-conn.Context().Done():
   317  			return nil, context.Cause(conn.Context())
   318  		}
   319  		if err := opt.CheckSettings(*c.settings); err != nil {
   320  			return nil, err
   321  		}
   322  	}
   323  
   324  	str, err := conn.OpenStreamSync(req.Context())
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	// Request Cancellation:
   330  	// This go routine keeps running even after RoundTripOpt() returns.
   331  	// It is shut down when the application is done processing the body.
   332  	reqDone := make(chan struct{})
   333  	done := make(chan struct{})
   334  	go func() {
   335  		defer close(done)
   336  		select {
   337  		case <-req.Context().Done():
   338  			str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   339  			str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   340  		case <-reqDone:
   341  		}
   342  	}()
   343  
   344  	doneChan := reqDone
   345  	if opt.DontCloseRequestStream {
   346  		doneChan = nil
   347  	}
   348  	rsp, rerr := c.doRequest(req, conn, str, opt, doneChan)
   349  	if rerr.err != nil { // if any error occurred
   350  		close(reqDone)
   351  		<-done
   352  		if rerr.streamErr != 0 { // if it was a stream error
   353  			str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
   354  		}
   355  		if rerr.connErr != 0 { // if it was a connection error
   356  			var reason string
   357  			if rerr.err != nil {
   358  				reason = rerr.err.Error()
   359  			}
   360  			conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
   361  		}
   362  		return nil, maybeReplaceError(rerr.err)
   363  	}
   364  	if opt.DontCloseRequestStream {
   365  		close(reqDone)
   366  		<-done
   367  	}
   368  	return rsp, maybeReplaceError(rerr.err)
   369  }
   370  
   371  // cancelingReader reads from the io.Reader.
   372  // It cancels writing on the stream if any error other than io.EOF occurs.
   373  type cancelingReader struct {
   374  	r   io.Reader
   375  	str Stream
   376  }
   377  
   378  func (r *cancelingReader) Read(b []byte) (int, error) {
   379  	n, err := r.r.Read(b)
   380  	if err != nil && err != io.EOF {
   381  		r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   382  	}
   383  	return n, err
   384  }
   385  
   386  func (c *client) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
   387  	defer body.Close()
   388  	buf := make([]byte, bodyCopyBufferSize)
   389  	sr := &cancelingReader{str: str, r: body}
   390  	if contentLength == -1 {
   391  		_, err := io.CopyBuffer(str, sr, buf)
   392  		return err
   393  	}
   394  
   395  	// make sure we don't send more bytes than the content length
   396  	n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
   397  	if err != nil {
   398  		return err
   399  	}
   400  	var extra int64
   401  	extra, err = io.CopyBuffer(io.Discard, sr, buf)
   402  	n += extra
   403  	if n > contentLength {
   404  		str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   405  		return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
   406  	}
   407  	return err
   408  }
   409  
   410  func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str quic.Stream, opt RoundTripOpt, reqDone chan<- struct{}) (*http.Response, requestError) {
   411  	var requestGzip bool
   412  	if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
   413  		requestGzip = true
   414  	}
   415  	if err := c.requestWriter.WriteRequestHeader(str, req, requestGzip); err != nil {
   416  		return nil, newStreamError(ErrCodeInternalError, err)
   417  	}
   418  
   419  	if req.Body == nil && !opt.DontCloseRequestStream {
   420  		str.Close()
   421  	}
   422  
   423  	hstr := newStream(str, func() { conn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "") })
   424  	if req.Body != nil {
   425  		// send the request body asynchronously
   426  		go func() {
   427  			contentLength := int64(-1)
   428  			// According to the documentation for http.Request.ContentLength,
   429  			// a value of 0 with a non-nil Body is also treated as unknown content length.
   430  			if req.ContentLength > 0 {
   431  				contentLength = req.ContentLength
   432  			}
   433  			if err := c.sendRequestBody(hstr, req.Body, contentLength); err != nil {
   434  				c.logger.Errorf("Error writing request: %s", err)
   435  			}
   436  			if !opt.DontCloseRequestStream {
   437  				hstr.Close()
   438  			}
   439  		}()
   440  	}
   441  
   442  	frame, err := parseNextFrame(str, nil)
   443  	if err != nil {
   444  		return nil, newStreamError(ErrCodeFrameError, err)
   445  	}
   446  	hf, ok := frame.(*headersFrame)
   447  	if !ok {
   448  		return nil, newConnError(ErrCodeFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
   449  	}
   450  	if hf.Length > c.maxHeaderBytes() {
   451  		return nil, newStreamError(ErrCodeFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
   452  	}
   453  	headerBlock := make([]byte, hf.Length)
   454  	if _, err := io.ReadFull(str, headerBlock); err != nil {
   455  		return nil, newStreamError(ErrCodeRequestIncomplete, err)
   456  	}
   457  	hfs, err := c.decoder.DecodeFull(headerBlock)
   458  	if err != nil {
   459  		// TODO: use the right error code
   460  		return nil, newConnError(ErrCodeGeneralProtocolError, err)
   461  	}
   462  
   463  	res, err := responseFromHeaders(hfs)
   464  	if err != nil {
   465  		return nil, newStreamError(ErrCodeMessageError, err)
   466  	}
   467  	connState := conn.ConnectionState().TLS
   468  	res.TLS = &connState
   469  	res.Request = req
   470  	// Check that the server doesn't send more data in DATA frames than indicated by the Content-Length header (if set).
   471  	// See section 4.1.2 of RFC 9114.
   472  	var httpStr Stream
   473  	if _, ok := res.Header["Content-Length"]; ok && res.ContentLength >= 0 {
   474  		httpStr = newLengthLimitedStream(hstr, res.ContentLength)
   475  	} else {
   476  		httpStr = hstr
   477  	}
   478  	respBody := newResponseBody(httpStr, conn, reqDone)
   479  
   480  	// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
   481  	_, hasTransferEncoding := res.Header["Transfer-Encoding"]
   482  	isInformational := res.StatusCode >= 100 && res.StatusCode < 200
   483  	isNoContent := res.StatusCode == http.StatusNoContent
   484  	isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
   485  	if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
   486  		res.ContentLength = -1
   487  		if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
   488  			if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
   489  				res.ContentLength = clen64
   490  			}
   491  		}
   492  	}
   493  
   494  	if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
   495  		res.Header.Del("Content-Encoding")
   496  		res.Header.Del("Content-Length")
   497  		res.ContentLength = -1
   498  		res.Body = newGzipReader(respBody)
   499  		res.Uncompressed = true
   500  	} else {
   501  		res.Body = respBody
   502  	}
   503  
   504  	return res, requestError{}
   505  }
   506  
   507  func (c *client) HandshakeComplete() bool {
   508  	conn := c.conn.Load()
   509  	if conn == nil {
   510  		return false
   511  	}
   512  	select {
   513  	case <-(*conn).HandshakeComplete():
   514  		return true
   515  	default:
   516  		return false
   517  	}
   518  }