github.com/MerlinKodo/quic-go@v0.39.2/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  
    15  	"golang.org/x/net/http/httpguts"
    16  
    17  	"github.com/MerlinKodo/quic-go"
    18  )
    19  
    20  type roundTripCloser interface {
    21  	RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error)
    22  	HandshakeComplete() bool
    23  	io.Closer
    24  }
    25  
    26  type roundTripCloserWithCount struct {
    27  	roundTripCloser
    28  	useCount atomic.Int64
    29  }
    30  
    31  // RoundTripper implements the http.RoundTripper interface
    32  type RoundTripper struct {
    33  	mutex sync.Mutex
    34  
    35  	// DisableCompression, if true, prevents the Transport from
    36  	// requesting compression with an "Accept-Encoding: gzip"
    37  	// request header when the Request contains no existing
    38  	// Accept-Encoding value. If the Transport requests gzip on
    39  	// its own and gets a gzipped response, it's transparently
    40  	// decoded in the Response.Body. However, if the user
    41  	// explicitly requested gzip it is not automatically
    42  	// uncompressed.
    43  	DisableCompression bool
    44  
    45  	// TLSClientConfig specifies the TLS configuration to use with
    46  	// tls.Client. If nil, the default configuration is used.
    47  	TLSClientConfig *tls.Config
    48  
    49  	// QuicConfig is the quic.Config used for dialing new connections.
    50  	// If nil, reasonable default values will be used.
    51  	QuicConfig *quic.Config
    52  
    53  	// Enable support for HTTP/3 datagrams.
    54  	// If set to true, QuicConfig.EnableDatagram will be set.
    55  	// See https://datatracker.ietf.org/doc/html/rfc9297.
    56  	EnableDatagrams bool
    57  
    58  	// Additional HTTP/3 settings.
    59  	// It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft.
    60  	AdditionalSettings map[uint64]uint64
    61  
    62  	// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
    63  	// It is called right after parsing the frame type.
    64  	// If parsing the frame type fails, the error is passed to the callback.
    65  	// In that case, the frame type will not be set.
    66  	// Callers can either ignore the frame and return control of the stream back to HTTP/3
    67  	// (by returning hijacked false).
    68  	// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
    69  	StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
    70  
    71  	// When set, this callback is called for unknown unidirectional stream of unknown stream type.
    72  	// If parsing the stream type fails, the error is passed to the callback.
    73  	// In that case, the stream type will not be set.
    74  	UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
    75  
    76  	// Dial specifies an optional dial function for creating QUIC
    77  	// connections for requests.
    78  	// If Dial is nil, a UDPConn will be created at the first request
    79  	// and will be reused for subsequent connections to other servers.
    80  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    81  
    82  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    83  	// allowed in the server's response header.
    84  	// Zero means to use a default limit.
    85  	MaxResponseHeaderBytes int64
    86  
    87  	newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
    88  	clients   map[string]*roundTripCloserWithCount
    89  	transport *quic.Transport
    90  }
    91  
    92  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    93  type RoundTripOpt struct {
    94  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    95  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    96  	OnlyCachedConn bool
    97  	// DontCloseRequestStream controls whether the request stream is closed after sending the request.
    98  	// If set, context cancellations have no effect after the response headers are received.
    99  	DontCloseRequestStream bool
   100  }
   101  
   102  var (
   103  	_ http.RoundTripper = &RoundTripper{}
   104  	_ io.Closer         = &RoundTripper{}
   105  )
   106  
   107  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   108  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   109  
   110  // RoundTripOpt is like RoundTrip, but takes options.
   111  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   112  	if req.URL == nil {
   113  		closeRequestBody(req)
   114  		return nil, errors.New("http3: nil Request.URL")
   115  	}
   116  	if req.URL.Scheme != "https" {
   117  		closeRequestBody(req)
   118  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   119  	}
   120  	if req.URL.Host == "" {
   121  		closeRequestBody(req)
   122  		return nil, errors.New("http3: no Host in request URL")
   123  	}
   124  	if req.Header == nil {
   125  		closeRequestBody(req)
   126  		return nil, errors.New("http3: nil Request.Header")
   127  	}
   128  	for k, vv := range req.Header {
   129  		if !httpguts.ValidHeaderFieldName(k) {
   130  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   131  		}
   132  		for _, v := range vv {
   133  			if !httpguts.ValidHeaderFieldValue(v) {
   134  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   135  			}
   136  		}
   137  	}
   138  
   139  	if req.Method != "" && !validMethod(req.Method) {
   140  		closeRequestBody(req)
   141  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   142  	}
   143  
   144  	hostname := authorityAddr("https", hostnameFromRequest(req))
   145  	cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	defer cl.useCount.Add(-1)
   150  	rsp, err := cl.RoundTripOpt(req, opt)
   151  	if err != nil {
   152  		r.removeClient(hostname)
   153  		if isReused {
   154  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   155  				return r.RoundTripOpt(req, opt)
   156  			}
   157  		}
   158  	}
   159  	return rsp, err
   160  }
   161  
   162  // RoundTrip does a round trip.
   163  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   164  	return r.RoundTripOpt(req, RoundTripOpt{})
   165  }
   166  
   167  func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
   168  	r.mutex.Lock()
   169  	defer r.mutex.Unlock()
   170  
   171  	if r.clients == nil {
   172  		r.clients = make(map[string]*roundTripCloserWithCount)
   173  	}
   174  
   175  	client, ok := r.clients[hostname]
   176  	if !ok {
   177  		if onlyCached {
   178  			return nil, false, ErrNoCachedConn
   179  		}
   180  		var err error
   181  		newCl := newClient
   182  		if r.newClient != nil {
   183  			newCl = r.newClient
   184  		}
   185  		dial := r.Dial
   186  		if dial == nil {
   187  			if r.transport == nil {
   188  				udpConn, err := net.ListenUDP("udp", nil)
   189  				if err != nil {
   190  					return nil, false, err
   191  				}
   192  				r.transport = &quic.Transport{Conn: udpConn}
   193  			}
   194  			dial = r.makeDialer()
   195  		}
   196  		c, err := newCl(
   197  			hostname,
   198  			r.TLSClientConfig,
   199  			&roundTripperOpts{
   200  				EnableDatagram:     r.EnableDatagrams,
   201  				DisableCompression: r.DisableCompression,
   202  				MaxHeaderBytes:     r.MaxResponseHeaderBytes,
   203  				StreamHijacker:     r.StreamHijacker,
   204  				UniStreamHijacker:  r.UniStreamHijacker,
   205  			},
   206  			r.QuicConfig,
   207  			dial,
   208  		)
   209  		if err != nil {
   210  			return nil, false, err
   211  		}
   212  		client = &roundTripCloserWithCount{roundTripCloser: c}
   213  		r.clients[hostname] = client
   214  	} else if client.HandshakeComplete() {
   215  		isReused = true
   216  	}
   217  	client.useCount.Add(1)
   218  	return client, isReused, nil
   219  }
   220  
   221  func (r *RoundTripper) removeClient(hostname string) {
   222  	r.mutex.Lock()
   223  	defer r.mutex.Unlock()
   224  	if r.clients == nil {
   225  		return
   226  	}
   227  	delete(r.clients, hostname)
   228  }
   229  
   230  // Close closes the QUIC connections that this RoundTripper has used.
   231  // It also closes the underlying UDPConn if it is not nil.
   232  func (r *RoundTripper) Close() error {
   233  	r.mutex.Lock()
   234  	defer r.mutex.Unlock()
   235  	for _, client := range r.clients {
   236  		if err := client.Close(); err != nil {
   237  			return err
   238  		}
   239  	}
   240  	r.clients = nil
   241  	if r.transport != nil {
   242  		if err := r.transport.Close(); err != nil {
   243  			return err
   244  		}
   245  		if err := r.transport.Conn.Close(); err != nil {
   246  			return err
   247  		}
   248  		r.transport = nil
   249  	}
   250  	return nil
   251  }
   252  
   253  func closeRequestBody(req *http.Request) {
   254  	if req.Body != nil {
   255  		req.Body.Close()
   256  	}
   257  }
   258  
   259  func validMethod(method string) bool {
   260  	/*
   261  				     Method         = "OPTIONS"                ; Section 9.2
   262  		   		                    | "GET"                    ; Section 9.3
   263  		   		                    | "HEAD"                   ; Section 9.4
   264  		   		                    | "POST"                   ; Section 9.5
   265  		   		                    | "PUT"                    ; Section 9.6
   266  		   		                    | "DELETE"                 ; Section 9.7
   267  		   		                    | "TRACE"                  ; Section 9.8
   268  		   		                    | "CONNECT"                ; Section 9.9
   269  		   		                    | extension-method
   270  		   		   extension-method = token
   271  		   		     token          = 1*<any CHAR except CTLs or separators>
   272  	*/
   273  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   274  }
   275  
   276  // copied from net/http/http.go
   277  func isNotToken(r rune) bool {
   278  	return !httpguts.IsTokenRune(r)
   279  }
   280  
   281  // makeDialer makes a QUIC dialer using r.udpConn.
   282  func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   283  	return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   284  		udpAddr, err := net.ResolveUDPAddr("udp", addr)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  		return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   289  	}
   290  }
   291  
   292  func (r *RoundTripper) CloseIdleConnections() {
   293  	r.mutex.Lock()
   294  	defer r.mutex.Unlock()
   295  	for hostname, client := range r.clients {
   296  		if client.useCount.Load() == 0 {
   297  			client.Close()
   298  			delete(r.clients, hostname)
   299  		}
   300  	}
   301  }