github.com/tumi8/quic-go@v0.37.4-tum/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "crypto/tls" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/golang/mock/gomock" 16 "github.com/tumi8/quic-go" 17 mockquic "github.com/tumi8/quic-go/noninternal/mocks/quic" 18 "github.com/tumi8/quic-go/quicvarint" 19 20 "github.com/tumi8/quic-go/noninternal/protocol" 21 "github.com/tumi8/quic-go/noninternal/utils" 22 "github.com/quic-go/qpack" 23 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 ) 27 28 var _ = Describe("Client", func() { 29 var ( 30 cl *client 31 req *http.Request 32 origDialAddr = dialAddr 33 handshakeChan <-chan struct{} // a closed chan 34 ) 35 36 BeforeEach(func() { 37 origDialAddr = dialAddr 38 hostname := "quic.clemente.io:1337" 39 c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) 40 Expect(err).ToNot(HaveOccurred()) 41 cl = c.(*client) 42 Expect(cl.hostname).To(Equal(hostname)) 43 44 req, err = http.NewRequest("GET", "https://localhost:1337", nil) 45 Expect(err).ToNot(HaveOccurred()) 46 47 ch := make(chan struct{}) 48 close(ch) 49 handshakeChan = ch 50 }) 51 52 AfterEach(func() { 53 dialAddr = origDialAddr 54 }) 55 56 It("rejects quic.Configs that allow multiple QUIC versions", func() { 57 qconf := &quic.Config{ 58 Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1}, 59 } 60 _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) 61 Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) 62 }) 63 64 It("uses the default QUIC and TLS config if none is give", func() { 65 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 66 Expect(err).ToNot(HaveOccurred()) 67 var dialAddrCalled bool 68 dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 69 Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) 70 Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) 71 Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) 72 dialAddrCalled = true 73 return nil, errors.New("test done") 74 } 75 client.RoundTripOpt(req, RoundTripOpt{}) 76 Expect(dialAddrCalled).To(BeTrue()) 77 }) 78 79 It("adds the port to the hostname, if none is given", func() { 80 client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) 81 Expect(err).ToNot(HaveOccurred()) 82 var dialAddrCalled bool 83 dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { 84 Expect(hostname).To(Equal("quic.clemente.io:443")) 85 dialAddrCalled = true 86 return nil, errors.New("test done") 87 } 88 req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) 89 Expect(err).ToNot(HaveOccurred()) 90 client.RoundTripOpt(req, RoundTripOpt{}) 91 Expect(dialAddrCalled).To(BeTrue()) 92 }) 93 94 It("sets the ServerName in the tls.Config, if not set", func() { 95 const host = "foo.bar" 96 dialCalled := false 97 dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 98 Expect(tlsCfg.ServerName).To(Equal(host)) 99 dialCalled = true 100 return nil, errors.New("test done") 101 } 102 client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) 103 Expect(err).ToNot(HaveOccurred()) 104 req, err := http.NewRequest("GET", "https://foo.bar", nil) 105 Expect(err).ToNot(HaveOccurred()) 106 client.RoundTripOpt(req, RoundTripOpt{}) 107 Expect(dialCalled).To(BeTrue()) 108 }) 109 110 It("uses the TLS config and QUIC config", func() { 111 tlsConf := &tls.Config{ 112 ServerName: "foo.bar", 113 NextProtos: []string{"proto foo", "proto bar"}, 114 } 115 quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} 116 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) 117 Expect(err).ToNot(HaveOccurred()) 118 var dialAddrCalled bool 119 dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 120 Expect(host).To(Equal("localhost:1337")) 121 Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) 122 Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) 123 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 124 dialAddrCalled = true 125 return nil, errors.New("test done") 126 } 127 client.RoundTripOpt(req, RoundTripOpt{}) 128 Expect(dialAddrCalled).To(BeTrue()) 129 // make sure the original tls.Config was not modified 130 Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) 131 }) 132 133 It("uses the custom dialer, if provided", func() { 134 testErr := errors.New("test done") 135 tlsConf := &tls.Config{ServerName: "foo.bar"} 136 quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} 137 ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 138 defer cancel() 139 var dialerCalled bool 140 dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 141 Expect(ctxP).To(Equal(ctx)) 142 Expect(address).To(Equal("localhost:1337")) 143 Expect(tlsConfP.ServerName).To(Equal("foo.bar")) 144 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 145 dialerCalled = true 146 return nil, testErr 147 } 148 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) 149 Expect(err).ToNot(HaveOccurred()) 150 _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) 151 Expect(err).To(MatchError(testErr)) 152 Expect(dialerCalled).To(BeTrue()) 153 }) 154 155 It("enables HTTP/3 Datagrams", func() { 156 testErr := errors.New("handshake error") 157 client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) 158 Expect(err).ToNot(HaveOccurred()) 159 dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 160 Expect(quicConf.EnableDatagrams).To(BeTrue()) 161 return nil, testErr 162 } 163 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 164 Expect(err).To(MatchError(testErr)) 165 }) 166 167 It("errors when dialing fails", func() { 168 testErr := errors.New("handshake error") 169 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 170 Expect(err).ToNot(HaveOccurred()) 171 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 172 return nil, testErr 173 } 174 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 175 Expect(err).To(MatchError(testErr)) 176 }) 177 178 It("closes correctly if connection was not created", func() { 179 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 180 Expect(err).ToNot(HaveOccurred()) 181 Expect(client.Close()).To(Succeed()) 182 }) 183 184 Context("validating the address", func() { 185 It("refuses to do requests for the wrong host", func() { 186 req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 187 Expect(err).ToNot(HaveOccurred()) 188 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 189 Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 190 }) 191 192 It("allows requests using a different scheme", func() { 193 testErr := errors.New("handshake error") 194 req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) 195 Expect(err).ToNot(HaveOccurred()) 196 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 197 return nil, testErr 198 } 199 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 200 Expect(err).To(MatchError(testErr)) 201 }) 202 }) 203 204 Context("hijacking bidirectional streams", func() { 205 var ( 206 request *http.Request 207 conn *mockquic.MockEarlyConnection 208 settingsFrameWritten chan struct{} 209 ) 210 testDone := make(chan struct{}) 211 212 BeforeEach(func() { 213 testDone = make(chan struct{}) 214 settingsFrameWritten = make(chan struct{}) 215 controlStr := mockquic.NewMockStream(mockCtrl) 216 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 217 defer GinkgoRecover() 218 close(settingsFrameWritten) 219 }) 220 conn = mockquic.NewMockEarlyConnection(mockCtrl) 221 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 222 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 223 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 224 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 225 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 226 return conn, nil 227 } 228 var err error 229 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 230 Expect(err).ToNot(HaveOccurred()) 231 }) 232 233 AfterEach(func() { 234 testDone <- struct{}{} 235 Eventually(settingsFrameWritten).Should(BeClosed()) 236 }) 237 238 It("hijacks a bidirectional stream of unknown frame type", func() { 239 frameTypeChan := make(chan FrameType, 1) 240 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 241 Expect(e).ToNot(HaveOccurred()) 242 frameTypeChan <- ft 243 return true, nil 244 } 245 246 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 247 unknownStr := mockquic.NewMockStream(mockCtrl) 248 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 249 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 250 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 251 <-testDone 252 return nil, errors.New("test done") 253 }) 254 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 255 Expect(err).To(MatchError("done")) 256 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 257 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 258 }) 259 260 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 261 frameTypeChan := make(chan FrameType, 1) 262 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 263 Expect(e).ToNot(HaveOccurred()) 264 frameTypeChan <- ft 265 return false, nil 266 } 267 268 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 269 unknownStr := mockquic.NewMockStream(mockCtrl) 270 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 271 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 272 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 273 <-testDone 274 return nil, errors.New("test done") 275 }) 276 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 277 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 278 Expect(err).To(MatchError("done")) 279 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 280 }) 281 282 It("closes the connection when hijacker returned error", func() { 283 frameTypeChan := make(chan FrameType, 1) 284 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 285 Expect(e).ToNot(HaveOccurred()) 286 frameTypeChan <- ft 287 return false, errors.New("error in hijacker") 288 } 289 290 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 291 unknownStr := mockquic.NewMockStream(mockCtrl) 292 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 293 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 294 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 295 <-testDone 296 return nil, errors.New("test done") 297 }) 298 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 299 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 300 Expect(err).To(MatchError("done")) 301 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 302 }) 303 304 It("handles errors that occur when reading the frame type", func() { 305 testErr := errors.New("test error") 306 unknownStr := mockquic.NewMockStream(mockCtrl) 307 done := make(chan struct{}) 308 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { 309 defer close(done) 310 Expect(e).To(MatchError(testErr)) 311 Expect(ft).To(BeZero()) 312 Expect(str).To(Equal(unknownStr)) 313 return false, nil 314 } 315 316 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 317 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 318 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 319 <-testDone 320 return nil, errors.New("test done") 321 }) 322 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 323 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 324 Expect(err).To(MatchError("done")) 325 Eventually(done).Should(BeClosed()) 326 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 327 }) 328 }) 329 330 Context("hijacking unidirectional streams", func() { 331 var ( 332 req *http.Request 333 conn *mockquic.MockEarlyConnection 334 settingsFrameWritten chan struct{} 335 ) 336 testDone := make(chan struct{}) 337 338 BeforeEach(func() { 339 testDone = make(chan struct{}) 340 settingsFrameWritten = make(chan struct{}) 341 controlStr := mockquic.NewMockStream(mockCtrl) 342 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 343 defer GinkgoRecover() 344 close(settingsFrameWritten) 345 }) 346 conn = mockquic.NewMockEarlyConnection(mockCtrl) 347 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 348 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 349 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 350 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 351 return conn, nil 352 } 353 var err error 354 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 355 Expect(err).ToNot(HaveOccurred()) 356 }) 357 358 AfterEach(func() { 359 testDone <- struct{}{} 360 Eventually(settingsFrameWritten).Should(BeClosed()) 361 }) 362 363 It("hijacks an unidirectional stream of unknown stream type", func() { 364 streamTypeChan := make(chan StreamType, 1) 365 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 366 Expect(err).ToNot(HaveOccurred()) 367 streamTypeChan <- st 368 return true 369 } 370 371 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 372 unknownStr := mockquic.NewMockStream(mockCtrl) 373 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 374 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 375 return unknownStr, nil 376 }) 377 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 378 <-testDone 379 return nil, errors.New("test done") 380 }) 381 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 382 Expect(err).To(MatchError("done")) 383 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 384 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 385 }) 386 387 It("handles errors that occur when reading the stream type", func() { 388 testErr := errors.New("test error") 389 done := make(chan struct{}) 390 unknownStr := mockquic.NewMockStream(mockCtrl) 391 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 392 defer close(done) 393 Expect(st).To(BeZero()) 394 Expect(str).To(Equal(unknownStr)) 395 Expect(err).To(MatchError(testErr)) 396 return true 397 } 398 399 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 400 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 401 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 402 <-testDone 403 return nil, errors.New("test done") 404 }) 405 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 406 Expect(err).To(MatchError("done")) 407 Eventually(done).Should(BeClosed()) 408 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 409 }) 410 411 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 412 streamTypeChan := make(chan StreamType, 1) 413 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 414 Expect(err).ToNot(HaveOccurred()) 415 streamTypeChan <- st 416 return false 417 } 418 419 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 420 unknownStr := mockquic.NewMockStream(mockCtrl) 421 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 422 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 423 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 424 return unknownStr, nil 425 }) 426 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 427 <-testDone 428 return nil, errors.New("test done") 429 }) 430 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 431 Expect(err).To(MatchError("done")) 432 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 433 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 434 }) 435 }) 436 437 Context("control stream handling", func() { 438 var ( 439 req *http.Request 440 conn *mockquic.MockEarlyConnection 441 settingsFrameWritten chan struct{} 442 ) 443 testDone := make(chan struct{}) 444 445 BeforeEach(func() { 446 settingsFrameWritten = make(chan struct{}) 447 controlStr := mockquic.NewMockStream(mockCtrl) 448 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 449 defer GinkgoRecover() 450 close(settingsFrameWritten) 451 }) 452 conn = mockquic.NewMockEarlyConnection(mockCtrl) 453 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 454 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 455 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 456 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 457 return conn, nil 458 } 459 var err error 460 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 461 Expect(err).ToNot(HaveOccurred()) 462 }) 463 464 AfterEach(func() { 465 testDone <- struct{}{} 466 Eventually(settingsFrameWritten).Should(BeClosed()) 467 }) 468 469 It("parses the SETTINGS frame", func() { 470 b := quicvarint.Append(nil, streamTypeControlStream) 471 b = (&settingsFrame{}).Append(b) 472 r := bytes.NewReader(b) 473 controlStr := mockquic.NewMockStream(mockCtrl) 474 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 475 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 476 return controlStr, nil 477 }) 478 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 479 <-testDone 480 return nil, errors.New("test done") 481 }) 482 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 483 Expect(err).To(MatchError("done")) 484 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 485 }) 486 487 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 488 streamType := t 489 name := "encoder" 490 if streamType == streamTypeQPACKDecoderStream { 491 name = "decoder" 492 } 493 494 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 495 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 496 str := mockquic.NewMockStream(mockCtrl) 497 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 498 499 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 500 return str, nil 501 }) 502 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 503 <-testDone 504 return nil, errors.New("test done") 505 }) 506 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 507 Expect(err).To(MatchError("done")) 508 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 509 }) 510 } 511 512 It("resets streams Other than the control stream and the QPACK streams", func() { 513 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) 514 str := mockquic.NewMockStream(mockCtrl) 515 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 516 done := make(chan struct{}) 517 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { 518 close(done) 519 }) 520 521 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 522 return str, nil 523 }) 524 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 525 <-testDone 526 return nil, errors.New("test done") 527 }) 528 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 529 Expect(err).To(MatchError("done")) 530 Eventually(done).Should(BeClosed()) 531 }) 532 533 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 534 b := quicvarint.Append(nil, streamTypeControlStream) 535 b = (&dataFrame{}).Append(b) 536 r := bytes.NewReader(b) 537 controlStr := mockquic.NewMockStream(mockCtrl) 538 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 539 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 540 return controlStr, nil 541 }) 542 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 543 <-testDone 544 return nil, errors.New("test done") 545 }) 546 done := make(chan struct{}) 547 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 548 defer GinkgoRecover() 549 Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) 550 close(done) 551 }) 552 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 553 Expect(err).To(MatchError("done")) 554 Eventually(done).Should(BeClosed()) 555 }) 556 557 It("errors when parsing the frame on the control stream fails", func() { 558 b := quicvarint.Append(nil, streamTypeControlStream) 559 b = (&settingsFrame{}).Append(b) 560 r := bytes.NewReader(b[:len(b)-1]) 561 controlStr := mockquic.NewMockStream(mockCtrl) 562 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 563 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 564 return controlStr, nil 565 }) 566 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 567 <-testDone 568 return nil, errors.New("test done") 569 }) 570 done := make(chan struct{}) 571 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 572 defer GinkgoRecover() 573 Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) 574 close(done) 575 }) 576 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 577 Expect(err).To(MatchError("done")) 578 Eventually(done).Should(BeClosed()) 579 }) 580 581 It("errors when parsing the server opens a push stream", func() { 582 buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) 583 controlStr := mockquic.NewMockStream(mockCtrl) 584 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 585 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 586 return controlStr, nil 587 }) 588 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 589 <-testDone 590 return nil, errors.New("test done") 591 }) 592 done := make(chan struct{}) 593 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 594 defer GinkgoRecover() 595 Expect(code).To(BeEquivalentTo(ErrCodeIDError)) 596 close(done) 597 }) 598 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 599 Expect(err).To(MatchError("done")) 600 Eventually(done).Should(BeClosed()) 601 }) 602 603 It("errors when the server advertises datagram support (and we enabled support for it)", func() { 604 cl.opts.EnableDatagram = true 605 b := quicvarint.Append(nil, streamTypeControlStream) 606 b = (&settingsFrame{Datagram: true}).Append(b) 607 r := bytes.NewReader(b) 608 controlStr := mockquic.NewMockStream(mockCtrl) 609 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 610 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 611 return controlStr, nil 612 }) 613 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 614 <-testDone 615 return nil, errors.New("test done") 616 }) 617 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 618 done := make(chan struct{}) 619 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { 620 defer GinkgoRecover() 621 Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) 622 Expect(reason).To(Equal("missing QUIC Datagram support")) 623 close(done) 624 }) 625 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 626 Expect(err).To(MatchError("done")) 627 Eventually(done).Should(BeClosed()) 628 }) 629 }) 630 631 Context("Doing requests", func() { 632 var ( 633 req *http.Request 634 str *mockquic.MockStream 635 conn *mockquic.MockEarlyConnection 636 settingsFrameWritten chan struct{} 637 ) 638 testDone := make(chan struct{}) 639 640 decodeHeader := func(str io.Reader) map[string]string { 641 fields := make(map[string]string) 642 decoder := qpack.NewDecoder(nil) 643 644 frame, err := parseNextFrame(str, nil) 645 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 646 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 647 headersFrame := frame.(*headersFrame) 648 data := make([]byte, headersFrame.Length) 649 _, err = io.ReadFull(str, data) 650 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 651 hfs, err := decoder.DecodeFull(data) 652 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 653 for _, p := range hfs { 654 fields[p.Name] = p.Value 655 } 656 return fields 657 } 658 659 getResponse := func(status int) []byte { 660 buf := &bytes.Buffer{} 661 rstr := mockquic.NewMockStream(mockCtrl) 662 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 663 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 664 rw.WriteHeader(status) 665 rw.Flush() 666 return buf.Bytes() 667 } 668 669 BeforeEach(func() { 670 settingsFrameWritten = make(chan struct{}) 671 controlStr := mockquic.NewMockStream(mockCtrl) 672 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 673 defer GinkgoRecover() 674 r := bytes.NewReader(b) 675 streamType, err := quicvarint.Read(r) 676 Expect(err).ToNot(HaveOccurred()) 677 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 678 close(settingsFrameWritten) 679 }) // SETTINGS frame 680 str = mockquic.NewMockStream(mockCtrl) 681 conn = mockquic.NewMockEarlyConnection(mockCtrl) 682 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 683 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 684 <-testDone 685 return nil, errors.New("test done") 686 }) 687 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 688 return conn, nil 689 } 690 var err error 691 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 692 Expect(err).ToNot(HaveOccurred()) 693 }) 694 695 AfterEach(func() { 696 testDone <- struct{}{} 697 Eventually(settingsFrameWritten).Should(BeClosed()) 698 }) 699 700 It("errors if it can't open a stream", func() { 701 testErr := errors.New("stream open error") 702 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 703 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 704 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 705 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 706 Expect(err).To(MatchError(testErr)) 707 }) 708 709 It("performs a 0-RTT request", func() { 710 testErr := errors.New("stream open error") 711 req.Method = MethodGet0RTT 712 // don't EXPECT any calls to HandshakeComplete() 713 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 714 buf := &bytes.Buffer{} 715 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 716 str.EXPECT().Close() 717 str.EXPECT().CancelWrite(gomock.Any()) 718 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 719 return 0, testErr 720 }) 721 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 722 Expect(err).To(MatchError(testErr)) 723 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) 724 }) 725 726 It("returns a response", func() { 727 rspBuf := bytes.NewBuffer(getResponse(418)) 728 gomock.InOrder( 729 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 730 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 731 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 732 ) 733 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 734 str.EXPECT().Close() 735 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 736 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 737 Expect(err).ToNot(HaveOccurred()) 738 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 739 Expect(rsp.ProtoMajor).To(Equal(3)) 740 Expect(rsp.StatusCode).To(Equal(418)) 741 Expect(rsp.Request).ToNot(BeNil()) 742 }) 743 744 It("doesn't close the request stream, with DontCloseRequestStream set", func() { 745 rspBuf := bytes.NewBuffer(getResponse(418)) 746 gomock.InOrder( 747 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 748 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 749 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 750 ) 751 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 752 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 753 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) 754 Expect(err).ToNot(HaveOccurred()) 755 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 756 Expect(rsp.ProtoMajor).To(Equal(3)) 757 Expect(rsp.StatusCode).To(Equal(418)) 758 }) 759 760 Context("requests containing a Body", func() { 761 var strBuf *bytes.Buffer 762 763 BeforeEach(func() { 764 strBuf = &bytes.Buffer{} 765 gomock.InOrder( 766 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 767 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 768 ) 769 body := &mockBody{} 770 body.SetData([]byte("request body")) 771 var err error 772 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 773 Expect(err).ToNot(HaveOccurred()) 774 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 775 }) 776 777 It("sends a request", func() { 778 done := make(chan struct{}) 779 gomock.InOrder( 780 str.EXPECT().Close().Do(func() { close(done) }), 781 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors 782 ) 783 // the response body is sent asynchronously, while already reading the response 784 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 785 <-done 786 return 0, errors.New("test done") 787 }) 788 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 789 Expect(err).To(MatchError("test done")) 790 hfs := decodeHeader(strBuf) 791 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 792 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 793 }) 794 795 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 796 req.ContentLength = 7 797 var once sync.Once 798 done := make(chan struct{}) 799 gomock.InOrder( 800 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 801 once.Do(func() { 802 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 803 close(done) 804 }) 805 }).AnyTimes(), 806 str.EXPECT().Close().MaxTimes(1), 807 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 808 ) 809 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 810 <-done 811 return 0, errors.New("done") 812 }) 813 cl.RoundTripOpt(req, RoundTripOpt{}) 814 Expect(strBuf.String()).To(ContainSubstring("request")) 815 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 816 }) 817 818 It("returns the error that occurred when reading the body", func() { 819 req.Body.(*mockBody).readErr = errors.New("testErr") 820 done := make(chan struct{}) 821 gomock.InOrder( 822 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 823 close(done) 824 }), 825 str.EXPECT().CancelWrite(gomock.Any()), 826 ) 827 828 // the response body is sent asynchronously, while already reading the response 829 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 830 <-done 831 return 0, errors.New("test done") 832 }) 833 closed := make(chan struct{}) 834 str.EXPECT().Close().Do(func() { close(closed) }) 835 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 836 Expect(err).To(MatchError("test done")) 837 Eventually(closed).Should(BeClosed()) 838 }) 839 840 It("closes the connection when the first frame is not a HEADERS frame", func() { 841 b := (&dataFrame{Length: 0x42}).Append(nil) 842 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 843 closed := make(chan struct{}) 844 r := bytes.NewReader(b) 845 str.EXPECT().Close().Do(func() { close(closed) }) 846 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 847 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 848 Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) 849 Eventually(closed).Should(BeClosed()) 850 }) 851 852 It("cancels the stream when parsing the headers fails", func() { 853 headerBuf := &bytes.Buffer{} 854 enc := qpack.NewEncoder(headerBuf) 855 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 856 Expect(enc.Close()).To(Succeed()) 857 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 858 b = append(b, headerBuf.Bytes()...) 859 860 r := bytes.NewReader(b) 861 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 862 closed := make(chan struct{}) 863 str.EXPECT().Close().Do(func() { close(closed) }) 864 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 865 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 866 Expect(err).To(HaveOccurred()) 867 Eventually(closed).Should(BeClosed()) 868 }) 869 870 It("cancels the stream when the HEADERS frame is too large", func() { 871 b := (&headersFrame{Length: 1338}).Append(nil) 872 r := bytes.NewReader(b) 873 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 874 closed := make(chan struct{}) 875 str.EXPECT().Close().Do(func() { close(closed) }) 876 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 877 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 878 Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) 879 Eventually(closed).Should(BeClosed()) 880 }) 881 }) 882 883 Context("request cancellations", func() { 884 for _, dontClose := range []bool{false, true} { 885 dontClose := dontClose 886 887 Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { 888 roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} 889 890 It("cancels a request while waiting for the handshake to complete", func() { 891 ctx, cancel := context.WithCancel(context.Background()) 892 req := req.WithContext(ctx) 893 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 894 895 errChan := make(chan error) 896 go func() { 897 _, err := cl.RoundTripOpt(req, roundTripOpt) 898 errChan <- err 899 }() 900 Consistently(errChan).ShouldNot(Receive()) 901 cancel() 902 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 903 }) 904 905 It("cancels a request while the request is still in flight", func() { 906 ctx, cancel := context.WithCancel(context.Background()) 907 req := req.WithContext(ctx) 908 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 909 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 910 buf := &bytes.Buffer{} 911 str.EXPECT().Close().MaxTimes(1) 912 913 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 914 915 done := make(chan struct{}) 916 canceled := make(chan struct{}) 917 gomock.InOrder( 918 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 919 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 920 ) 921 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 922 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 923 cancel() 924 <-canceled 925 return 0, errors.New("test done") 926 }) 927 _, err := cl.RoundTripOpt(req, roundTripOpt) 928 Expect(err).To(MatchError("test done")) 929 Eventually(done).Should(BeClosed()) 930 }) 931 }) 932 } 933 934 It("cancels a request after the response arrived", func() { 935 rspBuf := bytes.NewBuffer(getResponse(404)) 936 937 ctx, cancel := context.WithCancel(context.Background()) 938 req := req.WithContext(ctx) 939 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 940 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 941 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 942 buf := &bytes.Buffer{} 943 str.EXPECT().Close().MaxTimes(1) 944 945 done := make(chan struct{}) 946 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 947 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 948 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 949 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 950 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 951 Expect(err).ToNot(HaveOccurred()) 952 cancel() 953 Eventually(done).Should(BeClosed()) 954 }) 955 }) 956 957 Context("gzip compression", func() { 958 BeforeEach(func() { 959 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 960 }) 961 962 It("adds the gzip header to requests", func() { 963 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 964 buf := &bytes.Buffer{} 965 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 966 gomock.InOrder( 967 str.EXPECT().Close(), 968 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 969 ) 970 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 971 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 972 Expect(err).To(MatchError("test done")) 973 hfs := decodeHeader(buf) 974 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 975 }) 976 977 It("doesn't add gzip if the header disable it", func() { 978 client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) 979 Expect(err).ToNot(HaveOccurred()) 980 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 981 buf := &bytes.Buffer{} 982 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 983 gomock.InOrder( 984 str.EXPECT().Close(), 985 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 986 ) 987 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 988 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 989 Expect(err).To(MatchError("test done")) 990 hfs := decodeHeader(buf) 991 Expect(hfs).ToNot(HaveKey("accept-encoding")) 992 }) 993 994 It("decompresses the response", func() { 995 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 996 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 997 buf := &bytes.Buffer{} 998 rstr := mockquic.NewMockStream(mockCtrl) 999 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1000 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1001 rw.Header().Set("Content-Encoding", "gzip") 1002 gz := gzip.NewWriter(rw) 1003 gz.Write([]byte("gzipped response")) 1004 gz.Close() 1005 rw.Flush() 1006 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1007 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1008 str.EXPECT().Close() 1009 1010 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1011 Expect(err).ToNot(HaveOccurred()) 1012 data, err := io.ReadAll(rsp.Body) 1013 Expect(err).ToNot(HaveOccurred()) 1014 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 1015 Expect(string(data)).To(Equal("gzipped response")) 1016 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1017 Expect(rsp.Uncompressed).To(BeTrue()) 1018 }) 1019 1020 It("only decompresses the response if the response contains the right content-encoding header", func() { 1021 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1022 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1023 buf := &bytes.Buffer{} 1024 rstr := mockquic.NewMockStream(mockCtrl) 1025 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1026 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1027 rw.Write([]byte("not gzipped")) 1028 rw.Flush() 1029 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1030 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1031 str.EXPECT().Close() 1032 1033 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1034 Expect(err).ToNot(HaveOccurred()) 1035 data, err := io.ReadAll(rsp.Body) 1036 Expect(err).ToNot(HaveOccurred()) 1037 Expect(string(data)).To(Equal("not gzipped")) 1038 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1039 }) 1040 }) 1041 }) 1042 })