github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "crypto/tls" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/danielpfeifer02/quic-go-prio-packs" 16 mockquic "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/quic" 17 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 18 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr" 19 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 20 "github.com/danielpfeifer02/quic-go-prio-packs/quicvarint" 21 22 "github.com/quic-go/qpack" 23 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 "go.uber.org/mock/gomock" 27 ) 28 29 var _ = Describe("Client", func() { 30 var ( 31 cl *client 32 req *http.Request 33 origDialAddr = dialAddr 34 handshakeChan <-chan struct{} // a closed chan 35 ) 36 37 BeforeEach(func() { 38 origDialAddr = dialAddr 39 hostname := "quic.clemente.io:1337" 40 c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) 41 Expect(err).ToNot(HaveOccurred()) 42 cl = c.(*client) 43 Expect(cl.hostname).To(Equal(hostname)) 44 45 req, err = http.NewRequest("GET", "https://localhost:1337", nil) 46 Expect(err).ToNot(HaveOccurred()) 47 48 ch := make(chan struct{}) 49 close(ch) 50 handshakeChan = ch 51 }) 52 53 AfterEach(func() { 54 dialAddr = origDialAddr 55 }) 56 57 It("rejects quic.Configs that allow multiple QUIC versions", func() { 58 qconf := &quic.Config{ 59 Versions: []quic.Version{protocol.Version2, protocol.Version1}, 60 } 61 _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) 62 Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) 63 }) 64 65 It("uses the default QUIC and TLS config if none is give", func() { 66 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 67 Expect(err).ToNot(HaveOccurred()) 68 var dialAddrCalled bool 69 dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 70 Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) 71 Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) 72 Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) 73 dialAddrCalled = true 74 return nil, errors.New("test done") 75 } 76 client.RoundTripOpt(req, RoundTripOpt{}) 77 Expect(dialAddrCalled).To(BeTrue()) 78 }) 79 80 It("adds the port to the hostname, if none is given", func() { 81 client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) 82 Expect(err).ToNot(HaveOccurred()) 83 var dialAddrCalled bool 84 dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { 85 Expect(hostname).To(Equal("quic.clemente.io:443")) 86 dialAddrCalled = true 87 return nil, errors.New("test done") 88 } 89 req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) 90 Expect(err).ToNot(HaveOccurred()) 91 client.RoundTripOpt(req, RoundTripOpt{}) 92 Expect(dialAddrCalled).To(BeTrue()) 93 }) 94 95 It("sets the ServerName in the tls.Config, if not set", func() { 96 const host = "foo.bar" 97 dialCalled := false 98 dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 99 Expect(tlsCfg.ServerName).To(Equal(host)) 100 dialCalled = true 101 return nil, errors.New("test done") 102 } 103 client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) 104 Expect(err).ToNot(HaveOccurred()) 105 req, err := http.NewRequest("GET", "https://foo.bar", nil) 106 Expect(err).ToNot(HaveOccurred()) 107 client.RoundTripOpt(req, RoundTripOpt{}) 108 Expect(dialCalled).To(BeTrue()) 109 }) 110 111 It("uses the TLS config and QUIC config", func() { 112 tlsConf := &tls.Config{ 113 ServerName: "foo.bar", 114 NextProtos: []string{"proto foo", "proto bar"}, 115 } 116 quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} 117 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) 118 Expect(err).ToNot(HaveOccurred()) 119 var dialAddrCalled bool 120 dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 121 Expect(host).To(Equal("localhost:1337")) 122 Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) 123 Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) 124 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 125 dialAddrCalled = true 126 return nil, errors.New("test done") 127 } 128 client.RoundTripOpt(req, RoundTripOpt{}) 129 Expect(dialAddrCalled).To(BeTrue()) 130 // make sure the original tls.Config was not modified 131 Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) 132 }) 133 134 It("uses the custom dialer, if provided", func() { 135 testErr := errors.New("test done") 136 tlsConf := &tls.Config{ServerName: "foo.bar"} 137 quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} 138 ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 139 defer cancel() 140 var dialerCalled bool 141 dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 142 Expect(ctxP).To(Equal(ctx)) 143 Expect(address).To(Equal("localhost:1337")) 144 Expect(tlsConfP.ServerName).To(Equal("foo.bar")) 145 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 146 dialerCalled = true 147 return nil, testErr 148 } 149 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) 150 Expect(err).ToNot(HaveOccurred()) 151 _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) 152 Expect(err).To(MatchError(testErr)) 153 Expect(dialerCalled).To(BeTrue()) 154 }) 155 156 It("enables HTTP/3 Datagrams", func() { 157 testErr := errors.New("handshake error") 158 client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) 159 Expect(err).ToNot(HaveOccurred()) 160 dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 161 Expect(quicConf.EnableDatagrams).To(BeTrue()) 162 return nil, testErr 163 } 164 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 165 Expect(err).To(MatchError(testErr)) 166 }) 167 168 It("errors when dialing fails", func() { 169 testErr := errors.New("handshake error") 170 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 171 Expect(err).ToNot(HaveOccurred()) 172 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 173 return nil, testErr 174 } 175 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 176 Expect(err).To(MatchError(testErr)) 177 }) 178 179 It("closes correctly if connection was not created", func() { 180 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 181 Expect(err).ToNot(HaveOccurred()) 182 Expect(client.Close()).To(Succeed()) 183 }) 184 185 Context("validating the address", func() { 186 It("refuses to do requests for the wrong host", func() { 187 req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 188 Expect(err).ToNot(HaveOccurred()) 189 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 190 Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 191 }) 192 193 It("allows requests using a different scheme", func() { 194 testErr := errors.New("handshake error") 195 req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) 196 Expect(err).ToNot(HaveOccurred()) 197 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 198 return nil, testErr 199 } 200 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 201 Expect(err).To(MatchError(testErr)) 202 }) 203 }) 204 205 Context("hijacking bidirectional streams", func() { 206 var ( 207 request *http.Request 208 conn *mockquic.MockEarlyConnection 209 settingsFrameWritten chan struct{} 210 ) 211 testDone := make(chan struct{}) 212 213 BeforeEach(func() { 214 testDone = make(chan struct{}) 215 settingsFrameWritten = make(chan struct{}) 216 controlStr := mockquic.NewMockStream(mockCtrl) 217 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 218 defer GinkgoRecover() 219 close(settingsFrameWritten) 220 return len(b), nil 221 }) 222 conn = mockquic.NewMockEarlyConnection(mockCtrl) 223 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 224 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 225 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 226 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 227 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 228 return conn, nil 229 } 230 var err error 231 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 232 Expect(err).ToNot(HaveOccurred()) 233 }) 234 235 AfterEach(func() { 236 testDone <- struct{}{} 237 Eventually(settingsFrameWritten).Should(BeClosed()) 238 }) 239 240 It("hijacks a bidirectional stream of unknown frame type", func() { 241 frameTypeChan := make(chan FrameType, 1) 242 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 243 Expect(e).ToNot(HaveOccurred()) 244 frameTypeChan <- ft 245 return true, nil 246 } 247 248 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 249 unknownStr := mockquic.NewMockStream(mockCtrl) 250 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 251 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 252 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 253 <-testDone 254 return nil, errors.New("test done") 255 }) 256 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 257 Expect(err).To(MatchError("done")) 258 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 259 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 260 }) 261 262 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 263 frameTypeChan := make(chan FrameType, 1) 264 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 265 Expect(e).ToNot(HaveOccurred()) 266 frameTypeChan <- ft 267 return false, nil 268 } 269 270 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 271 unknownStr := mockquic.NewMockStream(mockCtrl) 272 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 273 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 274 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 275 <-testDone 276 return nil, errors.New("test done") 277 }) 278 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 279 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 280 Expect(err).To(MatchError("done")) 281 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 282 }) 283 284 It("closes the connection when hijacker returned error", func() { 285 frameTypeChan := make(chan FrameType, 1) 286 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 287 Expect(e).ToNot(HaveOccurred()) 288 frameTypeChan <- ft 289 return false, errors.New("error in hijacker") 290 } 291 292 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 293 unknownStr := mockquic.NewMockStream(mockCtrl) 294 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 295 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 296 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 297 <-testDone 298 return nil, errors.New("test done") 299 }) 300 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 301 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 302 Expect(err).To(MatchError("done")) 303 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 304 }) 305 306 It("handles errors that occur when reading the frame type", func() { 307 testErr := errors.New("test error") 308 unknownStr := mockquic.NewMockStream(mockCtrl) 309 done := make(chan struct{}) 310 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { 311 defer close(done) 312 Expect(e).To(MatchError(testErr)) 313 Expect(ft).To(BeZero()) 314 Expect(str).To(Equal(unknownStr)) 315 return false, nil 316 } 317 318 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 319 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 320 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 321 <-testDone 322 return nil, errors.New("test done") 323 }) 324 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 325 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 326 Expect(err).To(MatchError("done")) 327 Eventually(done).Should(BeClosed()) 328 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 329 }) 330 }) 331 332 Context("hijacking unidirectional streams", func() { 333 var ( 334 req *http.Request 335 conn *mockquic.MockEarlyConnection 336 settingsFrameWritten chan struct{} 337 ) 338 testDone := make(chan struct{}) 339 340 BeforeEach(func() { 341 testDone = make(chan struct{}) 342 settingsFrameWritten = make(chan struct{}) 343 controlStr := mockquic.NewMockStream(mockCtrl) 344 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 345 defer GinkgoRecover() 346 close(settingsFrameWritten) 347 return len(b), nil 348 }) 349 conn = mockquic.NewMockEarlyConnection(mockCtrl) 350 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 351 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 352 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 353 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 354 return conn, nil 355 } 356 var err error 357 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 358 Expect(err).ToNot(HaveOccurred()) 359 }) 360 361 AfterEach(func() { 362 testDone <- struct{}{} 363 Eventually(settingsFrameWritten).Should(BeClosed()) 364 }) 365 366 It("hijacks an unidirectional stream of unknown stream type", func() { 367 streamTypeChan := make(chan StreamType, 1) 368 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 369 Expect(err).ToNot(HaveOccurred()) 370 streamTypeChan <- st 371 return true 372 } 373 374 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 375 unknownStr := mockquic.NewMockStream(mockCtrl) 376 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 377 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 378 return unknownStr, nil 379 }) 380 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 381 <-testDone 382 return nil, errors.New("test done") 383 }) 384 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 385 Expect(err).To(MatchError("done")) 386 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 387 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 388 }) 389 390 It("handles errors that occur when reading the stream type", func() { 391 testErr := errors.New("test error") 392 done := make(chan struct{}) 393 unknownStr := mockquic.NewMockStream(mockCtrl) 394 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 395 defer close(done) 396 Expect(st).To(BeZero()) 397 Expect(str).To(Equal(unknownStr)) 398 Expect(err).To(MatchError(testErr)) 399 return true 400 } 401 402 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 403 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 404 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 405 <-testDone 406 return nil, errors.New("test done") 407 }) 408 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 409 Expect(err).To(MatchError("done")) 410 Eventually(done).Should(BeClosed()) 411 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 412 }) 413 414 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 415 streamTypeChan := make(chan StreamType, 1) 416 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 417 Expect(err).ToNot(HaveOccurred()) 418 streamTypeChan <- st 419 return false 420 } 421 422 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 423 unknownStr := mockquic.NewMockStream(mockCtrl) 424 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 425 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 426 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 427 return unknownStr, nil 428 }) 429 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 430 <-testDone 431 return nil, errors.New("test done") 432 }) 433 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 434 Expect(err).To(MatchError("done")) 435 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 436 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 437 }) 438 }) 439 440 Context("control stream handling", func() { 441 var ( 442 req *http.Request 443 conn *mockquic.MockEarlyConnection 444 settingsFrameWritten chan struct{} 445 ) 446 testDone := make(chan struct{}, 1) 447 448 BeforeEach(func() { 449 settingsFrameWritten = make(chan struct{}) 450 controlStr := mockquic.NewMockStream(mockCtrl) 451 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 452 defer GinkgoRecover() 453 close(settingsFrameWritten) 454 return len(b), nil 455 }) 456 conn = mockquic.NewMockEarlyConnection(mockCtrl) 457 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 458 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 459 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 460 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 461 return conn, nil 462 } 463 var err error 464 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 465 Expect(err).ToNot(HaveOccurred()) 466 }) 467 468 AfterEach(func() { 469 testDone <- struct{}{} 470 Eventually(settingsFrameWritten).Should(BeClosed()) 471 }) 472 473 It("parses the SETTINGS frame", func() { 474 b := quicvarint.Append(nil, streamTypeControlStream) 475 b = (&settingsFrame{}).Append(b) 476 r := bytes.NewReader(b) 477 controlStr := mockquic.NewMockStream(mockCtrl) 478 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 479 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 480 return controlStr, nil 481 }) 482 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 483 <-testDone 484 return nil, errors.New("test done") 485 }) 486 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 487 Expect(err).To(MatchError("done")) 488 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 489 }) 490 491 It("rejects duplicate control streams", func() { 492 b := quicvarint.Append(nil, streamTypeControlStream) 493 b = (&settingsFrame{}).Append(b) 494 r1 := bytes.NewReader(b) 495 controlStr1 := mockquic.NewMockStream(mockCtrl) 496 controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() 497 r2 := bytes.NewReader(b) 498 controlStr2 := mockquic.NewMockStream(mockCtrl) 499 controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() 500 done := make(chan struct{}) 501 conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { 502 close(done) 503 return nil 504 }) 505 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 506 return controlStr1, nil 507 }) 508 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 509 return controlStr2, nil 510 }) 511 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 512 <-done 513 return nil, errors.New("test done") 514 }) 515 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 516 Expect(err).To(HaveOccurred()) 517 Eventually(done).Should(BeClosed()) 518 }) 519 520 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 521 streamType := t 522 name := "encoder" 523 if streamType == streamTypeQPACKDecoderStream { 524 name = "decoder" 525 } 526 527 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 528 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 529 str := mockquic.NewMockStream(mockCtrl) 530 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 531 532 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 533 return str, nil 534 }) 535 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 536 <-testDone 537 return nil, errors.New("test done") 538 }) 539 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 540 Expect(err).To(MatchError("done")) 541 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 542 }) 543 } 544 545 It("resets streams Other than the control stream and the QPACK streams", func() { 546 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) 547 str := mockquic.NewMockStream(mockCtrl) 548 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 549 done := make(chan struct{}) 550 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) 551 552 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 553 return str, nil 554 }) 555 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 556 <-testDone 557 return nil, errors.New("test done") 558 }) 559 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 560 Expect(err).To(MatchError("done")) 561 Eventually(done).Should(BeClosed()) 562 }) 563 564 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 565 b := quicvarint.Append(nil, streamTypeControlStream) 566 b = (&dataFrame{}).Append(b) 567 r := bytes.NewReader(b) 568 controlStr := mockquic.NewMockStream(mockCtrl) 569 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 570 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 571 return controlStr, nil 572 }) 573 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 574 <-testDone 575 return nil, errors.New("test done") 576 }) 577 done := make(chan struct{}) 578 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 579 close(done) 580 return nil 581 }) 582 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 583 Expect(err).To(MatchError("done")) 584 Eventually(done).Should(BeClosed()) 585 }) 586 587 It("errors when parsing the frame on the control stream fails", func() { 588 b := quicvarint.Append(nil, streamTypeControlStream) 589 b = (&settingsFrame{}).Append(b) 590 r := bytes.NewReader(b[:len(b)-1]) 591 controlStr := mockquic.NewMockStream(mockCtrl) 592 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 593 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 594 return controlStr, nil 595 }) 596 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 597 <-testDone 598 return nil, errors.New("test done") 599 }) 600 done := make(chan struct{}) 601 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error { 602 close(done) 603 return nil 604 }) 605 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 606 Expect(err).To(MatchError("done")) 607 Eventually(done).Should(BeClosed()) 608 }) 609 610 It("errors when parsing the server opens a push stream", func() { 611 buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) 612 controlStr := mockquic.NewMockStream(mockCtrl) 613 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 614 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 615 return controlStr, nil 616 }) 617 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 618 <-testDone 619 return nil, errors.New("test done") 620 }) 621 done := make(chan struct{}) 622 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 623 close(done) 624 return nil 625 }) 626 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 627 Expect(err).To(MatchError("done")) 628 Eventually(done).Should(BeClosed()) 629 }) 630 631 It("errors when the server advertises datagram support (and we enabled support for it)", func() { 632 cl.opts.EnableDatagram = true 633 b := quicvarint.Append(nil, streamTypeControlStream) 634 b = (&settingsFrame{Datagram: true}).Append(b) 635 r := bytes.NewReader(b) 636 controlStr := mockquic.NewMockStream(mockCtrl) 637 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 638 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 639 return controlStr, nil 640 }) 641 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 642 <-testDone 643 return nil, errors.New("test done") 644 }) 645 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 646 done := make(chan struct{}) 647 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { 648 close(done) 649 return nil 650 }) 651 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 652 Expect(err).To(MatchError("done")) 653 Eventually(done).Should(BeClosed()) 654 }) 655 }) 656 657 Context("Doing requests", func() { 658 var ( 659 req *http.Request 660 str *mockquic.MockStream 661 conn *mockquic.MockEarlyConnection 662 settingsFrameWritten chan struct{} 663 ) 664 testDone := make(chan struct{}) 665 666 decodeHeader := func(str io.Reader) map[string]string { 667 fields := make(map[string]string) 668 decoder := qpack.NewDecoder(nil) 669 670 frame, err := parseNextFrame(str, nil) 671 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 672 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 673 headersFrame := frame.(*headersFrame) 674 data := make([]byte, headersFrame.Length) 675 _, err = io.ReadFull(str, data) 676 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 677 hfs, err := decoder.DecodeFull(data) 678 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 679 for _, p := range hfs { 680 fields[p.Name] = p.Value 681 } 682 return fields 683 } 684 685 getResponse := func(status int) []byte { 686 buf := &bytes.Buffer{} 687 rstr := mockquic.NewMockStream(mockCtrl) 688 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 689 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 690 rw.WriteHeader(status) 691 rw.Flush() 692 return buf.Bytes() 693 } 694 695 BeforeEach(func() { 696 settingsFrameWritten = make(chan struct{}) 697 controlStr := mockquic.NewMockStream(mockCtrl) 698 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 699 defer GinkgoRecover() 700 r := bytes.NewReader(b) 701 streamType, err := quicvarint.Read(r) 702 Expect(err).ToNot(HaveOccurred()) 703 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 704 close(settingsFrameWritten) 705 return len(b), nil 706 }) // SETTINGS frame 707 str = mockquic.NewMockStream(mockCtrl) 708 conn = mockquic.NewMockEarlyConnection(mockCtrl) 709 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 710 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 711 <-testDone 712 return nil, errors.New("test done") 713 }) 714 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 715 return conn, nil 716 } 717 var err error 718 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 719 Expect(err).ToNot(HaveOccurred()) 720 }) 721 722 AfterEach(func() { 723 testDone <- struct{}{} 724 Eventually(settingsFrameWritten).Should(BeClosed()) 725 }) 726 727 It("errors if it can't open a stream", func() { 728 testErr := errors.New("stream open error") 729 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 730 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 731 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 732 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 733 Expect(err).To(MatchError(testErr)) 734 }) 735 736 It("performs a 0-RTT request", func() { 737 testErr := errors.New("stream open error") 738 req.Method = MethodGet0RTT 739 // don't EXPECT any calls to HandshakeComplete() 740 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 741 buf := &bytes.Buffer{} 742 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 743 str.EXPECT().Close() 744 str.EXPECT().CancelWrite(gomock.Any()) 745 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 746 return 0, testErr 747 }) 748 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 749 Expect(err).To(MatchError(testErr)) 750 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) 751 }) 752 753 It("returns a response", func() { 754 rspBuf := bytes.NewBuffer(getResponse(418)) 755 gomock.InOrder( 756 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 757 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 758 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 759 ) 760 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 761 str.EXPECT().Close() 762 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 763 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 764 Expect(err).ToNot(HaveOccurred()) 765 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 766 Expect(rsp.ProtoMajor).To(Equal(3)) 767 Expect(rsp.StatusCode).To(Equal(418)) 768 Expect(rsp.Request).ToNot(BeNil()) 769 }) 770 771 It("doesn't close the request stream, with DontCloseRequestStream set", func() { 772 rspBuf := bytes.NewBuffer(getResponse(418)) 773 gomock.InOrder( 774 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 775 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 776 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 777 ) 778 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 779 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 780 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) 781 Expect(err).ToNot(HaveOccurred()) 782 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 783 Expect(rsp.ProtoMajor).To(Equal(3)) 784 Expect(rsp.StatusCode).To(Equal(418)) 785 }) 786 787 Context("requests containing a Body", func() { 788 var strBuf *bytes.Buffer 789 790 BeforeEach(func() { 791 strBuf = &bytes.Buffer{} 792 gomock.InOrder( 793 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 794 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 795 ) 796 body := &mockBody{} 797 body.SetData([]byte("request body")) 798 var err error 799 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 800 Expect(err).ToNot(HaveOccurred()) 801 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 802 }) 803 804 It("sends a request", func() { 805 done := make(chan struct{}) 806 gomock.InOrder( 807 str.EXPECT().Close().Do(func() error { close(done); return nil }), 808 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors 809 ) 810 // the response body is sent asynchronously, while already reading the response 811 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 812 <-done 813 return 0, errors.New("test done") 814 }) 815 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 816 Expect(err).To(MatchError("test done")) 817 hfs := decodeHeader(strBuf) 818 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 819 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 820 }) 821 822 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 823 req.ContentLength = 7 824 var once sync.Once 825 done := make(chan struct{}) 826 gomock.InOrder( 827 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 828 once.Do(func() { 829 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 830 close(done) 831 }) 832 }).AnyTimes(), 833 str.EXPECT().Close().MaxTimes(1), 834 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 835 ) 836 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 837 <-done 838 return 0, errors.New("done") 839 }) 840 cl.RoundTripOpt(req, RoundTripOpt{}) 841 Expect(strBuf.String()).To(ContainSubstring("request")) 842 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 843 }) 844 845 It("returns the error that occurred when reading the body", func() { 846 req.Body.(*mockBody).readErr = errors.New("testErr") 847 done := make(chan struct{}) 848 gomock.InOrder( 849 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 850 close(done) 851 }), 852 str.EXPECT().CancelWrite(gomock.Any()), 853 ) 854 855 // the response body is sent asynchronously, while already reading the response 856 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 857 <-done 858 return 0, errors.New("test done") 859 }) 860 closed := make(chan struct{}) 861 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 862 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 863 Expect(err).To(MatchError("test done")) 864 Eventually(closed).Should(BeClosed()) 865 }) 866 867 It("closes the connection when the first frame is not a HEADERS frame", func() { 868 b := (&dataFrame{Length: 0x42}).Append(nil) 869 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 870 closed := make(chan struct{}) 871 r := bytes.NewReader(b) 872 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 873 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 874 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 875 Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) 876 Eventually(closed).Should(BeClosed()) 877 }) 878 879 It("cancels the stream when parsing the headers fails", func() { 880 headerBuf := &bytes.Buffer{} 881 enc := qpack.NewEncoder(headerBuf) 882 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 883 Expect(enc.Close()).To(Succeed()) 884 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 885 b = append(b, headerBuf.Bytes()...) 886 887 r := bytes.NewReader(b) 888 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 889 closed := make(chan struct{}) 890 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 891 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 892 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 893 Expect(err).To(HaveOccurred()) 894 Eventually(closed).Should(BeClosed()) 895 }) 896 897 It("cancels the stream when the HEADERS frame is too large", func() { 898 b := (&headersFrame{Length: 1338}).Append(nil) 899 r := bytes.NewReader(b) 900 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 901 closed := make(chan struct{}) 902 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 903 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 904 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 905 Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) 906 Eventually(closed).Should(BeClosed()) 907 }) 908 }) 909 910 Context("request cancellations", func() { 911 for _, dontClose := range []bool{false, true} { 912 dontClose := dontClose 913 914 Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { 915 roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} 916 917 It("cancels a request while waiting for the handshake to complete", func() { 918 ctx, cancel := context.WithCancel(context.Background()) 919 req := req.WithContext(ctx) 920 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 921 922 errChan := make(chan error) 923 go func() { 924 _, err := cl.RoundTripOpt(req, roundTripOpt) 925 errChan <- err 926 }() 927 Consistently(errChan).ShouldNot(Receive()) 928 cancel() 929 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 930 }) 931 932 It("cancels a request while the request is still in flight", func() { 933 ctx, cancel := context.WithCancel(context.Background()) 934 req := req.WithContext(ctx) 935 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 936 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 937 buf := &bytes.Buffer{} 938 str.EXPECT().Close().MaxTimes(1) 939 940 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 941 942 done := make(chan struct{}) 943 canceled := make(chan struct{}) 944 gomock.InOrder( 945 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 946 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 947 ) 948 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 949 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 950 cancel() 951 <-canceled 952 return 0, errors.New("test done") 953 }) 954 _, err := cl.RoundTripOpt(req, roundTripOpt) 955 Expect(err).To(MatchError(context.Canceled)) 956 Eventually(done).Should(BeClosed()) 957 }) 958 }) 959 } 960 961 It("cancels a request after the response arrived", func() { 962 rspBuf := bytes.NewBuffer(getResponse(404)) 963 964 ctx, cancel := context.WithCancel(context.Background()) 965 req := req.WithContext(ctx) 966 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 967 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 968 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 969 buf := &bytes.Buffer{} 970 str.EXPECT().Close().MaxTimes(1) 971 972 done := make(chan struct{}) 973 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 974 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 975 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 976 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 977 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 978 Expect(err).ToNot(HaveOccurred()) 979 cancel() 980 Eventually(done).Should(BeClosed()) 981 }) 982 }) 983 984 Context("gzip compression", func() { 985 BeforeEach(func() { 986 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 987 }) 988 989 It("adds the gzip header to requests", func() { 990 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 991 buf := &bytes.Buffer{} 992 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 993 gomock.InOrder( 994 str.EXPECT().Close(), 995 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 996 ) 997 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 998 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 999 Expect(err).To(MatchError("test done")) 1000 hfs := decodeHeader(buf) 1001 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 1002 }) 1003 1004 It("doesn't add gzip if the header disable it", func() { 1005 client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) 1006 Expect(err).ToNot(HaveOccurred()) 1007 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1008 buf := &bytes.Buffer{} 1009 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 1010 gomock.InOrder( 1011 str.EXPECT().Close(), 1012 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 1013 ) 1014 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 1015 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 1016 Expect(err).To(MatchError("test done")) 1017 hfs := decodeHeader(buf) 1018 Expect(hfs).ToNot(HaveKey("accept-encoding")) 1019 }) 1020 1021 It("decompresses the response", func() { 1022 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1023 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1024 buf := &bytes.Buffer{} 1025 rstr := mockquic.NewMockStream(mockCtrl) 1026 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1027 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1028 rw.Header().Set("Content-Encoding", "gzip") 1029 gz := gzip.NewWriter(rw) 1030 gz.Write([]byte("gzipped response")) 1031 gz.Close() 1032 rw.Flush() 1033 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1034 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1035 str.EXPECT().Close() 1036 1037 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1038 Expect(err).ToNot(HaveOccurred()) 1039 data, err := io.ReadAll(rsp.Body) 1040 Expect(err).ToNot(HaveOccurred()) 1041 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 1042 Expect(string(data)).To(Equal("gzipped response")) 1043 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1044 Expect(rsp.Uncompressed).To(BeTrue()) 1045 }) 1046 1047 It("only decompresses the response if the response contains the right content-encoding header", func() { 1048 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1049 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1050 buf := &bytes.Buffer{} 1051 rstr := mockquic.NewMockStream(mockCtrl) 1052 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1053 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1054 rw.Write([]byte("not gzipped")) 1055 rw.Flush() 1056 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1057 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1058 str.EXPECT().Close() 1059 1060 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1061 Expect(err).ToNot(HaveOccurred()) 1062 data, err := io.ReadAll(rsp.Body) 1063 Expect(err).ToNot(HaveOccurred()) 1064 Expect(string(data)).To(Equal("not gzipped")) 1065 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1066 }) 1067 }) 1068 }) 1069 })