github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"sync"
    11  
    12  	quic "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go"
    13  
    14  	"golang.org/x/net/http/httpguts"
    15  )
    16  
    17  type roundTripCloser interface {
    18  	http.RoundTripper
    19  	io.Closer
    20  }
    21  
    22  // RoundTripper implements the http.RoundTripper interface
    23  type RoundTripper struct {
    24  	mutex sync.Mutex
    25  
    26  	// DisableCompression, if true, prevents the Transport from
    27  	// requesting compression with an "Accept-Encoding: gzip"
    28  	// request header when the Request contains no existing
    29  	// Accept-Encoding value. If the Transport requests gzip on
    30  	// its own and gets a gzipped response, it's transparently
    31  	// decoded in the Response.Body. However, if the user
    32  	// explicitly requested gzip it is not automatically
    33  	// uncompressed.
    34  	DisableCompression bool
    35  
    36  	// TLSClientConfig specifies the TLS configuration to use with
    37  	// tls.Client. If nil, the default configuration is used.
    38  	TLSClientConfig *tls.Config
    39  
    40  	// QuicConfig is the quic.Config used for dialing new connections.
    41  	// If nil, reasonable default values will be used.
    42  	QuicConfig *quic.Config
    43  
    44  	// Enable support for HTTP/3 datagrams.
    45  	// If set to true, QuicConfig.EnableDatagram will be set.
    46  	// See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html.
    47  	EnableDatagrams bool
    48  
    49  	// Dial specifies an optional dial function for creating QUIC
    50  	// connections for requests.
    51  	// If Dial is nil, quic.DialAddrEarly will be used.
    52  	Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
    53  
    54  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    55  	// allowed in the server's response header.
    56  	// Zero means to use a default limit.
    57  	MaxResponseHeaderBytes int64
    58  
    59  	clients map[string]roundTripCloser
    60  }
    61  
    62  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    63  type RoundTripOpt struct {
    64  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    65  	// If set true and no cached connection is available, RoundTrip will return ErrNoCachedConn.
    66  	OnlyCachedConn bool
    67  	// SkipSchemeCheck controls whether we check if the scheme is https.
    68  	// This allows the use of different schemes, e.g. masque://target.example.com:443/.
    69  	SkipSchemeCheck bool
    70  }
    71  
    72  var _ roundTripCloser = &RoundTripper{}
    73  
    74  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
    75  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
    76  
    77  // RoundTripOpt is like RoundTrip, but takes options.
    78  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
    79  	if req.URL == nil {
    80  		closeRequestBody(req)
    81  		return nil, errors.New("http3: nil Request.URL")
    82  	}
    83  	if req.URL.Host == "" {
    84  		closeRequestBody(req)
    85  		return nil, errors.New("http3: no Host in request URL")
    86  	}
    87  	if req.Header == nil {
    88  		closeRequestBody(req)
    89  		return nil, errors.New("http3: nil Request.Header")
    90  	}
    91  
    92  	if req.URL.Scheme == "https" {
    93  		for k, vv := range req.Header {
    94  			if !httpguts.ValidHeaderFieldName(k) {
    95  				return nil, fmt.Errorf("http3: invalid http header field name %q", k)
    96  			}
    97  			for _, v := range vv {
    98  				if !httpguts.ValidHeaderFieldValue(v) {
    99  					return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   100  				}
   101  			}
   102  		}
   103  	} else if !opt.SkipSchemeCheck {
   104  		closeRequestBody(req)
   105  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   106  	}
   107  
   108  	if req.Method != "" && !validMethod(req.Method) {
   109  		closeRequestBody(req)
   110  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   111  	}
   112  
   113  	hostname := authorityAddr("https", hostnameFromRequest(req))
   114  	cl, err := r.getClient(hostname, opt.OnlyCachedConn)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return cl.RoundTrip(req)
   119  }
   120  
   121  // RoundTrip does a round trip.
   122  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   123  	return r.RoundTripOpt(req, RoundTripOpt{})
   124  }
   125  
   126  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
   127  	r.mutex.Lock()
   128  	defer r.mutex.Unlock()
   129  
   130  	if r.clients == nil {
   131  		r.clients = make(map[string]roundTripCloser)
   132  	}
   133  
   134  	client, ok := r.clients[hostname]
   135  	if !ok {
   136  		if onlyCached {
   137  			return nil, ErrNoCachedConn
   138  		}
   139  		var err error
   140  		client, err = newClient(
   141  			hostname,
   142  			r.TLSClientConfig,
   143  			&roundTripperOpts{
   144  				EnableDatagram:     r.EnableDatagrams,
   145  				DisableCompression: r.DisableCompression,
   146  				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
   147  			},
   148  			r.QuicConfig,
   149  			r.Dial,
   150  		)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  		r.clients[hostname] = client
   155  	}
   156  	return client, nil
   157  }
   158  
   159  // Close closes the QUIC connections that this RoundTripper has used
   160  func (r *RoundTripper) Close() error {
   161  	r.mutex.Lock()
   162  	defer r.mutex.Unlock()
   163  	for _, client := range r.clients {
   164  		if err := client.Close(); err != nil {
   165  			return err
   166  		}
   167  	}
   168  	r.clients = nil
   169  	return nil
   170  }
   171  
   172  func closeRequestBody(req *http.Request) {
   173  	if req.Body != nil {
   174  		req.Body.Close()
   175  	}
   176  }
   177  
   178  func validMethod(method string) bool {
   179  	/*
   180  				     Method         = "OPTIONS"                ; Section 9.2
   181  		   		                    | "GET"                    ; Section 9.3
   182  		   		                    | "HEAD"                   ; Section 9.4
   183  		   		                    | "POST"                   ; Section 9.5
   184  		   		                    | "PUT"                    ; Section 9.6
   185  		   		                    | "DELETE"                 ; Section 9.7
   186  		   		                    | "TRACE"                  ; Section 9.8
   187  		   		                    | "CONNECT"                ; Section 9.9
   188  		   		                    | extension-method
   189  		   		   extension-method = token
   190  		   		     token          = 1*<any CHAR except CTLs or separators>
   191  	*/
   192  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   193  }
   194  
   195  // copied from net/http/http.go
   196  func isNotToken(r rune) bool {
   197  	return !httpguts.IsTokenRune(r)
   198  }