github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/h2quic/roundtrip.go (about)

     1  package h2quic
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"sync"
    11  
    12  	quic "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go"
    13  
    14  	"golang.org/x/net/http/httpguts"
    15  )
    16  
    17  type roundTripCloser interface {
    18  	http.RoundTripper
    19  	io.Closer
    20  }
    21  
    22  // RoundTripper implements the http.RoundTripper interface
    23  type RoundTripper struct {
    24  	mutex sync.Mutex
    25  
    26  	// DisableCompression, if true, prevents the Transport from
    27  	// requesting compression with an "Accept-Encoding: gzip"
    28  	// request header when the Request contains no existing
    29  	// Accept-Encoding value. If the Transport requests gzip on
    30  	// its own and gets a gzipped response, it's transparently
    31  	// decoded in the Response.Body. However, if the user
    32  	// explicitly requested gzip it is not automatically
    33  	// uncompressed.
    34  	DisableCompression bool
    35  
    36  	// TLSClientConfig specifies the TLS configuration to use with
    37  	// tls.Client. If nil, the default configuration is used.
    38  	TLSClientConfig *tls.Config
    39  
    40  	// QuicConfig is the quic.Config used for dialing new connections.
    41  	// If nil, reasonable default values will be used.
    42  	QuicConfig *quic.Config
    43  
    44  	// Dial specifies an optional dial function for creating QUIC
    45  	// connections for requests.
    46  	// If Dial is nil, quic.DialAddr will be used.
    47  	Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
    48  
    49  	clients map[string]roundTripCloser
    50  }
    51  
    52  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    53  type RoundTripOpt struct {
    54  	// OnlyCachedConn controls whether the RoundTripper may
    55  	// create a new QUIC connection. If set true and
    56  	// no cached connection is available, RoundTrip
    57  	// will return ErrNoCachedConn.
    58  	OnlyCachedConn bool
    59  }
    60  
    61  var _ roundTripCloser = &RoundTripper{}
    62  
    63  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
    64  var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
    65  
    66  // RoundTripOpt is like RoundTrip, but takes options.
    67  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
    68  	if req.URL == nil {
    69  		closeRequestBody(req)
    70  		return nil, errors.New("quic: nil Request.URL")
    71  	}
    72  	if req.URL.Host == "" {
    73  		closeRequestBody(req)
    74  		return nil, errors.New("quic: no Host in request URL")
    75  	}
    76  	if req.Header == nil {
    77  		closeRequestBody(req)
    78  		return nil, errors.New("quic: nil Request.Header")
    79  	}
    80  
    81  	if req.URL.Scheme == "https" {
    82  		for k, vv := range req.Header {
    83  			if !httpguts.ValidHeaderFieldName(k) {
    84  				return nil, fmt.Errorf("quic: invalid http header field name %q", k)
    85  			}
    86  			for _, v := range vv {
    87  				if !httpguts.ValidHeaderFieldValue(v) {
    88  					return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
    89  				}
    90  			}
    91  		}
    92  	} else {
    93  		closeRequestBody(req)
    94  		return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
    95  	}
    96  
    97  	if req.Method != "" && !validMethod(req.Method) {
    98  		closeRequestBody(req)
    99  		return nil, fmt.Errorf("quic: invalid method %q", req.Method)
   100  	}
   101  
   102  	hostname := authorityAddr("https", hostnameFromRequest(req))
   103  	cl, err := r.getClient(hostname, opt.OnlyCachedConn)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	return cl.RoundTrip(req)
   108  }
   109  
   110  // RoundTrip does a round trip.
   111  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   112  	return r.RoundTripOpt(req, RoundTripOpt{})
   113  }
   114  
   115  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
   116  	r.mutex.Lock()
   117  	defer r.mutex.Unlock()
   118  
   119  	if r.clients == nil {
   120  		r.clients = make(map[string]roundTripCloser)
   121  	}
   122  
   123  	client, ok := r.clients[hostname]
   124  	if !ok {
   125  		if onlyCached {
   126  			return nil, ErrNoCachedConn
   127  		}
   128  		client = newClient(
   129  			hostname,
   130  			r.TLSClientConfig,
   131  			&roundTripperOpts{DisableCompression: r.DisableCompression},
   132  			r.QuicConfig,
   133  			r.Dial,
   134  		)
   135  		r.clients[hostname] = client
   136  	}
   137  	return client, nil
   138  }
   139  
   140  // Close closes the QUIC connections that this RoundTripper has used
   141  func (r *RoundTripper) Close() error {
   142  	r.mutex.Lock()
   143  	defer r.mutex.Unlock()
   144  	for _, client := range r.clients {
   145  		if err := client.Close(); err != nil {
   146  			return err
   147  		}
   148  	}
   149  	r.clients = nil
   150  	return nil
   151  }
   152  
   153  func closeRequestBody(req *http.Request) {
   154  	if req.Body != nil {
   155  		req.Body.Close()
   156  	}
   157  }
   158  
   159  func validMethod(method string) bool {
   160  	/*
   161  				     Method         = "OPTIONS"                ; Section 9.2
   162  		   		                    | "GET"                    ; Section 9.3
   163  		   		                    | "HEAD"                   ; Section 9.4
   164  		   		                    | "POST"                   ; Section 9.5
   165  		   		                    | "PUT"                    ; Section 9.6
   166  		   		                    | "DELETE"                 ; Section 9.7
   167  		   		                    | "TRACE"                  ; Section 9.8
   168  		   		                    | "CONNECT"                ; Section 9.9
   169  		   		                    | extension-method
   170  		   		   extension-method = token
   171  		   		     token          = 1*<any CHAR except CTLs or separators>
   172  	*/
   173  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   174  }
   175  
   176  // copied from net/http/http.go
   177  func isNotToken(r rune) bool {
   178  	return !httpguts.IsTokenRune(r)
   179  }