github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "crypto/tls" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "time" 13 14 "github.com/mikelsr/quic-go" 15 mockquic "github.com/mikelsr/quic-go/internal/mocks/quic" 16 "github.com/mikelsr/quic-go/internal/protocol" 17 "github.com/mikelsr/quic-go/internal/utils" 18 "github.com/mikelsr/quic-go/quicvarint" 19 20 "github.com/golang/mock/gomock" 21 "github.com/quic-go/qpack" 22 23 . "github.com/onsi/ginkgo/v2" 24 . "github.com/onsi/gomega" 25 ) 26 27 var _ = Describe("Client", func() { 28 var ( 29 cl *client 30 req *http.Request 31 origDialAddr = dialAddr 32 handshakeChan <-chan struct{} // a closed chan 33 ) 34 35 BeforeEach(func() { 36 origDialAddr = dialAddr 37 hostname := "quic.clemente.io:1337" 38 c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) 39 Expect(err).ToNot(HaveOccurred()) 40 cl = c.(*client) 41 Expect(cl.hostname).To(Equal(hostname)) 42 43 req, err = http.NewRequest("GET", "https://localhost:1337", nil) 44 Expect(err).ToNot(HaveOccurred()) 45 46 ch := make(chan struct{}) 47 close(ch) 48 handshakeChan = ch 49 }) 50 51 AfterEach(func() { 52 dialAddr = origDialAddr 53 }) 54 55 It("rejects quic.Configs that allow multiple QUIC versions", func() { 56 qconf := &quic.Config{ 57 Versions: []quic.VersionNumber{protocol.Version2, protocol.Version1}, 58 } 59 _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) 60 Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) 61 }) 62 63 It("uses the default QUIC and TLS config if none is give", func() { 64 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 65 Expect(err).ToNot(HaveOccurred()) 66 var dialAddrCalled bool 67 dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 68 Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) 69 Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) 70 Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) 71 dialAddrCalled = true 72 return nil, errors.New("test done") 73 } 74 client.RoundTripOpt(req, RoundTripOpt{}) 75 Expect(dialAddrCalled).To(BeTrue()) 76 }) 77 78 It("adds the port to the hostname, if none is given", func() { 79 client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) 80 Expect(err).ToNot(HaveOccurred()) 81 var dialAddrCalled bool 82 dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { 83 Expect(hostname).To(Equal("quic.clemente.io:443")) 84 dialAddrCalled = true 85 return nil, errors.New("test done") 86 } 87 req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) 88 Expect(err).ToNot(HaveOccurred()) 89 client.RoundTripOpt(req, RoundTripOpt{}) 90 Expect(dialAddrCalled).To(BeTrue()) 91 }) 92 93 It("sets the ServerName in the tls.Config, if not set", func() { 94 const host = "foo.bar" 95 dialCalled := false 96 dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 97 Expect(tlsCfg.ServerName).To(Equal(host)) 98 dialCalled = true 99 return nil, errors.New("test done") 100 } 101 client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) 102 Expect(err).ToNot(HaveOccurred()) 103 req, err := http.NewRequest("GET", "https://foo.bar", nil) 104 Expect(err).ToNot(HaveOccurred()) 105 client.RoundTripOpt(req, RoundTripOpt{}) 106 Expect(dialCalled).To(BeTrue()) 107 }) 108 109 It("uses the TLS config and QUIC config", func() { 110 tlsConf := &tls.Config{ 111 ServerName: "foo.bar", 112 NextProtos: []string{"proto foo", "proto bar"}, 113 } 114 quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} 115 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) 116 Expect(err).ToNot(HaveOccurred()) 117 var dialAddrCalled bool 118 dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 119 Expect(host).To(Equal("localhost:1337")) 120 Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) 121 Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) 122 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 123 dialAddrCalled = true 124 return nil, errors.New("test done") 125 } 126 client.RoundTripOpt(req, RoundTripOpt{}) 127 Expect(dialAddrCalled).To(BeTrue()) 128 // make sure the original tls.Config was not modified 129 Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) 130 }) 131 132 It("uses the custom dialer, if provided", func() { 133 testErr := errors.New("test done") 134 tlsConf := &tls.Config{ServerName: "foo.bar"} 135 quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} 136 ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 137 defer cancel() 138 var dialerCalled bool 139 dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 140 Expect(ctxP).To(Equal(ctx)) 141 Expect(address).To(Equal("localhost:1337")) 142 Expect(tlsConfP.ServerName).To(Equal("foo.bar")) 143 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 144 dialerCalled = true 145 return nil, testErr 146 } 147 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) 148 Expect(err).ToNot(HaveOccurred()) 149 _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) 150 Expect(err).To(MatchError(testErr)) 151 Expect(dialerCalled).To(BeTrue()) 152 }) 153 154 It("enables HTTP/3 Datagrams", func() { 155 testErr := errors.New("handshake error") 156 client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) 157 Expect(err).ToNot(HaveOccurred()) 158 dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 159 Expect(quicConf.EnableDatagrams).To(BeTrue()) 160 return nil, testErr 161 } 162 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 163 Expect(err).To(MatchError(testErr)) 164 }) 165 166 It("errors when dialing fails", func() { 167 testErr := errors.New("handshake error") 168 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 169 Expect(err).ToNot(HaveOccurred()) 170 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 171 return nil, testErr 172 } 173 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 174 Expect(err).To(MatchError(testErr)) 175 }) 176 177 It("closes correctly if connection was not created", func() { 178 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 179 Expect(err).ToNot(HaveOccurred()) 180 Expect(client.Close()).To(Succeed()) 181 }) 182 183 Context("validating the address", func() { 184 It("refuses to do requests for the wrong host", func() { 185 req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 186 Expect(err).ToNot(HaveOccurred()) 187 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 188 Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 189 }) 190 191 It("allows requests using a different scheme", func() { 192 testErr := errors.New("handshake error") 193 req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) 194 Expect(err).ToNot(HaveOccurred()) 195 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 196 return nil, testErr 197 } 198 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 199 Expect(err).To(MatchError(testErr)) 200 }) 201 }) 202 203 Context("hijacking bidirectional streams", func() { 204 var ( 205 request *http.Request 206 conn *mockquic.MockEarlyConnection 207 settingsFrameWritten chan struct{} 208 ) 209 testDone := make(chan struct{}) 210 211 BeforeEach(func() { 212 testDone = make(chan struct{}) 213 settingsFrameWritten = make(chan struct{}) 214 controlStr := mockquic.NewMockStream(mockCtrl) 215 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 216 defer GinkgoRecover() 217 close(settingsFrameWritten) 218 }) 219 conn = mockquic.NewMockEarlyConnection(mockCtrl) 220 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 221 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 222 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 223 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 224 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 225 return conn, nil 226 } 227 var err error 228 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 229 Expect(err).ToNot(HaveOccurred()) 230 }) 231 232 AfterEach(func() { 233 testDone <- struct{}{} 234 Eventually(settingsFrameWritten).Should(BeClosed()) 235 }) 236 237 It("hijacks a bidirectional stream of unknown frame type", func() { 238 frameTypeChan := make(chan FrameType, 1) 239 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 240 Expect(e).ToNot(HaveOccurred()) 241 frameTypeChan <- ft 242 return true, nil 243 } 244 245 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 246 unknownStr := mockquic.NewMockStream(mockCtrl) 247 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 248 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 249 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 250 <-testDone 251 return nil, errors.New("test done") 252 }) 253 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 254 Expect(err).To(MatchError("done")) 255 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 256 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 257 }) 258 259 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 260 frameTypeChan := make(chan FrameType, 1) 261 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 262 Expect(e).ToNot(HaveOccurred()) 263 frameTypeChan <- ft 264 return false, nil 265 } 266 267 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 268 unknownStr := mockquic.NewMockStream(mockCtrl) 269 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 270 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 271 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 272 <-testDone 273 return nil, errors.New("test done") 274 }) 275 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 276 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 277 Expect(err).To(MatchError("done")) 278 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 279 }) 280 281 It("closes the connection when hijacker returned error", func() { 282 frameTypeChan := make(chan FrameType, 1) 283 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 284 Expect(e).ToNot(HaveOccurred()) 285 frameTypeChan <- ft 286 return false, errors.New("error in hijacker") 287 } 288 289 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 290 unknownStr := mockquic.NewMockStream(mockCtrl) 291 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 292 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 293 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 294 <-testDone 295 return nil, errors.New("test done") 296 }) 297 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 298 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 299 Expect(err).To(MatchError("done")) 300 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 301 }) 302 303 It("handles errors that occur when reading the frame type", func() { 304 testErr := errors.New("test error") 305 unknownStr := mockquic.NewMockStream(mockCtrl) 306 done := make(chan struct{}) 307 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { 308 defer close(done) 309 Expect(e).To(MatchError(testErr)) 310 Expect(ft).To(BeZero()) 311 Expect(str).To(Equal(unknownStr)) 312 return false, nil 313 } 314 315 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 316 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 317 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 318 <-testDone 319 return nil, errors.New("test done") 320 }) 321 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 322 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 323 Expect(err).To(MatchError("done")) 324 Eventually(done).Should(BeClosed()) 325 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 326 }) 327 }) 328 329 Context("hijacking unidirectional streams", func() { 330 var ( 331 req *http.Request 332 conn *mockquic.MockEarlyConnection 333 settingsFrameWritten chan struct{} 334 ) 335 testDone := make(chan struct{}) 336 337 BeforeEach(func() { 338 testDone = make(chan struct{}) 339 settingsFrameWritten = make(chan struct{}) 340 controlStr := mockquic.NewMockStream(mockCtrl) 341 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 342 defer GinkgoRecover() 343 close(settingsFrameWritten) 344 }) 345 conn = mockquic.NewMockEarlyConnection(mockCtrl) 346 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 347 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 348 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 349 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 350 return conn, nil 351 } 352 var err error 353 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 354 Expect(err).ToNot(HaveOccurred()) 355 }) 356 357 AfterEach(func() { 358 testDone <- struct{}{} 359 Eventually(settingsFrameWritten).Should(BeClosed()) 360 }) 361 362 It("hijacks an unidirectional stream of unknown stream type", func() { 363 streamTypeChan := make(chan StreamType, 1) 364 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 365 Expect(err).ToNot(HaveOccurred()) 366 streamTypeChan <- st 367 return true 368 } 369 370 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 371 unknownStr := mockquic.NewMockStream(mockCtrl) 372 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 373 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 374 return unknownStr, nil 375 }) 376 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 377 <-testDone 378 return nil, errors.New("test done") 379 }) 380 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 381 Expect(err).To(MatchError("done")) 382 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 383 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 384 }) 385 386 It("handles errors that occur when reading the stream type", func() { 387 testErr := errors.New("test error") 388 done := make(chan struct{}) 389 unknownStr := mockquic.NewMockStream(mockCtrl) 390 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 391 defer close(done) 392 Expect(st).To(BeZero()) 393 Expect(str).To(Equal(unknownStr)) 394 Expect(err).To(MatchError(testErr)) 395 return true 396 } 397 398 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 399 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 400 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 401 <-testDone 402 return nil, errors.New("test done") 403 }) 404 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 405 Expect(err).To(MatchError("done")) 406 Eventually(done).Should(BeClosed()) 407 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 408 }) 409 410 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 411 streamTypeChan := make(chan StreamType, 1) 412 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 413 Expect(err).ToNot(HaveOccurred()) 414 streamTypeChan <- st 415 return false 416 } 417 418 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 419 unknownStr := mockquic.NewMockStream(mockCtrl) 420 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 421 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 422 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 423 return unknownStr, nil 424 }) 425 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 426 <-testDone 427 return nil, errors.New("test done") 428 }) 429 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 430 Expect(err).To(MatchError("done")) 431 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 432 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 433 }) 434 }) 435 436 Context("control stream handling", func() { 437 var ( 438 req *http.Request 439 conn *mockquic.MockEarlyConnection 440 settingsFrameWritten chan struct{} 441 ) 442 testDone := make(chan struct{}) 443 444 BeforeEach(func() { 445 settingsFrameWritten = make(chan struct{}) 446 controlStr := mockquic.NewMockStream(mockCtrl) 447 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 448 defer GinkgoRecover() 449 close(settingsFrameWritten) 450 }) 451 conn = mockquic.NewMockEarlyConnection(mockCtrl) 452 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 453 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 454 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 455 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 456 return conn, nil 457 } 458 var err error 459 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 460 Expect(err).ToNot(HaveOccurred()) 461 }) 462 463 AfterEach(func() { 464 testDone <- struct{}{} 465 Eventually(settingsFrameWritten).Should(BeClosed()) 466 }) 467 468 It("parses the SETTINGS frame", func() { 469 b := quicvarint.Append(nil, streamTypeControlStream) 470 b = (&settingsFrame{}).Append(b) 471 r := bytes.NewReader(b) 472 controlStr := mockquic.NewMockStream(mockCtrl) 473 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 474 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 475 return controlStr, nil 476 }) 477 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 478 <-testDone 479 return nil, errors.New("test done") 480 }) 481 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 482 Expect(err).To(MatchError("done")) 483 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 484 }) 485 486 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 487 streamType := t 488 name := "encoder" 489 if streamType == streamTypeQPACKDecoderStream { 490 name = "decoder" 491 } 492 493 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 494 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 495 str := mockquic.NewMockStream(mockCtrl) 496 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 497 498 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 499 return str, nil 500 }) 501 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 502 <-testDone 503 return nil, errors.New("test done") 504 }) 505 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 506 Expect(err).To(MatchError("done")) 507 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 508 }) 509 } 510 511 It("resets streams Other than the control stream and the QPACK streams", func() { 512 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) 513 str := mockquic.NewMockStream(mockCtrl) 514 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 515 done := make(chan struct{}) 516 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { 517 close(done) 518 }) 519 520 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 521 return str, nil 522 }) 523 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 524 <-testDone 525 return nil, errors.New("test done") 526 }) 527 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 528 Expect(err).To(MatchError("done")) 529 Eventually(done).Should(BeClosed()) 530 }) 531 532 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 533 b := quicvarint.Append(nil, streamTypeControlStream) 534 b = (&dataFrame{}).Append(b) 535 r := bytes.NewReader(b) 536 controlStr := mockquic.NewMockStream(mockCtrl) 537 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 538 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 539 return controlStr, nil 540 }) 541 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 542 <-testDone 543 return nil, errors.New("test done") 544 }) 545 done := make(chan struct{}) 546 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 547 defer GinkgoRecover() 548 Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) 549 close(done) 550 }) 551 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 552 Expect(err).To(MatchError("done")) 553 Eventually(done).Should(BeClosed()) 554 }) 555 556 It("errors when parsing the frame on the control stream fails", func() { 557 b := quicvarint.Append(nil, streamTypeControlStream) 558 b = (&settingsFrame{}).Append(b) 559 r := bytes.NewReader(b[:len(b)-1]) 560 controlStr := mockquic.NewMockStream(mockCtrl) 561 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 562 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 563 return controlStr, nil 564 }) 565 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 566 <-testDone 567 return nil, errors.New("test done") 568 }) 569 done := make(chan struct{}) 570 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 571 defer GinkgoRecover() 572 Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) 573 close(done) 574 }) 575 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 576 Expect(err).To(MatchError("done")) 577 Eventually(done).Should(BeClosed()) 578 }) 579 580 It("errors when parsing the server opens a push stream", func() { 581 buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) 582 controlStr := mockquic.NewMockStream(mockCtrl) 583 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 584 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 585 return controlStr, nil 586 }) 587 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 588 <-testDone 589 return nil, errors.New("test done") 590 }) 591 done := make(chan struct{}) 592 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 593 defer GinkgoRecover() 594 Expect(code).To(BeEquivalentTo(ErrCodeIDError)) 595 close(done) 596 }) 597 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 598 Expect(err).To(MatchError("done")) 599 Eventually(done).Should(BeClosed()) 600 }) 601 602 It("errors when the server advertises datagram support (and we enabled support for it)", func() { 603 cl.opts.EnableDatagram = true 604 b := quicvarint.Append(nil, streamTypeControlStream) 605 b = (&settingsFrame{Datagram: true}).Append(b) 606 r := bytes.NewReader(b) 607 controlStr := mockquic.NewMockStream(mockCtrl) 608 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 609 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 610 return controlStr, nil 611 }) 612 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 613 <-testDone 614 return nil, errors.New("test done") 615 }) 616 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 617 done := make(chan struct{}) 618 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { 619 defer GinkgoRecover() 620 Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) 621 Expect(reason).To(Equal("missing QUIC Datagram support")) 622 close(done) 623 }) 624 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 625 Expect(err).To(MatchError("done")) 626 Eventually(done).Should(BeClosed()) 627 }) 628 }) 629 630 Context("Doing requests", func() { 631 var ( 632 req *http.Request 633 str *mockquic.MockStream 634 conn *mockquic.MockEarlyConnection 635 settingsFrameWritten chan struct{} 636 ) 637 testDone := make(chan struct{}) 638 639 getHeadersFrame := func(headers map[string]string) []byte { 640 headerBuf := &bytes.Buffer{} 641 enc := qpack.NewEncoder(headerBuf) 642 for name, value := range headers { 643 Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed()) 644 } 645 Expect(enc.Close()).To(Succeed()) 646 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 647 b = append(b, headerBuf.Bytes()...) 648 return b 649 } 650 651 decodeHeader := func(str io.Reader) map[string]string { 652 fields := make(map[string]string) 653 decoder := qpack.NewDecoder(nil) 654 655 frame, err := parseNextFrame(str, nil) 656 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 657 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 658 headersFrame := frame.(*headersFrame) 659 data := make([]byte, headersFrame.Length) 660 _, err = io.ReadFull(str, data) 661 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 662 hfs, err := decoder.DecodeFull(data) 663 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 664 for _, p := range hfs { 665 fields[p.Name] = p.Value 666 } 667 return fields 668 } 669 670 getResponse := func(status int) []byte { 671 buf := &bytes.Buffer{} 672 rstr := mockquic.NewMockStream(mockCtrl) 673 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 674 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 675 rw.WriteHeader(status) 676 rw.Flush() 677 return buf.Bytes() 678 } 679 680 BeforeEach(func() { 681 settingsFrameWritten = make(chan struct{}) 682 controlStr := mockquic.NewMockStream(mockCtrl) 683 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { 684 defer GinkgoRecover() 685 r := bytes.NewReader(b) 686 streamType, err := quicvarint.Read(r) 687 Expect(err).ToNot(HaveOccurred()) 688 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 689 close(settingsFrameWritten) 690 }) // SETTINGS frame 691 str = mockquic.NewMockStream(mockCtrl) 692 conn = mockquic.NewMockEarlyConnection(mockCtrl) 693 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 694 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 695 <-testDone 696 return nil, errors.New("test done") 697 }) 698 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 699 return conn, nil 700 } 701 var err error 702 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 703 Expect(err).ToNot(HaveOccurred()) 704 }) 705 706 AfterEach(func() { 707 testDone <- struct{}{} 708 Eventually(settingsFrameWritten).Should(BeClosed()) 709 }) 710 711 It("errors if it can't open a stream", func() { 712 testErr := errors.New("stream open error") 713 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 714 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 715 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 716 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 717 Expect(err).To(MatchError(testErr)) 718 }) 719 720 It("performs a 0-RTT request", func() { 721 testErr := errors.New("stream open error") 722 req.Method = MethodGet0RTT 723 // don't EXPECT any calls to HandshakeComplete() 724 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 725 buf := &bytes.Buffer{} 726 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 727 str.EXPECT().Close() 728 str.EXPECT().CancelWrite(gomock.Any()) 729 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 730 return 0, testErr 731 }) 732 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 733 Expect(err).To(MatchError(testErr)) 734 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) 735 }) 736 737 It("returns a response", func() { 738 rspBuf := bytes.NewBuffer(getResponse(418)) 739 gomock.InOrder( 740 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 741 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 742 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 743 ) 744 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 745 str.EXPECT().Close() 746 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 747 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 748 Expect(err).ToNot(HaveOccurred()) 749 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 750 Expect(rsp.ProtoMajor).To(Equal(3)) 751 Expect(rsp.StatusCode).To(Equal(418)) 752 Expect(rsp.Request).ToNot(BeNil()) 753 }) 754 755 It("doesn't close the request stream, with DontCloseRequestStream set", func() { 756 rspBuf := bytes.NewBuffer(getResponse(418)) 757 gomock.InOrder( 758 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 759 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 760 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 761 ) 762 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 763 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 764 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) 765 Expect(err).ToNot(HaveOccurred()) 766 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 767 Expect(rsp.ProtoMajor).To(Equal(3)) 768 Expect(rsp.StatusCode).To(Equal(418)) 769 }) 770 771 Context("requests containing a Body", func() { 772 var strBuf *bytes.Buffer 773 774 BeforeEach(func() { 775 strBuf = &bytes.Buffer{} 776 gomock.InOrder( 777 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 778 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 779 ) 780 body := &mockBody{} 781 body.SetData([]byte("request body")) 782 var err error 783 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 784 Expect(err).ToNot(HaveOccurred()) 785 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 786 }) 787 788 It("sends a request", func() { 789 done := make(chan struct{}) 790 gomock.InOrder( 791 str.EXPECT().Close().Do(func() { close(done) }), 792 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors 793 ) 794 // the response body is sent asynchronously, while already reading the response 795 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 796 <-done 797 return 0, errors.New("test done") 798 }) 799 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 800 Expect(err).To(MatchError("test done")) 801 hfs := decodeHeader(strBuf) 802 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 803 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 804 }) 805 806 It("returns the error that occurred when reading the body", func() { 807 req.Body.(*mockBody).readErr = errors.New("testErr") 808 done := make(chan struct{}) 809 gomock.InOrder( 810 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 811 close(done) 812 }), 813 str.EXPECT().CancelWrite(gomock.Any()), 814 ) 815 816 // the response body is sent asynchronously, while already reading the response 817 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 818 <-done 819 return 0, errors.New("test done") 820 }) 821 closed := make(chan struct{}) 822 str.EXPECT().Close().Do(func() { close(closed) }) 823 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 824 Expect(err).To(MatchError("test done")) 825 Eventually(closed).Should(BeClosed()) 826 }) 827 828 It("sets the Content-Length", func() { 829 done := make(chan struct{}) 830 b := getHeadersFrame(map[string]string{ 831 ":status": "200", 832 "Content-Length": "1337", 833 }) 834 b = (&dataFrame{Length: 0x6}).Append(b) 835 b = append(b, []byte("foobar")...) 836 r := bytes.NewReader(b) 837 str.EXPECT().Close().Do(func() { close(done) }) 838 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 839 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors 840 // the response body is sent asynchronously, while already reading the response 841 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 842 req, err := cl.RoundTripOpt(req, RoundTripOpt{}) 843 Expect(err).ToNot(HaveOccurred()) 844 Expect(req.ContentLength).To(BeEquivalentTo(1337)) 845 Eventually(done).Should(BeClosed()) 846 }) 847 848 It("closes the connection when the first frame is not a HEADERS frame", func() { 849 b := (&dataFrame{Length: 0x42}).Append(nil) 850 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 851 closed := make(chan struct{}) 852 r := bytes.NewReader(b) 853 str.EXPECT().Close().Do(func() { close(closed) }) 854 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 855 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 856 Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) 857 Eventually(closed).Should(BeClosed()) 858 }) 859 860 It("cancels the stream when the HEADERS frame is too large", func() { 861 b := (&headersFrame{Length: 1338}).Append(nil) 862 r := bytes.NewReader(b) 863 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 864 closed := make(chan struct{}) 865 str.EXPECT().Close().Do(func() { close(closed) }) 866 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 867 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 868 Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) 869 Eventually(closed).Should(BeClosed()) 870 }) 871 }) 872 873 Context("request cancellations", func() { 874 for _, dontClose := range []bool{false, true} { 875 dontClose := dontClose 876 877 Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { 878 roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} 879 880 It("cancels a request while waiting for the handshake to complete", func() { 881 ctx, cancel := context.WithCancel(context.Background()) 882 req := req.WithContext(ctx) 883 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 884 885 errChan := make(chan error) 886 go func() { 887 _, err := cl.RoundTripOpt(req, roundTripOpt) 888 errChan <- err 889 }() 890 Consistently(errChan).ShouldNot(Receive()) 891 cancel() 892 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 893 }) 894 895 It("cancels a request while the request is still in flight", func() { 896 ctx, cancel := context.WithCancel(context.Background()) 897 req := req.WithContext(ctx) 898 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 899 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 900 buf := &bytes.Buffer{} 901 str.EXPECT().Close().MaxTimes(1) 902 903 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 904 905 done := make(chan struct{}) 906 canceled := make(chan struct{}) 907 gomock.InOrder( 908 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 909 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 910 ) 911 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 912 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 913 cancel() 914 <-canceled 915 return 0, errors.New("test done") 916 }) 917 _, err := cl.RoundTripOpt(req, roundTripOpt) 918 Expect(err).To(MatchError("test done")) 919 Eventually(done).Should(BeClosed()) 920 }) 921 }) 922 } 923 924 It("cancels a request after the response arrived", func() { 925 rspBuf := bytes.NewBuffer(getResponse(404)) 926 927 ctx, cancel := context.WithCancel(context.Background()) 928 req := req.WithContext(ctx) 929 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 930 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 931 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 932 buf := &bytes.Buffer{} 933 str.EXPECT().Close().MaxTimes(1) 934 935 done := make(chan struct{}) 936 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 937 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 938 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 939 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 940 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 941 Expect(err).ToNot(HaveOccurred()) 942 cancel() 943 Eventually(done).Should(BeClosed()) 944 }) 945 }) 946 947 Context("gzip compression", func() { 948 BeforeEach(func() { 949 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 950 }) 951 952 It("adds the gzip header to requests", func() { 953 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 954 buf := &bytes.Buffer{} 955 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 956 gomock.InOrder( 957 str.EXPECT().Close(), 958 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 959 ) 960 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 961 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 962 Expect(err).To(MatchError("test done")) 963 hfs := decodeHeader(buf) 964 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 965 }) 966 967 It("doesn't add gzip if the header disable it", func() { 968 client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) 969 Expect(err).ToNot(HaveOccurred()) 970 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 971 buf := &bytes.Buffer{} 972 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 973 gomock.InOrder( 974 str.EXPECT().Close(), 975 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 976 ) 977 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 978 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 979 Expect(err).To(MatchError("test done")) 980 hfs := decodeHeader(buf) 981 Expect(hfs).ToNot(HaveKey("accept-encoding")) 982 }) 983 984 It("decompresses the response", func() { 985 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 986 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 987 buf := &bytes.Buffer{} 988 rstr := mockquic.NewMockStream(mockCtrl) 989 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 990 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 991 rw.Header().Set("Content-Encoding", "gzip") 992 gz := gzip.NewWriter(rw) 993 gz.Write([]byte("gzipped response")) 994 gz.Close() 995 rw.Flush() 996 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 997 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 998 str.EXPECT().Close() 999 1000 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1001 Expect(err).ToNot(HaveOccurred()) 1002 data, err := io.ReadAll(rsp.Body) 1003 Expect(err).ToNot(HaveOccurred()) 1004 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 1005 Expect(string(data)).To(Equal("gzipped response")) 1006 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1007 Expect(rsp.Uncompressed).To(BeTrue()) 1008 }) 1009 1010 It("only decompresses the response if the response contains the right content-encoding header", func() { 1011 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1012 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1013 buf := &bytes.Buffer{} 1014 rstr := mockquic.NewMockStream(mockCtrl) 1015 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1016 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1017 rw.Write([]byte("not gzipped")) 1018 rw.Flush() 1019 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1020 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1021 str.EXPECT().Close() 1022 1023 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1024 Expect(err).ToNot(HaveOccurred()) 1025 data, err := io.ReadAll(rsp.Body) 1026 Expect(err).ToNot(HaveOccurred()) 1027 Expect(string(data)).To(Equal("not gzipped")) 1028 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1029 }) 1030 }) 1031 }) 1032 })