github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/http3/client.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"strconv"
    12  	"sync"
    13  
    14  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go"
    15  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    16  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qtls"
    17  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    18  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/quicvarint"
    19  	"github.com/marten-seemann/qpack"
    20  )
    21  
    22  // MethodGet0RTT allows a GET request to be sent using 0-RTT.
    23  // Note that 0-RTT data doesn't provide replay protection.
    24  const MethodGet0RTT = "GET_0RTT"
    25  
    26  const (
    27  	defaultUserAgent              = "quic-go HTTP/3"
    28  	defaultMaxResponseHeaderBytes = 10 * 1 << 20 // 10 MB
    29  )
    30  
    31  var defaultQuicConfig = &quic.Config{
    32  	MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams
    33  	KeepAlive:          true,
    34  	Versions:           []protocol.VersionNumber{protocol.VersionTLS},
    35  }
    36  
    37  var dialAddr = quic.DialAddrEarly
    38  
    39  type roundTripperOpts struct {
    40  	DisableCompression bool
    41  	EnableDatagram     bool
    42  	MaxHeaderBytes     int64
    43  }
    44  
    45  // client is a HTTP3 client doing requests
    46  type client struct {
    47  	tlsConf *tls.Config
    48  	config  *quic.Config
    49  	opts    *roundTripperOpts
    50  
    51  	dialOnce     sync.Once
    52  	dialer       func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
    53  	handshakeErr error
    54  
    55  	requestWriter *requestWriter
    56  
    57  	decoder *qpack.Decoder
    58  
    59  	hostname string
    60  
    61  	// [Psiphon]
    62  	// Enable Close to be called concurrently with dial.
    63  	sessionMutex sync.Mutex
    64  	closed       bool
    65  	session      quic.EarlySession
    66  
    67  	logger utils.Logger
    68  }
    69  
    70  func newClient(
    71  	hostname string,
    72  	tlsConf *tls.Config,
    73  	opts *roundTripperOpts,
    74  	quicConfig *quic.Config,
    75  	dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
    76  ) (*client, error) {
    77  	if quicConfig == nil {
    78  		quicConfig = defaultQuicConfig.Clone()
    79  	} else if len(quicConfig.Versions) == 0 {
    80  		quicConfig = quicConfig.Clone()
    81  		quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
    82  	}
    83  	if len(quicConfig.Versions) != 1 {
    84  		return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
    85  	}
    86  
    87  	quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
    88  	quicConfig.EnableDatagrams = opts.EnableDatagram
    89  
    90  	logger := utils.DefaultLogger.WithPrefix("h3 client")
    91  
    92  	if tlsConf == nil {
    93  		tlsConf = &tls.Config{}
    94  	} else {
    95  		tlsConf = tlsConf.Clone()
    96  	}
    97  	// Replace existing ALPNs by H3
    98  	tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
    99  
   100  	return &client{
   101  		hostname:      authorityAddr("https", hostname),
   102  		tlsConf:       tlsConf,
   103  		requestWriter: newRequestWriter(logger),
   104  		decoder:       qpack.NewDecoder(func(hf qpack.HeaderField) {}),
   105  		config:        quicConfig,
   106  		opts:          opts,
   107  		dialer:        dialer,
   108  		logger:        logger,
   109  	}, nil
   110  }
   111  
   112  func (c *client) dial() error {
   113  	var err error
   114  	var session quic.EarlySession
   115  	if c.dialer != nil {
   116  		session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
   117  	} else {
   118  		session, err = dialAddr(c.hostname, c.tlsConf, c.config)
   119  	}
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	// [Psiphon]
   125  	c.sessionMutex.Lock()
   126  	if c.closed {
   127  		session.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
   128  		err = errors.New("closed while dialing")
   129  	} else {
   130  		c.session = session
   131  	}
   132  	c.sessionMutex.Unlock()
   133  
   134  	// send the SETTINGs frame, using 0-RTT data, if possible
   135  	go func() {
   136  		if err := c.setupSession(); err != nil {
   137  			c.logger.Debugf("Setting up session failed: %s", err)
   138  			c.session.CloseWithError(quic.ApplicationErrorCode(errorInternalError), "")
   139  		}
   140  	}()
   141  
   142  	go c.handleUnidirectionalStreams()
   143  	return nil
   144  }
   145  
   146  func (c *client) setupSession() error {
   147  	// open the control stream
   148  	str, err := c.session.OpenUniStream()
   149  	if err != nil {
   150  		return err
   151  	}
   152  	buf := &bytes.Buffer{}
   153  	quicvarint.Write(buf, streamTypeControlStream)
   154  	// send the SETTINGS frame
   155  	(&settingsFrame{Datagram: c.opts.EnableDatagram}).Write(buf)
   156  	_, err = str.Write(buf.Bytes())
   157  	return err
   158  }
   159  
   160  func (c *client) handleUnidirectionalStreams() {
   161  	for {
   162  		str, err := c.session.AcceptUniStream(context.Background())
   163  		if err != nil {
   164  			c.logger.Debugf("accepting unidirectional stream failed: %s", err)
   165  			return
   166  		}
   167  
   168  		go func() {
   169  			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
   170  			if err != nil {
   171  				c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
   172  				return
   173  			}
   174  			// We're only interested in the control stream here.
   175  			switch streamType {
   176  			case streamTypeControlStream:
   177  			case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream:
   178  				// Our QPACK implementation doesn't use the dynamic table yet.
   179  				// TODO: check that only one stream of each type is opened.
   180  				return
   181  			case streamTypePushStream:
   182  				// We never increased the Push ID, so we don't expect any push streams.
   183  				c.session.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
   184  				return
   185  			default:
   186  				str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
   187  				return
   188  			}
   189  			f, err := parseNextFrame(str)
   190  			if err != nil {
   191  				c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "")
   192  				return
   193  			}
   194  			sf, ok := f.(*settingsFrame)
   195  			if !ok {
   196  				c.session.CloseWithError(quic.ApplicationErrorCode(errorMissingSettings), "")
   197  				return
   198  			}
   199  			if !sf.Datagram {
   200  				return
   201  			}
   202  			// If datagram support was enabled on our side as well as on the server side,
   203  			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
   204  			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
   205  			if c.opts.EnableDatagram && !c.session.ConnectionState().SupportsDatagrams {
   206  				c.session.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
   207  			}
   208  		}()
   209  	}
   210  }
   211  
   212  func (c *client) Close() error {
   213  
   214  	// [Psiphon]
   215  	c.sessionMutex.Lock()
   216  	session := c.session
   217  	c.closed = true
   218  	c.sessionMutex.Unlock()
   219  
   220  	if session == nil {
   221  		return nil
   222  	}
   223  	return session.CloseWithError(quic.ApplicationErrorCode(errorNoError), "")
   224  }
   225  
   226  func (c *client) maxHeaderBytes() uint64 {
   227  	if c.opts.MaxHeaderBytes <= 0 {
   228  		return defaultMaxResponseHeaderBytes
   229  	}
   230  	return uint64(c.opts.MaxHeaderBytes)
   231  }
   232  
   233  // RoundTrip executes a request and returns a response
   234  func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
   235  	if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
   236  		return nil, fmt.Errorf("http3 client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
   237  	}
   238  
   239  	c.dialOnce.Do(func() {
   240  		c.handshakeErr = c.dial()
   241  	})
   242  
   243  	if c.handshakeErr != nil {
   244  		return nil, c.handshakeErr
   245  	}
   246  
   247  	// Immediately send out this request, if this is a 0-RTT request.
   248  	if req.Method == MethodGet0RTT {
   249  		req.Method = http.MethodGet
   250  	} else {
   251  		// wait for the handshake to complete
   252  		select {
   253  		case <-c.session.HandshakeComplete().Done():
   254  		case <-req.Context().Done():
   255  			return nil, req.Context().Err()
   256  		}
   257  	}
   258  
   259  	str, err := c.session.OpenStreamSync(req.Context())
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	// Request Cancellation:
   265  	// This go routine keeps running even after RoundTrip() returns.
   266  	// It is shut down when the application is done processing the body.
   267  	reqDone := make(chan struct{})
   268  	go func() {
   269  		select {
   270  		case <-req.Context().Done():
   271  			str.CancelWrite(quic.StreamErrorCode(errorRequestCanceled))
   272  			str.CancelRead(quic.StreamErrorCode(errorRequestCanceled))
   273  		case <-reqDone:
   274  		}
   275  	}()
   276  
   277  	rsp, rerr := c.doRequest(req, str, reqDone)
   278  	if rerr.err != nil { // if any error occurred
   279  		close(reqDone)
   280  		if rerr.streamErr != 0 { // if it was a stream error
   281  			str.CancelWrite(quic.StreamErrorCode(rerr.streamErr))
   282  		}
   283  		if rerr.connErr != 0 { // if it was a connection error
   284  			var reason string
   285  			if rerr.err != nil {
   286  				reason = rerr.err.Error()
   287  			}
   288  			c.session.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
   289  		}
   290  	}
   291  	return rsp, rerr.err
   292  }
   293  
   294  func (c *client) doRequest(
   295  	req *http.Request,
   296  	str quic.Stream,
   297  	reqDone chan struct{},
   298  ) (*http.Response, requestError) {
   299  	var requestGzip bool
   300  	if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
   301  		requestGzip = true
   302  	}
   303  	if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
   304  		return nil, newStreamError(errorInternalError, err)
   305  	}
   306  
   307  	frame, err := parseNextFrame(str)
   308  	if err != nil {
   309  		return nil, newStreamError(errorFrameError, err)
   310  	}
   311  	hf, ok := frame.(*headersFrame)
   312  	if !ok {
   313  		return nil, newConnError(errorFrameUnexpected, errors.New("expected first frame to be a HEADERS frame"))
   314  	}
   315  	if hf.Length > c.maxHeaderBytes() {
   316  		return nil, newStreamError(errorFrameError, fmt.Errorf("HEADERS frame too large: %d bytes (max: %d)", hf.Length, c.maxHeaderBytes()))
   317  	}
   318  	headerBlock := make([]byte, hf.Length)
   319  	if _, err := io.ReadFull(str, headerBlock); err != nil {
   320  		return nil, newStreamError(errorRequestIncomplete, err)
   321  	}
   322  	hfs, err := c.decoder.DecodeFull(headerBlock)
   323  	if err != nil {
   324  		// TODO: use the right error code
   325  		return nil, newConnError(errorGeneralProtocolError, err)
   326  	}
   327  
   328  	connState := qtls.ToTLSConnectionState(c.session.ConnectionState().TLS)
   329  	res := &http.Response{
   330  		Proto:      "HTTP/3",
   331  		ProtoMajor: 3,
   332  		Header:     http.Header{},
   333  		TLS:        &connState,
   334  	}
   335  	for _, hf := range hfs {
   336  		switch hf.Name {
   337  		case ":status":
   338  			status, err := strconv.Atoi(hf.Value)
   339  			if err != nil {
   340  				return nil, newStreamError(errorGeneralProtocolError, errors.New("malformed non-numeric status pseudo header"))
   341  			}
   342  			res.StatusCode = status
   343  			res.Status = hf.Value + " " + http.StatusText(status)
   344  		default:
   345  			res.Header.Add(hf.Name, hf.Value)
   346  		}
   347  	}
   348  	respBody := newResponseBody(str, reqDone, func() {
   349  		c.session.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "")
   350  	})
   351  
   352  	// Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2.
   353  	_, hasTransferEncoding := res.Header["Transfer-Encoding"]
   354  	isInformational := res.StatusCode >= 100 && res.StatusCode < 200
   355  	isNoContent := res.StatusCode == 204
   356  	isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300
   357  	if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect {
   358  		res.ContentLength = -1
   359  		if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 {
   360  			if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
   361  				res.ContentLength = clen64
   362  			}
   363  		}
   364  	}
   365  
   366  	if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
   367  		res.Header.Del("Content-Encoding")
   368  		res.Header.Del("Content-Length")
   369  		res.ContentLength = -1
   370  		res.Body = newGzipReader(respBody)
   371  		res.Uncompressed = true
   372  	} else {
   373  		res.Body = respBody
   374  	}
   375  
   376  	return res, requestError{}
   377  }