github.com/quic-go/quic-go@v0.44.0/http3/client.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log/slog"
     9  	"net/http"
    10  	"net/http/httptrace"
    11  	"net/textproto"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/quic-go/quic-go"
    16  	"github.com/quic-go/quic-go/internal/protocol"
    17  	"github.com/quic-go/quic-go/quicvarint"
    18  
    19  	"github.com/quic-go/qpack"
    20  )
    21  
    22  const (
    23  	// MethodGet0RTT allows a GET request to be sent using 0-RTT.
    24  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    25  	MethodGet0RTT = "GET_0RTT"
    26  	// MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
    27  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    28  	MethodHead0RTT = "HEAD_0RTT"
    29  )
    30  
    31  const (
    32  	defaultUserAgent              = "quic-go HTTP/3"
    33  	defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
    34  )
    35  
    36  var defaultQuicConfig = &quic.Config{
    37  	MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
    38  	KeepAlivePeriod:    10 * time.Second,
    39  }
    40  
    41  // SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
    42  type SingleDestinationRoundTripper struct {
    43  	Connection quic.Connection
    44  
    45  	// Enable support for HTTP/3 datagrams (RFC 9297).
    46  	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    47  	EnableDatagrams bool
    48  
    49  	// Additional HTTP/3 settings.
    50  	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
    51  	AdditionalSettings map[uint64]uint64
    52  	StreamHijacker     func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
    53  	UniStreamHijacker  func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
    54  
    55  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    56  	// allowed in the server's response header.
    57  	// Zero means to use a default limit.
    58  	MaxResponseHeaderBytes int64
    59  
    60  	// DisableCompression, if true, prevents the Transport from requesting compression with an
    61  	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
    62  	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
    63  	// decoded in the Response.Body.
    64  	// However, if the user explicitly requested gzip it is not automatically uncompressed.
    65  	DisableCompression bool
    66  
    67  	Logger *slog.Logger
    68  
    69  	initOnce      sync.Once
    70  	hconn         *connection
    71  	requestWriter *requestWriter
    72  	decoder       *qpack.Decoder
    73  }
    74  
    75  var _ http.RoundTripper = &SingleDestinationRoundTripper{}
    76  
    77  func (c *SingleDestinationRoundTripper) Start() Connection {
    78  	c.initOnce.Do(func() { c.init() })
    79  	return c.hconn
    80  }
    81  
    82  func (c *SingleDestinationRoundTripper) init() {
    83  	c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
    84  	c.requestWriter = newRequestWriter()
    85  	c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
    86  	// send the SETTINGs frame, using 0-RTT data, if possible
    87  	go func() {
    88  		if err := c.setupConn(c.hconn); err != nil {
    89  			if c.Logger != nil {
    90  				c.Logger.Debug("Setting up connection failed", "error", err)
    91  			}
    92  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
    93  		}
    94  	}()
    95  	if c.StreamHijacker != nil {
    96  		go c.handleBidirectionalStreams()
    97  	}
    98  	go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
    99  }
   100  
   101  func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
   102  	// open the control stream
   103  	str, err := conn.OpenUniStream()
   104  	if err != nil {
   105  		return err
   106  	}
   107  	b := make([]byte, 0, 64)
   108  	b = quicvarint.Append(b, streamTypeControlStream)
   109  	// send the SETTINGS frame
   110  	b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
   111  	_, err = str.Write(b)
   112  	return err
   113  }
   114  
   115  func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
   116  	for {
   117  		str, err := c.hconn.AcceptStream(context.Background())
   118  		if err != nil {
   119  			if c.Logger != nil {
   120  				c.Logger.Debug("accepting bidirectional stream failed", "error", err)
   121  			}
   122  			return
   123  		}
   124  		fp := &frameParser{
   125  			r:    str,
   126  			conn: c.hconn,
   127  			unknownFrameHandler: func(ft FrameType, e error) (processed bool, err error) {
   128  				id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   129  				return c.StreamHijacker(ft, id, str, e)
   130  			},
   131  		}
   132  		go func() {
   133  			if _, err := fp.ParseNext(); err == errHijacked {
   134  				return
   135  			}
   136  			if err != nil {
   137  				if c.Logger != nil {
   138  					c.Logger.Debug("error handling stream", "error", err)
   139  				}
   140  			}
   141  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
   142  		}()
   143  	}
   144  }
   145  
   146  func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
   147  	if c.MaxResponseHeaderBytes <= 0 {
   148  		return defaultMaxResponseHeaderBytes
   149  	}
   150  	return uint64(c.MaxResponseHeaderBytes)
   151  }
   152  
   153  // RoundTrip executes a request and returns a response
   154  func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   155  	c.initOnce.Do(func() { c.init() })
   156  
   157  	rsp, err := c.roundTrip(req)
   158  	if err != nil && req.Context().Err() != nil {
   159  		// if the context was canceled, return the context cancellation error
   160  		err = req.Context().Err()
   161  	}
   162  	return rsp, err
   163  }
   164  
   165  func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
   166  	// Immediately send out this request, if this is a 0-RTT request.
   167  	switch req.Method {
   168  	case MethodGet0RTT:
   169  		// don't modify the original request
   170  		reqCopy := *req
   171  		req = &reqCopy
   172  		req.Method = http.MethodGet
   173  	case MethodHead0RTT:
   174  		// don't modify the original request
   175  		reqCopy := *req
   176  		req = &reqCopy
   177  		req.Method = http.MethodHead
   178  	default:
   179  		// wait for the handshake to complete
   180  		earlyConn, ok := c.Connection.(quic.EarlyConnection)
   181  		if ok {
   182  			select {
   183  			case <-earlyConn.HandshakeComplete():
   184  			case <-req.Context().Done():
   185  				return nil, req.Context().Err()
   186  			}
   187  		}
   188  	}
   189  
   190  	// It is only possible to send an Extended CONNECT request once the SETTINGS were received.
   191  	// See section 3 of RFC 8441.
   192  	if isExtendedConnectRequest(req) {
   193  		connCtx := c.Connection.Context()
   194  		// wait for the server's SETTINGS frame to arrive
   195  		select {
   196  		case <-c.hconn.ReceivedSettings():
   197  		case <-connCtx.Done():
   198  			return nil, context.Cause(connCtx)
   199  		}
   200  		if !c.hconn.Settings().EnableExtendedConnect {
   201  			return nil, errors.New("http3: server didn't enable Extended CONNECT")
   202  		}
   203  	}
   204  
   205  	reqDone := make(chan struct{})
   206  	str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	// Request Cancellation:
   212  	// This go routine keeps running even after RoundTripOpt() returns.
   213  	// It is shut down when the application is done processing the body.
   214  	done := make(chan struct{})
   215  	go func() {
   216  		defer close(done)
   217  		select {
   218  		case <-req.Context().Done():
   219  			str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   220  			str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   221  		case <-reqDone:
   222  		}
   223  	}()
   224  
   225  	rsp, err := c.doRequest(req, str)
   226  	if err != nil { // if any error occurred
   227  		close(reqDone)
   228  		<-done
   229  		return nil, maybeReplaceError(err)
   230  	}
   231  	return rsp, maybeReplaceError(err)
   232  }
   233  
   234  func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
   235  	c.initOnce.Do(func() { c.init() })
   236  
   237  	return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
   238  }
   239  
   240  // cancelingReader reads from the io.Reader.
   241  // It cancels writing on the stream if any error other than io.EOF occurs.
   242  type cancelingReader struct {
   243  	r   io.Reader
   244  	str Stream
   245  }
   246  
   247  func (r *cancelingReader) Read(b []byte) (int, error) {
   248  	n, err := r.r.Read(b)
   249  	if err != nil && err != io.EOF {
   250  		r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   251  	}
   252  	return n, err
   253  }
   254  
   255  func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
   256  	defer body.Close()
   257  	buf := make([]byte, bodyCopyBufferSize)
   258  	sr := &cancelingReader{str: str, r: body}
   259  	if contentLength == -1 {
   260  		_, err := io.CopyBuffer(str, sr, buf)
   261  		return err
   262  	}
   263  
   264  	// make sure we don't send more bytes than the content length
   265  	n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	var extra int64
   270  	extra, err = io.CopyBuffer(io.Discard, sr, buf)
   271  	n += extra
   272  	if n > contentLength {
   273  		str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   274  		return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
   275  	}
   276  	return err
   277  }
   278  
   279  func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
   280  	if err := str.SendRequestHeader(req); err != nil {
   281  		return nil, err
   282  	}
   283  	if req.Body == nil {
   284  		str.Close()
   285  	} else {
   286  		// send the request body asynchronously
   287  		go func() {
   288  			contentLength := int64(-1)
   289  			// According to the documentation for http.Request.ContentLength,
   290  			// a value of 0 with a non-nil Body is also treated as unknown content length.
   291  			if req.ContentLength > 0 {
   292  				contentLength = req.ContentLength
   293  			}
   294  			if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
   295  				if c.Logger != nil {
   296  					c.Logger.Debug("error writing request", "error", err)
   297  				}
   298  			}
   299  			str.Close()
   300  		}()
   301  	}
   302  
   303  	// copy from net/http: support 1xx responses
   304  	trace := httptrace.ContextClientTrace(req.Context())
   305  	num1xx := 0               // number of informational 1xx headers received
   306  	const max1xxResponses = 5 // arbitrary bound on number of informational responses
   307  
   308  	var res *http.Response
   309  	for {
   310  		var err error
   311  		res, err = str.ReadResponse()
   312  		if err != nil {
   313  			return nil, err
   314  		}
   315  		resCode := res.StatusCode
   316  		is1xx := 100 <= resCode && resCode <= 199
   317  		// treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
   318  		is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
   319  		if is1xxNonTerminal {
   320  			num1xx++
   321  			if num1xx > max1xxResponses {
   322  				return nil, errors.New("http: too many 1xx informational responses")
   323  			}
   324  			if trace != nil && trace.Got1xxResponse != nil {
   325  				if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil {
   326  					return nil, err
   327  				}
   328  			}
   329  			continue
   330  		}
   331  		break
   332  	}
   333  	connState := c.hconn.ConnectionState().TLS
   334  	res.TLS = &connState
   335  	res.Request = req
   336  	return res, nil
   337  }