github.com/quic-go/quic-go@v0.44.0/http3/client_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "context" 7 "errors" 8 "io" 9 "net/http" 10 "net/http/httptrace" 11 "net/textproto" 12 "sync" 13 "time" 14 15 "github.com/quic-go/quic-go" 16 mockquic "github.com/quic-go/quic-go/internal/mocks/quic" 17 "github.com/quic-go/quic-go/quicvarint" 18 19 "github.com/quic-go/qpack" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 "go.uber.org/mock/gomock" 24 ) 25 26 func encodeResponse(status int) []byte { 27 buf := &bytes.Buffer{} 28 rstr := mockquic.NewMockStream(mockCtrl) 29 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 30 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 31 if status == http.StatusEarlyHints { 32 rw.header.Add("Link", "</style.css>; rel=preload; as=style") 33 rw.header.Add("Link", "</script.js>; rel=preload; as=script") 34 } 35 rw.WriteHeader(status) 36 rw.Flush() 37 return buf.Bytes() 38 } 39 40 var _ = Describe("Client", func() { 41 var handshakeChan <-chan struct{} // a closed chan 42 43 BeforeEach(func() { 44 ch := make(chan struct{}) 45 close(ch) 46 handshakeChan = ch 47 }) 48 49 Context("hijacking bidirectional streams", func() { 50 var ( 51 request *http.Request 52 conn *mockquic.MockEarlyConnection 53 settingsFrameWritten chan struct{} 54 ) 55 testDone := make(chan struct{}) 56 57 BeforeEach(func() { 58 testDone = make(chan struct{}) 59 settingsFrameWritten = make(chan struct{}) 60 controlStr := mockquic.NewMockStream(mockCtrl) 61 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 62 defer GinkgoRecover() 63 close(settingsFrameWritten) 64 return len(b), nil 65 }) 66 conn = mockquic.NewMockEarlyConnection(mockCtrl) 67 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 68 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 69 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 70 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() 71 var err error 72 request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 73 Expect(err).ToNot(HaveOccurred()) 74 }) 75 76 AfterEach(func() { 77 testDone <- struct{}{} 78 Eventually(settingsFrameWritten).Should(BeClosed()) 79 }) 80 81 It("hijacks a bidirectional stream of unknown frame type", func() { 82 id := quic.ConnectionTracingID(1234) 83 frameTypeChan := make(chan FrameType, 1) 84 rt := &SingleDestinationRoundTripper{ 85 Connection: conn, 86 StreamHijacker: func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 87 Expect(e).ToNot(HaveOccurred()) 88 Expect(connTracingID).To(Equal(id)) 89 frameTypeChan <- ft 90 return true, nil 91 }, 92 } 93 94 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 95 unknownStr := mockquic.NewMockStream(mockCtrl) 96 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 97 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 98 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 99 <-testDone 100 return nil, errors.New("test done") 101 }) 102 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 103 conn.EXPECT().Context().Return(ctx).AnyTimes() 104 _, err := rt.RoundTrip(request) 105 Expect(err).To(MatchError("done")) 106 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 107 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 108 }) 109 110 It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { 111 frameTypeChan := make(chan FrameType, 1) 112 rt := &SingleDestinationRoundTripper{ 113 Connection: conn, 114 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 115 Expect(e).ToNot(HaveOccurred()) 116 frameTypeChan <- ft 117 return false, nil 118 }, 119 } 120 121 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 122 unknownStr := mockquic.NewMockStream(mockCtrl) 123 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 124 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 125 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 126 <-testDone 127 return nil, errors.New("test done") 128 }) 129 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 130 conn.EXPECT().Context().Return(ctx).AnyTimes() 131 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 132 _, err := rt.RoundTrip(request) 133 Expect(err).To(MatchError("done")) 134 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 135 }) 136 137 It("closes the connection when hijacker returned error", func() { 138 frameTypeChan := make(chan FrameType, 1) 139 rt := &SingleDestinationRoundTripper{ 140 Connection: conn, 141 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 142 Expect(e).ToNot(HaveOccurred()) 143 frameTypeChan <- ft 144 return false, errors.New("error in hijacker") 145 }, 146 } 147 148 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 149 unknownStr := mockquic.NewMockStream(mockCtrl) 150 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 151 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 152 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 153 <-testDone 154 return nil, errors.New("test done") 155 }) 156 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 157 conn.EXPECT().Context().Return(ctx).AnyTimes() 158 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 159 _, err := rt.RoundTrip(request) 160 Expect(err).To(MatchError("done")) 161 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 162 }) 163 164 It("handles errors that occur when reading the frame type", func() { 165 testErr := errors.New("test error") 166 unknownStr := mockquic.NewMockStream(mockCtrl) 167 done := make(chan struct{}) 168 rt := &SingleDestinationRoundTripper{ 169 Connection: conn, 170 StreamHijacker: func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, e error) (hijacked bool, err error) { 171 defer close(done) 172 Expect(e).To(MatchError(testErr)) 173 Expect(ft).To(BeZero()) 174 Expect(str).To(Equal(unknownStr)) 175 return false, nil 176 }, 177 } 178 179 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 180 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 181 conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 182 <-testDone 183 return nil, errors.New("test done") 184 }) 185 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 186 conn.EXPECT().Context().Return(ctx).AnyTimes() 187 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() 188 _, err := rt.RoundTrip(request) 189 Expect(err).To(MatchError("done")) 190 Eventually(done).Should(BeClosed()) 191 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 192 }) 193 }) 194 195 Context("hijacking unidirectional streams", func() { 196 var ( 197 req *http.Request 198 conn *mockquic.MockEarlyConnection 199 settingsFrameWritten chan struct{} 200 ) 201 testDone := make(chan struct{}) 202 203 BeforeEach(func() { 204 testDone = make(chan struct{}) 205 settingsFrameWritten = make(chan struct{}) 206 controlStr := mockquic.NewMockStream(mockCtrl) 207 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 208 defer GinkgoRecover() 209 close(settingsFrameWritten) 210 return len(b), nil 211 }) 212 conn = mockquic.NewMockEarlyConnection(mockCtrl) 213 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 214 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 215 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) 216 var err error 217 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 218 Expect(err).ToNot(HaveOccurred()) 219 }) 220 221 AfterEach(func() { 222 testDone <- struct{}{} 223 Eventually(settingsFrameWritten).Should(BeClosed()) 224 }) 225 226 It("hijacks an unidirectional stream of unknown stream type", func() { 227 id := quic.ConnectionTracingID(100) 228 streamTypeChan := make(chan StreamType, 1) 229 rt := &SingleDestinationRoundTripper{ 230 Connection: conn, 231 UniStreamHijacker: func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 232 Expect(connTracingID).To(Equal(id)) 233 Expect(err).ToNot(HaveOccurred()) 234 streamTypeChan <- st 235 return true 236 }, 237 } 238 239 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 240 unknownStr := mockquic.NewMockStream(mockCtrl) 241 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 242 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 243 return unknownStr, nil 244 }) 245 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 246 <-testDone 247 return nil, errors.New("test done") 248 }) 249 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 250 conn.EXPECT().Context().Return(ctx).AnyTimes() 251 _, err := rt.RoundTrip(req) 252 Expect(err).To(MatchError("done")) 253 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 254 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 255 }) 256 257 It("handles errors that occur when reading the stream type", func() { 258 testErr := errors.New("test error") 259 done := make(chan struct{}) 260 unknownStr := mockquic.NewMockStream(mockCtrl) 261 rt := &SingleDestinationRoundTripper{ 262 Connection: conn, 263 UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { 264 defer close(done) 265 Expect(st).To(BeZero()) 266 Expect(str).To(Equal(unknownStr)) 267 Expect(err).To(MatchError(testErr)) 268 return true 269 }, 270 } 271 272 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr) 273 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 274 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 275 <-testDone 276 return nil, errors.New("test done") 277 }) 278 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 279 conn.EXPECT().Context().Return(ctx).AnyTimes() 280 _, err := rt.RoundTrip(req) 281 Expect(err).To(MatchError("done")) 282 Eventually(done).Should(BeClosed()) 283 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 284 }) 285 286 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 287 streamTypeChan := make(chan StreamType, 1) 288 rt := &SingleDestinationRoundTripper{ 289 Connection: conn, 290 UniStreamHijacker: func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 291 Expect(err).ToNot(HaveOccurred()) 292 streamTypeChan <- st 293 return false 294 }, 295 } 296 297 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 298 unknownStr := mockquic.NewMockStream(mockCtrl) 299 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 300 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 301 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 302 return unknownStr, nil 303 }) 304 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 305 <-testDone 306 return nil, errors.New("test done") 307 }) 308 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 309 conn.EXPECT().Context().Return(ctx).AnyTimes() 310 _, err := rt.RoundTrip(req) 311 Expect(err).To(MatchError("done")) 312 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 313 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 314 }) 315 }) 316 317 Context("SETTINGS handling", func() { 318 sendSettings := func() { 319 settingsFrameWritten := make(chan struct{}) 320 controlStr := mockquic.NewMockStream(mockCtrl) 321 var buf bytes.Buffer 322 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 323 defer GinkgoRecover() 324 buf.Write(b) 325 close(settingsFrameWritten) 326 return len(b), nil 327 }) 328 conn := mockquic.NewMockEarlyConnection(mockCtrl) 329 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 330 conn.EXPECT().OpenStreamSync(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { 331 <-settingsFrameWritten 332 return nil, errors.New("test done") 333 }) 334 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 335 <-settingsFrameWritten 336 return nil, errors.New("test done") 337 }).AnyTimes() 338 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 339 rt := &SingleDestinationRoundTripper{ 340 Connection: conn, 341 EnableDatagrams: true, 342 } 343 req, err := http.NewRequest(http.MethodGet, "https://quic-go.net", nil) 344 Expect(err).ToNot(HaveOccurred()) 345 _, err = rt.RoundTrip(req) 346 Expect(err).To(MatchError("test done")) 347 t, err := quicvarint.Read(&buf) 348 Expect(err).ToNot(HaveOccurred()) 349 Expect(t).To(BeEquivalentTo(streamTypeControlStream)) 350 settings, err := parseSettingsFrame(&buf, uint64(buf.Len())) 351 Expect(err).ToNot(HaveOccurred()) 352 Expect(settings.Datagram).To(BeTrue()) 353 } 354 355 It("receives SETTINGS", func() { 356 sendSettings() 357 done := make(chan struct{}) 358 conn := mockquic.NewMockEarlyConnection(mockCtrl) 359 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 360 <-done 361 return nil, errors.New("test done") 362 }).MaxTimes(1) 363 b := quicvarint.Append(nil, streamTypeControlStream) 364 b = (&settingsFrame{ExtendedConnect: true}).Append(b) 365 r := bytes.NewReader(b) 366 controlStr := mockquic.NewMockStream(mockCtrl) 367 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 368 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 369 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 370 <-done 371 return nil, errors.New("test done") 372 }) 373 374 rt := &SingleDestinationRoundTripper{Connection: conn} 375 hconn := rt.Start() 376 Eventually(hconn.ReceivedSettings()).Should(BeClosed()) 377 settings := hconn.Settings() 378 Expect(settings.EnableExtendedConnect).To(BeTrue()) 379 // test shutdown 380 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 381 close(done) 382 }) 383 384 It("checks the server's SETTINGS before sending an Extended CONNECT request", func() { 385 sendSettings() 386 done := make(chan struct{}) 387 conn := mockquic.NewMockEarlyConnection(mockCtrl) 388 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 389 <-done 390 return nil, errors.New("test done") 391 }).MaxTimes(1) 392 b := quicvarint.Append(nil, streamTypeControlStream) 393 b = (&settingsFrame{ExtendedConnect: true}).Append(b) 394 r := bytes.NewReader(b) 395 controlStr := mockquic.NewMockStream(mockCtrl) 396 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 397 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 398 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 399 <-done 400 return nil, errors.New("test done") 401 }) 402 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 403 conn.EXPECT().Context().Return(context.Background()) 404 conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("test error")) 405 406 rt := &SingleDestinationRoundTripper{Connection: conn} 407 _, err := rt.RoundTrip(&http.Request{ 408 Method: http.MethodConnect, 409 Proto: "connect", 410 Host: "localhost", 411 }) 412 Expect(err).To(MatchError("test error")) 413 414 // test shutdown 415 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 416 close(done) 417 }) 418 419 It("rejects Extended CONNECT requests if the server doesn't enable it", func() { 420 sendSettings() 421 done := make(chan struct{}) 422 conn := mockquic.NewMockEarlyConnection(mockCtrl) 423 conn.EXPECT().OpenUniStream().DoAndReturn(func() (quic.SendStream, error) { 424 <-done 425 return nil, errors.New("test done") 426 }).MaxTimes(1) 427 b := quicvarint.Append(nil, streamTypeControlStream) 428 b = (&settingsFrame{}).Append(b) 429 r := bytes.NewReader(b) 430 controlStr := mockquic.NewMockStream(mockCtrl) 431 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 432 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(controlStr, nil) 433 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 434 <-done 435 return nil, errors.New("test done") 436 }) 437 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 438 conn.EXPECT().Context().Return(context.Background()) 439 440 rt := &SingleDestinationRoundTripper{Connection: conn} 441 _, err := rt.RoundTrip(&http.Request{ 442 Method: http.MethodConnect, 443 Proto: "connect", 444 Host: "localhost", 445 }) 446 Expect(err).To(MatchError("http3: server didn't enable Extended CONNECT")) 447 448 // test shutdown 449 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 450 close(done) 451 }) 452 }) 453 454 Context("Doing requests", func() { 455 var ( 456 req *http.Request 457 str *mockquic.MockStream 458 conn *mockquic.MockEarlyConnection 459 cl *SingleDestinationRoundTripper 460 settingsFrameWritten chan struct{} 461 ) 462 testDone := make(chan struct{}) 463 464 decodeHeader := func(str io.Reader) map[string]string { 465 fields := make(map[string]string) 466 decoder := qpack.NewDecoder(nil) 467 468 fp := frameParser{r: str} 469 frame, err := fp.ParseNext() 470 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 471 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 472 headersFrame := frame.(*headersFrame) 473 data := make([]byte, headersFrame.Length) 474 _, err = io.ReadFull(str, data) 475 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 476 hfs, err := decoder.DecodeFull(data) 477 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 478 for _, p := range hfs { 479 fields[p.Name] = p.Value 480 } 481 return fields 482 } 483 484 BeforeEach(func() { 485 settingsFrameWritten = make(chan struct{}) 486 controlStr := mockquic.NewMockStream(mockCtrl) 487 controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) (int, error) { 488 defer GinkgoRecover() 489 r := bytes.NewReader(b) 490 streamType, err := quicvarint.Read(r) 491 Expect(err).ToNot(HaveOccurred()) 492 Expect(streamType).To(BeEquivalentTo(streamTypeControlStream)) 493 close(settingsFrameWritten) 494 return len(b), nil 495 }) // SETTINGS frame 496 str = mockquic.NewMockStream(mockCtrl) 497 str.EXPECT().Context().Return(context.Background()).AnyTimes() 498 str.EXPECT().StreamID().AnyTimes() 499 conn = mockquic.NewMockEarlyConnection(mockCtrl) 500 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 501 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 502 <-testDone 503 return nil, errors.New("test done") 504 }) 505 cl = &SingleDestinationRoundTripper{Connection: conn} 506 var err error 507 req, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) 508 Expect(err).ToNot(HaveOccurred()) 509 }) 510 511 AfterEach(func() { 512 testDone <- struct{}{} 513 Eventually(settingsFrameWritten).Should(BeClosed()) 514 }) 515 516 It("errors if it can't open a request stream", func() { 517 testErr := errors.New("stream open error") 518 conn.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr) 519 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).MaxTimes(1) 520 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 521 _, err := cl.RoundTrip(req) 522 Expect(err).To(MatchError(testErr)) 523 }) 524 525 DescribeTable( 526 "performs a 0-RTT request", 527 func(method, serialized string) { 528 testErr := errors.New("stream open error") 529 req.Method = method 530 // don't EXPECT any calls to HandshakeComplete() 531 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 532 buf := &bytes.Buffer{} 533 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 534 str.EXPECT().Close() 535 str.EXPECT().CancelWrite(gomock.Any()) 536 str.EXPECT().CancelRead(gomock.Any()) 537 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 538 return 0, testErr 539 }) 540 _, err := cl.RoundTrip(req) 541 Expect(err).To(MatchError(testErr)) 542 Expect(decodeHeader(buf)).To(HaveKeyWithValue(":method", serialized)) 543 // make sure the request wasn't modified 544 Expect(req.Method).To(Equal(method)) 545 }, 546 Entry("GET", MethodGet0RTT, http.MethodGet), 547 Entry("HEAD", MethodHead0RTT, http.MethodHead), 548 ) 549 550 It("returns a response", func() { 551 rspBuf := bytes.NewBuffer(encodeResponse(418)) 552 gomock.InOrder( 553 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 554 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 555 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 556 ) 557 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 558 str.EXPECT().Close() 559 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 560 rsp, err := cl.RoundTrip(req) 561 Expect(err).ToNot(HaveOccurred()) 562 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 563 Expect(rsp.ProtoMajor).To(Equal(3)) 564 Expect(rsp.StatusCode).To(Equal(418)) 565 Expect(rsp.Request).ToNot(BeNil()) 566 }) 567 568 Context("requests containing a Body", func() { 569 var strBuf *bytes.Buffer 570 571 BeforeEach(func() { 572 strBuf = &bytes.Buffer{} 573 gomock.InOrder( 574 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 575 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), 576 ) 577 body := &mockBody{} 578 body.SetData([]byte("request body")) 579 var err error 580 req, err = http.NewRequest("POST", "https://quic.clemente.io:1337/upload", body) 581 Expect(err).ToNot(HaveOccurred()) 582 str.EXPECT().Write(gomock.Any()).DoAndReturn(strBuf.Write).AnyTimes() 583 }) 584 585 It("sends a request", func() { 586 done := make(chan struct{}) 587 gomock.InOrder( 588 str.EXPECT().Close().Do(func() error { close(done); return nil }), 589 // when reading the response errors 590 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 591 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 592 ) 593 // the response body is sent asynchronously, while already reading the response 594 testErr := errors.New("test done") 595 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 596 <-done 597 return 0, testErr 598 }) 599 _, err := cl.RoundTrip(req) 600 Expect(err).To(MatchError(testErr)) 601 hfs := decodeHeader(strBuf) 602 Expect(hfs).To(HaveKeyWithValue(":method", "POST")) 603 Expect(hfs).To(HaveKeyWithValue(":path", "/upload")) 604 }) 605 606 It("doesn't send more bytes than allowed by http.Request.ContentLength", func() { 607 req.ContentLength = 7 608 var once sync.Once 609 done := make(chan struct{}) 610 str.EXPECT().CancelRead(gomock.Any()) 611 gomock.InOrder( 612 str.EXPECT().CancelWrite(gomock.Any()).Do(func(c quic.StreamErrorCode) { 613 once.Do(func() { 614 Expect(c).To(Equal(quic.StreamErrorCode(ErrCodeRequestCanceled))) 615 close(done) 616 }) 617 }).AnyTimes(), 618 str.EXPECT().Close().MaxTimes(1), 619 str.EXPECT().CancelWrite(gomock.Any()).AnyTimes(), 620 ) 621 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 622 <-done 623 return 0, errors.New("done") 624 }) 625 cl.RoundTrip(req) 626 Expect(strBuf.String()).To(ContainSubstring("request")) 627 Expect(strBuf.String()).ToNot(ContainSubstring("request body")) 628 }) 629 630 It("returns the error that occurred when reading the body", func() { 631 req.Body.(*mockBody).readErr = errors.New("testErr") 632 done := make(chan struct{}) 633 str.EXPECT().CancelRead(gomock.Any()) 634 gomock.InOrder( 635 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { 636 close(done) 637 }), 638 str.EXPECT().CancelWrite(gomock.Any()), 639 ) 640 641 // the response body is sent asynchronously, while already reading the response 642 testErr := errors.New("test done") 643 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 644 <-done 645 return 0, testErr 646 }) 647 closed := make(chan struct{}) 648 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 649 _, err := cl.RoundTrip(req) 650 Expect(err).To(MatchError(testErr)) 651 Eventually(closed).Should(BeClosed()) 652 }) 653 654 It("closes the connection when the first frame is not a HEADERS frame", func() { 655 b := (&dataFrame{Length: 0x42}).Append(nil) 656 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()) 657 closed := make(chan struct{}) 658 r := bytes.NewReader(b) 659 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 660 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 661 _, err := cl.RoundTrip(req) 662 Expect(err).To(MatchError("http3: expected first frame to be a HEADERS frame")) 663 Eventually(closed).Should(BeClosed()) 664 }) 665 666 It("cancels the stream when parsing the headers fails", func() { 667 headerBuf := &bytes.Buffer{} 668 enc := qpack.NewEncoder(headerBuf) 669 Expect(enc.WriteField(qpack.HeaderField{Name: ":method", Value: "GET"})).To(Succeed()) // not a valid response pseudo header 670 Expect(enc.Close()).To(Succeed()) 671 b := (&headersFrame{Length: uint64(headerBuf.Len())}).Append(nil) 672 b = append(b, headerBuf.Bytes()...) 673 674 r := bytes.NewReader(b) 675 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeMessageError)) 676 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeMessageError)) 677 closed := make(chan struct{}) 678 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 679 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 680 _, err := cl.RoundTrip(req) 681 Expect(err).To(HaveOccurred()) 682 Eventually(closed).Should(BeClosed()) 683 }) 684 685 It("cancels the stream when the HEADERS frame is too large", func() { 686 cl.MaxResponseHeaderBytes = 1337 687 b := (&headersFrame{Length: 1338}).Append(nil) 688 r := bytes.NewReader(b) 689 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 690 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)) 691 closed := make(chan struct{}) 692 str.EXPECT().Close().Do(func() error { close(closed); return nil }) 693 str.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 694 _, err := cl.RoundTrip(req) 695 Expect(err).To(MatchError("http3: HEADERS frame too large: 1338 bytes (max: 1337)")) 696 Eventually(closed).Should(BeClosed()) 697 }) 698 699 It("opens a request stream", func() { 700 cl.Connection.(quic.EarlyConnection).HandshakeComplete() 701 str, err := cl.OpenRequestStream(context.Background()) 702 Expect(err).ToNot(HaveOccurred()) 703 Expect(str.SendRequestHeader(req)).To(Succeed()) 704 str.Write([]byte("foobar")) 705 d := dataFrame{Length: 6} 706 data := d.Append([]byte{}) 707 data = append(data, []byte("foobar")...) 708 Expect(bytes.Contains(strBuf.Bytes(), data)).To(BeTrue()) 709 }) 710 }) 711 712 Context("request cancellations", func() { 713 It("cancels a request while waiting for the handshake to complete", func() { 714 ctx, cancel := context.WithCancel(context.Background()) 715 req := req.WithContext(ctx) 716 conn.EXPECT().HandshakeComplete().Return(make(chan struct{})) 717 718 errChan := make(chan error) 719 go func() { 720 _, err := cl.RoundTrip(req) 721 errChan <- err 722 }() 723 Consistently(errChan).ShouldNot(Receive()) 724 cancel() 725 Eventually(errChan).Should(Receive(MatchError("context canceled"))) 726 }) 727 728 It("cancels a request while the request is still in flight", func() { 729 ctx, cancel := context.WithCancel(context.Background()) 730 req := req.WithContext(ctx) 731 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 732 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 733 buf := &bytes.Buffer{} 734 str.EXPECT().Close().MaxTimes(1) 735 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 736 737 done := make(chan struct{}) 738 canceled := make(chan struct{}) 739 gomock.InOrder( 740 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), 741 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), 742 ) 743 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) 744 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1) 745 str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { 746 cancel() 747 <-canceled 748 return 0, errors.New("test done") 749 }) 750 _, err := cl.RoundTrip(req) 751 Expect(err).To(MatchError(context.Canceled)) 752 Eventually(done).Should(BeClosed()) 753 }) 754 755 It("cancels a request after the response arrived", func() { 756 rspBuf := bytes.NewBuffer(encodeResponse(404)) 757 758 ctx, cancel := context.WithCancel(context.Background()) 759 req := req.WithContext(ctx) 760 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 761 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) 762 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 763 buf := &bytes.Buffer{} 764 str.EXPECT().Close().MaxTimes(1) 765 766 done := make(chan struct{}) 767 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 768 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 769 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestCanceled)) 770 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }) 771 _, err := cl.RoundTrip(req) 772 Expect(err).ToNot(HaveOccurred()) 773 cancel() 774 Eventually(done).Should(BeClosed()) 775 }) 776 }) 777 778 Context("gzip compression", func() { 779 BeforeEach(func() { 780 conn.EXPECT().HandshakeComplete().Return(handshakeChan) 781 }) 782 783 It("adds the gzip header to requests", func() { 784 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 785 buf := &bytes.Buffer{} 786 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 787 gomock.InOrder( 788 str.EXPECT().Close(), 789 // when the Read errors 790 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 791 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 792 ) 793 testErr := errors.New("test done") 794 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 795 _, err := cl.RoundTrip(req) 796 Expect(err).To(MatchError(testErr)) 797 hfs := decodeHeader(buf) 798 Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) 799 }) 800 801 It("doesn't add gzip if the header disable it", func() { 802 client := &SingleDestinationRoundTripper{ 803 Connection: conn, 804 DisableCompression: true, 805 } 806 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 807 buf := &bytes.Buffer{} 808 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) 809 gomock.InOrder( 810 str.EXPECT().Close(), 811 // when the Read errors 812 str.EXPECT().CancelRead(gomock.Any()).MaxTimes(1), 813 str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1), 814 ) 815 testErr := errors.New("test done") 816 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 817 _, err := client.RoundTrip(req) 818 Expect(err).To(MatchError(testErr)) 819 hfs := decodeHeader(buf) 820 Expect(hfs).ToNot(HaveKey("accept-encoding")) 821 }) 822 823 It("decompresses the response", func() { 824 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 825 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 826 buf := &bytes.Buffer{} 827 rstr := mockquic.NewMockStream(mockCtrl) 828 rstr.EXPECT().StreamID().AnyTimes() 829 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 830 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 831 rw.Header().Set("Content-Encoding", "gzip") 832 gz := gzip.NewWriter(rw) 833 gz.Write([]byte("gzipped response")) 834 gz.Close() 835 rw.Flush() 836 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 837 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 838 str.EXPECT().Close() 839 840 rsp, err := cl.RoundTrip(req) 841 Expect(err).ToNot(HaveOccurred()) 842 data, err := io.ReadAll(rsp.Body) 843 Expect(err).ToNot(HaveOccurred()) 844 Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) 845 Expect(string(data)).To(Equal("gzipped response")) 846 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 847 Expect(rsp.Uncompressed).To(BeTrue()) 848 }) 849 850 It("only decompresses the response if the response contains the right content-encoding header", func() { 851 conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) 852 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) 853 buf := &bytes.Buffer{} 854 rstr := mockquic.NewMockStream(mockCtrl) 855 rstr.EXPECT().StreamID().AnyTimes() 856 rstr.EXPECT().Write(gomock.Any()).Do(buf.Write).AnyTimes() 857 rw := newResponseWriter(newStream(rstr, nil, nil), nil, false, nil) 858 rw.Write([]byte("not gzipped")) 859 rw.Flush() 860 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 861 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 862 str.EXPECT().Close() 863 864 rsp, err := cl.RoundTrip(req) 865 Expect(err).ToNot(HaveOccurred()) 866 data, err := io.ReadAll(rsp.Body) 867 Expect(err).ToNot(HaveOccurred()) 868 Expect(string(data)).To(Equal("not gzipped")) 869 Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) 870 }) 871 }) 872 873 Context("1xx status code", func() { 874 It("continues to read next header if code is 103", func() { 875 var ( 876 cnt int 877 status int 878 hdr textproto.MIMEHeader 879 ) 880 header1 := "</style.css>; rel=preload; as=style" 881 header2 := "</script.js>; rel=preload; as=script" 882 ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ 883 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 884 cnt++ 885 status = code 886 hdr = header 887 return nil 888 }, 889 }) 890 req := req.WithContext(ctx) 891 rspBuf := bytes.NewBuffer(encodeResponse(103)) 892 gomock.InOrder( 893 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 894 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), 895 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 896 ) 897 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 898 str.EXPECT().Close() 899 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 900 rsp, err := cl.RoundTrip(req) 901 Expect(err).ToNot(HaveOccurred()) 902 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 903 Expect(rsp.ProtoMajor).To(Equal(3)) 904 Expect(rsp.StatusCode).To(Equal(200)) 905 Expect(rsp.Header).To(HaveKeyWithValue("Link", []string{header1, header2})) 906 Expect(status).To(Equal(103)) 907 Expect(cnt).To(Equal(1)) 908 Expect(hdr).To(HaveKeyWithValue("Link", []string{header1, header2})) 909 Expect(rsp.Request).ToNot(BeNil()) 910 }) 911 912 It("doesn't continue to read next header if code is a terminal status", func() { 913 cnt := 0 914 status := 0 915 ctx := httptrace.WithClientTrace(req.Context(), &httptrace.ClientTrace{ 916 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 917 cnt++ 918 status = code 919 return nil 920 }, 921 }) 922 req := req.WithContext(ctx) 923 rspBuf := bytes.NewBuffer(encodeResponse(101)) 924 gomock.InOrder( 925 conn.EXPECT().HandshakeComplete().Return(handshakeChan), 926 conn.EXPECT().OpenStreamSync(ctx).Return(str, nil), 927 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), 928 ) 929 str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) 930 str.EXPECT().Close() 931 str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() 932 rsp, err := cl.RoundTrip(req) 933 Expect(err).ToNot(HaveOccurred()) 934 Expect(rsp.Proto).To(Equal("HTTP/3.0")) 935 Expect(rsp.ProtoMajor).To(Equal(3)) 936 Expect(rsp.StatusCode).To(Equal(101)) 937 Expect(status).To(Equal(0)) 938 Expect(cnt).To(Equal(0)) 939 Expect(rsp.Request).ToNot(BeNil()) 940 }) 941 }) 942 }) 943 })