github.com/tumi8/quic-go@v0.37.4-tum/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/tumi8/quic-go"
    15  
    16  	"sync/atomic"
    17  
    18  
    19  	"golang.org/x/net/http/httpguts"
    20  )
    21  
    22  type roundTripCloser interface {
    23  	RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
    24  	HandshakeComplete() bool
    25  	io.Closer
    26  }
    27  
    28  type roundTripCloserWithCount struct {
    29  	roundTripCloser
    30  	useCount atomic.Int64
    31  }
    32  
    33  // RoundTripper implements the http.RoundTripper interface
    34  type RoundTripper struct {
    35  	mutex sync.Mutex
    36  
    37  	// DisableCompression, if true, prevents the Transport from
    38  	// requesting compression with an "Accept-Encoding: gzip"
    39  	// request header when the Request contains no existing
    40  	// Accept-Encoding value. If the Transport requests gzip on
    41  	// its own and gets a gzipped response, it's transparently
    42  	// decoded in the Response.Body. However, if the user
    43  	// explicitly requested gzip it is not automatically
    44  	// uncompressed.
    45  	DisableCompression bool
    46  
    47  	// TLSClientConfig specifies the TLS configuration to use with
    48  	// tls.Client. If nil, the default configuration is used.
    49  	TLSClientConfig *tls.Config
    50  
    51  	// QuicConfig is the quic.Config used for dialing new connections.
    52  	// If nil, reasonable default values will be used.
    53  	QuicConfig *quic.Config
    54  
    55  	// Enable support for HTTP/3 datagrams.
    56  	// If set to true, QuicConfig.EnableDatagram will be set.
    57  	// See https://www.ietf.org/archive/id/draft-schinazi-masque-h3-datagram-02.html.
    58  	EnableDatagrams bool
    59  
    60  	// Additional HTTP/3 settings.
    61  	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
    62  	AdditionalSettings map[uint64]uint64
    63  
    64  	// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
    65  	// It is called right after parsing the frame type.
    66  	// If parsing the frame type fails, the error is passed to the callback.
    67  	// In that case, the frame type will not be set.
    68  	// Callers can either ignore the frame and return control of the stream back to HTTP/3
    69  	// (by returning hijacked false).
    70  	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
    71  	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
    72  
    73  	// When set, this callback is called for unknown unidirectional stream of unknown stream type.
    74  	// If parsing the stream type fails, the error is passed to the callback.
    75  	// In that case, the stream type will not be set.
    76  	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
    77  
    78  	// Dial specifies an optional dial function for creating QUIC
    79  	// connections for requests.
    80  	// If Dial is nil, a UDPConn will be created at the first request
    81  	// and will be reused for subsequent connections to other servers.
    82  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    83  
    84  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    85  	// allowed in the server's response header.
    86  	// Zero means to use a default limit.
    87  	MaxResponseHeaderBytes int64
    88  
    89  	newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
    90  	clients   map[string]*roundTripCloserWithCount
    91  	transport *quic.Transport
    92  }
    93  
    94  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    95  type RoundTripOpt struct {
    96  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    97  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    98  	OnlyCachedConn bool
    99  	// DontCloseRequestStream controls whether the request stream is closed after sending the request.
   100  	// If set, context cancellations have no effect after the response headers are received.
   101  	DontCloseRequestStream bool
   102  }
   103  
   104  var (
   105  	_ http.RoundTripper = &RoundTripper{}
   106  	_ io.Closer         = &RoundTripper{}
   107  )
   108  
   109  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   110  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   111  
   112  // RoundTripOpt is like RoundTrip, but takes options.
   113  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   114  	if req.URL == nil {
   115  		closeRequestBody(req)
   116  		return nil, errors.New("http3: nil Request.URL")
   117  	}
   118  	if req.URL.Scheme != "https" {
   119  		closeRequestBody(req)
   120  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   121  	}
   122  	if req.URL.Host == "" {
   123  		closeRequestBody(req)
   124  		return nil, errors.New("http3: no Host in request URL")
   125  	}
   126  	if req.Header == nil {
   127  		closeRequestBody(req)
   128  		return nil, errors.New("http3: nil Request.Header")
   129  	}
   130  	for k, vv := range req.Header {
   131  		if !httpguts.ValidHeaderFieldName(k) {
   132  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   133  		}
   134  		for _, v := range vv {
   135  			if !httpguts.ValidHeaderFieldValue(v) {
   136  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   137  			}
   138  		}
   139  	}
   140  
   141  	if req.Method != "" && !validMethod(req.Method) {
   142  		closeRequestBody(req)
   143  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   144  	}
   145  
   146  	hostname := authorityAddr("https", hostnameFromRequest(req))
   147  	cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	defer cl.useCount.Add(-1)
   152  	rsp, err := cl.RoundTripOpt(req, opt)
   153  	if err != nil {
   154  		r.removeClient(hostname)
   155  		if isReused {
   156  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   157  				return r.RoundTripOpt(req, opt)
   158  			}
   159  		}
   160  	}
   161  	return rsp, err
   162  }
   163  
   164  // RoundTrip does a round trip.
   165  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   166  	return r.RoundTripOpt(req, RoundTripOpt{})
   167  }
   168  
   169  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
   170  	r.mutex.Lock()
   171  	defer r.mutex.Unlock()
   172  
   173  	if r.clients == nil {
   174  		r.clients = make(map[string]*roundTripCloserWithCount)
   175  	}
   176  
   177  	client, ok := r.clients[hostname]
   178  	if !ok {
   179  		if onlyCached {
   180  			return nil, false, ErrNoCachedConn
   181  		}
   182  		var err error
   183  		newCl := newClient
   184  		if r.newClient != nil {
   185  			newCl = r.newClient
   186  		}
   187  		dial := r.Dial
   188  		if dial == nil {
   189  			if r.transport == nil {
   190  				udpConn, err := net.ListenUDP("udp", nil)
   191  				if err != nil {
   192  					return nil, false, err
   193  				}
   194  				r.transport = &quic.Transport{Conn: udpConn}
   195  			}
   196  			dial = r.makeDialer()
   197  		}
   198  		c, err := newCl(
   199  			hostname,
   200  			r.TLSClientConfig,
   201  			&roundTripperOpts{
   202  				EnableDatagram:     r.EnableDatagrams,
   203  				DisableCompression: r.DisableCompression,
   204  				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
   205  				StreamHijacker:     r.StreamHijacker,
   206  				UniStreamHijacker:  r.UniStreamHijacker,
   207  			},
   208  			r.QuicConfig,
   209  			dial,
   210  		)
   211  		if err != nil {
   212  			return nil, false, err
   213  		}
   214  		client = &roundTripCloserWithCount{roundTripCloser: c}
   215  		r.clients[hostname] = client
   216  	} else if client.HandshakeComplete() {
   217  		isReused = true
   218  	}
   219  	client.useCount.Add(1)
   220  	return client, isReused, nil
   221  }
   222  
   223  func (r *RoundTripper) removeClient(hostname string) {
   224  	r.mutex.Lock()
   225  	defer r.mutex.Unlock()
   226  	if r.clients == nil {
   227  		return
   228  	}
   229  	delete(r.clients, hostname)
   230  }
   231  
   232  // Close closes the QUIC connections that this RoundTripper has used.
   233  // It also closes the underlying UDPConn if it is not nil.
   234  func (r *RoundTripper) Close() error {
   235  	r.mutex.Lock()
   236  	defer r.mutex.Unlock()
   237  	for _, client := range r.clients {
   238  		if err := client.Close(); err != nil {
   239  			return err
   240  		}
   241  	}
   242  	r.clients = nil
   243  	if r.transport != nil {
   244  		if err := r.transport.Close(); err != nil {
   245  			return err
   246  		}
   247  		if err := r.transport.Conn.Close(); err != nil {
   248  			return err
   249  		}
   250  		r.transport = nil
   251  	}
   252  	return nil
   253  }
   254  
   255  func closeRequestBody(req *http.Request) {
   256  	if req.Body != nil {
   257  		req.Body.Close()
   258  	}
   259  }
   260  
   261  func validMethod(method string) bool {
   262  	/*
   263  				     Method         = "OPTIONS"                ; Section 9.2
   264  		   		                    | "GET"                    ; Section 9.3
   265  		   		                    | "HEAD"                   ; Section 9.4
   266  		   		                    | "POST"                   ; Section 9.5
   267  		   		                    | "PUT"                    ; Section 9.6
   268  		   		                    | "DELETE"                 ; Section 9.7
   269  		   		                    | "TRACE"                  ; Section 9.8
   270  		   		                    | "CONNECT"                ; Section 9.9
   271  		   		                    | extension-method
   272  		   		   extension-method = token
   273  		   		     token          = 1*<any CHAR except CTLs or separators>
   274  	*/
   275  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   276  }
   277  
   278  // copied from net/http/http.go
   279  func isNotToken(r rune) bool {
   280  	return !httpguts.IsTokenRune(r)
   281  }
   282  
   283  // makeDialer makes a QUIC dialer using r.udpConn.
   284  func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   285  	return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   286  		udpAddr, err := net.ResolveUDPAddr("udp", addr)
   287  		if err != nil {
   288  			return nil, err
   289  		}
   290  		return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   291  	}
   292  }
   293  
   294  func (r *RoundTripper) CloseIdleConnections() {
   295  	r.mutex.Lock()
   296  	defer r.mutex.Unlock()
   297  	for hostname, client := range r.clients {
   298  		if client.useCount.Load() == 0 {
   299  			client.Close()
   300  			delete(r.clients, hostname)
   301  		}
   302  	}
   303  }