github.com/MerlinKodo/quic-go@v0.39.2/http3/roundtrip_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "io" 9 "net/http" 10 "sync/atomic" 11 "time" 12 13 "github.com/MerlinKodo/quic-go" 14 "github.com/MerlinKodo/quic-go/internal/qerr" 15 16 . "github.com/onsi/ginkgo/v2" 17 . "github.com/onsi/gomega" 18 "go.uber.org/mock/gomock" 19 ) 20 21 type mockBody struct { 22 reader bytes.Reader 23 readErr error 24 closeErr error 25 closed bool 26 } 27 28 // make sure the mockBody can be used as a http.Request.Body 29 var _ io.ReadCloser = &mockBody{} 30 31 func (m *mockBody) Read(p []byte) (int, error) { 32 if m.readErr != nil { 33 return 0, m.readErr 34 } 35 return m.reader.Read(p) 36 } 37 38 func (m *mockBody) SetData(data []byte) { 39 m.reader = *bytes.NewReader(data) 40 } 41 42 func (m *mockBody) Close() error { 43 m.closed = true 44 return m.closeErr 45 } 46 47 var _ = Describe("RoundTripper", func() { 48 var ( 49 rt *RoundTripper 50 req *http.Request 51 ) 52 53 BeforeEach(func() { 54 rt = &RoundTripper{} 55 var err error 56 req, err = http.NewRequest("GET", "https://www.example.org/file1.html", nil) 57 Expect(err).ToNot(HaveOccurred()) 58 }) 59 60 Context("dialing hosts", func() { 61 It("creates new clients", func() { 62 testErr := errors.New("test err") 63 req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 64 Expect(err).ToNot(HaveOccurred()) 65 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 66 cl := NewMockRoundTripCloser(mockCtrl) 67 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) 68 return cl, nil 69 } 70 _, err = rt.RoundTrip(req) 71 Expect(err).To(MatchError(testErr)) 72 }) 73 74 It("uses the quic.Config, if provided", func() { 75 config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} 76 var receivedConfig *quic.Config 77 rt.Dial = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlyConnection, error) { 78 receivedConfig = config 79 return nil, errors.New("handshake error") 80 } 81 rt.QuicConfig = config 82 _, err := rt.RoundTrip(req) 83 Expect(err).To(MatchError("handshake error")) 84 Expect(receivedConfig.HandshakeIdleTimeout).To(Equal(config.HandshakeIdleTimeout)) 85 }) 86 87 It("uses the custom dialer, if provided", func() { 88 var dialed bool 89 dialer := func(_ context.Context, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 90 dialed = true 91 return nil, errors.New("handshake error") 92 } 93 rt.Dial = dialer 94 _, err := rt.RoundTrip(req) 95 Expect(err).To(MatchError("handshake error")) 96 Expect(dialed).To(BeTrue()) 97 }) 98 }) 99 100 Context("reusing clients", func() { 101 var req1, req2 *http.Request 102 103 BeforeEach(func() { 104 var err error 105 req1, err = http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil) 106 Expect(err).ToNot(HaveOccurred()) 107 req2, err = http.NewRequest("GET", "https://quic.clemente.io/file2.html", nil) 108 Expect(err).ToNot(HaveOccurred()) 109 Expect(req1.URL).ToNot(Equal(req2.URL)) 110 }) 111 112 It("reuses existing clients", func() { 113 var count int 114 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 115 count++ 116 cl := NewMockRoundTripCloser(mockCtrl) 117 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 118 return &http.Response{Request: req}, nil 119 }).Times(2) 120 cl.EXPECT().HandshakeComplete().Return(true) 121 return cl, nil 122 } 123 rsp1, err := rt.RoundTrip(req1) 124 Expect(err).ToNot(HaveOccurred()) 125 Expect(rsp1.Request.URL).To(Equal(req1.URL)) 126 rsp2, err := rt.RoundTrip(req2) 127 Expect(err).ToNot(HaveOccurred()) 128 Expect(rsp2.Request.URL).To(Equal(req2.URL)) 129 Expect(count).To(Equal(1)) 130 }) 131 132 It("immediately removes a clients when a request errored", func() { 133 testErr := errors.New("test err") 134 135 var count int 136 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 137 count++ 138 cl := NewMockRoundTripCloser(mockCtrl) 139 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, testErr) 140 return cl, nil 141 } 142 _, err := rt.RoundTrip(req1) 143 Expect(err).To(MatchError(testErr)) 144 _, err = rt.RoundTrip(req2) 145 Expect(err).To(MatchError(testErr)) 146 Expect(count).To(Equal(2)) 147 }) 148 149 It("recreates a client when a request times out", func() { 150 var reqCount int 151 cl1 := NewMockRoundTripCloser(mockCtrl) 152 cl1.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 153 reqCount++ 154 if reqCount == 1 { // the first request is successful... 155 Expect(req.URL).To(Equal(req1.URL)) 156 return &http.Response{Request: req}, nil 157 } 158 // ... after that, the connection timed out in the background 159 Expect(req.URL).To(Equal(req2.URL)) 160 return nil, &qerr.IdleTimeoutError{} 161 }).Times(2) 162 cl1.EXPECT().HandshakeComplete().Return(true) 163 cl2 := NewMockRoundTripCloser(mockCtrl) 164 cl2.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 165 return &http.Response{Request: req}, nil 166 }) 167 168 var count int 169 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 170 count++ 171 if count == 1 { 172 return cl1, nil 173 } 174 return cl2, nil 175 } 176 rsp1, err := rt.RoundTrip(req1) 177 Expect(err).ToNot(HaveOccurred()) 178 Expect(rsp1.Request.RemoteAddr).To(Equal(req1.RemoteAddr)) 179 rsp2, err := rt.RoundTrip(req2) 180 Expect(err).ToNot(HaveOccurred()) 181 Expect(rsp2.Request.RemoteAddr).To(Equal(req2.RemoteAddr)) 182 }) 183 184 It("only issues a request once, even if a timeout error occurs", func() { 185 var count int 186 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 187 count++ 188 cl := NewMockRoundTripCloser(mockCtrl) 189 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).Return(nil, &qerr.IdleTimeoutError{}) 190 return cl, nil 191 } 192 _, err := rt.RoundTrip(req1) 193 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 194 Expect(count).To(Equal(1)) 195 }) 196 197 It("handles a burst of requests", func() { 198 wait := make(chan struct{}) 199 reqs := make(chan struct{}, 2) 200 var count int 201 rt.newClient = func(string, *tls.Config, *roundTripperOpts, *quic.Config, dialFunc) (roundTripCloser, error) { 202 count++ 203 cl := NewMockRoundTripCloser(mockCtrl) 204 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(req *http.Request, _ RoundTripOpt) (*http.Response, error) { 205 reqs <- struct{}{} 206 <-wait 207 return nil, &qerr.IdleTimeoutError{} 208 }).Times(2) 209 cl.EXPECT().HandshakeComplete() 210 return cl, nil 211 } 212 done := make(chan struct{}, 2) 213 go func() { 214 defer GinkgoRecover() 215 defer func() { done <- struct{}{} }() 216 _, err := rt.RoundTrip(req1) 217 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 218 }() 219 go func() { 220 defer GinkgoRecover() 221 defer func() { done <- struct{}{} }() 222 _, err := rt.RoundTrip(req2) 223 Expect(err).To(MatchError(&qerr.IdleTimeoutError{})) 224 }() 225 // wait for both requests to be issued 226 Eventually(reqs).Should(Receive()) 227 Eventually(reqs).Should(Receive()) 228 close(wait) // now return the requests 229 Eventually(done).Should(Receive()) 230 Eventually(done).Should(Receive()) 231 Expect(count).To(Equal(1)) 232 }) 233 234 It("doesn't create new clients if RoundTripOpt.OnlyCachedConn is set", func() { 235 req, err := http.NewRequest("GET", "https://quic.clemente.io/foobar.html", nil) 236 Expect(err).ToNot(HaveOccurred()) 237 _, err = rt.RoundTripOpt(req, RoundTripOpt{OnlyCachedConn: true}) 238 Expect(err).To(MatchError(ErrNoCachedConn)) 239 }) 240 }) 241 242 Context("validating request", func() { 243 It("rejects plain HTTP requests", func() { 244 req, err := http.NewRequest("GET", "http://www.example.org/", nil) 245 req.Body = &mockBody{} 246 Expect(err).ToNot(HaveOccurred()) 247 _, err = rt.RoundTrip(req) 248 Expect(err).To(MatchError("http3: unsupported protocol scheme: http")) 249 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 250 }) 251 252 It("rejects requests without a URL", func() { 253 req.URL = nil 254 req.Body = &mockBody{} 255 _, err := rt.RoundTrip(req) 256 Expect(err).To(MatchError("http3: nil Request.URL")) 257 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 258 }) 259 260 It("rejects request without a URL Host", func() { 261 req.URL.Host = "" 262 req.Body = &mockBody{} 263 _, err := rt.RoundTrip(req) 264 Expect(err).To(MatchError("http3: no Host in request URL")) 265 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 266 }) 267 268 It("doesn't try to close the body if the request doesn't have one", func() { 269 req.URL = nil 270 Expect(req.Body).To(BeNil()) 271 _, err := rt.RoundTrip(req) 272 Expect(err).To(MatchError("http3: nil Request.URL")) 273 }) 274 275 It("rejects requests without a header", func() { 276 req.Header = nil 277 req.Body = &mockBody{} 278 _, err := rt.RoundTrip(req) 279 Expect(err).To(MatchError("http3: nil Request.Header")) 280 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 281 }) 282 283 It("rejects requests with invalid header name fields", func() { 284 req.Header.Add("foobär", "value") 285 _, err := rt.RoundTrip(req) 286 Expect(err).To(MatchError("http3: invalid http header field name \"foobär\"")) 287 }) 288 289 It("rejects requests with invalid header name values", func() { 290 req.Header.Add("foo", string([]byte{0x7})) 291 _, err := rt.RoundTrip(req) 292 Expect(err.Error()).To(ContainSubstring("http3: invalid http header field value")) 293 }) 294 295 It("rejects requests with an invalid request method", func() { 296 req.Method = "foobär" 297 req.Body = &mockBody{} 298 _, err := rt.RoundTrip(req) 299 Expect(err).To(MatchError("http3: invalid method \"foobär\"")) 300 Expect(req.Body.(*mockBody).closed).To(BeTrue()) 301 }) 302 }) 303 304 Context("closing", func() { 305 It("closes", func() { 306 rt.clients = make(map[string]*roundTripCloserWithCount) 307 cl := NewMockRoundTripCloser(mockCtrl) 308 cl.EXPECT().Close() 309 rt.clients["foo.bar"] = &roundTripCloserWithCount{cl, atomic.Int64{}} 310 err := rt.Close() 311 Expect(err).ToNot(HaveOccurred()) 312 Expect(len(rt.clients)).To(BeZero()) 313 }) 314 315 It("closes a RoundTripper that has never been used", func() { 316 Expect(len(rt.clients)).To(BeZero()) 317 err := rt.Close() 318 Expect(err).ToNot(HaveOccurred()) 319 Expect(len(rt.clients)).To(BeZero()) 320 }) 321 322 It("closes idle connections", func() { 323 Expect(len(rt.clients)).To(Equal(0)) 324 req1, err := http.NewRequest("GET", "https://site1.com", nil) 325 Expect(err).ToNot(HaveOccurred()) 326 req2, err := http.NewRequest("GET", "https://site2.com", nil) 327 Expect(err).ToNot(HaveOccurred()) 328 Expect(req1.Host).ToNot(Equal(req2.Host)) 329 ctx1, cancel1 := context.WithCancel(context.Background()) 330 ctx2, cancel2 := context.WithCancel(context.Background()) 331 req1 = req1.WithContext(ctx1) 332 req2 = req2.WithContext(ctx2) 333 roundTripCalled := make(chan struct{}) 334 reqFinished := make(chan struct{}) 335 rt.newClient = func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { 336 cl := NewMockRoundTripCloser(mockCtrl) 337 cl.EXPECT().Close() 338 cl.EXPECT().RoundTripOpt(gomock.Any(), gomock.Any()).DoAndReturn(func(r *http.Request, _ RoundTripOpt) (*http.Response, error) { 339 roundTripCalled <- struct{}{} 340 <-r.Context().Done() 341 return nil, nil 342 }) 343 return cl, nil 344 } 345 go func() { 346 rt.RoundTrip(req1) 347 reqFinished <- struct{}{} 348 }() 349 go func() { 350 rt.RoundTrip(req2) 351 reqFinished <- struct{}{} 352 }() 353 <-roundTripCalled 354 <-roundTripCalled 355 // Both two requests are started. 356 Expect(len(rt.clients)).To(Equal(2)) 357 cancel1() 358 <-reqFinished 359 // req1 is finished 360 rt.CloseIdleConnections() 361 Expect(len(rt.clients)).To(Equal(1)) 362 cancel2() 363 <-reqFinished 364 // all requests are finished 365 rt.CloseIdleConnections() 366 Expect(len(rt.clients)).To(Equal(0)) 367 }) 368 }) 369 })