github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  
    15  	"golang.org/x/net/http/httpguts"
    16  
    17  	"github.com/danielpfeifer02/quic-go-prio-packs"
    18  )
    19  
    20  type roundTripCloser interface {
    21  	RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
    22  	HandshakeComplete() bool
    23  	io.Closer
    24  }
    25  
    26  type roundTripCloserWithCount struct {
    27  	roundTripCloser
    28  	useCount atomic.Int64
    29  }
    30  
    31  // RoundTripper implements the http.RoundTripper interface
    32  type RoundTripper struct {
    33  	mutex sync.Mutex
    34  
    35  	// DisableCompression, if true, prevents the Transport from
    36  	// requesting compression with an "Accept-Encoding: gzip"
    37  	// request header when the Request contains no existing
    38  	// Accept-Encoding value. If the Transport requests gzip on
    39  	// its own and gets a gzipped response, it's transparently
    40  	// decoded in the Response.Body. However, if the user
    41  	// explicitly requested gzip it is not automatically
    42  	// uncompressed.
    43  	DisableCompression bool
    44  
    45  	// TLSClientConfig specifies the TLS configuration to use with
    46  	// tls.Client. If nil, the default configuration is used.
    47  	TLSClientConfig *tls.Config
    48  
    49  	// QuicConfig is the quic.Config used for dialing new connections.
    50  	// If nil, reasonable default values will be used.
    51  	QuicConfig *quic.Config
    52  
    53  	// Enable support for HTTP/3 datagrams (RFC 9297).
    54  	// If a QuicConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    55  	EnableDatagrams bool
    56  
    57  	// Additional HTTP/3 settings.
    58  	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
    59  	AdditionalSettings map[uint64]uint64
    60  
    61  	// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
    62  	// It is called right after parsing the frame type.
    63  	// If parsing the frame type fails, the error is passed to the callback.
    64  	// In that case, the frame type will not be set.
    65  	// Callers can either ignore the frame and return control of the stream back to HTTP/3
    66  	// (by returning hijacked false).
    67  	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
    68  	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
    69  
    70  	// When set, this callback is called for unknown unidirectional stream of unknown stream type.
    71  	// If parsing the stream type fails, the error is passed to the callback.
    72  	// In that case, the stream type will not be set.
    73  	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
    74  
    75  	// Dial specifies an optional dial function for creating QUIC
    76  	// connections for requests.
    77  	// If Dial is nil, a UDPConn will be created at the first request
    78  	// and will be reused for subsequent connections to other servers.
    79  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    80  
    81  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    82  	// allowed in the server's response header.
    83  	// Zero means to use a default limit.
    84  	MaxResponseHeaderBytes int64
    85  
    86  	newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
    87  	clients   map[string]*roundTripCloserWithCount
    88  	transport *quic.Transport
    89  }
    90  
    91  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    92  type RoundTripOpt struct {
    93  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    94  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    95  	OnlyCachedConn bool
    96  	// DontCloseRequestStream controls whether the request stream is closed after sending the request.
    97  	// If set, context cancellations have no effect after the response headers are received.
    98  	DontCloseRequestStream bool
    99  }
   100  
   101  var (
   102  	_ http.RoundTripper = &RoundTripper{}
   103  	_ io.Closer         = &RoundTripper{}
   104  )
   105  
   106  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   107  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   108  
   109  // RoundTripOpt is like RoundTrip, but takes options.
   110  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   111  	if req.URL == nil {
   112  		closeRequestBody(req)
   113  		return nil, errors.New("http3: nil Request.URL")
   114  	}
   115  	if req.URL.Scheme != "https" {
   116  		closeRequestBody(req)
   117  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   118  	}
   119  	if req.URL.Host == "" {
   120  		closeRequestBody(req)
   121  		return nil, errors.New("http3: no Host in request URL")
   122  	}
   123  	if req.Header == nil {
   124  		closeRequestBody(req)
   125  		return nil, errors.New("http3: nil Request.Header")
   126  	}
   127  	for k, vv := range req.Header {
   128  		if !httpguts.ValidHeaderFieldName(k) {
   129  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   130  		}
   131  		for _, v := range vv {
   132  			if !httpguts.ValidHeaderFieldValue(v) {
   133  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   134  			}
   135  		}
   136  	}
   137  
   138  	if req.Method != "" && !validMethod(req.Method) {
   139  		closeRequestBody(req)
   140  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   141  	}
   142  
   143  	hostname := authorityAddr("https", hostnameFromRequest(req))
   144  	cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	defer cl.useCount.Add(-1)
   149  	rsp, err := cl.RoundTripOpt(req, opt)
   150  	if err != nil {
   151  		r.removeClient(hostname)
   152  		if isReused {
   153  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   154  				return r.RoundTripOpt(req, opt)
   155  			}
   156  		}
   157  	}
   158  	return rsp, err
   159  }
   160  
   161  // RoundTrip does a round trip.
   162  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   163  	return r.RoundTripOpt(req, RoundTripOpt{})
   164  }
   165  
   166  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
   167  	r.mutex.Lock()
   168  	defer r.mutex.Unlock()
   169  
   170  	if r.clients == nil {
   171  		r.clients = make(map[string]*roundTripCloserWithCount)
   172  	}
   173  
   174  	client, ok := r.clients[hostname]
   175  	if !ok {
   176  		if onlyCached {
   177  			return nil, false, ErrNoCachedConn
   178  		}
   179  		var err error
   180  		newCl := newClient
   181  		if r.newClient != nil {
   182  			newCl = r.newClient
   183  		}
   184  		dial := r.Dial
   185  		if dial == nil {
   186  			if r.transport == nil {
   187  				udpConn, err := net.ListenUDP("udp", nil)
   188  				if err != nil {
   189  					return nil, false, err
   190  				}
   191  				r.transport = &quic.Transport{Conn: udpConn}
   192  			}
   193  			dial = r.makeDialer()
   194  		}
   195  		c, err := newCl(
   196  			hostname,
   197  			r.TLSClientConfig,
   198  			&roundTripperOpts{
   199  				EnableDatagram:     r.EnableDatagrams,
   200  				DisableCompression: r.DisableCompression,
   201  				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
   202  				StreamHijacker:     r.StreamHijacker,
   203  				UniStreamHijacker:  r.UniStreamHijacker,
   204  				AdditionalSettings: r.AdditionalSettings,
   205  			},
   206  			r.QuicConfig,
   207  			dial,
   208  		)
   209  		if err != nil {
   210  			return nil, false, err
   211  		}
   212  		client = &roundTripCloserWithCount{roundTripCloser: c}
   213  		r.clients[hostname] = client
   214  	} else if client.HandshakeComplete() {
   215  		isReused = true
   216  	}
   217  	client.useCount.Add(1)
   218  	return client, isReused, nil
   219  }
   220  
   221  func (r *RoundTripper) removeClient(hostname string) {
   222  	r.mutex.Lock()
   223  	defer r.mutex.Unlock()
   224  	if r.clients == nil {
   225  		return
   226  	}
   227  	delete(r.clients, hostname)
   228  }
   229  
   230  // Close closes the QUIC connections that this RoundTripper has used.
   231  // It also closes the underlying UDPConn if it is not nil.
   232  func (r *RoundTripper) Close() error {
   233  	r.mutex.Lock()
   234  	defer r.mutex.Unlock()
   235  	for _, client := range r.clients {
   236  		if err := client.Close(); err != nil {
   237  			return err
   238  		}
   239  	}
   240  	r.clients = nil
   241  	if r.transport != nil {
   242  		if err := r.transport.Close(); err != nil {
   243  			return err
   244  		}
   245  		if err := r.transport.Conn.Close(); err != nil {
   246  			return err
   247  		}
   248  		r.transport = nil
   249  	}
   250  	return nil
   251  }
   252  
   253  func closeRequestBody(req *http.Request) {
   254  	if req.Body != nil {
   255  		req.Body.Close()
   256  	}
   257  }
   258  
   259  func validMethod(method string) bool {
   260  	/*
   261  				     Method         = "OPTIONS"                ; Section 9.2
   262  		   		                    | "GET"                    ; Section 9.3
   263  		   		                    | "HEAD"                   ; Section 9.4
   264  		   		                    | "POST"                   ; Section 9.5
   265  		   		                    | "PUT"                    ; Section 9.6
   266  		   		                    | "DELETE"                 ; Section 9.7
   267  		   		                    | "TRACE"                  ; Section 9.8
   268  		   		                    | "CONNECT"                ; Section 9.9
   269  		   		                    | extension-method
   270  		   		   extension-method = token
   271  		   		     token          = 1*<any CHAR except CTLs or separators>
   272  	*/
   273  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   274  }
   275  
   276  // copied from net/http/http.go
   277  func isNotToken(r rune) bool {
   278  	return !httpguts.IsTokenRune(r)
   279  }
   280  
   281  // makeDialer makes a QUIC dialer using r.udpConn.
   282  func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   283  	return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   284  		udpAddr, err := net.ResolveUDPAddr("udp", addr)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  		return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   289  	}
   290  }
   291  
   292  func (r *RoundTripper) CloseIdleConnections() {
   293  	r.mutex.Lock()
   294  	defer r.mutex.Unlock()
   295  	for hostname, client := range r.clients {
   296  		if client.useCount.Load() == 0 {
   297  			client.Close()
   298  			delete(r.clients, hostname)
   299  		}
   300  	}
   301  }