github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  
    15  	"golang.org/x/net/http/httpguts"
    16  
    17  	"github.com/metacubex/quic-go"
    18  	"github.com/metacubex/quic-go/internal/protocol"
    19  )
    20  
    21  // Settings are HTTP/3 settings that apply to the underlying connection.
    22  type Settings struct {
    23  	// Support for HTTP/3 datagrams (RFC 9297)
    24  	EnableDatagrams bool
    25  	// Extended CONNECT, RFC 9220
    26  	EnableExtendedConnect bool
    27  	// Other settings, defined by the application
    28  	Other map[uint64]uint64
    29  }
    30  
    31  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    32  type RoundTripOpt struct {
    33  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    34  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    35  	OnlyCachedConn bool
    36  }
    37  
    38  type singleRoundTripper interface {
    39  	OpenRequestStream(context.Context) (RequestStream, error)
    40  	RoundTrip(*http.Request) (*http.Response, error)
    41  }
    42  
    43  type roundTripperWithCount struct {
    44  	cancel  context.CancelFunc
    45  	dialing chan struct{} // closed as soon as quic.Dial(Early) returned
    46  	dialErr error
    47  	conn    quic.EarlyConnection
    48  	rt      singleRoundTripper
    49  
    50  	useCount atomic.Int64
    51  }
    52  
    53  func (r *roundTripperWithCount) Close() error {
    54  	r.cancel()
    55  	<-r.dialing
    56  	if r.conn != nil {
    57  		return r.conn.CloseWithError(0, "")
    58  	}
    59  	return nil
    60  }
    61  
    62  // RoundTripper implements the http.RoundTripper interface
    63  type RoundTripper struct {
    64  	mutex sync.Mutex
    65  
    66  	// TLSClientConfig specifies the TLS configuration to use with
    67  	// tls.Client. If nil, the default configuration is used.
    68  	TLSClientConfig *tls.Config
    69  
    70  	// QUICConfig is the quic.Config used for dialing new connections.
    71  	// If nil, reasonable default values will be used.
    72  	QUICConfig *quic.Config
    73  
    74  	// Dial specifies an optional dial function for creating QUIC
    75  	// connections for requests.
    76  	// If Dial is nil, a UDPConn will be created at the first request
    77  	// and will be reused for subsequent connections to other servers.
    78  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    79  
    80  	// Enable support for HTTP/3 datagrams (RFC 9297).
    81  	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    82  	EnableDatagrams bool
    83  
    84  	// Additional HTTP/3 settings.
    85  	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
    86  	AdditionalSettings map[uint64]uint64
    87  
    88  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    89  	// allowed in the server's response header.
    90  	// Zero means to use a default limit.
    91  	MaxResponseHeaderBytes int64
    92  
    93  	// DisableCompression, if true, prevents the Transport from requesting compression with an
    94  	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
    95  	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
    96  	// decoded in the Response.Body.
    97  	// However, if the user explicitly requested gzip it is not automatically uncompressed.
    98  	DisableCompression bool
    99  
   100  	initOnce sync.Once
   101  	initErr  error
   102  
   103  	newClient func(quic.EarlyConnection) singleRoundTripper
   104  
   105  	clients   map[string]*roundTripperWithCount
   106  	transport *quic.Transport
   107  }
   108  
   109  var (
   110  	_ http.RoundTripper = &RoundTripper{}
   111  	_ io.Closer         = &RoundTripper{}
   112  )
   113  
   114  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   115  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   116  
   117  // RoundTripOpt is like RoundTrip, but takes options.
   118  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   119  	r.initOnce.Do(func() { r.initErr = r.init() })
   120  	if r.initErr != nil {
   121  		return nil, r.initErr
   122  	}
   123  
   124  	if req.URL == nil {
   125  		closeRequestBody(req)
   126  		return nil, errors.New("http3: nil Request.URL")
   127  	}
   128  	if req.URL.Scheme != "https" {
   129  		closeRequestBody(req)
   130  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   131  	}
   132  	if req.URL.Host == "" {
   133  		closeRequestBody(req)
   134  		return nil, errors.New("http3: no Host in request URL")
   135  	}
   136  	if req.Header == nil {
   137  		closeRequestBody(req)
   138  		return nil, errors.New("http3: nil Request.Header")
   139  	}
   140  	for k, vv := range req.Header {
   141  		if !httpguts.ValidHeaderFieldName(k) {
   142  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   143  		}
   144  		for _, v := range vv {
   145  			if !httpguts.ValidHeaderFieldValue(v) {
   146  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   147  			}
   148  		}
   149  	}
   150  
   151  	if req.Method != "" && !validMethod(req.Method) {
   152  		closeRequestBody(req)
   153  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   154  	}
   155  
   156  	hostname := authorityAddr(hostnameFromURL(req.URL))
   157  	cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	select {
   163  	case <-cl.dialing:
   164  	case <-req.Context().Done():
   165  		return nil, context.Cause(req.Context())
   166  	}
   167  
   168  	if cl.dialErr != nil {
   169  		return nil, cl.dialErr
   170  	}
   171  	defer cl.useCount.Add(-1)
   172  	rsp, err := cl.rt.RoundTrip(req)
   173  	if err != nil {
   174  		// non-nil errors on roundtrip are likely due to a problem with the connection
   175  		// so we remove the client from the cache so that subsequent trips reconnect
   176  		// context cancelation is excluded as is does not signify a connection error
   177  		if !errors.Is(err, context.Canceled) {
   178  			r.removeClient(hostname)
   179  		}
   180  
   181  		if isReused {
   182  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   183  				return r.RoundTripOpt(req, opt)
   184  			}
   185  		}
   186  	}
   187  	return rsp, err
   188  }
   189  
   190  // RoundTrip does a round trip.
   191  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   192  	return r.RoundTripOpt(req, RoundTripOpt{})
   193  }
   194  
   195  func (r *RoundTripper) init() error {
   196  	if r.newClient == nil {
   197  		r.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
   198  			return &SingleDestinationRoundTripper{
   199  				Connection:             conn,
   200  				EnableDatagrams:        r.EnableDatagrams,
   201  				DisableCompression:     r.DisableCompression,
   202  				AdditionalSettings:     r.AdditionalSettings,
   203  				MaxResponseHeaderBytes: r.MaxResponseHeaderBytes,
   204  			}
   205  		}
   206  	}
   207  	if r.QUICConfig == nil {
   208  		r.QUICConfig = defaultQuicConfig.Clone()
   209  		r.QUICConfig.EnableDatagrams = r.EnableDatagrams
   210  	}
   211  	if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams {
   212  		return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
   213  	}
   214  	if len(r.QUICConfig.Versions) == 0 {
   215  		r.QUICConfig = r.QUICConfig.Clone()
   216  		r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
   217  	}
   218  	if len(r.QUICConfig.Versions) != 1 {
   219  		return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
   220  	}
   221  	if r.QUICConfig.MaxIncomingStreams == 0 {
   222  		r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
   223  	}
   224  	return nil
   225  }
   226  
   227  func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
   228  	r.mutex.Lock()
   229  	defer r.mutex.Unlock()
   230  
   231  	if r.clients == nil {
   232  		r.clients = make(map[string]*roundTripperWithCount)
   233  	}
   234  
   235  	cl, ok := r.clients[hostname]
   236  	if !ok {
   237  		if onlyCached {
   238  			return nil, false, ErrNoCachedConn
   239  		}
   240  		ctx, cancel := context.WithCancel(ctx)
   241  		cl = &roundTripperWithCount{
   242  			dialing: make(chan struct{}),
   243  			cancel:  cancel,
   244  		}
   245  		go func() {
   246  			defer close(cl.dialing)
   247  			defer cancel()
   248  			conn, rt, err := r.dial(ctx, hostname)
   249  			if err != nil {
   250  				cl.dialErr = err
   251  				return
   252  			}
   253  			cl.conn = conn
   254  			cl.rt = rt
   255  		}()
   256  		r.clients[hostname] = cl
   257  	}
   258  	select {
   259  	case <-cl.dialing:
   260  		if cl.dialErr != nil {
   261  			return nil, false, cl.dialErr
   262  		}
   263  		select {
   264  		case <-cl.conn.HandshakeComplete():
   265  			isReused = true
   266  		default:
   267  		}
   268  	default:
   269  	}
   270  	cl.useCount.Add(1)
   271  	return cl, isReused, nil
   272  }
   273  
   274  func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
   275  	var tlsConf *tls.Config
   276  	if r.TLSClientConfig == nil {
   277  		tlsConf = &tls.Config{}
   278  	} else {
   279  		tlsConf = r.TLSClientConfig.Clone()
   280  	}
   281  	if tlsConf.ServerName == "" {
   282  		sni, _, err := net.SplitHostPort(hostname)
   283  		if err != nil {
   284  			// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
   285  			sni = hostname
   286  		}
   287  		tlsConf.ServerName = sni
   288  	}
   289  	// Replace existing ALPNs by H3
   290  	tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])}
   291  
   292  	dial := r.Dial
   293  	if dial == nil {
   294  		if r.transport == nil {
   295  			udpConn, err := net.ListenUDP("udp", nil)
   296  			if err != nil {
   297  				return nil, nil, err
   298  			}
   299  			r.transport = &quic.Transport{Conn: udpConn}
   300  		}
   301  		dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   302  			udpAddr, err := net.ResolveUDPAddr("udp", addr)
   303  			if err != nil {
   304  				return nil, err
   305  			}
   306  			return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   307  		}
   308  	}
   309  
   310  	conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig)
   311  	if err != nil {
   312  		return nil, nil, err
   313  	}
   314  	return conn, r.newClient(conn), nil
   315  }
   316  
   317  func (r *RoundTripper) removeClient(hostname string) {
   318  	r.mutex.Lock()
   319  	defer r.mutex.Unlock()
   320  	if r.clients == nil {
   321  		return
   322  	}
   323  	delete(r.clients, hostname)
   324  }
   325  
   326  // Close closes the QUIC connections that this RoundTripper has used.
   327  // It also closes the underlying UDPConn if it is not nil.
   328  func (r *RoundTripper) Close() error {
   329  	r.mutex.Lock()
   330  	defer r.mutex.Unlock()
   331  	for _, cl := range r.clients {
   332  		if err := cl.Close(); err != nil {
   333  			return err
   334  		}
   335  	}
   336  	r.clients = nil
   337  	if r.transport != nil {
   338  		if err := r.transport.Close(); err != nil {
   339  			return err
   340  		}
   341  		if err := r.transport.Conn.Close(); err != nil {
   342  			return err
   343  		}
   344  		r.transport = nil
   345  	}
   346  	return nil
   347  }
   348  
   349  func closeRequestBody(req *http.Request) {
   350  	if req.Body != nil {
   351  		req.Body.Close()
   352  	}
   353  }
   354  
   355  func validMethod(method string) bool {
   356  	/*
   357  				     Method         = "OPTIONS"                ; Section 9.2
   358  		   		                    | "GET"                    ; Section 9.3
   359  		   		                    | "HEAD"                   ; Section 9.4
   360  		   		                    | "POST"                   ; Section 9.5
   361  		   		                    | "PUT"                    ; Section 9.6
   362  		   		                    | "DELETE"                 ; Section 9.7
   363  		   		                    | "TRACE"                  ; Section 9.8
   364  		   		                    | "CONNECT"                ; Section 9.9
   365  		   		                    | extension-method
   366  		   		   extension-method = token
   367  		   		     token          = 1*<any CHAR except CTLs or separators>
   368  	*/
   369  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   370  }
   371  
   372  // copied from net/http/http.go
   373  func isNotToken(r rune) bool {
   374  	return !httpguts.IsTokenRune(r)
   375  }
   376  
   377  func (r *RoundTripper) CloseIdleConnections() {
   378  	r.mutex.Lock()
   379  	defer r.mutex.Unlock()
   380  	for hostname, cl := range r.clients {
   381  		if cl.useCount.Load() == 0 {
   382  			cl.Close()
   383  			delete(r.clients, hostname)
   384  		}
   385  	}
   386  }