github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  
    15  	"golang.org/x/net/http/httpguts"
    16  
    17  	"github.com/daeuniverse/quic-go"
    18  )
    19  
    20  // Settings are HTTP/3 settings that apply to the underlying connection.
    21  type Settings struct {
    22  	// Support for HTTP/3 datagrams (RFC 9297)
    23  	EnableDatagram bool
    24  	// Extended CONNECT, RFC 9220
    25  	EnableExtendedConnect bool
    26  	// Other settings, defined by the application
    27  	Other map[uint64]uint64
    28  }
    29  
    30  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    31  type RoundTripOpt struct {
    32  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    33  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    34  	OnlyCachedConn bool
    35  	// DontCloseRequestStream controls whether the request stream is closed after sending the request.
    36  	// If set, context cancellations have no effect after the response headers are received.
    37  	DontCloseRequestStream bool
    38  	// CheckSettings is run before the request is sent to the server.
    39  	// If not yet received, it blocks until the server's SETTINGS frame is received.
    40  	// If an error is returned, the request won't be sent to the server, and the error is returned.
    41  	CheckSettings func(Settings) error
    42  }
    43  
    44  type roundTripCloser interface {
    45  	RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
    46  	HandshakeComplete() bool
    47  	io.Closer
    48  }
    49  
    50  type roundTripCloserWithCount struct {
    51  	roundTripCloser
    52  	useCount atomic.Int64
    53  }
    54  
    55  // RoundTripper implements the http.RoundTripper interface
    56  type RoundTripper struct {
    57  	mutex sync.Mutex
    58  
    59  	// DisableCompression, if true, prevents the Transport from
    60  	// requesting compression with an "Accept-Encoding: gzip"
    61  	// request header when the Request contains no existing
    62  	// Accept-Encoding value. If the Transport requests gzip on
    63  	// its own and gets a gzipped response, it's transparently
    64  	// decoded in the Response.Body. However, if the user
    65  	// explicitly requested gzip it is not automatically
    66  	// uncompressed.
    67  	DisableCompression bool
    68  
    69  	// TLSClientConfig specifies the TLS configuration to use with
    70  	// tls.Client. If nil, the default configuration is used.
    71  	TLSClientConfig *tls.Config
    72  
    73  	// QuicConfig is the quic.Config used for dialing new connections.
    74  	// If nil, reasonable default values will be used.
    75  	QuicConfig *quic.Config
    76  
    77  	// Enable support for HTTP/3 datagrams (RFC 9297).
    78  	// If a QuicConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    79  	EnableDatagrams bool
    80  
    81  	// Additional HTTP/3 settings.
    82  	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
    83  	AdditionalSettings map[uint64]uint64
    84  
    85  	// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
    86  	// It is called right after parsing the frame type.
    87  	// If parsing the frame type fails, the error is passed to the callback.
    88  	// In that case, the frame type will not be set.
    89  	// Callers can either ignore the frame and return control of the stream back to HTTP/3
    90  	// (by returning hijacked false).
    91  	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
    92  	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
    93  
    94  	// When set, this callback is called for unknown unidirectional stream of unknown stream type.
    95  	// If parsing the stream type fails, the error is passed to the callback.
    96  	// In that case, the stream type will not be set.
    97  	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
    98  
    99  	// Dial specifies an optional dial function for creating QUIC
   100  	// connections for requests.
   101  	// If Dial is nil, a UDPConn will be created at the first request
   102  	// and will be reused for subsequent connections to other servers.
   103  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
   104  
   105  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
   106  	// allowed in the server's response header.
   107  	// Zero means to use a default limit.
   108  	MaxResponseHeaderBytes int64
   109  
   110  	newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
   111  	clients   map[string]*roundTripCloserWithCount
   112  	transport *quic.Transport
   113  }
   114  
   115  var (
   116  	_ http.RoundTripper = &RoundTripper{}
   117  	_ io.Closer         = &RoundTripper{}
   118  )
   119  
   120  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   121  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   122  
   123  // RoundTripOpt is like RoundTrip, but takes options.
   124  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   125  	if req.URL == nil {
   126  		closeRequestBody(req)
   127  		return nil, errors.New("http3: nil Request.URL")
   128  	}
   129  	if req.URL.Scheme != "https" {
   130  		closeRequestBody(req)
   131  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   132  	}
   133  	if req.URL.Host == "" {
   134  		closeRequestBody(req)
   135  		return nil, errors.New("http3: no Host in request URL")
   136  	}
   137  	if req.Header == nil {
   138  		closeRequestBody(req)
   139  		return nil, errors.New("http3: nil Request.Header")
   140  	}
   141  	for k, vv := range req.Header {
   142  		if !httpguts.ValidHeaderFieldName(k) {
   143  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   144  		}
   145  		for _, v := range vv {
   146  			if !httpguts.ValidHeaderFieldValue(v) {
   147  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   148  			}
   149  		}
   150  	}
   151  
   152  	if req.Method != "" && !validMethod(req.Method) {
   153  		closeRequestBody(req)
   154  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   155  	}
   156  
   157  	hostname := authorityAddr("https", hostnameFromRequest(req))
   158  	cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	defer cl.useCount.Add(-1)
   163  	rsp, err := cl.RoundTripOpt(req, opt)
   164  	if err != nil {
   165  		r.removeClient(hostname)
   166  		if isReused {
   167  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   168  				return r.RoundTripOpt(req, opt)
   169  			}
   170  		}
   171  	}
   172  	return rsp, err
   173  }
   174  
   175  // RoundTrip does a round trip.
   176  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   177  	return r.RoundTripOpt(req, RoundTripOpt{})
   178  }
   179  
   180  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
   181  	r.mutex.Lock()
   182  	defer r.mutex.Unlock()
   183  
   184  	if r.clients == nil {
   185  		r.clients = make(map[string]*roundTripCloserWithCount)
   186  	}
   187  
   188  	client, ok := r.clients[hostname]
   189  	if !ok {
   190  		if onlyCached {
   191  			return nil, false, ErrNoCachedConn
   192  		}
   193  		var err error
   194  		newCl := newClient
   195  		if r.newClient != nil {
   196  			newCl = r.newClient
   197  		}
   198  		dial := r.Dial
   199  		if dial == nil {
   200  			if r.transport == nil {
   201  				udpConn, err := net.ListenUDP("udp", nil)
   202  				if err != nil {
   203  					return nil, false, err
   204  				}
   205  				r.transport = &quic.Transport{Conn: udpConn}
   206  			}
   207  			dial = r.makeDialer()
   208  		}
   209  		c, err := newCl(
   210  			hostname,
   211  			r.TLSClientConfig,
   212  			&roundTripperOpts{
   213  				EnableDatagram:     r.EnableDatagrams,
   214  				DisableCompression: r.DisableCompression,
   215  				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
   216  				StreamHijacker:     r.StreamHijacker,
   217  				UniStreamHijacker:  r.UniStreamHijacker,
   218  				AdditionalSettings: r.AdditionalSettings,
   219  			},
   220  			r.QuicConfig,
   221  			dial,
   222  		)
   223  		if err != nil {
   224  			return nil, false, err
   225  		}
   226  		client = &roundTripCloserWithCount{roundTripCloser: c}
   227  		r.clients[hostname] = client
   228  	} else if client.HandshakeComplete() {
   229  		isReused = true
   230  	}
   231  	client.useCount.Add(1)
   232  	return client, isReused, nil
   233  }
   234  
   235  func (r *RoundTripper) removeClient(hostname string) {
   236  	r.mutex.Lock()
   237  	defer r.mutex.Unlock()
   238  	if r.clients == nil {
   239  		return
   240  	}
   241  	delete(r.clients, hostname)
   242  }
   243  
   244  // Close closes the QUIC connections that this RoundTripper has used.
   245  // It also closes the underlying UDPConn if it is not nil.
   246  func (r *RoundTripper) Close() error {
   247  	r.mutex.Lock()
   248  	defer r.mutex.Unlock()
   249  	for _, client := range r.clients {
   250  		if err := client.Close(); err != nil {
   251  			return err
   252  		}
   253  	}
   254  	r.clients = nil
   255  	if r.transport != nil {
   256  		if err := r.transport.Close(); err != nil {
   257  			return err
   258  		}
   259  		if err := r.transport.Conn.Close(); err != nil {
   260  			return err
   261  		}
   262  		r.transport = nil
   263  	}
   264  	return nil
   265  }
   266  
   267  func closeRequestBody(req *http.Request) {
   268  	if req.Body != nil {
   269  		req.Body.Close()
   270  	}
   271  }
   272  
   273  func validMethod(method string) bool {
   274  	/*
   275  				     Method         = "OPTIONS"                ; Section 9.2
   276  		   		                    | "GET"                    ; Section 9.3
   277  		   		                    | "HEAD"                   ; Section 9.4
   278  		   		                    | "POST"                   ; Section 9.5
   279  		   		                    | "PUT"                    ; Section 9.6
   280  		   		                    | "DELETE"                 ; Section 9.7
   281  		   		                    | "TRACE"                  ; Section 9.8
   282  		   		                    | "CONNECT"                ; Section 9.9
   283  		   		                    | extension-method
   284  		   		   extension-method = token
   285  		   		     token          = 1*<any CHAR except CTLs or separators>
   286  	*/
   287  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   288  }
   289  
   290  // copied from net/http/http.go
   291  func isNotToken(r rune) bool {
   292  	return !httpguts.IsTokenRune(r)
   293  }
   294  
   295  // makeDialer makes a QUIC dialer using r.udpConn.
   296  func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   297  	return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   298  		udpAddr, err := net.ResolveUDPAddr("udp", addr)
   299  		if err != nil {
   300  			return nil, err
   301  		}
   302  		return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   303  	}
   304  }
   305  
   306  func (r *RoundTripper) CloseIdleConnections() {
   307  	r.mutex.Lock()
   308  	defer r.mutex.Unlock()
   309  	for hostname, client := range r.clients {
   310  		if client.useCount.Load() == 0 {
   311  			client.Close()
   312  			delete(r.clients, hostname)
   313  		}
   314  	}
   315  }