github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "errors" 8 "io" 9 "net/http" 10 "net/http/httptrace" 11 "net/textproto" 12 "sync" 13 "time" 14 15 "github.com/apernet/quic-go" 16 mockquic "github.com/apernet/quic-go/internal/mocks/quic" 17 "github.com/apernet/quic-go/quicvarint" 18 19 "github.com/quic-go/qpack" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 "go.uber.org/mock/gomock" 24 ) 25 26 func encodeResponse(status int) []byte { 27 buf := &bytes.Buffer{} 28 rstr := mockquic.NewMockStream(mockCtrl) 29 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 30 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 31 if status == http.StatusEarlyHints { 32 rw.header.Add("Link", "</style.css>; rel=preload; as=style") 33 rw.header.Add("Link", "</script.js>; rel=preload; as=script") 34 } 35 rw.WriteHeader(status) 36 rw.Flush() 37 return buf.Bytes() 38 } 39 40 var _ = Describe("Client", func() { 41 var handshakeChan <-chan struct{} // a closed chan 42 43 BeforeEach(func() { 44 ch := make(chan struct{}) 45 close(ch) 46 handshakeChan = ch 47 }) 48 49 Context("hijacking bidirectional streams", func() { 50 var ( 51 request *http.Request 52 conn *mockquic.MockEarlyConnection 53 settingsFrameWritten chan struct{} 54 ) 55 testDone := make(chan struct{}) 56 57 BeforeEach(func() { 58 testDone = make(chan struct{}) 59 settingsFrameWritten = make(chan struct{}) 60 controlStr := mockquic.NewMockStream(mockCtrl) 61 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 62 defer GinkgoRecover() 63 close(settingsFrameWritten) 64 return len(b), nil 65 }) 66 conn = mockquic.NewMockEarlyConnection(mockCtrl) 67 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 68 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 69 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 70 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 71 var err error 72 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 73 Expect(err).ToNot(HaveOccurred()) 74 }) 75 76 AfterEach(func() { 77 testDone <- struct{}{} 78 Eventually(settingsFrameWritten).Should(BeClosed()) 79 }) 80 81 It("hijacks a bidirectional stream of unknown frame type", func() { 82 id := quic.ConnectionTracingID(1234) 83 frameTypeChan := make(chan FrameType, 1) 84 rt := &SingleDestinationRoundTripper{ 85 Connection: conn, 86 StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 87 Expect(e).ToNot(HaveOccurred()) 88 Expect(connTracingID).To(Equal(id)) 89 frameTypeChan <- ft 90 return true, nil 91 }, 92 } 93 94 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 95 unknownStr := mockquic.NewMockStream(mockCtrl) 96 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 97 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 98 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 99 <-testDone 100 return nil, errors.New("test done") 101 }) 102 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 103 conn.EXPECT().Context().Return(ctx).AnyTimes() 104 _, err := rt.RoundTrip(request) 105 Expect(err).To(MatchError("done")) 106 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 107 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 108 }) 109 110 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 111 frameTypeChan := make(chan FrameType, 1) 112 rt := &SingleDestinationRoundTripper{ 113 Connection: conn, 114 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 115 Expect(e).ToNot(HaveOccurred()) 116 frameTypeChan <- ft 117 return false, nil 118 }, 119 } 120 121 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 122 unknownStr := mockquic.NewMockStream(mockCtrl) 123 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 124 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 125 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 126 <-testDone 127 return nil, errors.New("test done") 128 }) 129 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 130 conn.EXPECT().Context().Return(ctx).AnyTimes() 131 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 132 _, err := rt.RoundTrip(request) 133 Expect(err).To(MatchError("done")) 134 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 135 }) 136 137 It("closes the connection when hijacker returned error", func() { 138 frameTypeChan := make(chan FrameType, 1) 139 rt := &SingleDestinationRoundTripper{ 140 Connection: conn, 141 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 142 Expect(e).ToNot(HaveOccurred()) 143 frameTypeChan <- ft 144 return false, errors.New("error in hijacker") 145 }, 146 } 147 148 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 149 unknownStr := mockquic.NewMockStream(mockCtrl) 150 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 151 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 152 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 153 <-testDone 154 return nil, errors.New("test done") 155 }) 156 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 157 conn.EXPECT().Context().Return(ctx).AnyTimes() 158 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 159 _, err := rt.RoundTrip(request) 160 Expect(err).To(MatchError("done")) 161 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 162 }) 163 164 It("handles errors that occur when reading the frame type", func() { 165 testErr := errors.New("test error") 166 unknownStr := mockquic.NewMockStream(mockCtrl) 167 done := make(chan struct{}) 168 rt := &SingleDestinationRoundTripper{ 169 Connection: conn, 170 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) { 171 defer close(done) 172 Expect(e).To(MatchError(testErr)) 173 Expect(ft).To(BeZero()) 174 Expect(str).To(Equal(unknownStr)) 175 return false, nil 176 }, 177 } 178 179 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 180 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 181 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 182 <-testDone 183 return nil, errors.New("test done") 184 }) 185 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 186 conn.EXPECT().Context().Return(ctx).AnyTimes() 187 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 188 _, err := rt.RoundTrip(request) 189 Expect(err).To(MatchError("done")) 190 Eventually(done).Should(BeClosed()) 191 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 192 }) 193 }) 194 195 Context("hijacking unidirectional streams", func() { 196 var ( 197 req *http.Request 198 conn *mockquic.MockEarlyConnection 199 settingsFrameWritten chan struct{} 200 ) 201 testDone := make(chan struct{}) 202 203 BeforeEach(func() { 204 testDone = make(chan struct{}) 205 settingsFrameWritten = make(chan struct{}) 206 controlStr := mockquic.NewMockStream(mockCtrl) 207 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 208 defer GinkgoRecover() 209 close(settingsFrameWritten) 210 return len(b), nil 211 }) 212 conn = mockquic.NewMockEarlyConnection(mockCtrl) 213 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 214 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 215 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 216 var err error 217 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 218 Expect(err).ToNot(HaveOccurred()) 219 }) 220 221 AfterEach(func() { 222 testDone <- struct{}{} 223 Eventually(settingsFrameWritten).Should(BeClosed()) 224 }) 225 226 It("hijacks an unidirectional stream of unknown stream type", func() { 227 id := quic.ConnectionTracingID(100) 228 streamTypeChan := make(chan StreamType, 1) 229 rt := &SingleDestinationRoundTripper{ 230 Connection: conn, 231 UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 232 Expect(connTracingID).To(Equal(id)) 233 Expect(err).ToNot(HaveOccurred()) 234 streamTypeChan <- st 235 return true 236 }, 237 } 238 239 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 240 unknownStr := mockquic.NewMockStream(mockCtrl) 241 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 242 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 243 return unknownStr, nil 244 }) 245 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 246 <-testDone 247 return nil, errors.New("test done") 248 }) 249 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 250 conn.EXPECT().Context().Return(ctx).AnyTimes() 251 _, err := rt.RoundTrip(req) 252 Expect(err).To(MatchError("done")) 253 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 254 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 255 }) 256 257 It("handles errors that occur when reading the stream type", func() { 258 testErr := errors.New("test error") 259 done := make(chan struct{}) 260 unknownStr := mockquic.NewMockStream(mockCtrl) 261 rt := &SingleDestinationRoundTripper{ 262 Connection: conn, 263 UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { 264 defer close(done) 265 Expect(st).To(BeZero()) 266 Expect(str).To(Equal(unknownStr)) 267 Expect(err).To(MatchError(testErr)) 268 return true 269 }, 270 } 271 272 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 273 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 274 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 275 <-testDone 276 return nil, errors.New("test done") 277 }) 278 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 279 conn.EXPECT().Context().Return(ctx).AnyTimes() 280 _, err := rt.RoundTrip(req) 281 Expect(err).To(MatchError("done")) 282 Eventually(done).Should(BeClosed()) 283 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 284 }) 285 286 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 287 streamTypeChan := make(chan StreamType, 1) 288 rt := &SingleDestinationRoundTripper{ 289 Connection: conn, 290 UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 291 Expect(err).ToNot(HaveOccurred()) 292 streamTypeChan <- st 293 return false 294 }, 295 } 296 297 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 298 unknownStr := mockquic.NewMockStream(mockCtrl) 299 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 300 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 301 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 302 return unknownStr, nil 303 }) 304 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 305 <-testDone 306 return nil, errors.New("test done") 307 }) 308 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 309 conn.EXPECT().Context().Return(ctx).AnyTimes() 310 _, err := rt.RoundTrip(req) 311 Expect(err).To(MatchError("done")) 312 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 313 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 314 }) 315 }) 316 317 Context("SETTINGS handling", func() { 318 sendSettings := func() { 319 settingsFrameWritten := make(chan struct{}) 320 controlStr := mockquic.NewMockStream(mockCtrl) 321 var buf bytes.Buffer 322 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 323 defer GinkgoRecover() 324 buf.Write(b) 325 close(settingsFrameWritten) 326 return len(b), nil 327 }) 328 conn := mockquic.NewMockEarlyConnection(mockCtrl) 329 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 330 conn.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 331 <-settingsFrameWritten 332 return nil, errors.New("test done") 333 }) 334 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 335 <-settingsFrameWritten 336 return nil, errors.New("test done") 337 }).AnyTimes() 338 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 339 rt := &SingleDestinationRoundTripper{ 340 Connection: conn, 341 EnableDatagrams: true, 342 } 343 req, err := http.NewRequest(http.MethodGet, "https://quic-go.net", nil) 344 Expect(err).ToNot(HaveOccurred()) 345 _, err = rt.RoundTrip(req) 346 Expect(err).To(MatchError("test done")) 347 t, err := quicvarint.Read(&buf) 348 Expect(err).ToNot(HaveOccurred()) 349 Expect(t).To(BeEquivalentTo(streamTypeControlStream)) 350 settings, err := parseSettingsFrame(&buf, uint64(buf.Len())) 351 Expect(err).ToNot(HaveOccurred()) 352 Expect(settings.Datagram).To(BeTrue()) 353 } 354 355 It("receives SETTINGS", func() { 356 sendSettings() 357 done := make(chan struct{}) 358 conn := mockquic.NewMockEarlyConnection(mockCtrl) 359 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 360 <-done 361 return nil, errors.New("test done") 362 }).MaxTimes(1) 363 b := quicvarint.Append(nil, streamTypeControlStream) 364 b = (&settingsFrame{ExtendedConnect: true}).Append(b) 365 r := bytes.NewReader(b) 366 controlStr := mockquic.NewMockStream(mockCtrl) 367 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 368 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 369 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 370 <-done 371 return nil, errors.New("test done") 372 }) 373 374 rt := &SingleDestinationRoundTripper{Connection: conn} 375 hconn := rt.Start() 376 Eventually(hconn.ReceivedSettings()).Should(BeClosed()) 377 settings := hconn.Settings() 378 Expect(settings.EnableExtendedConnect).To(BeTrue()) 379 // test shutdown 380 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 381 close(done) 382 }) 383 384 It("checks the server's SETTINGS before sending an Extended CONNECT request", func() { 385 sendSettings() 386 done := make(chan struct{}) 387 conn := mockquic.NewMockEarlyConnection(mockCtrl) 388 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 389 <-done 390 return nil, errors.New("test done") 391 }).MaxTimes(1) 392 b := quicvarint.Append(nil, streamTypeControlStream) 393 b = (&settingsFrame{ExtendedConnect: true}).Append(b) 394 r := bytes.NewReader(b) 395 controlStr := mockquic.NewMockStream(mockCtrl) 396 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 397 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 398 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 399 <-done 400 return nil, errors.New("test done") 401 }) 402 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 403 conn.EXPECT().Context().Return(context.Background()) 404 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("test error")) 405 406 rt := &SingleDestinationRoundTripper{Connection: conn} 407 _, err := rt.RoundTrip(&http.Request{ 408 Method: http.MethodConnect, 409 Proto: "connect", 410 Host: "localhost", 411 }) 412 Expect(err).To(MatchError("test error")) 413 414 // test shutdown 415 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 416 close(done) 417 }) 418 419 It("rejects Extended CONNECT requests if the server doesn't enable it", func() { 420 sendSettings() 421 done := make(chan struct{}) 422 conn := mockquic.NewMockEarlyConnection(mockCtrl) 423 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 424 <-done 425 return nil, errors.New("test done") 426 }).MaxTimes(1) 427 b := quicvarint.Append(nil, streamTypeControlStream) 428 b = (&settingsFrame{}).Append(b) 429 r := bytes.NewReader(b) 430 controlStr := mockquic.NewMockStream(mockCtrl) 431 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 432 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 433 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 434 <-done 435 return nil, errors.New("test done") 436 }) 437 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 438 conn.EXPECT().Context().Return(context.Background()) 439 440 rt := &SingleDestinationRoundTripper{Connection: conn} 441 _, err := rt.RoundTrip(&http.Request{ 442 Method: http.MethodConnect, 443 Proto: "connect", 444 Host: "localhost", 445 }) 446 Expect(err).To(MatchError("http3: server didn't enable Extended CONNECT")) 447 448 // test shutdown 449 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 450 close(done) 451 }) 452 }) 453 454 Context("Doing requests", func() { 455 var ( 456 req *http.Request 457 str *mockquic.MockStream 458 conn *mockquic.MockEarlyConnection 459 cl *SingleDestinationRoundTripper 460 settingsFrameWritten chan struct{} 461 ) 462 testDone := make(chan struct{}) 463 464 decodeHeader := func(str io.Reader) map[string]string { 465 fields := make(map[string]string) 466 decoder := qpack.NewDecoder(nil) 467 468 frame, err := parseNextFrame(str, nil) 469 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 470 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 471 headersFrame := frame.(*headersFrame) 472 data := make([]byte, headersFrame.Length) 473 _, err = io.ReadFull(str, data) 474 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 475 hfs, err := decoder.DecodeFull(data) 476 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 477 for _, p := range hfs { 478 fields[p.Name] = p.Value 479 } 480 return fields 481 } 482 483 BeforeEach(func() { 484 settingsFrameWritten = make(chan struct{}) 485 controlStr := mockquic.NewMockStream(mockCtrl) 486 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 487 defer GinkgoRecover() 488 r := bytes.NewReader(b) 489 streamType, err := quicvarint.Read(r) 490 Expect(err).ToNot(HaveOccurred()) 491 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 492 close(settingsFrameWritten) 493 return len(b), nil 494 }) // SETTINGS frame 495 str = mockquic.NewMockStream(mockCtrl) 496 str.EXPECT().StreamID().AnyTimes() 497 conn = mockquic.NewMockEarlyConnection(mockCtrl) 498 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 499 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 500 <-testDone 501 return nil, errors.New("test done") 502 }) 503 cl = &SingleDestinationRoundTripper{Connection: conn} 504 var err error 505 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 506 Expect(err).ToNot(HaveOccurred()) 507 }) 508 509 AfterEach(func() { 510 testDone <- struct{}{} 511 Eventually(settingsFrameWritten).Should(BeClosed()) 512 }) 513 514 It("errors if it can't open a request stream", func() { 515 testErr := errors.New("stream open error") 516 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 517 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 518 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 519 _, err := cl.RoundTrip(req) 520 Expect(err).To(MatchError(testErr)) 521 }) 522 523 DescribeTable( 524 "performs a 0-RTT request", 525 func(method, serialized string) { 526 testErr := errors.New("stream open error") 527 req.Method = method 528 // don't EXPECT any calls to HandshakeComplete() 529 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 530 buf := &bytes.Buffer{} 531 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 532 str.EXPECT().Close() 533 str.EXPECT().CancelWrite(gomock.Any()) 534 str.EXPECT().CancelRead(gomock.Any()) 535 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 536 return 0, testErr 537 }) 538 _, err := cl.RoundTrip(req) 539 Expect(err).To(MatchError(testErr)) 540 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", serialized)) 541 // make sure the request wasn't modified 542 Expect(req.Method).To(Equal(method)) 543 }, 544 Entry("GET", MethodGet0RTT, http.MethodGet), 545 Entry("HEAD", MethodHead0RTT, http.MethodHead), 546 ) 547 548 It("returns a response", func() { 549 rspBuf := bytes.NewBuffer(encodeResponse(418)) 550 gomock.InOrder( 551 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 552 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 553 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 554 ) 555 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 556 str.EXPECT().Close() 557 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 558 rsp, err := cl.RoundTrip(req) 559 Expect(err).ToNot(HaveOccurred()) 560 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 561 Expect(rsp.ProtoMajor).To(Equal(3)) 562 Expect(rsp.StatusCode).To(Equal(418)) 563 Expect(rsp.Request).ToNot(BeNil()) 564 }) 565 566 Context("requests containing a Body", func() { 567 var strBuf *bytes.Buffer 568 569 BeforeEach(func() { 570 strBuf = &bytes.Buffer{} 571 gomock.InOrder( 572 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 573 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 574 ) 575 body := &mockBody{} 576 body.SetData([]byte("request body")) 577 var err error 578 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 579 Expect(err).ToNot(HaveOccurred()) 580 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 581 }) 582 583 It("sends a request", func() { 584 done := make(chan struct{}) 585 gomock.InOrder( 586 str.EXPECT().Close().Do(func() error { close(done); return nil }), 587 // when reading the response errors 588 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 589 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 590 ) 591 // the response body is sent asynchronously, while already reading the response 592 testErr := errors.New("test done") 593 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 594 <-done 595 return 0, testErr 596 }) 597 _, err := cl.RoundTrip(req) 598 Expect(err).To(MatchError(testErr)) 599 hfs := decodeHeader(strBuf) 600 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 601 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 602 }) 603 604 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 605 req.ContentLength = 7 606 var once sync.Once 607 done := make(chan struct{}) 608 str.EXPECT().CancelRead(gomock.Any()) 609 gomock.InOrder( 610 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 611 once.Do(func() { 612 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 613 close(done) 614 }) 615 }).AnyTimes(), 616 str.EXPECT().Close().MaxTimes(1), 617 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 618 ) 619 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 620 <-done 621 return 0, errors.New("done") 622 }) 623 cl.RoundTrip(req) 624 Expect(strBuf.String()).To(ContainSubstring("request")) 625 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 626 }) 627 628 It("returns the error that occurred when reading the body", func() { 629 req.Body.(*mockBody).readErr = errors.New("testErr") 630 done := make(chan struct{}) 631 str.EXPECT().CancelRead(gomock.Any()) 632 gomock.InOrder( 633 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 634 close(done) 635 }), 636 str.EXPECT().CancelWrite(gomock.Any()), 637 ) 638 639 // the response body is sent asynchronously, while already reading the response 640 testErr := errors.New("test done") 641 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 642 <-done 643 return 0, testErr 644 }) 645 closed := make(chan struct{}) 646 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 647 _, err := cl.RoundTrip(req) 648 Expect(err).To(MatchError(testErr)) 649 Eventually(closed).Should(BeClosed()) 650 }) 651 652 It("closes the connection when the first frame is not a HEADERS frame", func() { 653 b := (&dataFrame{Length: 0x42}).Append(nil) 654 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 655 closed := make(chan struct{}) 656 r := bytes.NewReader(b) 657 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 658 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 659 _, err := cl.RoundTrip(req) 660 Expect(err).To(MatchError("http3: expected first frame to be a HEADERS frame")) 661 Eventually(closed).Should(BeClosed()) 662 }) 663 664 It("cancels the stream when parsing the headers fails", func() { 665 headerBuf := &bytes.Buffer{} 666 enc := qpack.NewEncoder(headerBuf) 667 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 668 Expect(enc.Close()).To(Succeed()) 669 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 670 b = append(b, headerBuf.Bytes()...) 671 672 r := bytes.NewReader(b) 673 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) 674 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 675 closed := make(chan struct{}) 676 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 677 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 678 _, err := cl.RoundTrip(req) 679 Expect(err).To(HaveOccurred()) 680 Eventually(closed).Should(BeClosed()) 681 }) 682 683 It("cancels the stream when the HEADERS frame is too large", func() { 684 cl.MaxResponseHeaderBytes = 1337 685 b := (&headersFrame{Length: 1338}).Append(nil) 686 r := bytes.NewReader(b) 687 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 688 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 689 closed := make(chan struct{}) 690 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 691 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 692 _, err := cl.RoundTrip(req) 693 Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)")) 694 Eventually(closed).Should(BeClosed()) 695 }) 696 697 It("opens a request stream", func() { 698 cl.Connection.(quic.EarlyConnection).HandshakeComplete() 699 str, err := cl.OpenRequestStream(context.Background()) 700 Expect(err).ToNot(HaveOccurred()) 701 Expect(str.SendRequestHeader(req)).To(Succeed()) 702 str.Write([]byte("foobar")) 703 d := dataFrame{Length: 6} 704 data := d.Append([]byte{}) 705 data = append(data, []byte("foobar")...) 706 Expect(bytes.Contains(strBuf.Bytes(), data)).To(BeTrue()) 707 }) 708 }) 709 710 Context("request cancellations", func() { 711 It("cancels a request while waiting for the handshake to complete", func() { 712 ctx, cancel := context.WithCancel(context.Background()) 713 req := req.WithContext(ctx) 714 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 715 716 errChan := make(chan error) 717 go func() { 718 _, err := cl.RoundTrip(req) 719 errChan <- err 720 }() 721 Consistently(errChan).ShouldNot(Receive()) 722 cancel() 723 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 724 }) 725 726 It("cancels a request while the request is still in flight", func() { 727 ctx, cancel := context.WithCancel(context.Background()) 728 req := req.WithContext(ctx) 729 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 730 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 731 buf := &bytes.Buffer{} 732 str.EXPECT().Close().MaxTimes(1) 733 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 734 735 done := make(chan struct{}) 736 canceled := make(chan struct{}) 737 gomock.InOrder( 738 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 739 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 740 ) 741 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 742 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1) 743 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 744 cancel() 745 <-canceled 746 return 0, errors.New("test done") 747 }) 748 _, err := cl.RoundTrip(req) 749 Expect(err).To(MatchError(context.Canceled)) 750 Eventually(done).Should(BeClosed()) 751 }) 752 753 It("cancels a request after the response arrived", func() { 754 rspBuf := bytes.NewBuffer(encodeResponse(404)) 755 756 ctx, cancel := context.WithCancel(context.Background()) 757 req := req.WithContext(ctx) 758 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 759 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 760 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 761 buf := &bytes.Buffer{} 762 str.EXPECT().Close().MaxTimes(1) 763 764 done := make(chan struct{}) 765 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 766 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 767 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 768 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 769 _, err := cl.RoundTrip(req) 770 Expect(err).ToNot(HaveOccurred()) 771 cancel() 772 Eventually(done).Should(BeClosed()) 773 }) 774 }) 775 776 Context("gzip compression", func() { 777 BeforeEach(func() { 778 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 779 }) 780 781 It("adds the gzip header to requests", func() { 782 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 783 buf := &bytes.Buffer{} 784 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 785 gomock.InOrder( 786 str.EXPECT().Close(), 787 // when the Read errors 788 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 789 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 790 ) 791 testErr := errors.New("test done") 792 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 793 _, err := cl.RoundTrip(req) 794 Expect(err).To(MatchError(testErr)) 795 hfs := decodeHeader(buf) 796 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 797 }) 798 799 It("doesn't add gzip if the header disable it", func() { 800 client := &SingleDestinationRoundTripper{ 801 Connection: conn, 802 DisableCompression: true, 803 } 804 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 805 buf := &bytes.Buffer{} 806 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 807 gomock.InOrder( 808 str.EXPECT().Close(), 809 // when the Read errors 810 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 811 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 812 ) 813 testErr := errors.New("test done") 814 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 815 _, err := client.RoundTrip(req) 816 Expect(err).To(MatchError(testErr)) 817 hfs := decodeHeader(buf) 818 Expect(hfs).ToNot(HaveKey("accept-encoding")) 819 }) 820 821 It("decompresses the response", func() { 822 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 823 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 824 buf := &bytes.Buffer{} 825 rstr := mockquic.NewMockStream(mockCtrl) 826 rstr.EXPECT().StreamID().AnyTimes() 827 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 828 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 829 rw.Header().Set("Content-Encoding", "gzip") 830 gz := gzip.NewWriter(rw) 831 gz.Write([]byte("gzipped response")) 832 gz.Close() 833 rw.Flush() 834 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 835 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 836 str.EXPECT().Close() 837 838 rsp, err := cl.RoundTrip(req) 839 Expect(err).ToNot(HaveOccurred()) 840 data, err := io.ReadAll(rsp.Body) 841 Expect(err).ToNot(HaveOccurred()) 842 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 843 Expect(string(data)).To(Equal("gzipped response")) 844 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 845 Expect(rsp.Uncompressed).To(BeTrue()) 846 }) 847 848 It("only decompresses the response if the response contains the right content-encoding header", func() { 849 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 850 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 851 buf := &bytes.Buffer{} 852 rstr := mockquic.NewMockStream(mockCtrl) 853 rstr.EXPECT().StreamID().AnyTimes() 854 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 855 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 856 rw.Write([]byte("not gzipped")) 857 rw.Flush() 858 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 859 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 860 str.EXPECT().Close() 861 862 rsp, err := cl.RoundTrip(req) 863 Expect(err).ToNot(HaveOccurred()) 864 data, err := io.ReadAll(rsp.Body) 865 Expect(err).ToNot(HaveOccurred()) 866 Expect(string(data)).To(Equal("not gzipped")) 867 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 868 }) 869 }) 870 871 Context("1xx status code", func() { 872 It("continues to read next header if code is 103", func() { 873 var ( 874 cnt int 875 status int 876 hdr textproto.MIMEHeader 877 ) 878 header1 := "</style.css>; rel=preload; as=style" 879 header2 := "</script.js>; rel=preload; as=script" 880 ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ 881 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 882 cnt++ 883 status = code 884 hdr = header 885 return nil 886 }, 887 }) 888 req := req.WithContext(ctx) 889 rspBuf := bytes.NewBuffer(encodeResponse(103)) 890 gomock.InOrder( 891 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 892 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), 893 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 894 ) 895 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 896 str.EXPECT().Close() 897 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 898 rsp, err := cl.RoundTrip(req) 899 Expect(err).ToNot(HaveOccurred()) 900 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 901 Expect(rsp.ProtoMajor).To(Equal(3)) 902 Expect(rsp.StatusCode).To(Equal(200)) 903 Expect(rsp.Header).To(HaveKeyWithValue("Link", []string{header1, header2})) 904 Expect(status).To(Equal(103)) 905 Expect(cnt).To(Equal(1)) 906 Expect(hdr).To(HaveKeyWithValue("Link", []string{header1, header2})) 907 Expect(rsp.Request).ToNot(BeNil()) 908 }) 909 910 It("doesn't continue to read next header if code is a terminal status", func() { 911 cnt := 0 912 status := 0 913 ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ 914 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 915 cnt++ 916 status = code 917 return nil 918 }, 919 }) 920 req := req.WithContext(ctx) 921 rspBuf := bytes.NewBuffer(encodeResponse(101)) 922 gomock.InOrder( 923 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 924 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), 925 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 926 ) 927 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 928 str.EXPECT().Close() 929 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 930 rsp, err := cl.RoundTrip(req) 931 Expect(err).ToNot(HaveOccurred()) 932 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 933 Expect(rsp.ProtoMajor).To(Equal(3)) 934 Expect(rsp.StatusCode).To(Equal(101)) 935 Expect(status).To(Equal(0)) 936 Expect(cnt).To(Equal(0)) 937 Expect(rsp.Request).ToNot(BeNil()) 938 }) 939 }) 940 }) 941 })