github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/roundtrip.go (about) 1 package http3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strings" 12 "sync" 13 "sync/atomic" 14 15 "golang.org/x/net/http/httpguts" 16 17 "github.com/apernet/quic-go" 18 "github.com/apernet/quic-go/internal/protocol" 19 ) 20 21 // Settings are HTTP/3 settings that apply to the underlying connection. 22 type Settings struct { 23 // Support for HTTP/3 datagrams (RFC 9297) 24 EnableDatagrams bool 25 // Extended CONNECT, RFC 9220 26 EnableExtendedConnect bool 27 // Other settings, defined by the application 28 Other map[uint64]uint64 29 } 30 31 // RoundTripOpt are options for the Transport.RoundTripOpt method. 32 type RoundTripOpt struct { 33 // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. 34 // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. 35 OnlyCachedConn bool 36 } 37 38 type singleRoundTripper interface { 39 OpenRequestStream(context.Context) (RequestStream, error) 40 RoundTrip(*http.Request) (*http.Response, error) 41 } 42 43 type roundTripperWithCount struct { 44 cancel context.CancelFunc 45 dialing chan struct{} // closed as soon as quic.Dial(Early) returned 46 dialErr error 47 conn quic.EarlyConnection 48 rt singleRoundTripper 49 50 useCount atomic.Int64 51 } 52 53 func (r *roundTripperWithCount) Close() error { 54 r.cancel() 55 <-r.dialing 56 if r.conn != nil { 57 return r.conn.CloseWithError(0, "") 58 } 59 return nil 60 } 61 62 // RoundTripper implements the http.RoundTripper interface 63 type RoundTripper struct { 64 mutex sync.Mutex 65 66 // TLSClientConfig specifies the TLS configuration to use with 67 // tls.Client. If nil, the default configuration is used. 68 TLSClientConfig *tls.Config 69 70 // QUICConfig is the quic.Config used for dialing new connections. 71 // If nil, reasonable default values will be used. 72 QUICConfig *quic.Config 73 74 // Dial specifies an optional dial function for creating QUIC 75 // connections for requests. 76 // If Dial is nil, a UDPConn will be created at the first request 77 // and will be reused for subsequent connections to other servers. 78 Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) 79 80 // Enable support for HTTP/3 datagrams (RFC 9297). 81 // If a QUICConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. 82 EnableDatagrams bool 83 84 // Additional HTTP/3 settings. 85 // It is invalid to specify any settings defined by RFC 9114 (HTTP/3) and RFC 9297 (HTTP Datagrams). 86 AdditionalSettings map[uint64]uint64 87 88 // MaxResponseHeaderBytes specifies a limit on how many response bytes are 89 // allowed in the server's response header. 90 // Zero means to use a default limit. 91 MaxResponseHeaderBytes int64 92 93 // DisableCompression, if true, prevents the Transport from requesting compression with an 94 // "Accept-Encoding: gzip" request header when the Request contains no existing Accept-Encoding value. 95 // If the Transport requests gzip on its own and gets a gzipped response, it's transparently 96 // decoded in the Response.Body. 97 // However, if the user explicitly requested gzip it is not automatically uncompressed. 98 DisableCompression bool 99 100 initOnce sync.Once 101 initErr error 102 103 newClient func(quic.EarlyConnection) singleRoundTripper 104 105 clients map[string]*roundTripperWithCount 106 transport *quic.Transport 107 } 108 109 var ( 110 _ http.RoundTripper = &RoundTripper{} 111 _ io.Closer = &RoundTripper{} 112 ) 113 114 // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set 115 var ErrNoCachedConn = errors.New("http3: no cached connection was available") 116 117 // RoundTripOpt is like RoundTrip, but takes options. 118 func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 119 r.initOnce.Do(func() { r.initErr = r.init() }) 120 if r.initErr != nil { 121 return nil, r.initErr 122 } 123 124 if req.URL == nil { 125 closeRequestBody(req) 126 return nil, errors.New("http3: nil Request.URL") 127 } 128 if req.URL.Scheme != "https" { 129 closeRequestBody(req) 130 return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) 131 } 132 if req.URL.Host == "" { 133 closeRequestBody(req) 134 return nil, errors.New("http3: no Host in request URL") 135 } 136 if req.Header == nil { 137 closeRequestBody(req) 138 return nil, errors.New("http3: nil Request.Header") 139 } 140 for k, vv := range req.Header { 141 if !httpguts.ValidHeaderFieldName(k) { 142 return nil, fmt.Errorf("http3: invalid http header field name %q", k) 143 } 144 for _, v := range vv { 145 if !httpguts.ValidHeaderFieldValue(v) { 146 return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) 147 } 148 } 149 } 150 151 if req.Method != "" && !validMethod(req.Method) { 152 closeRequestBody(req) 153 return nil, fmt.Errorf("http3: invalid method %q", req.Method) 154 } 155 156 hostname := authorityAddr(hostnameFromURL(req.URL)) 157 cl, isReused, err := r.getClient(req.Context(), hostname, opt.OnlyCachedConn) 158 if err != nil { 159 return nil, err 160 } 161 162 select { 163 case <-cl.dialing: 164 case <-req.Context().Done(): 165 return nil, context.Cause(req.Context()) 166 } 167 168 if cl.dialErr != nil { 169 return nil, cl.dialErr 170 } 171 defer cl.useCount.Add(-1) 172 rsp, err := cl.rt.RoundTrip(req) 173 if err != nil { 174 // non-nil errors on roundtrip are likely due to a problem with the connection 175 // so we remove the client from the cache so that subsequent trips reconnect 176 // context cancelation is excluded as is does not signify a connection error 177 if !errors.Is(err, context.Canceled) { 178 r.removeClient(hostname) 179 } 180 181 if isReused { 182 if nerr, ok := err.(net.Error); ok && nerr.Timeout() { 183 return r.RoundTripOpt(req, opt) 184 } 185 } 186 } 187 return rsp, err 188 } 189 190 // RoundTrip does a round trip. 191 func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 192 return r.RoundTripOpt(req, RoundTripOpt{}) 193 } 194 195 func (r *RoundTripper) init() error { 196 if r.newClient == nil { 197 r.newClient = func(conn quic.EarlyConnection) singleRoundTripper { 198 return &SingleDestinationRoundTripper{ 199 Connection: conn, 200 EnableDatagrams: r.EnableDatagrams, 201 DisableCompression: r.DisableCompression, 202 AdditionalSettings: r.AdditionalSettings, 203 MaxResponseHeaderBytes: r.MaxResponseHeaderBytes, 204 } 205 } 206 } 207 if r.QUICConfig == nil { 208 r.QUICConfig = defaultQuicConfig.Clone() 209 r.QUICConfig.EnableDatagrams = r.EnableDatagrams 210 } 211 if r.EnableDatagrams && !r.QUICConfig.EnableDatagrams { 212 return errors.New("HTTP Datagrams enabled, but QUIC Datagrams disabled") 213 } 214 if len(r.QUICConfig.Versions) == 0 { 215 r.QUICConfig = r.QUICConfig.Clone() 216 r.QUICConfig.Versions = []quic.Version{protocol.SupportedVersions[0]} 217 } 218 if len(r.QUICConfig.Versions) != 1 { 219 return errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") 220 } 221 if r.QUICConfig.MaxIncomingStreams == 0 { 222 r.QUICConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams 223 } 224 return nil 225 } 226 227 func (r *RoundTripper) getClient(ctx context.Context, hostname string, onlyCached bool) (rtc *roundTripperWithCount, isReused bool, err error) { 228 r.mutex.Lock() 229 defer r.mutex.Unlock() 230 231 if r.clients == nil { 232 r.clients = make(map[string]*roundTripperWithCount) 233 } 234 235 cl, ok := r.clients[hostname] 236 if !ok { 237 if onlyCached { 238 return nil, false, ErrNoCachedConn 239 } 240 ctx, cancel := context.WithCancel(ctx) 241 cl = &roundTripperWithCount{ 242 dialing: make(chan struct{}), 243 cancel: cancel, 244 } 245 go func() { 246 defer close(cl.dialing) 247 defer cancel() 248 conn, rt, err := r.dial(ctx, hostname) 249 if err != nil { 250 cl.dialErr = err 251 return 252 } 253 cl.conn = conn 254 cl.rt = rt 255 }() 256 r.clients[hostname] = cl 257 } 258 select { 259 case <-cl.dialing: 260 if cl.dialErr != nil { 261 return nil, false, cl.dialErr 262 } 263 select { 264 case <-cl.conn.HandshakeComplete(): 265 isReused = true 266 default: 267 } 268 default: 269 } 270 cl.useCount.Add(1) 271 return cl, isReused, nil 272 } 273 274 func (r *RoundTripper) dial(ctx context.Context, hostname string) (quic.EarlyConnection, singleRoundTripper, error) { 275 var tlsConf *tls.Config 276 if r.TLSClientConfig == nil { 277 tlsConf = &tls.Config{} 278 } else { 279 tlsConf = r.TLSClientConfig.Clone() 280 } 281 if tlsConf.ServerName == "" { 282 sni, _, err := net.SplitHostPort(hostname) 283 if err != nil { 284 // It's ok if net.SplitHostPort returns an error - it could be a hostname/IP address without a port. 285 sni = hostname 286 } 287 tlsConf.ServerName = sni 288 } 289 // Replace existing ALPNs by H3 290 tlsConf.NextProtos = []string{versionToALPN(r.QUICConfig.Versions[0])} 291 292 dial := r.Dial 293 if dial == nil { 294 if r.transport == nil { 295 udpConn, err := net.ListenUDP("udp", nil) 296 if err != nil { 297 return nil, nil, err 298 } 299 r.transport = &quic.Transport{Conn: udpConn} 300 } 301 dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 302 udpAddr, err := net.ResolveUDPAddr("udp", addr) 303 if err != nil { 304 return nil, err 305 } 306 return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) 307 } 308 } 309 310 conn, err := dial(ctx, hostname, tlsConf, r.QUICConfig) 311 if err != nil { 312 return nil, nil, err 313 } 314 return conn, r.newClient(conn), nil 315 } 316 317 func (r *RoundTripper) removeClient(hostname string) { 318 r.mutex.Lock() 319 defer r.mutex.Unlock() 320 if r.clients == nil { 321 return 322 } 323 delete(r.clients, hostname) 324 } 325 326 // Close closes the QUIC connections that this RoundTripper has used. 327 // It also closes the underlying UDPConn if it is not nil. 328 func (r *RoundTripper) Close() error { 329 r.mutex.Lock() 330 defer r.mutex.Unlock() 331 for _, cl := range r.clients { 332 if err := cl.Close(); err != nil { 333 return err 334 } 335 } 336 r.clients = nil 337 if r.transport != nil { 338 if err := r.transport.Close(); err != nil { 339 return err 340 } 341 if err := r.transport.Conn.Close(); err != nil { 342 return err 343 } 344 r.transport = nil 345 } 346 return nil 347 } 348 349 func closeRequestBody(req *http.Request) { 350 if req.Body != nil { 351 req.Body.Close() 352 } 353 } 354 355 func validMethod(method string) bool { 356 /* 357 Method = "OPTIONS" ; Section 9.2 358 | "GET" ; Section 9.3 359 | "HEAD" ; Section 9.4 360 | "POST" ; Section 9.5 361 | "PUT" ; Section 9.6 362 | "DELETE" ; Section 9.7 363 | "TRACE" ; Section 9.8 364 | "CONNECT" ; Section 9.9 365 | extension-method 366 extension-method = token 367 token = 1*<any CHAR except CTLs or separators> 368 */ 369 return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 370 } 371 372 // copied from net/http/http.go 373 func isNotToken(r rune) bool { 374 return !httpguts.IsTokenRune(r) 375 } 376 377 func (r *RoundTripper) CloseIdleConnections() { 378 r.mutex.Lock() 379 defer r.mutex.Unlock() 380 for hostname, cl := range r.clients { 381 if cl.useCount.Load() == 0 { 382 cl.Close() 383 delete(r.clients, hostname) 384 } 385 } 386 }