github.com/sagernet/quic-go@v0.43.1-beta.1/http3_ech/client.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptrace"
    10  	"net/textproto"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/quic-go/qpack"
    15  	"github.com/sagernet/quic-go/ech"
    16  	"github.com/sagernet/quic-go/internal/protocol"
    17  	"github.com/sagernet/quic-go/quicvarint"
    18  	"golang.org/x/exp/slog"
    19  )
    20  
    21  const (
    22  	// MethodGet0RTT allows a GET request to be sent using 0-RTT.
    23  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    24  	MethodGet0RTT = "GET_0RTT"
    25  	// MethodHead0RTT allows a HEAD request to be sent using 0-RTT.
    26  	// Note that 0-RTT doesn't provide replay protection and should only be used for idempotent requests.
    27  	MethodHead0RTT = "HEAD_0RTT"
    28  )
    29  
    30  const (
    31  	defaultUserAgent              = "quic-go HTTP/3"
    32  	defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
    33  )
    34  
    35  var defaultQuicConfig = &quic.Config{
    36  	MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
    37  	KeepAlivePeriod:    10 * time.Second,
    38  }
    39  
    40  // SingleDestinationRoundTripper is an HTTP/3 client doing requests to a single remote server.
    41  type SingleDestinationRoundTripper struct {
    42  	Connection quic.Connection
    43  
    44  	// Enable support for HTTP/3 datagrams (RFC 9297).
    45  	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    46  	EnableDatagrams bool
    47  
    48  	// Additional HTTP/3 settings.
    49  	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
    50  	AdditionalSettings map[uint64]uint64
    51  	StreamHijacker     func(FrameType, quic.ConnectionTracingID, quic.Stream, error) (hijacked bool, err error)
    52  	UniStreamHijacker  func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)
    53  
    54  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    55  	// allowed in the server's response header.
    56  	// Zero means to use a default limit.
    57  	MaxResponseHeaderBytes int64
    58  
    59  	// DisableCompression, if true, prevents the Transport from requesting compression with an
    60  	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
    61  	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
    62  	// decoded in the Response.Body.
    63  	// However, if the user explicitly requested gzip it is not automatically uncompressed.
    64  	DisableCompression bool
    65  
    66  	Logger *slog.Logger
    67  
    68  	initOnce      sync.Once
    69  	hconn         *connection
    70  	requestWriter *requestWriter
    71  	decoder       *qpack.Decoder
    72  }
    73  
    74  var _ http.RoundTripper = &SingleDestinationRoundTripper{}
    75  
    76  func (c *SingleDestinationRoundTripper) Start() Connection {
    77  	c.initOnce.Do(func() { c.init() })
    78  	return c.hconn
    79  }
    80  
    81  func (c *SingleDestinationRoundTripper) init() {
    82  	c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
    83  	c.requestWriter = newRequestWriter()
    84  	c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
    85  	// send the SETTINGs frame, using 0-RTT data, if possible
    86  	go func() {
    87  		if err := c.setupConn(c.hconn); err != nil {
    88  			if c.Logger != nil {
    89  				c.Logger.Debug("Setting up connection failed", "error", err)
    90  			}
    91  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeInternalError), "")
    92  		}
    93  	}()
    94  	if c.StreamHijacker != nil {
    95  		go c.handleBidirectionalStreams()
    96  	}
    97  	go c.hconn.HandleUnidirectionalStreams(c.UniStreamHijacker)
    98  }
    99  
   100  func (c *SingleDestinationRoundTripper) setupConn(conn *connection) error {
   101  	// open the control stream
   102  	str, err := conn.OpenUniStream()
   103  	if err != nil {
   104  		return err
   105  	}
   106  	b := make([]byte, 0, 64)
   107  	b = quicvarint.Append(b, streamTypeControlStream)
   108  	// send the SETTINGS frame
   109  	b = (&settingsFrame{Datagram: c.EnableDatagrams, Other: c.AdditionalSettings}).Append(b)
   110  	_, err = str.Write(b)
   111  	return err
   112  }
   113  
   114  func (c *SingleDestinationRoundTripper) handleBidirectionalStreams() {
   115  	for {
   116  		str, err := c.hconn.AcceptStream(context.Background())
   117  		if err != nil {
   118  			if c.Logger != nil {
   119  				c.Logger.Debug("accepting bidirectional stream failed", "error", err)
   120  			}
   121  			return
   122  		}
   123  		go func(str quic.Stream) {
   124  			_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
   125  				id := c.hconn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   126  				return c.StreamHijacker(ft, id, str, e)
   127  			})
   128  			if err == errHijacked {
   129  				return
   130  			}
   131  			if err != nil {
   132  				if c.Logger != nil {
   133  					c.Logger.Debug("error handling stream", "error", err)
   134  				}
   135  			}
   136  			c.hconn.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
   137  		}(str)
   138  	}
   139  }
   140  
   141  func (c *SingleDestinationRoundTripper) maxHeaderBytes() uint64 {
   142  	if c.MaxResponseHeaderBytes <= 0 {
   143  		return defaultMaxResponseHeaderBytes
   144  	}
   145  	return uint64(c.MaxResponseHeaderBytes)
   146  }
   147  
   148  // RoundTrip executes a request and returns a response
   149  func (c *SingleDestinationRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   150  	c.initOnce.Do(func() { c.init() })
   151  
   152  	rsp, err := c.roundTrip(req)
   153  	if err != nil && req.Context().Err() != nil {
   154  		// if the context was canceled, return the context cancellation error
   155  		err = req.Context().Err()
   156  	}
   157  	return rsp, err
   158  }
   159  
   160  func (c *SingleDestinationRoundTripper) roundTrip(req *http.Request) (*http.Response, error) {
   161  	// Immediately send out this request, if this is a 0-RTT request.
   162  	switch req.Method {
   163  	case MethodGet0RTT:
   164  		// don't modify the original request
   165  		reqCopy := *req
   166  		req = &reqCopy
   167  		req.Method = http.MethodGet
   168  	case MethodHead0RTT:
   169  		// don't modify the original request
   170  		reqCopy := *req
   171  		req = &reqCopy
   172  		req.Method = http.MethodHead
   173  	default:
   174  		// wait for the handshake to complete
   175  		earlyConn, ok := c.Connection.(quic.EarlyConnection)
   176  		if ok {
   177  			select {
   178  			case <-earlyConn.HandshakeComplete():
   179  			case <-req.Context().Done():
   180  				return nil, req.Context().Err()
   181  			}
   182  		}
   183  	}
   184  
   185  	// It is only possible to send an Extended CONNECT request once the SETTINGS were received.
   186  	// See section 3 of RFC 8441.
   187  	if isExtendedConnectRequest(req) {
   188  		connCtx := c.Connection.Context()
   189  		// wait for the server's SETTINGS frame to arrive
   190  		select {
   191  		case <-c.hconn.ReceivedSettings():
   192  		case <-connCtx.Done():
   193  			return nil, context.Cause(connCtx)
   194  		}
   195  		if !c.hconn.Settings().EnableExtendedConnect {
   196  			return nil, errors.New("http3: server didn't enable Extended CONNECT")
   197  		}
   198  	}
   199  
   200  	reqDone := make(chan struct{})
   201  	str, err := c.hconn.openRequestStream(req.Context(), c.requestWriter, reqDone, c.DisableCompression, c.maxHeaderBytes())
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	// Request Cancellation:
   207  	// This go routine keeps running even after RoundTripOpt() returns.
   208  	// It is shut down when the application is done processing the body.
   209  	done := make(chan struct{})
   210  	go func() {
   211  		defer close(done)
   212  		select {
   213  		case <-req.Context().Done():
   214  			str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   215  			str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   216  		case <-reqDone:
   217  		}
   218  	}()
   219  
   220  	rsp, err := c.doRequest(req, str)
   221  	if err != nil { // if any error occurred
   222  		close(reqDone)
   223  		<-done
   224  		return nil, maybeReplaceError(err)
   225  	}
   226  	return rsp, maybeReplaceError(err)
   227  }
   228  
   229  func (c *SingleDestinationRoundTripper) OpenRequestStream(ctx context.Context) (RequestStream, error) {
   230  	c.initOnce.Do(func() { c.init() })
   231  
   232  	return c.hconn.openRequestStream(ctx, c.requestWriter, nil, c.DisableCompression, c.maxHeaderBytes())
   233  }
   234  
   235  // cancelingReader reads from the io.Reader.
   236  // It cancels writing on the stream if any error other than io.EOF occurs.
   237  type cancelingReader struct {
   238  	r   io.Reader
   239  	str Stream
   240  }
   241  
   242  func (r *cancelingReader) Read(b []byte) (int, error) {
   243  	n, err := r.r.Read(b)
   244  	if err != nil && err != io.EOF {
   245  		r.str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   246  	}
   247  	return n, err
   248  }
   249  
   250  func (c *SingleDestinationRoundTripper) sendRequestBody(str Stream, body io.ReadCloser, contentLength int64) error {
   251  	defer body.Close()
   252  	buf := make([]byte, bodyCopyBufferSize)
   253  	sr := &cancelingReader{str: str, r: body}
   254  	if contentLength == -1 {
   255  		_, err := io.CopyBuffer(str, sr, buf)
   256  		return err
   257  	}
   258  
   259  	// make sure we don't send more bytes than the content length
   260  	n, err := io.CopyBuffer(str, io.LimitReader(sr, contentLength), buf)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	var extra int64
   265  	extra, err = io.CopyBuffer(io.Discard, sr, buf)
   266  	n += extra
   267  	if n > contentLength {
   268  		str.CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled))
   269  		return fmt.Errorf("http: ContentLength=%d with Body length %d", contentLength, n)
   270  	}
   271  	return err
   272  }
   273  
   274  func (c *SingleDestinationRoundTripper) doRequest(req *http.Request, str *requestStream) (*http.Response, error) {
   275  	if err := str.SendRequestHeader(req); err != nil {
   276  		return nil, err
   277  	}
   278  	if req.Body == nil {
   279  		str.Close()
   280  	} else {
   281  		// send the request body asynchronously
   282  		go func() {
   283  			contentLength := int64(-1)
   284  			// According to the documentation for http.Request.ContentLength,
   285  			// a value of 0 with a non-nil Body is also treated as unknown content length.
   286  			if req.ContentLength > 0 {
   287  				contentLength = req.ContentLength
   288  			}
   289  			if err := c.sendRequestBody(str, req.Body, contentLength); err != nil {
   290  				if c.Logger != nil {
   291  					c.Logger.Debug("error writing request", "error", err)
   292  				}
   293  			}
   294  			str.Close()
   295  		}()
   296  	}
   297  
   298  	// copy from net/http: support 1xx responses
   299  	trace := httptrace.ContextClientTrace(req.Context())
   300  	num1xx := 0               // number of informational 1xx headers received
   301  	const max1xxResponses = 5 // arbitrary bound on number of informational responses
   302  
   303  	var res *http.Response
   304  	for {
   305  		var err error
   306  		res, err = str.ReadResponse()
   307  		if err != nil {
   308  			return nil, err
   309  		}
   310  		resCode := res.StatusCode
   311  		is1xx := 100 <= resCode && resCode <= 199
   312  		// treat 101 as a terminal status, see https://github.com/golang/go/issues/26161
   313  		is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
   314  		if is1xxNonTerminal {
   315  			num1xx++
   316  			if num1xx > max1xxResponses {
   317  				return nil, errors.New("http: too many 1xx informational responses")
   318  			}
   319  			if trace != nil && trace.Got1xxResponse != nil {
   320  				if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(res.Header)); err != nil {
   321  					return nil, err
   322  				}
   323  			}
   324  			continue
   325  		}
   326  		break
   327  	}
   328  	connState := c.hconn.ConnectionState().TLS
   329  	res.TLS = &connState
   330  	res.Request = req
   331  	return res, nil
   332  }