github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/client.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log/slog"
     9  	"net/http"
    10  	"net/http/httptrace"
    11  	"net/textproto"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/apernet/quic-go"
    16  	"github.com/apernet/quic-go/internal/protocol"
    17  	"github.com/apernet/quic-go/quicvarint"
    18  
    19  	"github.com/quic-go/qpack"
    20  )
    21  
    22  const (
    23  	// MethodGet0RTT allows a GET request to be sent using 0-RTT.
    24  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    25  	MethodGet0RTT = "GET_0RTT"
    26  	// MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
    27  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    28  	MethodHead0RTT = "HEAD_0RTT"
    29  )
    30  
    31  const (
    32  	defaultUserAgent              = "quic-go HTTP/3"
    33  	defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
    34  )
    35  
    36  var defaultQuicConfig = &quic.Config{
    37  	MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
    38  	KeepAlivePeriod:    10 * time.Second,
    39  }
    40  
    41  // SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
    42  type SingleDestinationRoundTripper struct {
    43  	Connection quic.Connection
    44  
    45  	// Enable support for HTTP/3 datagrams (RFC 9297).
    46  	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    47  	EnableDatagrams bool
    48  
    49  	// Additional HTTP/3 settings.
    50  	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
    51  	AdditionalSettings map[uint64]uint64
    52  	StreamHijacker     func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
    53  	UniStreamHijacker  func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
    54  
    55  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    56  	// allowed in the server's response header.
    57  	// Zero means to use a default limit.
    58  	MaxResponseHeaderBytes int64
    59  
    60  	// DisableCompression, if true, prevents the Transport from requesting compression with an
    61  	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
    62  	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
    63  	// decoded in the Response.Body.
    64  	// However, if the user explicitly requested gzip it is not automatically uncompressed.
    65  	DisableCompression bool
    66  
    67  	Logger *slog.Logger
    68  
    69  	initOnce      sync.Once
    70  	hconn         *connection
    71  	requestWriter *requestWriter
    72  	decoder       *qpack.Decoder
    73  }
    74  
    75  var _ http.RoundTripper = &SingleDestinationRoundTripper{}
    76  
    77  func (c *SingleDestinationRoundTripper) Start() Connection {
    78  	c.initOnce.Do(func() { c.init() })
    79  	return c.hconn
    80  }
    81  
    82  func (c *SingleDestinationRoundTripper) init() {
    83  	c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
    84  	c.requestWriter = newRequestWriter()
    85  	c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
    86  	// send the SETTINGs frame, using 0-RTT data, if possible
    87  	go func() {
    88  		if err := c.setupConn(c.hconn); err != nil {
    89  			if c.Logger != nil {
    90  				c.Logger.Debug("Setting up connection failed", "error", err)
    91  			}
    92  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
    93  		}
    94  	}()
    95  	if c.StreamHijacker != nil {
    96  		go c.handleBidirectionalStreams()
    97  	}
    98  	go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
    99  }
   100  
   101  func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
   102  	// open the control stream
   103  	str, err := conn.OpenUniStream()
   104  	if err != nil {
   105  		return err
   106  	}
   107  	b := make([]byte, 0, 64)
   108  	b = quicvarint.Append(b, streamTypeControlStream)
   109  	// send the SETTINGS frame
   110  	b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
   111  	_, err = str.Write(b)
   112  	return err
   113  }
   114  
   115  func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
   116  	for {
   117  		str, err := c.hconn.AcceptStream(context.Background())
   118  		if err != nil {
   119  			if c.Logger != nil {
   120  				c.Logger.Debug("accepting bidirectional stream failed", "error", err)
   121  			}
   122  			return
   123  		}
   124  		go func(str quic.Stream) {
   125  			_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
   126  				id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   127  				return c.StreamHijacker(ft, id, str, e)
   128  			})
   129  			if err == errHijacked {
   130  				return
   131  			}
   132  			if err != nil {
   133  				if c.Logger != nil {
   134  					c.Logger.Debug("error handling stream", "error", err)
   135  				}
   136  			}
   137  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
   138  		}(str)
   139  	}
   140  }
   141  
   142  func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
   143  	if c.MaxResponseHeaderBytes <= 0 {
   144  		return defaultMaxResponseHeaderBytes
   145  	}
   146  	return uint64(c.MaxResponseHeaderBytes)
   147  }
   148  
   149  // RoundTrip executes a request and returns a response
   150  func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   151  	c.initOnce.Do(func() { c.init() })
   152  
   153  	rsp, err := c.roundTrip(req)
   154  	if err != nil && req.Context().Err() != nil {
   155  		// if the context was canceled, return the context cancellation error
   156  		err = req.Context().Err()
   157  	}
   158  	return rsp, err
   159  }
   160  
   161  func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
   162  	// Immediately send out this request, if this is a 0-RTT request.
   163  	switch req.Method {
   164  	case MethodGet0RTT:
   165  		// don't modify the original request
   166  		reqCopy := *req
   167  		req = &reqCopy
   168  		req.Method = http.MethodGet
   169  	case MethodHead0RTT:
   170  		// don't modify the original request
   171  		reqCopy := *req
   172  		req = &reqCopy
   173  		req.Method = http.MethodHead
   174  	default:
   175  		// wait for the handshake to complete
   176  		earlyConn, ok := c.Connection.(quic.EarlyConnection)
   177  		if ok {
   178  			select {
   179  			case <-earlyConn.HandshakeComplete():
   180  			case <-req.Context().Done():
   181  				return nil, req.Context().Err()
   182  			}
   183  		}
   184  	}
   185  
   186  	// It is only possible to send an Extended CONNECT request once the SETTINGS were received.
   187  	// See section 3 of RFC 8441.
   188  	if isExtendedConnectRequest(req) {
   189  		connCtx := c.Connection.Context()
   190  		// wait for the server's SETTINGS frame to arrive
   191  		select {
   192  		case <-c.hconn.ReceivedSettings():
   193  		case <-connCtx.Done():
   194  			return nil, context.Cause(connCtx)
   195  		}
   196  		if !c.hconn.Settings().EnableExtendedConnect {
   197  			return nil, errors.New("http3: server didn't enable Extended CONNECT")
   198  		}
   199  	}
   200  
   201  	reqDone := make(chan struct{})
   202  	str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	// Request Cancellation:
   208  	// This go routine keeps running even after RoundTripOpt() returns.
   209  	// It is shut down when the application is done processing the body.
   210  	done := make(chan struct{})
   211  	go func() {
   212  		defer close(done)
   213  		select {
   214  		case <-req.Context().Done():
   215  			str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   216  			str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   217  		case <-reqDone:
   218  		}
   219  	}()
   220  
   221  	rsp, err := c.doRequest(req, str)
   222  	if err != nil { // if any error occurred
   223  		close(reqDone)
   224  		<-done
   225  		return nil, maybeReplaceError(err)
   226  	}
   227  	return rsp, maybeReplaceError(err)
   228  }
   229  
   230  func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
   231  	c.initOnce.Do(func() { c.init() })
   232  
   233  	return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
   234  }
   235  
   236  // cancelingReader reads from the io.Reader.
   237  // It cancels writing on the stream if any error other than io.EOF occurs.
   238  type cancelingReader struct {
   239  	r   io.Reader
   240  	str Stream
   241  }
   242  
   243  func (r *cancelingReader) Read(b []byte) (int, error) {
   244  	n, err := r.r.Read(b)
   245  	if err != nil && err != io.EOF {
   246  		r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   247  	}
   248  	return n, err
   249  }
   250  
   251  func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
   252  	defer body.Close()
   253  	buf := make([]byte, bodyCopyBufferSize)
   254  	sr := &cancelingReader{str: str, r: body}
   255  	if contentLength == -1 {
   256  		_, err := io.CopyBuffer(str, sr, buf)
   257  		return err
   258  	}
   259  
   260  	// make sure we don't send more bytes than the content length
   261  	n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
   262  	if err != nil {
   263  		return err
   264  	}
   265  	var extra int64
   266  	extra, err = io.CopyBuffer(io.Discard, sr, buf)
   267  	n += extra
   268  	if n > contentLength {
   269  		str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   270  		return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
   271  	}
   272  	return err
   273  }
   274  
   275  func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
   276  	if err := str.SendRequestHeader(req); err != nil {
   277  		return nil, err
   278  	}
   279  	if req.Body == nil {
   280  		str.Close()
   281  	} else {
   282  		// send the request body asynchronously
   283  		go func() {
   284  			contentLength := int64(-1)
   285  			// According to the documentation for http.Request.ContentLength,
   286  			// a value of 0 with a non-nil Body is also treated as unknown content length.
   287  			if req.ContentLength > 0 {
   288  				contentLength = req.ContentLength
   289  			}
   290  			if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
   291  				if c.Logger != nil {
   292  					c.Logger.Debug("error writing request", "error", err)
   293  				}
   294  			}
   295  			str.Close()
   296  		}()
   297  	}
   298  
   299  	// copy from net/http: support 1xx responses
   300  	trace := httptrace.ContextClientTrace(req.Context())
   301  	num1xx := 0               // number of informational 1xx headers received
   302  	const max1xxResponses = 5 // arbitrary bound on number of informational responses
   303  
   304  	var res *http.Response
   305  	for {
   306  		var err error
   307  		res, err = str.ReadResponse()
   308  		if err != nil {
   309  			return nil, err
   310  		}
   311  		resCode := res.StatusCode
   312  		is1xx := 100 <= resCode && resCode <= 199
   313  		// treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
   314  		is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
   315  		if is1xxNonTerminal {
   316  			num1xx++
   317  			if num1xx > max1xxResponses {
   318  				return nil, errors.New("http: too many 1xx informational responses")
   319  			}
   320  			if trace != nil && trace.Got1xxResponse != nil {
   321  				if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil {
   322  					return nil, err
   323  				}
   324  			}
   325  			continue
   326  		}
   327  		break
   328  	}
   329  	connState := c.hconn.ConnectionState().TLS
   330  	res.TLS = &connState
   331  	res.Request = req
   332  	return res, nil
   333  }