github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/roundtrip.go (about) 1 package http3 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "net/http" 11 "strings" 12 "sync" 13 "sync/atomic" 14 15 "golang.org/x/net/http/httpguts" 16 17 "github.com/daeuniverse/quic-go" 18 ) 19 20 // Settings are HTTP/3 settings that apply to the underlying connection. 21 type Settings struct { 22 // Support for HTTP/3 datagrams (RFC 9297) 23 EnableDatagram bool 24 // Extended CONNECT, RFC 9220 25 EnableExtendedConnect bool 26 // Other settings, defined by the application 27 Other map[uint64]uint64 28 } 29 30 // RoundTripOpt are options for the Transport.RoundTripOpt method. 31 type RoundTripOpt struct { 32 // OnlyCachedConn controls whether the RoundTripper may create a new QUIC connection. 33 // If set true and no cached connection is available, RoundTripOpt will return ErrNoCachedConn. 34 OnlyCachedConn bool 35 // DontCloseRequestStream controls whether the request stream is closed after sending the request. 36 // If set, context cancellations have no effect after the response headers are received. 37 DontCloseRequestStream bool 38 // CheckSettings is run before the request is sent to the server. 39 // If not yet received, it blocks until the server's SETTINGS frame is received. 40 // If an error is returned, the request won't be sent to the server, and the error is returned. 41 CheckSettings func(Settings) error 42 } 43 44 type roundTripCloser interface { 45 RoundTripOpt(*http.Request, RoundTripOpt) (*http.Response, error) 46 HandshakeComplete() bool 47 io.Closer 48 } 49 50 type roundTripCloserWithCount struct { 51 roundTripCloser 52 useCount atomic.Int64 53 } 54 55 // RoundTripper implements the http.RoundTripper interface 56 type RoundTripper struct { 57 mutex sync.Mutex 58 59 // DisableCompression, if true, prevents the Transport from 60 // requesting compression with an "Accept-Encoding: gzip" 61 // request header when the Request contains no existing 62 // Accept-Encoding value. If the Transport requests gzip on 63 // its own and gets a gzipped response, it's transparently 64 // decoded in the Response.Body. However, if the user 65 // explicitly requested gzip it is not automatically 66 // uncompressed. 67 DisableCompression bool 68 69 // TLSClientConfig specifies the TLS configuration to use with 70 // tls.Client. If nil, the default configuration is used. 71 TLSClientConfig *tls.Config 72 73 // QuicConfig is the quic.Config used for dialing new connections. 74 // If nil, reasonable default values will be used. 75 QuicConfig *quic.Config 76 77 // Enable support for HTTP/3 datagrams (RFC 9297). 78 // If a QuicConfig is set, datagram support also needs to be enabled on the QUIC layer by setting EnableDatagrams. 79 EnableDatagrams bool 80 81 // Additional HTTP/3 settings. 82 // It is invalid to specify any settings defined by the HTTP/3 draft and the datagram draft. 83 AdditionalSettings map[uint64]uint64 84 85 // When set, this callback is called for the first unknown frame parsed on a bidirectional stream. 86 // It is called right after parsing the frame type. 87 // If parsing the frame type fails, the error is passed to the callback. 88 // In that case, the frame type will not be set. 89 // Callers can either ignore the frame and return control of the stream back to HTTP/3 90 // (by returning hijacked false). 91 // Alternatively, callers can take over the QUIC stream (by returning hijacked true). 92 StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error) 93 94 // When set, this callback is called for unknown unidirectional stream of unknown stream type. 95 // If parsing the stream type fails, the error is passed to the callback. 96 // In that case, the stream type will not be set. 97 UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool) 98 99 // Dial specifies an optional dial function for creating QUIC 100 // connections for requests. 101 // If Dial is nil, a UDPConn will be created at the first request 102 // and will be reused for subsequent connections to other servers. 103 Dial func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) 104 105 // MaxResponseHeaderBytes specifies a limit on how many response bytes are 106 // allowed in the server's response header. 107 // Zero means to use a default limit. 108 MaxResponseHeaderBytes int64 109 110 newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests 111 clients map[string]*roundTripCloserWithCount 112 transport *quic.Transport 113 } 114 115 var ( 116 _ http.RoundTripper = &RoundTripper{} 117 _ io.Closer = &RoundTripper{} 118 ) 119 120 // ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set 121 var ErrNoCachedConn = errors.New("http3: no cached connection was available") 122 123 // RoundTripOpt is like RoundTrip, but takes options. 124 func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { 125 if req.URL == nil { 126 closeRequestBody(req) 127 return nil, errors.New("http3: nil Request.URL") 128 } 129 if req.URL.Scheme != "https" { 130 closeRequestBody(req) 131 return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) 132 } 133 if req.URL.Host == "" { 134 closeRequestBody(req) 135 return nil, errors.New("http3: no Host in request URL") 136 } 137 if req.Header == nil { 138 closeRequestBody(req) 139 return nil, errors.New("http3: nil Request.Header") 140 } 141 for k, vv := range req.Header { 142 if !httpguts.ValidHeaderFieldName(k) { 143 return nil, fmt.Errorf("http3: invalid http header field name %q", k) 144 } 145 for _, v := range vv { 146 if !httpguts.ValidHeaderFieldValue(v) { 147 return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) 148 } 149 } 150 } 151 152 if req.Method != "" && !validMethod(req.Method) { 153 closeRequestBody(req) 154 return nil, fmt.Errorf("http3: invalid method %q", req.Method) 155 } 156 157 hostname := authorityAddr("https", hostnameFromRequest(req)) 158 cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) 159 if err != nil { 160 return nil, err 161 } 162 defer cl.useCount.Add(-1) 163 rsp, err := cl.RoundTripOpt(req, opt) 164 if err != nil { 165 r.removeClient(hostname) 166 if isReused { 167 if nerr, ok := err.(net.Error); ok && nerr.Timeout() { 168 return r.RoundTripOpt(req, opt) 169 } 170 } 171 } 172 return rsp, err 173 } 174 175 // RoundTrip does a round trip. 176 func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 177 return r.RoundTripOpt(req, RoundTripOpt{}) 178 } 179 180 func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { 181 r.mutex.Lock() 182 defer r.mutex.Unlock() 183 184 if r.clients == nil { 185 r.clients = make(map[string]*roundTripCloserWithCount) 186 } 187 188 client, ok := r.clients[hostname] 189 if !ok { 190 if onlyCached { 191 return nil, false, ErrNoCachedConn 192 } 193 var err error 194 newCl := newClient 195 if r.newClient != nil { 196 newCl = r.newClient 197 } 198 dial := r.Dial 199 if dial == nil { 200 if r.transport == nil { 201 udpConn, err := net.ListenUDP("udp", nil) 202 if err != nil { 203 return nil, false, err 204 } 205 r.transport = &quic.Transport{Conn: udpConn} 206 } 207 dial = r.makeDialer() 208 } 209 c, err := newCl( 210 hostname, 211 r.TLSClientConfig, 212 &roundTripperOpts{ 213 EnableDatagram: r.EnableDatagrams, 214 DisableCompression: r.DisableCompression, 215 MaxHeaderBytes: r.MaxResponseHeaderBytes, 216 StreamHijacker: r.StreamHijacker, 217 UniStreamHijacker: r.UniStreamHijacker, 218 AdditionalSettings: r.AdditionalSettings, 219 }, 220 r.QuicConfig, 221 dial, 222 ) 223 if err != nil { 224 return nil, false, err 225 } 226 client = &roundTripCloserWithCount{roundTripCloser: c} 227 r.clients[hostname] = client 228 } else if client.HandshakeComplete() { 229 isReused = true 230 } 231 client.useCount.Add(1) 232 return client, isReused, nil 233 } 234 235 func (r *RoundTripper) removeClient(hostname string) { 236 r.mutex.Lock() 237 defer r.mutex.Unlock() 238 if r.clients == nil { 239 return 240 } 241 delete(r.clients, hostname) 242 } 243 244 // Close closes the QUIC connections that this RoundTripper has used. 245 // It also closes the underlying UDPConn if it is not nil. 246 func (r *RoundTripper) Close() error { 247 r.mutex.Lock() 248 defer r.mutex.Unlock() 249 for _, client := range r.clients { 250 if err := client.Close(); err != nil { 251 return err 252 } 253 } 254 r.clients = nil 255 if r.transport != nil { 256 if err := r.transport.Close(); err != nil { 257 return err 258 } 259 if err := r.transport.Conn.Close(); err != nil { 260 return err 261 } 262 r.transport = nil 263 } 264 return nil 265 } 266 267 func closeRequestBody(req *http.Request) { 268 if req.Body != nil { 269 req.Body.Close() 270 } 271 } 272 273 func validMethod(method string) bool { 274 /* 275 Method = "OPTIONS" ; Section 9.2 276 | "GET" ; Section 9.3 277 | "HEAD" ; Section 9.4 278 | "POST" ; Section 9.5 279 | "PUT" ; Section 9.6 280 | "DELETE" ; Section 9.7 281 | "TRACE" ; Section 9.8 282 | "CONNECT" ; Section 9.9 283 | extension-method 284 extension-method = token 285 token = 1*<any CHAR except CTLs or separators> 286 */ 287 return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 288 } 289 290 // copied from net/http/http.go 291 func isNotToken(r rune) bool { 292 return !httpguts.IsTokenRune(r) 293 } 294 295 // makeDialer makes a QUIC dialer using r.udpConn. 296 func (r *RoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 297 return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 298 udpAddr, err := net.ResolveUDPAddr("udp", addr) 299 if err != nil { 300 return nil, err 301 } 302 return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) 303 } 304 } 305 306 func (r *RoundTripper) CloseIdleConnections() { 307 r.mutex.Lock() 308 defer r.mutex.Unlock() 309 for hostname, client := range r.clients { 310 if client.useCount.Load() == 0 { 311 client.Close() 312 delete(r.clients, hostname) 313 } 314 } 315 }