github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "crypto/tls" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/TugasAkhir-QUIC/quic-go" 16 mockquic "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/quic" 17 "github.com/TugasAkhir-QUIC/quic-go/internal/protocol" 18 "github.com/TugasAkhir-QUIC/quic-go/internal/utils" 19 "github.com/TugasAkhir-QUIC/quic-go/quicvarint" 20 21 "github.com/quic-go/qpack" 22 23 . "github.com/onsi/ginkgo/v2" 24 . "github.com/onsi/gomega" 25 "go.uber.org/mock/gomock" 26 ) 27 28 var _ = Describe("Client", func() { 29 var ( 30 cl *client 31 req *http.Request 32 origDialAddr = dialAddr 33 handshakeChan <-chan struct{} // a closed chan 34 ) 35 36 BeforeEach(func() { 37 origDialAddr = dialAddr 38 hostname := "quic.clemente.io:1337" 39 c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) 40 Expect(err).ToNot(HaveOccurred()) 41 cl = c.(*client) 42 Expect(cl.hostname).To(Equal(hostname)) 43 44 req, err = http.NewRequest("GET", "https://localhost:1337", nil) 45 Expect(err).ToNot(HaveOccurred()) 46 47 ch := make(chan struct{}) 48 close(ch) 49 handshakeChan = ch 50 }) 51 52 AfterEach(func() { 53 dialAddr = origDialAddr 54 }) 55 56 It("rejects quic.Configs that allow multiple QUIC versions", func() { 57 qconf := &quic.Config{ 58 Versions: []quic.Version{protocol.Version2, protocol.Version1}, 59 } 60 _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) 61 Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) 62 }) 63 64 It("uses the default QUIC and TLS config if none is give", func() { 65 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 66 Expect(err).ToNot(HaveOccurred()) 67 var dialAddrCalled bool 68 dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 69 Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) 70 Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) 71 Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) 72 dialAddrCalled = true 73 return nil, errors.New("test done") 74 } 75 client.RoundTripOpt(req, RoundTripOpt{}) 76 Expect(dialAddrCalled).To(BeTrue()) 77 }) 78 79 It("adds the port to the hostname, if none is given", func() { 80 client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) 81 Expect(err).ToNot(HaveOccurred()) 82 var dialAddrCalled bool 83 dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { 84 Expect(hostname).To(Equal("quic.clemente.io:443")) 85 dialAddrCalled = true 86 return nil, errors.New("test done") 87 } 88 req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) 89 Expect(err).ToNot(HaveOccurred()) 90 client.RoundTripOpt(req, RoundTripOpt{}) 91 Expect(dialAddrCalled).To(BeTrue()) 92 }) 93 94 It("sets the ServerName in the tls.Config, if not set", func() { 95 const host = "foo.bar" 96 dialCalled := false 97 dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 98 Expect(tlsCfg.ServerName).To(Equal(host)) 99 dialCalled = true 100 return nil, errors.New("test done") 101 } 102 client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) 103 Expect(err).ToNot(HaveOccurred()) 104 req, err := http.NewRequest("GET", "https://foo.bar", nil) 105 Expect(err).ToNot(HaveOccurred()) 106 client.RoundTripOpt(req, RoundTripOpt{}) 107 Expect(dialCalled).To(BeTrue()) 108 }) 109 110 It("uses the TLS config and QUIC config", func() { 111 tlsConf := &tls.Config{ 112 ServerName: "foo.bar", 113 NextProtos: []string{"proto foo", "proto bar"}, 114 } 115 quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} 116 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) 117 Expect(err).ToNot(HaveOccurred()) 118 var dialAddrCalled bool 119 dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 120 Expect(host).To(Equal("localhost:1337")) 121 Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) 122 Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) 123 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 124 dialAddrCalled = true 125 return nil, errors.New("test done") 126 } 127 client.RoundTripOpt(req, RoundTripOpt{}) 128 Expect(dialAddrCalled).To(BeTrue()) 129 // make sure the original tls.Config was not modified 130 Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) 131 }) 132 133 It("uses the custom dialer, if provided", func() { 134 testErr := errors.New("test done") 135 tlsConf := &tls.Config{ServerName: "foo.bar"} 136 quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} 137 ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 138 defer cancel() 139 var dialerCalled bool 140 dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 141 Expect(ctxP).To(Equal(ctx)) 142 Expect(address).To(Equal("localhost:1337")) 143 Expect(tlsConfP.ServerName).To(Equal("foo.bar")) 144 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 145 dialerCalled = true 146 return nil, testErr 147 } 148 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) 149 Expect(err).ToNot(HaveOccurred()) 150 _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) 151 Expect(err).To(MatchError(testErr)) 152 Expect(dialerCalled).To(BeTrue()) 153 }) 154 155 It("enables HTTP/3 Datagrams", func() { 156 testErr := errors.New("handshake error") 157 client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) 158 Expect(err).ToNot(HaveOccurred()) 159 dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 160 Expect(quicConf.EnableDatagrams).To(BeTrue()) 161 return nil, testErr 162 } 163 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 164 Expect(err).To(MatchError(testErr)) 165 }) 166 167 It("errors when dialing fails", func() { 168 testErr := errors.New("handshake error") 169 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 170 Expect(err).ToNot(HaveOccurred()) 171 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 172 return nil, testErr 173 } 174 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 175 Expect(err).To(MatchError(testErr)) 176 }) 177 178 It("closes correctly if connection was not created", func() { 179 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 180 Expect(err).ToNot(HaveOccurred()) 181 Expect(client.Close()).To(Succeed()) 182 }) 183 184 Context("validating the address", func() { 185 It("refuses to do requests for the wrong host", func() { 186 req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 187 Expect(err).ToNot(HaveOccurred()) 188 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 189 Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 190 }) 191 192 It("allows requests using a different scheme", func() { 193 testErr := errors.New("handshake error") 194 req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) 195 Expect(err).ToNot(HaveOccurred()) 196 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 197 return nil, testErr 198 } 199 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 200 Expect(err).To(MatchError(testErr)) 201 }) 202 }) 203 204 Context("hijacking bidirectional streams", func() { 205 var ( 206 request *http.Request 207 conn *mockquic.MockEarlyConnection 208 settingsFrameWritten chan struct{} 209 ) 210 testDone := make(chan struct{}) 211 212 BeforeEach(func() { 213 testDone = make(chan struct{}) 214 settingsFrameWritten = make(chan struct{}) 215 controlStr := mockquic.NewMockStream(mockCtrl) 216 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 217 defer GinkgoRecover() 218 close(settingsFrameWritten) 219 return len(b), nil 220 }) 221 conn = mockquic.NewMockEarlyConnection(mockCtrl) 222 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 223 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 224 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 225 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 226 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 227 return conn, nil 228 } 229 var err error 230 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 231 Expect(err).ToNot(HaveOccurred()) 232 }) 233 234 AfterEach(func() { 235 testDone <- struct{}{} 236 Eventually(settingsFrameWritten).Should(BeClosed()) 237 }) 238 239 It("hijacks a bidirectional stream of unknown frame type", func() { 240 frameTypeChan := make(chan FrameType, 1) 241 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 242 Expect(e).ToNot(HaveOccurred()) 243 frameTypeChan <- ft 244 return true, nil 245 } 246 247 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 248 unknownStr := mockquic.NewMockStream(mockCtrl) 249 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 250 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 251 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 252 <-testDone 253 return nil, errors.New("test done") 254 }) 255 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 256 Expect(err).To(MatchError("done")) 257 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 258 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 259 }) 260 261 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 262 frameTypeChan := make(chan FrameType, 1) 263 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 264 Expect(e).ToNot(HaveOccurred()) 265 frameTypeChan <- ft 266 return false, nil 267 } 268 269 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 270 unknownStr := mockquic.NewMockStream(mockCtrl) 271 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 272 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 273 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 274 <-testDone 275 return nil, errors.New("test done") 276 }) 277 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 278 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 279 Expect(err).To(MatchError("done")) 280 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 281 }) 282 283 It("closes the connection when hijacker returned error", func() { 284 frameTypeChan := make(chan FrameType, 1) 285 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 286 Expect(e).ToNot(HaveOccurred()) 287 frameTypeChan <- ft 288 return false, errors.New("error in hijacker") 289 } 290 291 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 292 unknownStr := mockquic.NewMockStream(mockCtrl) 293 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 294 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 295 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 296 <-testDone 297 return nil, errors.New("test done") 298 }) 299 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 300 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 301 Expect(err).To(MatchError("done")) 302 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 303 }) 304 305 It("handles errors that occur when reading the frame type", func() { 306 testErr := errors.New("test error") 307 unknownStr := mockquic.NewMockStream(mockCtrl) 308 done := make(chan struct{}) 309 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { 310 defer close(done) 311 Expect(e).To(MatchError(testErr)) 312 Expect(ft).To(BeZero()) 313 Expect(str).To(Equal(unknownStr)) 314 return false, nil 315 } 316 317 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 318 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 319 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 320 <-testDone 321 return nil, errors.New("test done") 322 }) 323 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 324 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 325 Expect(err).To(MatchError("done")) 326 Eventually(done).Should(BeClosed()) 327 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 328 }) 329 }) 330 331 Context("hijacking unidirectional streams", func() { 332 var ( 333 req *http.Request 334 conn *mockquic.MockEarlyConnection 335 settingsFrameWritten chan struct{} 336 ) 337 testDone := make(chan struct{}) 338 339 BeforeEach(func() { 340 testDone = make(chan struct{}) 341 settingsFrameWritten = make(chan struct{}) 342 controlStr := mockquic.NewMockStream(mockCtrl) 343 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 344 defer GinkgoRecover() 345 close(settingsFrameWritten) 346 return len(b), nil 347 }) 348 conn = mockquic.NewMockEarlyConnection(mockCtrl) 349 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 350 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 351 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 352 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 353 return conn, nil 354 } 355 var err error 356 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 357 Expect(err).ToNot(HaveOccurred()) 358 }) 359 360 AfterEach(func() { 361 testDone <- struct{}{} 362 Eventually(settingsFrameWritten).Should(BeClosed()) 363 }) 364 365 It("hijacks an unidirectional stream of unknown stream type", func() { 366 streamTypeChan := make(chan StreamType, 1) 367 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 368 Expect(err).ToNot(HaveOccurred()) 369 streamTypeChan <- st 370 return true 371 } 372 373 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 374 unknownStr := mockquic.NewMockStream(mockCtrl) 375 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 376 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 377 return unknownStr, nil 378 }) 379 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 380 <-testDone 381 return nil, errors.New("test done") 382 }) 383 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 384 Expect(err).To(MatchError("done")) 385 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 386 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 387 }) 388 389 It("handles errors that occur when reading the stream type", func() { 390 testErr := errors.New("test error") 391 done := make(chan struct{}) 392 unknownStr := mockquic.NewMockStream(mockCtrl) 393 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 394 defer close(done) 395 Expect(st).To(BeZero()) 396 Expect(str).To(Equal(unknownStr)) 397 Expect(err).To(MatchError(testErr)) 398 return true 399 } 400 401 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 402 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 403 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 404 <-testDone 405 return nil, errors.New("test done") 406 }) 407 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 408 Expect(err).To(MatchError("done")) 409 Eventually(done).Should(BeClosed()) 410 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 411 }) 412 413 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 414 streamTypeChan := make(chan StreamType, 1) 415 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 416 Expect(err).ToNot(HaveOccurred()) 417 streamTypeChan <- st 418 return false 419 } 420 421 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 422 unknownStr := mockquic.NewMockStream(mockCtrl) 423 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 424 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 425 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 426 return unknownStr, nil 427 }) 428 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 429 <-testDone 430 return nil, errors.New("test done") 431 }) 432 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 433 Expect(err).To(MatchError("done")) 434 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 435 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 436 }) 437 }) 438 439 Context("control stream handling", func() { 440 var ( 441 req *http.Request 442 conn *mockquic.MockEarlyConnection 443 settingsFrameWritten chan struct{} 444 ) 445 testDone := make(chan struct{}) 446 447 BeforeEach(func() { 448 settingsFrameWritten = make(chan struct{}) 449 controlStr := mockquic.NewMockStream(mockCtrl) 450 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 451 defer GinkgoRecover() 452 close(settingsFrameWritten) 453 return len(b), nil 454 }) 455 conn = mockquic.NewMockEarlyConnection(mockCtrl) 456 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 457 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 458 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 459 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 460 return conn, nil 461 } 462 var err error 463 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 464 Expect(err).ToNot(HaveOccurred()) 465 }) 466 467 AfterEach(func() { 468 testDone <- struct{}{} 469 Eventually(settingsFrameWritten).Should(BeClosed()) 470 }) 471 472 It("parses the SETTINGS frame", func() { 473 b := quicvarint.Append(nil, streamTypeControlStream) 474 b = (&settingsFrame{}).Append(b) 475 r := bytes.NewReader(b) 476 controlStr := mockquic.NewMockStream(mockCtrl) 477 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 478 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 479 return controlStr, nil 480 }) 481 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 482 <-testDone 483 return nil, errors.New("test done") 484 }) 485 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 486 Expect(err).To(MatchError("done")) 487 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 488 }) 489 490 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 491 streamType := t 492 name := "encoder" 493 if streamType == streamTypeQPACKDecoderStream { 494 name = "decoder" 495 } 496 497 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 498 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 499 str := mockquic.NewMockStream(mockCtrl) 500 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 501 502 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 503 return str, nil 504 }) 505 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 506 <-testDone 507 return nil, errors.New("test done") 508 }) 509 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 510 Expect(err).To(MatchError("done")) 511 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 512 }) 513 } 514 515 It("resets streams Other than the control stream and the QPACK streams", func() { 516 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) 517 str := mockquic.NewMockStream(mockCtrl) 518 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 519 done := make(chan struct{}) 520 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) 521 522 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 523 return str, nil 524 }) 525 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 526 <-testDone 527 return nil, errors.New("test done") 528 }) 529 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 530 Expect(err).To(MatchError("done")) 531 Eventually(done).Should(BeClosed()) 532 }) 533 534 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 535 b := quicvarint.Append(nil, streamTypeControlStream) 536 b = (&dataFrame{}).Append(b) 537 r := bytes.NewReader(b) 538 controlStr := mockquic.NewMockStream(mockCtrl) 539 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 540 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 541 return controlStr, nil 542 }) 543 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 544 <-testDone 545 return nil, errors.New("test done") 546 }) 547 done := make(chan struct{}) 548 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 549 close(done) 550 return nil 551 }) 552 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 553 Expect(err).To(MatchError("done")) 554 Eventually(done).Should(BeClosed()) 555 }) 556 557 It("errors when parsing the frame on the control stream fails", func() { 558 b := quicvarint.Append(nil, streamTypeControlStream) 559 b = (&settingsFrame{}).Append(b) 560 r := bytes.NewReader(b[:len(b)-1]) 561 controlStr := mockquic.NewMockStream(mockCtrl) 562 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 563 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 564 return controlStr, nil 565 }) 566 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 567 <-testDone 568 return nil, errors.New("test done") 569 }) 570 done := make(chan struct{}) 571 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error { 572 close(done) 573 return nil 574 }) 575 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 576 Expect(err).To(MatchError("done")) 577 Eventually(done).Should(BeClosed()) 578 }) 579 580 It("errors when parsing the server opens a push stream", func() { 581 buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) 582 controlStr := mockquic.NewMockStream(mockCtrl) 583 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 584 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 585 return controlStr, nil 586 }) 587 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 588 <-testDone 589 return nil, errors.New("test done") 590 }) 591 done := make(chan struct{}) 592 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 593 close(done) 594 return nil 595 }) 596 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 597 Expect(err).To(MatchError("done")) 598 Eventually(done).Should(BeClosed()) 599 }) 600 601 It("errors when the server advertises datagram support (and we enabled support for it)", func() { 602 cl.opts.EnableDatagram = true 603 b := quicvarint.Append(nil, streamTypeControlStream) 604 b = (&settingsFrame{Datagram: true}).Append(b) 605 r := bytes.NewReader(b) 606 controlStr := mockquic.NewMockStream(mockCtrl) 607 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 608 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 609 return controlStr, nil 610 }) 611 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 612 <-testDone 613 return nil, errors.New("test done") 614 }) 615 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 616 done := make(chan struct{}) 617 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { 618 close(done) 619 return nil 620 }) 621 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 622 Expect(err).To(MatchError("done")) 623 Eventually(done).Should(BeClosed()) 624 }) 625 }) 626 627 Context("Doing requests", func() { 628 var ( 629 req *http.Request 630 str *mockquic.MockStream 631 conn *mockquic.MockEarlyConnection 632 settingsFrameWritten chan struct{} 633 ) 634 testDone := make(chan struct{}) 635 636 decodeHeader := func(str io.Reader) map[string]string { 637 fields := make(map[string]string) 638 decoder := qpack.NewDecoder(nil) 639 640 frame, err := parseNextFrame(str, nil) 641 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 642 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 643 headersFrame := frame.(*headersFrame) 644 data := make([]byte, headersFrame.Length) 645 _, err = io.ReadFull(str, data) 646 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 647 hfs, err := decoder.DecodeFull(data) 648 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 649 for _, p := range hfs { 650 fields[p.Name] = p.Value 651 } 652 return fields 653 } 654 655 getResponse := func(status int) []byte { 656 buf := &bytes.Buffer{} 657 rstr := mockquic.NewMockStream(mockCtrl) 658 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 659 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 660 rw.WriteHeader(status) 661 rw.Flush() 662 return buf.Bytes() 663 } 664 665 BeforeEach(func() { 666 settingsFrameWritten = make(chan struct{}) 667 controlStr := mockquic.NewMockStream(mockCtrl) 668 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 669 defer GinkgoRecover() 670 r := bytes.NewReader(b) 671 streamType, err := quicvarint.Read(r) 672 Expect(err).ToNot(HaveOccurred()) 673 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 674 close(settingsFrameWritten) 675 return len(b), nil 676 }) // SETTINGS frame 677 str = mockquic.NewMockStream(mockCtrl) 678 conn = mockquic.NewMockEarlyConnection(mockCtrl) 679 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 680 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 681 <-testDone 682 return nil, errors.New("test done") 683 }) 684 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 685 return conn, nil 686 } 687 var err error 688 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 689 Expect(err).ToNot(HaveOccurred()) 690 }) 691 692 AfterEach(func() { 693 testDone <- struct{}{} 694 Eventually(settingsFrameWritten).Should(BeClosed()) 695 }) 696 697 It("errors if it can't open a stream", func() { 698 testErr := errors.New("stream open error") 699 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 700 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 701 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 702 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 703 Expect(err).To(MatchError(testErr)) 704 }) 705 706 It("performs a 0-RTT request", func() { 707 testErr := errors.New("stream open error") 708 req.Method = MethodGet0RTT 709 // don't EXPECT any calls to HandshakeComplete() 710 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 711 buf := &bytes.Buffer{} 712 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 713 str.EXPECT().Close() 714 str.EXPECT().CancelWrite(gomock.Any()) 715 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 716 return 0, testErr 717 }) 718 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 719 Expect(err).To(MatchError(testErr)) 720 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) 721 }) 722 723 It("returns a response", func() { 724 rspBuf := bytes.NewBuffer(getResponse(418)) 725 gomock.InOrder( 726 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 727 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 728 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 729 ) 730 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 731 str.EXPECT().Close() 732 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 733 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 734 Expect(err).ToNot(HaveOccurred()) 735 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 736 Expect(rsp.ProtoMajor).To(Equal(3)) 737 Expect(rsp.StatusCode).To(Equal(418)) 738 Expect(rsp.Request).ToNot(BeNil()) 739 }) 740 741 It("doesn't close the request stream, with DontCloseRequestStream set", func() { 742 rspBuf := bytes.NewBuffer(getResponse(418)) 743 gomock.InOrder( 744 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 745 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 746 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 747 ) 748 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 749 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 750 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) 751 Expect(err).ToNot(HaveOccurred()) 752 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 753 Expect(rsp.ProtoMajor).To(Equal(3)) 754 Expect(rsp.StatusCode).To(Equal(418)) 755 }) 756 757 Context("requests containing a Body", func() { 758 var strBuf *bytes.Buffer 759 760 BeforeEach(func() { 761 strBuf = &bytes.Buffer{} 762 gomock.InOrder( 763 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 764 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 765 ) 766 body := &mockBody{} 767 body.SetData([]byte("request body")) 768 var err error 769 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 770 Expect(err).ToNot(HaveOccurred()) 771 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 772 }) 773 774 It("sends a request", func() { 775 done := make(chan struct{}) 776 gomock.InOrder( 777 str.EXPECT().Close().Do(func() error { close(done); return nil }), 778 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors 779 ) 780 // the response body is sent asynchronously, while already reading the response 781 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 782 <-done 783 return 0, errors.New("test done") 784 }) 785 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 786 Expect(err).To(MatchError("test done")) 787 hfs := decodeHeader(strBuf) 788 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 789 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 790 }) 791 792 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 793 req.ContentLength = 7 794 var once sync.Once 795 done := make(chan struct{}) 796 gomock.InOrder( 797 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 798 once.Do(func() { 799 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 800 close(done) 801 }) 802 }).AnyTimes(), 803 str.EXPECT().Close().MaxTimes(1), 804 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 805 ) 806 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 807 <-done 808 return 0, errors.New("done") 809 }) 810 cl.RoundTripOpt(req, RoundTripOpt{}) 811 Expect(strBuf.String()).To(ContainSubstring("request")) 812 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 813 }) 814 815 It("returns the error that occurred when reading the body", func() { 816 req.Body.(*mockBody).readErr = errors.New("testErr") 817 done := make(chan struct{}) 818 gomock.InOrder( 819 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 820 close(done) 821 }), 822 str.EXPECT().CancelWrite(gomock.Any()), 823 ) 824 825 // the response body is sent asynchronously, while already reading the response 826 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 827 <-done 828 return 0, errors.New("test done") 829 }) 830 closed := make(chan struct{}) 831 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 832 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 833 Expect(err).To(MatchError("test done")) 834 Eventually(closed).Should(BeClosed()) 835 }) 836 837 It("closes the connection when the first frame is not a HEADERS frame", func() { 838 b := (&dataFrame{Length: 0x42}).Append(nil) 839 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 840 closed := make(chan struct{}) 841 r := bytes.NewReader(b) 842 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 843 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 844 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 845 Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) 846 Eventually(closed).Should(BeClosed()) 847 }) 848 849 It("cancels the stream when parsing the headers fails", func() { 850 headerBuf := &bytes.Buffer{} 851 enc := qpack.NewEncoder(headerBuf) 852 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 853 Expect(enc.Close()).To(Succeed()) 854 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 855 b = append(b, headerBuf.Bytes()...) 856 857 r := bytes.NewReader(b) 858 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 859 closed := make(chan struct{}) 860 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 861 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 862 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 863 Expect(err).To(HaveOccurred()) 864 Eventually(closed).Should(BeClosed()) 865 }) 866 867 It("cancels the stream when the HEADERS frame is too large", func() { 868 b := (&headersFrame{Length: 1338}).Append(nil) 869 r := bytes.NewReader(b) 870 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 871 closed := make(chan struct{}) 872 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 873 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 874 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 875 Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) 876 Eventually(closed).Should(BeClosed()) 877 }) 878 }) 879 880 Context("request cancellations", func() { 881 for _, dontClose := range []bool{false, true} { 882 dontClose := dontClose 883 884 Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { 885 roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} 886 887 It("cancels a request while waiting for the handshake to complete", func() { 888 ctx, cancel := context.WithCancel(context.Background()) 889 req := req.WithContext(ctx) 890 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 891 892 errChan := make(chan error) 893 go func() { 894 _, err := cl.RoundTripOpt(req, roundTripOpt) 895 errChan <- err 896 }() 897 Consistently(errChan).ShouldNot(Receive()) 898 cancel() 899 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 900 }) 901 902 It("cancels a request while the request is still in flight", func() { 903 ctx, cancel := context.WithCancel(context.Background()) 904 req := req.WithContext(ctx) 905 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 906 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 907 buf := &bytes.Buffer{} 908 str.EXPECT().Close().MaxTimes(1) 909 910 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 911 912 done := make(chan struct{}) 913 canceled := make(chan struct{}) 914 gomock.InOrder( 915 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 916 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 917 ) 918 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 919 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 920 cancel() 921 <-canceled 922 return 0, errors.New("test done") 923 }) 924 _, err := cl.RoundTripOpt(req, roundTripOpt) 925 Expect(err).To(MatchError(context.Canceled)) 926 Eventually(done).Should(BeClosed()) 927 }) 928 }) 929 } 930 931 It("cancels a request after the response arrived", func() { 932 rspBuf := bytes.NewBuffer(getResponse(404)) 933 934 ctx, cancel := context.WithCancel(context.Background()) 935 req := req.WithContext(ctx) 936 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 937 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 938 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 939 buf := &bytes.Buffer{} 940 str.EXPECT().Close().MaxTimes(1) 941 942 done := make(chan struct{}) 943 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 944 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 945 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 946 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 947 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 948 Expect(err).ToNot(HaveOccurred()) 949 cancel() 950 Eventually(done).Should(BeClosed()) 951 }) 952 }) 953 954 Context("gzip compression", func() { 955 BeforeEach(func() { 956 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 957 }) 958 959 It("adds the gzip header to requests", func() { 960 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 961 buf := &bytes.Buffer{} 962 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 963 gomock.InOrder( 964 str.EXPECT().Close(), 965 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 966 ) 967 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 968 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 969 Expect(err).To(MatchError("test done")) 970 hfs := decodeHeader(buf) 971 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 972 }) 973 974 It("doesn't add gzip if the header disable it", func() { 975 client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) 976 Expect(err).ToNot(HaveOccurred()) 977 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 978 buf := &bytes.Buffer{} 979 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 980 gomock.InOrder( 981 str.EXPECT().Close(), 982 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 983 ) 984 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 985 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 986 Expect(err).To(MatchError("test done")) 987 hfs := decodeHeader(buf) 988 Expect(hfs).ToNot(HaveKey("accept-encoding")) 989 }) 990 991 It("decompresses the response", func() { 992 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 993 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 994 buf := &bytes.Buffer{} 995 rstr := mockquic.NewMockStream(mockCtrl) 996 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 997 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 998 rw.Header().Set("Content-Encoding", "gzip") 999 gz := gzip.NewWriter(rw) 1000 gz.Write([]byte("gzipped response")) 1001 gz.Close() 1002 rw.Flush() 1003 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1004 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1005 str.EXPECT().Close() 1006 1007 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1008 Expect(err).ToNot(HaveOccurred()) 1009 data, err := io.ReadAll(rsp.Body) 1010 Expect(err).ToNot(HaveOccurred()) 1011 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 1012 Expect(string(data)).To(Equal("gzipped response")) 1013 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1014 Expect(rsp.Uncompressed).To(BeTrue()) 1015 }) 1016 1017 It("only decompresses the response if the response contains the right content-encoding header", func() { 1018 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1019 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1020 buf := &bytes.Buffer{} 1021 rstr := mockquic.NewMockStream(mockCtrl) 1022 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1023 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1024 rw.Write([]byte("not gzipped")) 1025 rw.Flush() 1026 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1027 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1028 str.EXPECT().Close() 1029 1030 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1031 Expect(err).ToNot(HaveOccurred()) 1032 data, err := io.ReadAll(rsp.Body) 1033 Expect(err).ToNot(HaveOccurred()) 1034 Expect(string(data)).To(Equal("not gzipped")) 1035 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1036 }) 1037 }) 1038 }) 1039 })