github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "crypto/tls" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "sync" 13 "time" 14 15 "github.com/daeuniverse/quic-go" 16 mockquic "github.com/daeuniverse/quic-go/internal/mocks/quic" 17 "github.com/daeuniverse/quic-go/internal/protocol" 18 "github.com/daeuniverse/quic-go/internal/qerr" 19 "github.com/daeuniverse/quic-go/internal/utils" 20 "github.com/daeuniverse/quic-go/quicvarint" 21 22 "github.com/quic-go/qpack" 23 24 . "github.com/onsi/ginkgo/v2" 25 . "github.com/onsi/gomega" 26 "go.uber.org/mock/gomock" 27 ) 28 29 var _ = Describe("Client", func() { 30 var ( 31 cl *client 32 req *http.Request 33 origDialAddr = dialAddr 34 handshakeChan <-chan struct{} // a closed chan 35 ) 36 37 BeforeEach(func() { 38 origDialAddr = dialAddr 39 hostname := "quic.clemente.io:1337" 40 c, err := newClient(hostname, nil, &roundTripperOpts{MaxHeaderBytes: 1337}, nil, nil) 41 Expect(err).ToNot(HaveOccurred()) 42 cl = c.(*client) 43 Expect(cl.hostname).To(Equal(hostname)) 44 45 req, err = http.NewRequest("GET", "https://localhost:1337", nil) 46 Expect(err).ToNot(HaveOccurred()) 47 48 ch := make(chan struct{}) 49 close(ch) 50 handshakeChan = ch 51 }) 52 53 AfterEach(func() { 54 dialAddr = origDialAddr 55 }) 56 57 It("rejects quic.Configs that allow multiple QUIC versions", func() { 58 qconf := &quic.Config{ 59 Versions: []quic.Version{protocol.Version2, protocol.Version1}, 60 } 61 _, err := newClient("localhost:1337", nil, &roundTripperOpts{}, qconf, nil) 62 Expect(err).To(MatchError("can only use a single QUIC version for dialing a HTTP/3 connection")) 63 }) 64 65 It("uses the default QUIC and TLS config if none is give", func() { 66 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 67 Expect(err).ToNot(HaveOccurred()) 68 var dialAddrCalled bool 69 dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 70 Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) 71 Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) 72 Expect(quicConf.Versions).To(Equal([]protocol.Version{protocol.Version1})) 73 dialAddrCalled = true 74 return nil, errors.New("test done") 75 } 76 client.RoundTripOpt(req, RoundTripOpt{}) 77 Expect(dialAddrCalled).To(BeTrue()) 78 }) 79 80 It("adds the port to the hostname, if none is given", func() { 81 client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) 82 Expect(err).ToNot(HaveOccurred()) 83 var dialAddrCalled bool 84 dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { 85 Expect(hostname).To(Equal("quic.clemente.io:443")) 86 dialAddrCalled = true 87 return nil, errors.New("test done") 88 } 89 req, err := http.NewRequest("GET", "https://quic.clemente.io:443", nil) 90 Expect(err).ToNot(HaveOccurred()) 91 client.RoundTripOpt(req, RoundTripOpt{}) 92 Expect(dialAddrCalled).To(BeTrue()) 93 }) 94 95 It("sets the ServerName in the tls.Config, if not set", func() { 96 const host = "foo.bar" 97 dialCalled := false 98 dialFunc := func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { 99 Expect(tlsCfg.ServerName).To(Equal(host)) 100 dialCalled = true 101 return nil, errors.New("test done") 102 } 103 client, err := newClient(host, nil, &roundTripperOpts{}, nil, dialFunc) 104 Expect(err).ToNot(HaveOccurred()) 105 req, err := http.NewRequest("GET", "https://foo.bar", nil) 106 Expect(err).ToNot(HaveOccurred()) 107 client.RoundTripOpt(req, RoundTripOpt{}) 108 Expect(dialCalled).To(BeTrue()) 109 }) 110 111 It("uses the TLS config and QUIC config", func() { 112 tlsConf := &tls.Config{ 113 ServerName: "foo.bar", 114 NextProtos: []string{"proto foo", "proto bar"}, 115 } 116 quicConf := &quic.Config{MaxIdleTimeout: time.Nanosecond} 117 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) 118 Expect(err).ToNot(HaveOccurred()) 119 var dialAddrCalled bool 120 dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 121 Expect(host).To(Equal("localhost:1337")) 122 Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) 123 Expect(tlsConfP.NextProtos).To(Equal([]string{NextProtoH3})) 124 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 125 dialAddrCalled = true 126 return nil, errors.New("test done") 127 } 128 client.RoundTripOpt(req, RoundTripOpt{}) 129 Expect(dialAddrCalled).To(BeTrue()) 130 // make sure the original tls.Config was not modified 131 Expect(tlsConf.NextProtos).To(Equal([]string{"proto foo", "proto bar"})) 132 }) 133 134 It("uses the custom dialer, if provided", func() { 135 testErr := errors.New("test done") 136 tlsConf := &tls.Config{ServerName: "foo.bar"} 137 quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} 138 ctx, cancel := context.WithTimeout(context.Background(), time.Hour) 139 defer cancel() 140 var dialerCalled bool 141 dialer := func(ctxP context.Context, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlyConnection, error) { 142 Expect(ctxP).To(Equal(ctx)) 143 Expect(address).To(Equal("localhost:1337")) 144 Expect(tlsConfP.ServerName).To(Equal("foo.bar")) 145 Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) 146 dialerCalled = true 147 return nil, testErr 148 } 149 client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) 150 Expect(err).ToNot(HaveOccurred()) 151 _, err = client.RoundTripOpt(req.WithContext(ctx), RoundTripOpt{}) 152 Expect(err).To(MatchError(testErr)) 153 Expect(dialerCalled).To(BeTrue()) 154 }) 155 156 It("enables HTTP/3 Datagrams", func() { 157 testErr := errors.New("handshake error") 158 client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) 159 Expect(err).ToNot(HaveOccurred()) 160 dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { 161 Expect(quicConf.EnableDatagrams).To(BeTrue()) 162 return nil, testErr 163 } 164 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 165 Expect(err).To(MatchError(testErr)) 166 }) 167 168 It("errors when dialing fails", func() { 169 testErr := errors.New("handshake error") 170 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 171 Expect(err).ToNot(HaveOccurred()) 172 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 173 return nil, testErr 174 } 175 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 176 Expect(err).To(MatchError(testErr)) 177 }) 178 179 It("closes correctly if connection was not created", func() { 180 client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) 181 Expect(err).ToNot(HaveOccurred()) 182 Expect(client.Close()).To(Succeed()) 183 }) 184 185 Context("validating the address", func() { 186 It("refuses to do requests for the wrong host", func() { 187 req, err := http.NewRequest("https", "https://quic.clemente.io:1336/foobar.html", nil) 188 Expect(err).ToNot(HaveOccurred()) 189 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 190 Expect(err).To(MatchError("http3 client BUG: RoundTripOpt called for the wrong client (expected quic.clemente.io:1337, got quic.clemente.io:1336)")) 191 }) 192 193 It("allows requests using a different scheme", func() { 194 testErr := errors.New("handshake error") 195 req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) 196 Expect(err).ToNot(HaveOccurred()) 197 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 198 return nil, testErr 199 } 200 _, err = cl.RoundTripOpt(req, RoundTripOpt{}) 201 Expect(err).To(MatchError(testErr)) 202 }) 203 }) 204 205 Context("hijacking bidirectional streams", func() { 206 var ( 207 request *http.Request 208 conn *mockquic.MockEarlyConnection 209 settingsFrameWritten chan struct{} 210 ) 211 testDone := make(chan struct{}) 212 213 BeforeEach(func() { 214 testDone = make(chan struct{}) 215 settingsFrameWritten = make(chan struct{}) 216 controlStr := mockquic.NewMockStream(mockCtrl) 217 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 218 defer GinkgoRecover() 219 close(settingsFrameWritten) 220 return len(b), nil 221 }) 222 conn = mockquic.NewMockEarlyConnection(mockCtrl) 223 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 224 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 225 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 226 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 227 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 228 return conn, nil 229 } 230 var err error 231 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 232 Expect(err).ToNot(HaveOccurred()) 233 }) 234 235 AfterEach(func() { 236 testDone <- struct{}{} 237 Eventually(settingsFrameWritten).Should(BeClosed()) 238 }) 239 240 It("hijacks a bidirectional stream of unknown frame type", func() { 241 frameTypeChan := make(chan FrameType, 1) 242 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 243 Expect(e).ToNot(HaveOccurred()) 244 frameTypeChan <- ft 245 return true, nil 246 } 247 248 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 249 unknownStr := mockquic.NewMockStream(mockCtrl) 250 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 251 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 252 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 253 <-testDone 254 return nil, errors.New("test done") 255 }) 256 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 257 Expect(err).To(MatchError("done")) 258 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 259 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 260 }) 261 262 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 263 frameTypeChan := make(chan FrameType, 1) 264 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 265 Expect(e).ToNot(HaveOccurred()) 266 frameTypeChan <- ft 267 return false, nil 268 } 269 270 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 271 unknownStr := mockquic.NewMockStream(mockCtrl) 272 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 273 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 274 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 275 <-testDone 276 return nil, errors.New("test done") 277 }) 278 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 279 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 280 Expect(err).To(MatchError("done")) 281 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 282 }) 283 284 It("closes the connection when hijacker returned error", func() { 285 frameTypeChan := make(chan FrameType, 1) 286 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 287 Expect(e).ToNot(HaveOccurred()) 288 frameTypeChan <- ft 289 return false, errors.New("error in hijacker") 290 } 291 292 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 293 unknownStr := mockquic.NewMockStream(mockCtrl) 294 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 295 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 296 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 297 <-testDone 298 return nil, errors.New("test done") 299 }) 300 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 301 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 302 Expect(err).To(MatchError("done")) 303 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 304 }) 305 306 It("handles errors that occur when reading the frame type", func() { 307 testErr := errors.New("test error") 308 unknownStr := mockquic.NewMockStream(mockCtrl) 309 done := make(chan struct{}) 310 cl.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) { 311 defer close(done) 312 Expect(e).To(MatchError(testErr)) 313 Expect(ft).To(BeZero()) 314 Expect(str).To(Equal(unknownStr)) 315 return false, nil 316 } 317 318 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 319 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 320 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 321 <-testDone 322 return nil, errors.New("test done") 323 }) 324 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 325 _, err := cl.RoundTripOpt(request, RoundTripOpt{}) 326 Expect(err).To(MatchError("done")) 327 Eventually(done).Should(BeClosed()) 328 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 329 }) 330 }) 331 332 Context("hijacking unidirectional streams", func() { 333 var ( 334 req *http.Request 335 conn *mockquic.MockEarlyConnection 336 settingsFrameWritten chan struct{} 337 ) 338 testDone := make(chan struct{}) 339 340 BeforeEach(func() { 341 testDone = make(chan struct{}) 342 settingsFrameWritten = make(chan struct{}) 343 controlStr := mockquic.NewMockStream(mockCtrl) 344 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 345 defer GinkgoRecover() 346 close(settingsFrameWritten) 347 return len(b), nil 348 }) 349 conn = mockquic.NewMockEarlyConnection(mockCtrl) 350 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 351 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 352 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 353 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 354 return conn, nil 355 } 356 var err error 357 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 358 Expect(err).ToNot(HaveOccurred()) 359 }) 360 361 AfterEach(func() { 362 testDone <- struct{}{} 363 Eventually(settingsFrameWritten).Should(BeClosed()) 364 }) 365 366 It("hijacks an unidirectional stream of unknown stream type", func() { 367 streamTypeChan := make(chan StreamType, 1) 368 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 369 Expect(err).ToNot(HaveOccurred()) 370 streamTypeChan <- st 371 return true 372 } 373 374 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 375 unknownStr := mockquic.NewMockStream(mockCtrl) 376 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 377 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 378 return unknownStr, nil 379 }) 380 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 381 <-testDone 382 return nil, errors.New("test done") 383 }) 384 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 385 Expect(err).To(MatchError("done")) 386 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 387 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 388 }) 389 390 It("handles errors that occur when reading the stream type", func() { 391 testErr := errors.New("test error") 392 done := make(chan struct{}) 393 unknownStr := mockquic.NewMockStream(mockCtrl) 394 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 395 defer close(done) 396 Expect(st).To(BeZero()) 397 Expect(str).To(Equal(unknownStr)) 398 Expect(err).To(MatchError(testErr)) 399 return true 400 } 401 402 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 403 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 404 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 405 <-testDone 406 return nil, errors.New("test done") 407 }) 408 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 409 Expect(err).To(MatchError("done")) 410 Eventually(done).Should(BeClosed()) 411 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 412 }) 413 414 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 415 streamTypeChan := make(chan StreamType, 1) 416 cl.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 417 Expect(err).ToNot(HaveOccurred()) 418 streamTypeChan <- st 419 return false 420 } 421 422 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 423 unknownStr := mockquic.NewMockStream(mockCtrl) 424 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 425 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 426 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 427 return unknownStr, nil 428 }) 429 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 430 <-testDone 431 return nil, errors.New("test done") 432 }) 433 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 434 Expect(err).To(MatchError("done")) 435 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 436 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 437 }) 438 }) 439 440 Context("control stream handling", func() { 441 var ( 442 req *http.Request 443 conn *mockquic.MockEarlyConnection 444 settingsFrameWritten chan struct{} 445 ) 446 testDone := make(chan struct{}, 1) 447 448 BeforeEach(func() { 449 settingsFrameWritten = make(chan struct{}) 450 controlStr := mockquic.NewMockStream(mockCtrl) 451 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 452 defer GinkgoRecover() 453 close(settingsFrameWritten) 454 return len(b), nil 455 }) 456 conn = mockquic.NewMockEarlyConnection(mockCtrl) 457 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 458 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 459 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 460 return conn, nil 461 } 462 var err error 463 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 464 Expect(err).ToNot(HaveOccurred()) 465 }) 466 467 AfterEach(func() { 468 testDone <- struct{}{} 469 Eventually(settingsFrameWritten).Should(BeClosed()) 470 }) 471 472 It("parses the SETTINGS frame", func() { 473 b := quicvarint.Append(nil, streamTypeControlStream) 474 b = (&settingsFrame{ 475 Datagram: true, 476 ExtendedConnect: true, 477 Other: map[uint64]uint64{1337: 42}, 478 }).Append(b) 479 r := bytes.NewReader(b) 480 controlStr := mockquic.NewMockStream(mockCtrl) 481 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 482 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 483 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 484 return controlStr, nil 485 }) 486 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 487 <-testDone 488 return nil, errors.New("test done") 489 }) 490 conn.EXPECT().Context().Return(context.Background()) 491 _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { 492 defer GinkgoRecover() 493 Expect(settings.EnableDatagram).To(BeTrue()) 494 Expect(settings.EnableExtendedConnect).To(BeTrue()) 495 Expect(settings.Other).To(HaveLen(1)) 496 Expect(settings.Other).To(HaveKeyWithValue(uint64(1337), uint64(42))) 497 return nil 498 }}) 499 Expect(err).To(MatchError("done")) 500 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 501 }) 502 503 It("allows the client to reject the SETTINGS using the CheckSettings RoundTripOpt", func() { 504 b := quicvarint.Append(nil, streamTypeControlStream) 505 b = (&settingsFrame{}).Append(b) 506 r := bytes.NewReader(b) 507 controlStr := mockquic.NewMockStream(mockCtrl) 508 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 509 // Don't EXPECT any call to OpenStreamSync. 510 // When the SETTINGS are rejected, we don't even open the request stream. 511 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 512 return controlStr, nil 513 }) 514 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 515 <-testDone 516 return nil, errors.New("test done") 517 }) 518 conn.EXPECT().Context().Return(context.Background()) 519 _, err := cl.RoundTripOpt(req, RoundTripOpt{CheckSettings: func(settings Settings) error { 520 return errors.New("wrong settings") 521 }}) 522 Expect(err).To(MatchError("wrong settings")) 523 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 524 }) 525 526 It("rejects duplicate control streams", func() { 527 b := quicvarint.Append(nil, streamTypeControlStream) 528 b = (&settingsFrame{}).Append(b) 529 r1 := bytes.NewReader(b) 530 controlStr1 := mockquic.NewMockStream(mockCtrl) 531 controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() 532 r2 := bytes.NewReader(b) 533 controlStr2 := mockquic.NewMockStream(mockCtrl) 534 controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() 535 done := make(chan struct{}) 536 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 537 conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { 538 close(done) 539 return nil 540 }) 541 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 542 return controlStr1, nil 543 }) 544 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 545 return controlStr2, nil 546 }) 547 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 548 <-done 549 return nil, errors.New("test done") 550 }) 551 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 552 Expect(err).To(HaveOccurred()) 553 Eventually(done).Should(BeClosed()) 554 }) 555 556 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 557 streamType := t 558 name := "encoder" 559 if streamType == streamTypeQPACKDecoderStream { 560 name = "decoder" 561 } 562 563 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 564 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 565 str := mockquic.NewMockStream(mockCtrl) 566 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 567 568 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 569 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 570 return str, nil 571 }) 572 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 573 <-testDone 574 return nil, errors.New("test done") 575 }) 576 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 577 Expect(err).To(MatchError("done")) 578 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 579 }) 580 } 581 582 It("resets streams other than the control stream and the QPACK streams", func() { 583 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x1337)) 584 str := mockquic.NewMockStream(mockCtrl) 585 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 586 done := make(chan struct{}) 587 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) 588 589 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 590 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 591 return str, nil 592 }) 593 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 594 <-testDone 595 return nil, errors.New("test done") 596 }) 597 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 598 Expect(err).To(MatchError("done")) 599 Eventually(done).Should(BeClosed()) 600 }) 601 602 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 603 b := quicvarint.Append(nil, streamTypeControlStream) 604 b = (&dataFrame{}).Append(b) 605 r := bytes.NewReader(b) 606 controlStr := mockquic.NewMockStream(mockCtrl) 607 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 608 609 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 610 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 611 return controlStr, nil 612 }) 613 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 614 <-testDone 615 return nil, errors.New("test done") 616 }) 617 done := make(chan struct{}) 618 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 619 close(done) 620 return nil 621 }) 622 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 623 Expect(err).To(MatchError("done")) 624 Eventually(done).Should(BeClosed()) 625 }) 626 627 It("errors when the first frame on the control stream is not a SETTINGS frame, when checking SETTINGS", func() { 628 b := quicvarint.Append(nil, streamTypeControlStream) 629 b = (&dataFrame{}).Append(b) 630 r := bytes.NewReader(b) 631 controlStr := mockquic.NewMockStream(mockCtrl) 632 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 633 634 // Don't EXPECT any calls to OpenStreamSync. 635 // We fail before we even get the chance to open the request stream. 636 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 637 return controlStr, nil 638 }) 639 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 640 <-testDone 641 return nil, errors.New("test done") 642 }) 643 doneCtx, doneCancel := context.WithCancelCause(context.Background()) 644 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 645 doneCancel(errors.New("done")) 646 return nil 647 }) 648 conn.EXPECT().Context().Return(doneCtx).Times(2) 649 var checked bool 650 _, err := cl.RoundTripOpt(req, RoundTripOpt{ 651 CheckSettings: func(Settings) error { checked = true; return nil }, 652 }) 653 Expect(checked).To(BeFalse()) 654 Expect(err).To(MatchError("done")) 655 Eventually(doneCtx.Done()).Should(BeClosed()) 656 }) 657 658 It("errors when parsing the frame on the control stream fails", func() { 659 b := quicvarint.Append(nil, streamTypeControlStream) 660 b = (&settingsFrame{}).Append(b) 661 r := bytes.NewReader(b[:len(b)-1]) 662 controlStr := mockquic.NewMockStream(mockCtrl) 663 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 664 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 665 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 666 return controlStr, nil 667 }) 668 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 669 <-testDone 670 return nil, errors.New("test done") 671 }) 672 done := make(chan struct{}) 673 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) error { 674 close(done) 675 return nil 676 }) 677 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 678 Expect(err).To(MatchError("done")) 679 Eventually(done).Should(BeClosed()) 680 }) 681 682 It("errors when parsing the server opens a push stream", func() { 683 buf := bytes.NewBuffer(quicvarint.Append(nil, streamTypePushStream)) 684 controlStr := mockquic.NewMockStream(mockCtrl) 685 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 686 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 687 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 688 return controlStr, nil 689 }) 690 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 691 <-testDone 692 return nil, errors.New("test done") 693 }) 694 done := make(chan struct{}) 695 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 696 close(done) 697 return nil 698 }) 699 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 700 Expect(err).To(MatchError("done")) 701 Eventually(done).Should(BeClosed()) 702 }) 703 704 It("errors when the server advertises datagram support (and we enabled support for it)", func() { 705 cl.opts.EnableDatagram = true 706 b := quicvarint.Append(nil, streamTypeControlStream) 707 b = (&settingsFrame{Datagram: true}).Append(b) 708 r := bytes.NewReader(b) 709 controlStr := mockquic.NewMockStream(mockCtrl) 710 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 711 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 712 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 713 return controlStr, nil 714 }) 715 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 716 <-testDone 717 return nil, errors.New("test done") 718 }) 719 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 720 done := make(chan struct{}) 721 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { 722 close(done) 723 return nil 724 }) 725 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 726 Expect(err).To(MatchError("done")) 727 Eventually(done).Should(BeClosed()) 728 }) 729 }) 730 731 Context("Doing requests", func() { 732 var ( 733 req *http.Request 734 str *mockquic.MockStream 735 conn *mockquic.MockEarlyConnection 736 settingsFrameWritten chan struct{} 737 ) 738 testDone := make(chan struct{}) 739 740 decodeHeader := func(str io.Reader) map[string]string { 741 fields := make(map[string]string) 742 decoder := qpack.NewDecoder(nil) 743 744 frame, err := parseNextFrame(str, nil) 745 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 746 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 747 headersFrame := frame.(*headersFrame) 748 data := make([]byte, headersFrame.Length) 749 _, err = io.ReadFull(str, data) 750 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 751 hfs, err := decoder.DecodeFull(data) 752 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 753 for _, p := range hfs { 754 fields[p.Name] = p.Value 755 } 756 return fields 757 } 758 759 getResponse := func(status int) []byte { 760 buf := &bytes.Buffer{} 761 rstr := mockquic.NewMockStream(mockCtrl) 762 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 763 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 764 rw.WriteHeader(status) 765 rw.Flush() 766 return buf.Bytes() 767 } 768 769 BeforeEach(func() { 770 settingsFrameWritten = make(chan struct{}) 771 controlStr := mockquic.NewMockStream(mockCtrl) 772 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 773 defer GinkgoRecover() 774 r := bytes.NewReader(b) 775 streamType, err := quicvarint.Read(r) 776 Expect(err).ToNot(HaveOccurred()) 777 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 778 close(settingsFrameWritten) 779 return len(b), nil 780 }) // SETTINGS frame 781 str = mockquic.NewMockStream(mockCtrl) 782 conn = mockquic.NewMockEarlyConnection(mockCtrl) 783 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 784 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 785 <-testDone 786 return nil, errors.New("test done") 787 }) 788 dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { 789 return conn, nil 790 } 791 var err error 792 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 793 Expect(err).ToNot(HaveOccurred()) 794 }) 795 796 AfterEach(func() { 797 testDone <- struct{}{} 798 Eventually(settingsFrameWritten).Should(BeClosed()) 799 }) 800 801 It("errors if it can't open a stream", func() { 802 testErr := errors.New("stream open error") 803 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 804 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 805 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 806 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 807 Expect(err).To(MatchError(testErr)) 808 }) 809 810 It("performs a 0-RTT request", func() { 811 testErr := errors.New("stream open error") 812 req.Method = MethodGet0RTT 813 // don't EXPECT any calls to HandshakeComplete() 814 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 815 buf := &bytes.Buffer{} 816 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 817 str.EXPECT().Close() 818 str.EXPECT().CancelWrite(gomock.Any()) 819 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 820 return 0, testErr 821 }) 822 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 823 Expect(err).To(MatchError(testErr)) 824 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", "GET")) 825 }) 826 827 It("returns a response", func() { 828 rspBuf := bytes.NewBuffer(getResponse(418)) 829 gomock.InOrder( 830 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 831 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 832 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 833 ) 834 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 835 str.EXPECT().Close() 836 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 837 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 838 Expect(err).ToNot(HaveOccurred()) 839 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 840 Expect(rsp.ProtoMajor).To(Equal(3)) 841 Expect(rsp.StatusCode).To(Equal(418)) 842 Expect(rsp.Request).ToNot(BeNil()) 843 }) 844 845 It("doesn't close the request stream, with DontCloseRequestStream set", func() { 846 rspBuf := bytes.NewBuffer(getResponse(418)) 847 gomock.InOrder( 848 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 849 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 850 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 851 ) 852 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 853 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 854 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) 855 Expect(err).ToNot(HaveOccurred()) 856 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 857 Expect(rsp.ProtoMajor).To(Equal(3)) 858 Expect(rsp.StatusCode).To(Equal(418)) 859 }) 860 861 Context("requests containing a Body", func() { 862 var strBuf *bytes.Buffer 863 864 BeforeEach(func() { 865 strBuf = &bytes.Buffer{} 866 gomock.InOrder( 867 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 868 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 869 ) 870 body := &mockBody{} 871 body.SetData([]byte("request body")) 872 var err error 873 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 874 Expect(err).ToNot(HaveOccurred()) 875 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 876 }) 877 878 It("sends a request", func() { 879 done := make(chan struct{}) 880 gomock.InOrder( 881 str.EXPECT().Close().Do(func() error { close(done); return nil }), 882 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when reading the response errors 883 ) 884 // the response body is sent asynchronously, while already reading the response 885 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 886 <-done 887 return 0, errors.New("test done") 888 }) 889 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 890 Expect(err).To(MatchError("test done")) 891 hfs := decodeHeader(strBuf) 892 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 893 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 894 }) 895 896 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 897 req.ContentLength = 7 898 var once sync.Once 899 done := make(chan struct{}) 900 gomock.InOrder( 901 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 902 once.Do(func() { 903 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 904 close(done) 905 }) 906 }).AnyTimes(), 907 str.EXPECT().Close().MaxTimes(1), 908 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 909 ) 910 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 911 <-done 912 return 0, errors.New("done") 913 }) 914 cl.RoundTripOpt(req, RoundTripOpt{}) 915 Expect(strBuf.String()).To(ContainSubstring("request")) 916 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 917 }) 918 919 It("returns the error that occurred when reading the body", func() { 920 req.Body.(*mockBody).readErr = errors.New("testErr") 921 done := make(chan struct{}) 922 gomock.InOrder( 923 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 924 close(done) 925 }), 926 str.EXPECT().CancelWrite(gomock.Any()), 927 ) 928 929 // the response body is sent asynchronously, while already reading the response 930 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 931 <-done 932 return 0, errors.New("test done") 933 }) 934 closed := make(chan struct{}) 935 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 936 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 937 Expect(err).To(MatchError("test done")) 938 Eventually(closed).Should(BeClosed()) 939 }) 940 941 It("closes the connection when the first frame is not a HEADERS frame", func() { 942 b := (&dataFrame{Length: 0x42}).Append(nil) 943 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 944 closed := make(chan struct{}) 945 r := bytes.NewReader(b) 946 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 947 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 948 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 949 Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) 950 Eventually(closed).Should(BeClosed()) 951 }) 952 953 It("cancels the stream when parsing the headers fails", func() { 954 headerBuf := &bytes.Buffer{} 955 enc := qpack.NewEncoder(headerBuf) 956 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 957 Expect(enc.Close()).To(Succeed()) 958 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 959 b = append(b, headerBuf.Bytes()...) 960 961 r := bytes.NewReader(b) 962 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 963 closed := make(chan struct{}) 964 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 965 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 966 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 967 Expect(err).To(HaveOccurred()) 968 Eventually(closed).Should(BeClosed()) 969 }) 970 971 It("cancels the stream when the HEADERS frame is too large", func() { 972 b := (&headersFrame{Length: 1338}).Append(nil) 973 r := bytes.NewReader(b) 974 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 975 closed := make(chan struct{}) 976 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 977 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 978 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 979 Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) 980 Eventually(closed).Should(BeClosed()) 981 }) 982 }) 983 984 Context("request cancellations", func() { 985 for _, dontClose := range []bool{false, true} { 986 dontClose := dontClose 987 988 Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { 989 roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} 990 991 It("cancels a request while waiting for the handshake to complete", func() { 992 ctx, cancel := context.WithCancel(context.Background()) 993 req := req.WithContext(ctx) 994 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 995 996 errChan := make(chan error) 997 go func() { 998 _, err := cl.RoundTripOpt(req, roundTripOpt) 999 errChan <- err 1000 }() 1001 Consistently(errChan).ShouldNot(Receive()) 1002 cancel() 1003 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 1004 }) 1005 1006 It("cancels a request while the request is still in flight", func() { 1007 ctx, cancel := context.WithCancel(context.Background()) 1008 req := req.WithContext(ctx) 1009 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1010 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 1011 buf := &bytes.Buffer{} 1012 str.EXPECT().Close().MaxTimes(1) 1013 1014 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 1015 1016 done := make(chan struct{}) 1017 canceled := make(chan struct{}) 1018 gomock.InOrder( 1019 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 1020 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 1021 ) 1022 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 1023 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 1024 cancel() 1025 <-canceled 1026 return 0, errors.New("test done") 1027 }) 1028 _, err := cl.RoundTripOpt(req, roundTripOpt) 1029 Expect(err).To(MatchError(context.Canceled)) 1030 Eventually(done).Should(BeClosed()) 1031 }) 1032 }) 1033 } 1034 1035 It("cancels a request after the response arrived", func() { 1036 rspBuf := bytes.NewBuffer(getResponse(404)) 1037 1038 ctx, cancel := context.WithCancel(context.Background()) 1039 req := req.WithContext(ctx) 1040 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1041 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 1042 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1043 buf := &bytes.Buffer{} 1044 str.EXPECT().Close().MaxTimes(1) 1045 1046 done := make(chan struct{}) 1047 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 1048 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 1049 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 1050 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 1051 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1052 Expect(err).ToNot(HaveOccurred()) 1053 cancel() 1054 Eventually(done).Should(BeClosed()) 1055 }) 1056 }) 1057 1058 Context("gzip compression", func() { 1059 BeforeEach(func() { 1060 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 1061 }) 1062 1063 It("adds the gzip header to requests", func() { 1064 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1065 buf := &bytes.Buffer{} 1066 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 1067 gomock.InOrder( 1068 str.EXPECT().Close(), 1069 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 1070 ) 1071 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 1072 _, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1073 Expect(err).To(MatchError("test done")) 1074 hfs := decodeHeader(buf) 1075 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 1076 }) 1077 1078 It("doesn't add gzip if the header disable it", func() { 1079 client, err := newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) 1080 Expect(err).ToNot(HaveOccurred()) 1081 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1082 buf := &bytes.Buffer{} 1083 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 1084 gomock.InOrder( 1085 str.EXPECT().Close(), 1086 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), // when the Read errors 1087 ) 1088 str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) 1089 _, err = client.RoundTripOpt(req, RoundTripOpt{}) 1090 Expect(err).To(MatchError("test done")) 1091 hfs := decodeHeader(buf) 1092 Expect(hfs).ToNot(HaveKey("accept-encoding")) 1093 }) 1094 1095 It("decompresses the response", func() { 1096 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1097 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1098 buf := &bytes.Buffer{} 1099 rstr := mockquic.NewMockStream(mockCtrl) 1100 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1101 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1102 rw.Header().Set("Content-Encoding", "gzip") 1103 gz := gzip.NewWriter(rw) 1104 gz.Write([]byte("gzipped response")) 1105 gz.Close() 1106 rw.Flush() 1107 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1108 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1109 str.EXPECT().Close() 1110 1111 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1112 Expect(err).ToNot(HaveOccurred()) 1113 data, err := io.ReadAll(rsp.Body) 1114 Expect(err).ToNot(HaveOccurred()) 1115 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 1116 Expect(string(data)).To(Equal("gzipped response")) 1117 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1118 Expect(rsp.Uncompressed).To(BeTrue()) 1119 }) 1120 1121 It("only decompresses the response if the response contains the right content-encoding header", func() { 1122 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 1123 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 1124 buf := &bytes.Buffer{} 1125 rstr := mockquic.NewMockStream(mockCtrl) 1126 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 1127 rw := newResponseWriter(rstr, nil, utils.DefaultLogger) 1128 rw.Write([]byte("not gzipped")) 1129 rw.Flush() 1130 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 1131 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 1132 str.EXPECT().Close() 1133 1134 rsp, err := cl.RoundTripOpt(req, RoundTripOpt{}) 1135 Expect(err).ToNot(HaveOccurred()) 1136 data, err := io.ReadAll(rsp.Body) 1137 Expect(err).ToNot(HaveOccurred()) 1138 Expect(string(data)).To(Equal("not gzipped")) 1139 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 1140 }) 1141 }) 1142 }) 1143 })