github.com/sagernet/quic-go@v0.43.1-beta.1/http3_ech/roundtrip.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  
    14  	"github.com/sagernet/quic-go/ech"
    15  	"github.com/sagernet/quic-go/internal/protocol"
    16  	"github.com/sagernet/cloudflare-tls"
    17  	"golang.org/x/net/http/httpguts"
    18  )
    19  
    20  // Settings are HTTP/3 settings that apply to the underlying connection.
    21  type Settings struct {
    22  	// Support for HTTP/3 datagrams (RFC 9297)
    23  	EnableDatagrams bool
    24  	// Extended CONNECT, RFC 9220
    25  	EnableExtendedConnect bool
    26  	// Other settings, defined by the application
    27  	Other map[uint64]uint64
    28  }
    29  
    30  // RoundTripOpt are options for the Transport.RoundTripOpt method.
    31  type RoundTripOpt struct {
    32  	// OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection.
    33  	// If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn.
    34  	OnlyCachedConn bool
    35  }
    36  
    37  type singleRoundTripper interface {
    38  	OpenRequestStream(context.Context) (RequestStream, error)
    39  	RoundTrip(*http.Request) (*http.Response, error)
    40  }
    41  
    42  type roundTripperWithCount struct {
    43  	cancel  context.CancelFunc
    44  	dialing chan struct{} // closed as soon as quic.Dial(Early) returned
    45  	dialErr error
    46  	conn    quic.EarlyConnection
    47  	rt      singleRoundTripper
    48  
    49  	useCount atomic.Int64
    50  }
    51  
    52  func (r *roundTripperWithCount) Close() error {
    53  	r.cancel()
    54  	<-r.dialing
    55  	if r.conn != nil {
    56  		return r.conn.CloseWithError(0, "")
    57  	}
    58  	return nil
    59  }
    60  
    61  // RoundTripper implements the http.RoundTripper interface
    62  type RoundTripper struct {
    63  	mutex sync.Mutex
    64  
    65  	// TLSClientConfig specifies the TLS configuration to use with
    66  	// tls.Client. If nil, the default configuration is used.
    67  	TLSClientConfig *tls.Config
    68  
    69  	// QUICConfig is the quic.Config used for dialing new connections.
    70  	// If nil, reasonable default values will be used.
    71  	QUICConfig *quic.Config
    72  
    73  	// Dial specifies an optional dial function for creating QUIC
    74  	// connections for requests.
    75  	// If Dial is nil, a UDPConn will be created at the first request
    76  	// and will be reused for subsequent connections to other servers.
    77  	Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error)
    78  
    79  	// Enable support for HTTP/3 datagrams (RFC 9297).
    80  	// If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams.
    81  	EnableDatagrams bool
    82  
    83  	// Additional HTTP/3 settings.
    84  	// It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams).
    85  	AdditionalSettings map[uint64]uint64
    86  
    87  	// MaxResponseHeaderBytes specifies a limit on how many response bytes are
    88  	// allowed in the server's response header.
    89  	// Zero means to use a default limit.
    90  	MaxResponseHeaderBytes int64
    91  
    92  	// DisableCompression, if true, prevents the Transport from requesting compression with an
    93  	// "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value.
    94  	// If the Transport requests gzip on its own and gets a gzipped response, it's transparently
    95  	// decoded in the Response.Body.
    96  	// However, if the user explicitly requested gzip it is not automatically uncompressed.
    97  	DisableCompression bool
    98  
    99  	initOnce sync.Once
   100  	initErr  error
   101  
   102  	newClient func(quic.EarlyConnection) singleRoundTripper
   103  
   104  	clients   map[string]*roundTripperWithCount
   105  	transport *quic.Transport
   106  }
   107  
   108  var (
   109  	_ http.RoundTripper = &RoundTripper{}
   110  	_ io.Closer         = &RoundTripper{}
   111  )
   112  
   113  // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
   114  var ErrNoCachedConn = errors.New("http3: no cached connection was available")
   115  
   116  // RoundTripOpt is like RoundTrip, but takes options.
   117  func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
   118  	r.initOnce.Do(func() { r.initErr = r.init() })
   119  	if r.initErr != nil {
   120  		return nil, r.initErr
   121  	}
   122  
   123  	if req.URL == nil {
   124  		closeRequestBody(req)
   125  		return nil, errors.New("http3: nil Request.URL")
   126  	}
   127  	if req.URL.Scheme != "https" {
   128  		closeRequestBody(req)
   129  		return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
   130  	}
   131  	if req.URL.Host == "" {
   132  		closeRequestBody(req)
   133  		return nil, errors.New("http3: no Host in request URL")
   134  	}
   135  	if req.Header == nil {
   136  		closeRequestBody(req)
   137  		return nil, errors.New("http3: nil Request.Header")
   138  	}
   139  	for k, vv := range req.Header {
   140  		if !httpguts.ValidHeaderFieldName(k) {
   141  			return nil, fmt.Errorf("http3: invalid http header field name %q", k)
   142  		}
   143  		for _, v := range vv {
   144  			if !httpguts.ValidHeaderFieldValue(v) {
   145  				return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
   146  			}
   147  		}
   148  	}
   149  
   150  	if req.Method != "" && !validMethod(req.Method) {
   151  		closeRequestBody(req)
   152  		return nil, fmt.Errorf("http3: invalid method %q", req.Method)
   153  	}
   154  
   155  	hostname := authorityAddr(hostnameFromURL(req.URL))
   156  	cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	select {
   162  	case <-cl.dialing:
   163  	case <-req.Context().Done():
   164  		return nil, context.Cause(req.Context())
   165  	}
   166  
   167  	if cl.dialErr != nil {
   168  		return nil, cl.dialErr
   169  	}
   170  	defer cl.useCount.Add(-1)
   171  	rsp, err := cl.rt.RoundTrip(req)
   172  	if err != nil {
   173  		// non-nil errors on roundtrip are likely due to a problem with the connection
   174  		// so we remove the client from the cache so that subsequent trips reconnect
   175  		// context cancelation is excluded as is does not signify a connection error
   176  		if !errors.Is(err, context.Canceled) {
   177  			r.removeClient(hostname)
   178  		}
   179  
   180  		if isReused {
   181  			if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   182  				return r.RoundTripOpt(req, opt)
   183  			}
   184  		}
   185  	}
   186  	return rsp, err
   187  }
   188  
   189  // RoundTrip does a round trip.
   190  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   191  	return r.RoundTripOpt(req, RoundTripOpt{})
   192  }
   193  
   194  func (r *RoundTripper) init() error {
   195  	if r.newClient == nil {
   196  		r.newClient = func(conn quic.EarlyConnection) singleRoundTripper {
   197  			return &SingleDestinationRoundTripper{
   198  				Connection:             conn,
   199  				EnableDatagrams:        r.EnableDatagrams,
   200  				DisableCompression:     r.DisableCompression,
   201  				AdditionalSettings:     r.AdditionalSettings,
   202  				MaxResponseHeaderBytes: r.MaxResponseHeaderBytes,
   203  			}
   204  		}
   205  	}
   206  	if r.QUICConfig == nil {
   207  		r.QUICConfig = defaultQuicConfig.Clone()
   208  		r.QUICConfig.EnableDatagrams = r.EnableDatagrams
   209  	}
   210  	if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams {
   211  		return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled")
   212  	}
   213  	if len(r.QUICConfig.Versions) == 0 {
   214  		r.QUICConfig = r.QUICConfig.Clone()
   215  		r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]}
   216  	}
   217  	if len(r.QUICConfig.Versions) != 1 {
   218  		return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
   219  	}
   220  	if r.QUICConfig.MaxIncomingStreams == 0 {
   221  		r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
   222  	}
   223  	return nil
   224  }
   225  
   226  func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) {
   227  	r.mutex.Lock()
   228  	defer r.mutex.Unlock()
   229  
   230  	if r.clients == nil {
   231  		r.clients = make(map[string]*roundTripperWithCount)
   232  	}
   233  
   234  	cl, ok := r.clients[hostname]
   235  	if !ok {
   236  		if onlyCached {
   237  			return nil, false, ErrNoCachedConn
   238  		}
   239  		ctx, cancel := context.WithCancel(ctx)
   240  		cl = &roundTripperWithCount{
   241  			dialing: make(chan struct{}),
   242  			cancel:  cancel,
   243  		}
   244  		go func() {
   245  			defer close(cl.dialing)
   246  			defer cancel()
   247  			conn, rt, err := r.dial(ctx, hostname)
   248  			if err != nil {
   249  				cl.dialErr = err
   250  				return
   251  			}
   252  			cl.conn = conn
   253  			cl.rt = rt
   254  		}()
   255  		r.clients[hostname] = cl
   256  	}
   257  	select {
   258  	case <-cl.dialing:
   259  		if cl.dialErr != nil {
   260  			return nil, false, cl.dialErr
   261  		}
   262  		select {
   263  		case <-cl.conn.HandshakeComplete():
   264  			isReused = true
   265  		default:
   266  		}
   267  	default:
   268  	}
   269  	cl.useCount.Add(1)
   270  	return cl, isReused, nil
   271  }
   272  
   273  func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) {
   274  	var tlsConf *tls.Config
   275  	if r.TLSClientConfig == nil {
   276  		tlsConf = &tls.Config{}
   277  	} else {
   278  		tlsConf = r.TLSClientConfig.Clone()
   279  	}
   280  	if tlsConf.ServerName == "" {
   281  		sni, _, err := net.SplitHostPort(hostname)
   282  		if err != nil {
   283  			// It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port.
   284  			sni = hostname
   285  		}
   286  		tlsConf.ServerName = sni
   287  	}
   288  	// Replace existing ALPNs by H3
   289  	tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])}
   290  
   291  	dial := r.Dial
   292  	if dial == nil {
   293  		if r.transport == nil {
   294  			udpConn, err := net.ListenUDP("udp", nil)
   295  			if err != nil {
   296  				return nil, nil, err
   297  			}
   298  			r.transport = &quic.Transport{Conn: udpConn}
   299  		}
   300  		dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
   301  			udpAddr, err := net.ResolveUDPAddr("udp", addr)
   302  			if err != nil {
   303  				return nil, err
   304  			}
   305  			return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
   306  		}
   307  	}
   308  
   309  	conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig)
   310  	if err != nil {
   311  		return nil, nil, err
   312  	}
   313  	return conn, r.newClient(conn), nil
   314  }
   315  
   316  func (r *RoundTripper) removeClient(hostname string) {
   317  	r.mutex.Lock()
   318  	defer r.mutex.Unlock()
   319  	if r.clients == nil {
   320  		return
   321  	}
   322  	delete(r.clients, hostname)
   323  }
   324  
   325  // Close closes the QUIC connections that this RoundTripper has used.
   326  // It also closes the underlying UDPConn if it is not nil.
   327  func (r *RoundTripper) Close() error {
   328  	r.mutex.Lock()
   329  	defer r.mutex.Unlock()
   330  	for _, cl := range r.clients {
   331  		if err := cl.Close(); err != nil {
   332  			return err
   333  		}
   334  	}
   335  	r.clients = nil
   336  	if r.transport != nil {
   337  		if err := r.transport.Close(); err != nil {
   338  			return err
   339  		}
   340  		if err := r.transport.Conn.Close(); err != nil {
   341  			return err
   342  		}
   343  		r.transport = nil
   344  	}
   345  	return nil
   346  }
   347  
   348  func closeRequestBody(req *http.Request) {
   349  	if req.Body != nil {
   350  		req.Body.Close()
   351  	}
   352  }
   353  
   354  func validMethod(method string) bool {
   355  	/*
   356  				     Method         = "OPTIONS"                ; Section 9.2
   357  		   		                    | "GET"                    ; Section 9.3
   358  		   		                    | "HEAD"                   ; Section 9.4
   359  		   		                    | "POST"                   ; Section 9.5
   360  		   		                    | "PUT"                    ; Section 9.6
   361  		   		                    | "DELETE"                 ; Section 9.7
   362  		   		                    | "TRACE"                  ; Section 9.8
   363  		   		                    | "CONNECT"                ; Section 9.9
   364  		   		                    | extension-method
   365  		   		   extension-method = token
   366  		   		     token          = 1*<any CHAR except CTLs or separators>
   367  	*/
   368  	return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
   369  }
   370  
   371  // copied from net/http/http.go
   372  func isNotToken(r rune) bool {
   373  	return !httpguts.IsTokenRune(r)
   374  }
   375  
   376  func (r *RoundTripper) CloseIdleConnections() {
   377  	r.mutex.Lock()
   378  	defer r.mutex.Unlock()
   379  	for hostname, cl := range r.clients {
   380  		if cl.useCount.Load() == 0 {
   381  			cl.Close()
   382  			delete(r.clients, hostname)
   383  		}
   384  	}
   385  }