github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/roundtrip_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "io" 9 "net/http" 10 "sync/atomic" 11 "time" 12 13 "github.com/daeuniverse/quic-go" 14 "github.com/daeuniverse/quic-go/internal/qerr" 15 16 . "github.com/onsi/ginkgo/v2" 17 . "github.com/onsi/gomega" 18 "go.uber.org/mock/gomock" 19 ) 20 21 type mockBody struct { 22 reader bytes.Reader 23 readErr error 24 closeErr error 25 closed bool 26 } 27 28 // make sure the mockBody can be used as a http.Request.Body 29 var _ io.ReadCloser = &mockBody{} 30 31 func (m *mockBody) Read(p []byte) (int, error) { 32 if m.readErr != nil { 33 return 0, m.readErr 34 } 35 return m.reader.Read(p) 36 } 37 38 func (m *mockBody) SetData(data []byte) { 39 m.reader = *bytes.NewReader(data) 40 } 41 42 func (m *mockBody) Close() error { 43 m.closed = true 44 return m.closeErr 45 } 46 47 var _ = Describe("RoundTripper", func() { 48 var ( 49 rt *RoundTripper 50 req *http.Request 51 ) 52 53 BeforeEach(func() { 54 rt = &RoundTripper{} 55 var err error 56 req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) 57 Expect(err).ToNot(HaveOccurred()) 58 }) 59 60 Context("dialing hosts", func() { 61 It("creates new clients", func() { 62 testErr := errors.New("test err") 63 req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 64 Expect(err).ToNot(HaveOccurred()) 65 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 66 cl := NewMockRoundTripCloser(mockCtrl) 67 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) 68 return cl, nil 69 } 70 _, err = rt.RoundTrip(req) 71 Expect(err).To(MatchError(testErr)) 72 }) 73 74 It("creates new clients with additional settings", func() { 75 testErr := errors.New("test err") 76 req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 77 Expect(err).ToNot(HaveOccurred()) 78 rt.AdditionalSettings = map[uint64]uint64{1337: 42} 79 rt.newClient = func(_ string, _ *tls.Config, opts *roundTripperOpts, conf *quic.Config, _ dialFunc) (roundTripCloser, error) { 80 cl := NewMockRoundTripCloser(mockCtrl) 81 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) 82 Expect(opts.AdditionalSettings).To(HaveKeyWithValue(uint64(1337), uint64(42))) 83 return cl, nil 84 } 85 _, err = rt.RoundTrip(req) 86 Expect(err).To(MatchError(testErr)) 87 }) 88 89 It("uses the quic.Config, if provided", func() { 90 config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} 91 var receivedConfig *quic.Config 92 rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { 93 receivedConfig = config 94 return nil, errors.New("handshake error") 95 } 96 rt.QuicConfig = config 97 _, err := rt.RoundTrip(req) 98 Expect(err).To(MatchError("handshake error")) 99 Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) 100 }) 101 102 It("requires quic.Config.EnableDatagram if HTTP datagrams are enabled", func() { 103 rt.QuicConfig = &quic.Config{EnableDatagrams: false} 104 rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { 105 return nil, errors.New("handshake error") 106 } 107 rt.EnableDatagrams = true 108 _, err := rt.RoundTrip(req) 109 Expect(err).To(MatchError("HTTP Datagrams enabled, but QUIC Datagrams disabled")) 110 rt.QuicConfig.EnableDatagrams = true 111 _, err = rt.RoundTrip(req) 112 Expect(err).To(MatchError("handshake error")) 113 }) 114 115 It("uses the custom dialer, if provided", func() { 116 var dialed bool 117 dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 118 dialed = true 119 return nil, errors.New("handshake error") 120 } 121 rt.Dial = dialer 122 _, err := rt.RoundTrip(req) 123 Expect(err).To(MatchError("handshake error")) 124 Expect(dialed).To(BeTrue()) 125 }) 126 }) 127 128 Context("reusing clients", func() { 129 var req1, req2 *http.Request 130 131 BeforeEach(func() { 132 var err error 133 req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) 134 Expect(err).ToNot(HaveOccurred()) 135 req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) 136 Expect(err).ToNot(HaveOccurred()) 137 Expect(req1.URL).ToNot(Equal(req2.URL)) 138 }) 139 140 It("reuses existing clients", func() { 141 var count int 142 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 143 count++ 144 cl := NewMockRoundTripCloser(mockCtrl) 145 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 146 return &http.Response{Request: req}, nil 147 }).Times(2) 148 cl.EXPECT().HandshakeComplete().Return(true) 149 return cl, nil 150 } 151 rsp1, err := rt.RoundTrip(req1) 152 Expect(err).ToNot(HaveOccurred()) 153 Expect(rsp1.Request.URL).To(Equal(req1.URL)) 154 rsp2, err := rt.RoundTrip(req2) 155 Expect(err).ToNot(HaveOccurred()) 156 Expect(rsp2.Request.URL).To(Equal(req2.URL)) 157 Expect(count).To(Equal(1)) 158 }) 159 160 It("immediately removes a clients when a request errored", func() { 161 testErr := errors.New("test err") 162 163 var count int 164 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 165 count++ 166 cl := NewMockRoundTripCloser(mockCtrl) 167 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) 168 return cl, nil 169 } 170 _, err := rt.RoundTrip(req1) 171 Expect(err).To(MatchError(testErr)) 172 _, err = rt.RoundTrip(req2) 173 Expect(err).To(MatchError(testErr)) 174 Expect(count).To(Equal(2)) 175 }) 176 177 It("recreates a client when a request times out", func() { 178 var reqCount int 179 cl1 := NewMockRoundTripCloser(mockCtrl) 180 cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 181 reqCount++ 182 if reqCount == 1 { // the first request is successful... 183 Expect(req.URL).To(Equal(req1.URL)) 184 return &http.Response{Request: req}, nil 185 } 186 // ... after that, the connection timed out in the background 187 Expect(req.URL).To(Equal(req2.URL)) 188 return nil, &qerr.IdleTimeoutError{} 189 }).Times(2) 190 cl1.EXPECT().HandshakeComplete().Return(true) 191 cl2 := NewMockRoundTripCloser(mockCtrl) 192 cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 193 return &http.Response{Request: req}, nil 194 }) 195 196 var count int 197 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 198 count++ 199 if count == 1 { 200 return cl1, nil 201 } 202 return cl2, nil 203 } 204 rsp1, err := rt.RoundTrip(req1) 205 Expect(err).ToNot(HaveOccurred()) 206 Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr)) 207 rsp2, err := rt.RoundTrip(req2) 208 Expect(err).ToNot(HaveOccurred()) 209 Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr)) 210 }) 211 212 It("only issues a request once, even if a timeout error occurs", func() { 213 var count int 214 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 215 count++ 216 cl := NewMockRoundTripCloser(mockCtrl) 217 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) 218 return cl, nil 219 } 220 _, err := rt.RoundTrip(req1) 221 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 222 Expect(count).To(Equal(1)) 223 }) 224 225 It("handles a burst of requests", func() { 226 wait := make(chan struct{}) 227 reqs := make(chan struct{}, 2) 228 var count int 229 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 230 count++ 231 cl := NewMockRoundTripCloser(mockCtrl) 232 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 233 reqs <- struct{}{} 234 <-wait 235 return nil, &qerr.IdleTimeoutError{} 236 }).Times(2) 237 cl.EXPECT().HandshakeComplete() 238 return cl, nil 239 } 240 done := make(chan struct{}, 2) 241 go func() { 242 defer GinkgoRecover() 243 defer func() { done <- struct{}{} }() 244 _, err := rt.RoundTrip(req1) 245 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 246 }() 247 go func() { 248 defer GinkgoRecover() 249 defer func() { done <- struct{}{} }() 250 _, err := rt.RoundTrip(req2) 251 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 252 }() 253 // wait for both requests to be issued 254 Eventually(reqs).Should(Receive()) 255 Eventually(reqs).Should(Receive()) 256 close(wait) // now return the requests 257 Eventually(done).Should(Receive()) 258 Eventually(done).Should(Receive()) 259 Expect(count).To(Equal(1)) 260 }) 261 262 It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { 263 req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 264 Expect(err).ToNot(HaveOccurred()) 265 _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) 266 Expect(err).To(MatchError(ErrNoCachedConn)) 267 }) 268 }) 269 270 Context("validating request", func() { 271 It("rejects plain HTTP requests", func() { 272 req, err := http.NewRequest("GET", "http://www.example.org/", nil) 273 req.Body = &mockBody{} 274 Expect(err).ToNot(HaveOccurred()) 275 _, err = rt.RoundTrip(req) 276 Expect(err).To(MatchError("http3: unsupported protocol scheme: http")) 277 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 278 }) 279 280 It("rejects requests without a URL", func() { 281 req.URL = nil 282 req.Body = &mockBody{} 283 _, err := rt.RoundTrip(req) 284 Expect(err).To(MatchError("http3: nil Request.URL")) 285 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 286 }) 287 288 It("rejects request without a URL Host", func() { 289 req.URL.Host = "" 290 req.Body = &mockBody{} 291 _, err := rt.RoundTrip(req) 292 Expect(err).To(MatchError("http3: no Host in request URL")) 293 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 294 }) 295 296 It("doesn't try to close the body if the request doesn't have one", func() { 297 req.URL = nil 298 Expect(req.Body).To(BeNil()) 299 _, err := rt.RoundTrip(req) 300 Expect(err).To(MatchError("http3: nil Request.URL")) 301 }) 302 303 It("rejects requests without a header", func() { 304 req.Header = nil 305 req.Body = &mockBody{} 306 _, err := rt.RoundTrip(req) 307 Expect(err).To(MatchError("http3: nil Request.Header")) 308 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 309 }) 310 311 It("rejects requests with invalid header name fields", func() { 312 req.Header.Add("foobär", "value") 313 _, err := rt.RoundTrip(req) 314 Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) 315 }) 316 317 It("rejects requests with invalid header name values", func() { 318 req.Header.Add("foo", string([]byte{0x7})) 319 _, err := rt.RoundTrip(req) 320 Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) 321 }) 322 323 It("rejects requests with an invalid request method", func() { 324 req.Method = "foobär" 325 req.Body = &mockBody{} 326 _, err := rt.RoundTrip(req) 327 Expect(err).To(MatchError("http3: invalid method \"foobär\"")) 328 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 329 }) 330 }) 331 332 Context("closing", func() { 333 It("closes", func() { 334 rt.clients = make(map[string]*roundTripCloserWithCount) 335 cl := NewMockRoundTripCloser(mockCtrl) 336 cl.EXPECT().Close() 337 rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}} 338 err := rt.Close() 339 Expect(err).ToNot(HaveOccurred()) 340 Expect(len(rt.clients)).To(BeZero()) 341 }) 342 343 It("closes a RoundTripper that has never been used", func() { 344 Expect(len(rt.clients)).To(BeZero()) 345 err := rt.Close() 346 Expect(err).ToNot(HaveOccurred()) 347 Expect(len(rt.clients)).To(BeZero()) 348 }) 349 350 It("closes idle connections", func() { 351 Expect(len(rt.clients)).To(Equal(0)) 352 req1, err := http.NewRequest("GET", "https://site1.com", nil) 353 Expect(err).ToNot(HaveOccurred()) 354 req2, err := http.NewRequest("GET", "https://site2.com", nil) 355 Expect(err).ToNot(HaveOccurred()) 356 Expect(req1.Host).ToNot(Equal(req2.Host)) 357 ctx1, cancel1 := context.WithCancel(context.Background()) 358 ctx2, cancel2 := context.WithCancel(context.Background()) 359 req1 = req1.WithContext(ctx1) 360 req2 = req2.WithContext(ctx2) 361 roundTripCalled := make(chan struct{}) 362 reqFinished := make(chan struct{}) 363 rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { 364 cl := NewMockRoundTripCloser(mockCtrl) 365 cl.EXPECT().Close() 366 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) { 367 roundTripCalled <- struct{}{} 368 <-r.Context().Done() 369 return nil, nil 370 }) 371 return cl, nil 372 } 373 go func() { 374 rt.RoundTrip(req1) 375 reqFinished <- struct{}{} 376 }() 377 go func() { 378 rt.RoundTrip(req2) 379 reqFinished <- struct{}{} 380 }() 381 <-roundTripCalled 382 <-roundTripCalled 383 // Both two requests are started. 384 Expect(len(rt.clients)).To(Equal(2)) 385 cancel1() 386 <-reqFinished 387 // req1 is finished 388 rt.CloseIdleConnections() 389 Expect(len(rt.clients)).To(Equal(1)) 390 cancel2() 391 <-reqFinished 392 // all requests are finished 393 rt.CloseIdleConnections() 394 Expect(len(rt.clients)).To(Equal(0)) 395 }) 396 }) 397 })