github.com/tumi8/quic-go@v0.37.4-tum/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/tumi8/quic-go" 17 mockquic "github.com/tumi8/quic-go/noninternal/mocks/quic" 18 "github.com/tumi8/quic-go/noninternal/protocol" 19 "github.com/tumi8/quic-go/noninternal/testdata" 20 "github.com/tumi8/quic-go/noninternal/utils" 21 "github.com/tumi8/quic-go/quicvarint" 22 23 "github.com/golang/mock/gomock" 24 "github.com/quic-go/qpack" 25 26 . "github.com/onsi/ginkgo/v2" 27 . "github.com/onsi/gomega" 28 gmtypes "github.com/onsi/gomega/types" 29 ) 30 31 type mockAddr struct{ addr string } 32 33 func (ma *mockAddr) Network() string { return "udp" } 34 func (ma *mockAddr) String() string { return ma.addr } 35 36 type mockAddrListener struct { 37 *MockQUICEarlyListener 38 addr *mockAddr 39 } 40 41 func (m *mockAddrListener) Addr() net.Addr { 42 _ = m.MockQUICEarlyListener.Addr() 43 return m.addr 44 } 45 46 func newMockAddrListener(addr string) *mockAddrListener { 47 return &mockAddrListener{ 48 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 49 addr: &mockAddr{addr: addr}, 50 } 51 } 52 53 type noPortListener struct { 54 *mockAddrListener 55 } 56 57 func (m *noPortListener) Addr() net.Addr { 58 _ = m.mockAddrListener.Addr() 59 return &net.UnixAddr{ 60 Net: "unix", 61 Name: "/tmp/quic.sock", 62 } 63 } 64 65 var _ = Describe("Server", func() { 66 var ( 67 s *Server 68 origQuicListenAddr = quicListenAddr 69 ) 70 71 BeforeEach(func() { 72 s = &Server{ 73 TLSConfig: testdata.GetTLSConfig(), 74 logger: utils.DefaultLogger, 75 } 76 origQuicListenAddr = quicListenAddr 77 }) 78 79 AfterEach(func() { 80 quicListenAddr = origQuicListenAddr 81 }) 82 83 Context("handling requests", func() { 84 var ( 85 qpackDecoder *qpack.Decoder 86 str *mockquic.MockStream 87 conn *mockquic.MockEarlyConnection 88 exampleGetRequest *http.Request 89 examplePostRequest *http.Request 90 ) 91 reqContext := context.Background() 92 93 decodeHeader := func(str io.Reader) map[string][]string { 94 fields := make(map[string][]string) 95 decoder := qpack.NewDecoder(nil) 96 97 frame, err := parseNextFrame(str, nil) 98 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 99 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 100 headersFrame := frame.(*headersFrame) 101 data := make([]byte, headersFrame.Length) 102 _, err = io.ReadFull(str, data) 103 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 104 hfs, err := decoder.DecodeFull(data) 105 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 106 for _, p := range hfs { 107 fields[p.Name] = append(fields[p.Name], p.Value) 108 } 109 return fields 110 } 111 112 encodeRequest := func(req *http.Request) []byte { 113 buf := &bytes.Buffer{} 114 str := mockquic.NewMockStream(mockCtrl) 115 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 116 rw := newRequestWriter(utils.DefaultLogger) 117 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 118 return buf.Bytes() 119 } 120 121 setRequest := func(data []byte) { 122 buf := bytes.NewBuffer(data) 123 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 124 if buf.Len() == 0 { 125 return 0, io.EOF 126 } 127 return buf.Read(p) 128 }).AnyTimes() 129 } 130 131 BeforeEach(func() { 132 var err error 133 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 134 Expect(err).ToNot(HaveOccurred()) 135 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 136 Expect(err).ToNot(HaveOccurred()) 137 138 qpackDecoder = qpack.NewDecoder(nil) 139 str = mockquic.NewMockStream(mockCtrl) 140 conn = mockquic.NewMockEarlyConnection(mockCtrl) 141 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 142 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 143 conn.EXPECT().LocalAddr().AnyTimes() 144 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 145 }) 146 147 It("calls the HTTP handler function", func() { 148 requestChan := make(chan *http.Request, 1) 149 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 150 requestChan <- r 151 }) 152 153 setRequest(encodeRequest(exampleGetRequest)) 154 str.EXPECT().Context().Return(reqContext) 155 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 156 return len(p), nil 157 }).AnyTimes() 158 str.EXPECT().CancelRead(gomock.Any()) 159 160 Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) 161 var req *http.Request 162 Eventually(requestChan).Should(Receive(&req)) 163 Expect(req.Host).To(Equal("www.example.com")) 164 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 165 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 166 }) 167 168 It("returns 200 with an empty handler", func() { 169 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 170 171 responseBuf := &bytes.Buffer{} 172 setRequest(encodeRequest(exampleGetRequest)) 173 str.EXPECT().Context().Return(reqContext) 174 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 175 str.EXPECT().CancelRead(gomock.Any()) 176 177 serr := s.handleRequest(conn, str, qpackDecoder, nil) 178 Expect(serr.err).ToNot(HaveOccurred()) 179 hfs := decodeHeader(responseBuf) 180 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 181 }) 182 183 It("handles a aborting handler", func() { 184 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 185 panic(http.ErrAbortHandler) 186 }) 187 188 responseBuf := &bytes.Buffer{} 189 setRequest(encodeRequest(exampleGetRequest)) 190 str.EXPECT().Context().Return(reqContext) 191 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 192 str.EXPECT().CancelRead(gomock.Any()) 193 194 serr := s.handleRequest(conn, str, qpackDecoder, nil) 195 Expect(serr.err).ToNot(HaveOccurred()) 196 Expect(responseBuf.Bytes()).To(HaveLen(0)) 197 }) 198 199 It("handles a panicking handler", func() { 200 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 201 panic("foobar") 202 }) 203 204 responseBuf := &bytes.Buffer{} 205 setRequest(encodeRequest(exampleGetRequest)) 206 str.EXPECT().Context().Return(reqContext) 207 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 208 str.EXPECT().CancelRead(gomock.Any()) 209 210 serr := s.handleRequest(conn, str, qpackDecoder, nil) 211 Expect(serr.err).ToNot(HaveOccurred()) 212 Expect(responseBuf.Bytes()).To(HaveLen(0)) 213 }) 214 215 Context("hijacking bidirectional streams", func() { 216 var conn *mockquic.MockEarlyConnection 217 testDone := make(chan struct{}) 218 219 BeforeEach(func() { 220 testDone = make(chan struct{}) 221 conn = mockquic.NewMockEarlyConnection(mockCtrl) 222 controlStr := mockquic.NewMockStream(mockCtrl) 223 controlStr.EXPECT().Write(gomock.Any()) 224 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 225 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 226 conn.EXPECT().LocalAddr().AnyTimes() 227 }) 228 229 AfterEach(func() { testDone <- struct{}{} }) 230 231 It("hijacks a bidirectional stream of unknown frame type", func() { 232 frameTypeChan := make(chan FrameType, 1) 233 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 234 Expect(e).ToNot(HaveOccurred()) 235 frameTypeChan <- ft 236 return true, nil 237 } 238 239 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 240 unknownStr := mockquic.NewMockStream(mockCtrl) 241 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 242 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 243 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 244 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 245 <-testDone 246 return nil, errors.New("test done") 247 }) 248 s.handleConn(conn) 249 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 250 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 251 }) 252 253 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 254 frameTypeChan := make(chan FrameType, 1) 255 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 256 Expect(e).ToNot(HaveOccurred()) 257 frameTypeChan <- ft 258 return false, nil 259 } 260 261 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 262 unknownStr := mockquic.NewMockStream(mockCtrl) 263 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 264 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 265 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 266 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 267 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 268 <-testDone 269 return nil, errors.New("test done") 270 }) 271 s.handleConn(conn) 272 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 273 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 274 }) 275 276 It("cancels writing when hijacker returned error", func() { 277 frameTypeChan := make(chan FrameType, 1) 278 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 279 Expect(e).ToNot(HaveOccurred()) 280 frameTypeChan <- ft 281 return false, errors.New("error in hijacker") 282 } 283 284 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 285 unknownStr := mockquic.NewMockStream(mockCtrl) 286 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 287 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 288 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 289 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 290 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 291 <-testDone 292 return nil, errors.New("test done") 293 }) 294 s.handleConn(conn) 295 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 296 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 297 }) 298 299 It("handles errors that occur when reading the stream type", func() { 300 testErr := errors.New("test error") 301 done := make(chan struct{}) 302 unknownStr := mockquic.NewMockStream(mockCtrl) 303 s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { 304 defer close(done) 305 Expect(ft).To(BeZero()) 306 Expect(str).To(Equal(unknownStr)) 307 Expect(err).To(MatchError(testErr)) 308 return true, nil 309 } 310 311 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 312 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 313 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 314 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 315 <-testDone 316 return nil, errors.New("test done") 317 }) 318 s.handleConn(conn) 319 Eventually(done).Should(BeClosed()) 320 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 321 }) 322 }) 323 324 Context("hijacking unidirectional streams", func() { 325 var conn *mockquic.MockEarlyConnection 326 testDone := make(chan struct{}) 327 328 BeforeEach(func() { 329 testDone = make(chan struct{}) 330 conn = mockquic.NewMockEarlyConnection(mockCtrl) 331 controlStr := mockquic.NewMockStream(mockCtrl) 332 controlStr.EXPECT().Write(gomock.Any()) 333 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 334 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 335 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 336 conn.EXPECT().LocalAddr().AnyTimes() 337 }) 338 339 AfterEach(func() { testDone <- struct{}{} }) 340 341 It("hijacks an unidirectional stream of unknown stream type", func() { 342 streamTypeChan := make(chan StreamType, 1) 343 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 344 Expect(err).ToNot(HaveOccurred()) 345 streamTypeChan <- st 346 return true 347 } 348 349 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 350 unknownStr := mockquic.NewMockStream(mockCtrl) 351 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 352 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 353 return unknownStr, nil 354 }) 355 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 356 <-testDone 357 return nil, errors.New("test done") 358 }) 359 s.handleConn(conn) 360 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 361 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 362 }) 363 364 It("handles errors that occur when reading the stream type", func() { 365 testErr := errors.New("test error") 366 done := make(chan struct{}) 367 unknownStr := mockquic.NewMockStream(mockCtrl) 368 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 369 defer close(done) 370 Expect(st).To(BeZero()) 371 Expect(str).To(Equal(unknownStr)) 372 Expect(err).To(MatchError(testErr)) 373 return true 374 } 375 376 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 377 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 378 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 379 <-testDone 380 return nil, errors.New("test done") 381 }) 382 s.handleConn(conn) 383 Eventually(done).Should(BeClosed()) 384 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 385 }) 386 387 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 388 streamTypeChan := make(chan StreamType, 1) 389 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 390 Expect(err).ToNot(HaveOccurred()) 391 streamTypeChan <- st 392 return false 393 } 394 395 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 396 unknownStr := mockquic.NewMockStream(mockCtrl) 397 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 398 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 399 400 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 401 return unknownStr, nil 402 }) 403 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 404 <-testDone 405 return nil, errors.New("test done") 406 }) 407 s.handleConn(conn) 408 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 409 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 410 }) 411 }) 412 413 Context("control stream handling", func() { 414 var conn *mockquic.MockEarlyConnection 415 testDone := make(chan struct{}) 416 417 BeforeEach(func() { 418 conn = mockquic.NewMockEarlyConnection(mockCtrl) 419 controlStr := mockquic.NewMockStream(mockCtrl) 420 controlStr.EXPECT().Write(gomock.Any()) 421 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 422 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 423 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 424 conn.EXPECT().LocalAddr().AnyTimes() 425 }) 426 427 AfterEach(func() { testDone <- struct{}{} }) 428 429 It("parses the SETTINGS frame", func() { 430 b := quicvarint.Append(nil, streamTypeControlStream) 431 b = (&settingsFrame{}).Append(b) 432 controlStr := mockquic.NewMockStream(mockCtrl) 433 r := bytes.NewReader(b) 434 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 435 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 436 return controlStr, nil 437 }) 438 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 439 <-testDone 440 return nil, errors.New("test done") 441 }) 442 s.handleConn(conn) 443 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 444 }) 445 446 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 447 streamType := t 448 name := "encoder" 449 if streamType == streamTypeQPACKDecoderStream { 450 name = "decoder" 451 } 452 453 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 454 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 455 str := mockquic.NewMockStream(mockCtrl) 456 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 457 458 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 459 return str, nil 460 }) 461 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 462 <-testDone 463 return nil, errors.New("test done") 464 }) 465 s.handleConn(conn) 466 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 467 }) 468 } 469 470 It("reset streams other than the control stream and the QPACK streams", func() { 471 buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337)) 472 str := mockquic.NewMockStream(mockCtrl) 473 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 474 done := make(chan struct{}) 475 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { 476 close(done) 477 }) 478 479 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 480 return str, nil 481 }) 482 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 483 <-testDone 484 return nil, errors.New("test done") 485 }) 486 s.handleConn(conn) 487 Eventually(done).Should(BeClosed()) 488 }) 489 490 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 491 b := quicvarint.Append(nil, streamTypeControlStream) 492 b = (&dataFrame{}).Append(b) 493 controlStr := mockquic.NewMockStream(mockCtrl) 494 r := bytes.NewReader(b) 495 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 496 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 497 return controlStr, nil 498 }) 499 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 500 <-testDone 501 return nil, errors.New("test done") 502 }) 503 done := make(chan struct{}) 504 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 505 defer GinkgoRecover() 506 Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) 507 close(done) 508 }) 509 s.handleConn(conn) 510 Eventually(done).Should(BeClosed()) 511 }) 512 513 It("errors when parsing the frame on the control stream fails", func() { 514 b := quicvarint.Append(nil, streamTypeControlStream) 515 b = (&settingsFrame{}).Append(b) 516 r := bytes.NewReader(b[:len(b)-1]) 517 controlStr := mockquic.NewMockStream(mockCtrl) 518 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 519 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 520 return controlStr, nil 521 }) 522 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 523 <-testDone 524 return nil, errors.New("test done") 525 }) 526 done := make(chan struct{}) 527 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 528 defer GinkgoRecover() 529 Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) 530 close(done) 531 }) 532 s.handleConn(conn) 533 Eventually(done).Should(BeClosed()) 534 }) 535 536 It("errors when the client opens a push stream", func() { 537 b := quicvarint.Append(nil, streamTypePushStream) 538 b = (&dataFrame{}).Append(b) 539 r := bytes.NewReader(b) 540 controlStr := mockquic.NewMockStream(mockCtrl) 541 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 542 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 543 return controlStr, nil 544 }) 545 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 546 <-testDone 547 return nil, errors.New("test done") 548 }) 549 done := make(chan struct{}) 550 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 551 defer GinkgoRecover() 552 Expect(code).To(BeEquivalentTo(ErrCodeStreamCreationError)) 553 close(done) 554 }) 555 s.handleConn(conn) 556 Eventually(done).Should(BeClosed()) 557 }) 558 559 It("errors when the client advertises datagram support (and we enabled support for it)", func() { 560 s.EnableDatagrams = true 561 b := quicvarint.Append(nil, streamTypeControlStream) 562 b = (&settingsFrame{Datagram: true}).Append(b) 563 r := bytes.NewReader(b) 564 controlStr := mockquic.NewMockStream(mockCtrl) 565 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 566 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 567 return controlStr, nil 568 }) 569 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 570 <-testDone 571 return nil, errors.New("test done") 572 }) 573 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 574 done := make(chan struct{}) 575 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { 576 defer GinkgoRecover() 577 Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) 578 Expect(reason).To(Equal("missing QUIC Datagram support")) 579 close(done) 580 }) 581 s.handleConn(conn) 582 Eventually(done).Should(BeClosed()) 583 }) 584 }) 585 586 Context("stream- and connection-level errors", func() { 587 var conn *mockquic.MockEarlyConnection 588 testDone := make(chan struct{}) 589 590 BeforeEach(func() { 591 testDone = make(chan struct{}) 592 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 593 conn = mockquic.NewMockEarlyConnection(mockCtrl) 594 controlStr := mockquic.NewMockStream(mockCtrl) 595 controlStr.EXPECT().Write(gomock.Any()) 596 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 597 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 598 <-testDone 599 return nil, errors.New("test done") 600 }) 601 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 602 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 603 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 604 conn.EXPECT().LocalAddr().AnyTimes() 605 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 606 }) 607 608 AfterEach(func() { testDone <- struct{}{} }) 609 610 It("cancels reading when client sends a body in GET request", func() { 611 var handlerCalled bool 612 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 613 handlerCalled = true 614 }) 615 616 requestData := encodeRequest(exampleGetRequest) 617 b := (&dataFrame{Length: 6}).Append(nil) // add a body 618 b = append(b, []byte("foobar")...) 619 responseBuf := &bytes.Buffer{} 620 setRequest(append(requestData, b...)) 621 done := make(chan struct{}) 622 str.EXPECT().Context().Return(reqContext) 623 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 624 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 625 str.EXPECT().Close().Do(func() { close(done) }) 626 627 s.handleConn(conn) 628 Eventually(done).Should(BeClosed()) 629 hfs := decodeHeader(responseBuf) 630 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 631 Expect(handlerCalled).To(BeTrue()) 632 }) 633 634 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 635 handlerCalled := make(chan struct{}) 636 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 637 defer close(handlerCalled) 638 r.Body.(HTTPStreamer).HTTPStream() 639 str.Write([]byte("foobar")) 640 }) 641 642 requestData := encodeRequest(exampleGetRequest) 643 b := (&dataFrame{Length: 6}).Append(nil) // add a body 644 b = append(b, []byte("foobar")...) 645 setRequest(append(requestData, b...)) 646 str.EXPECT().Context().Return(reqContext) 647 str.EXPECT().Write([]byte("foobar")).Return(6, nil) 648 649 s.handleConn(conn) 650 Eventually(handlerCalled).Should(BeClosed()) 651 }) 652 653 It("errors when the client sends a too large header frame", func() { 654 s.MaxHeaderBytes = 20 655 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 656 Fail("Handler should not be called.") 657 }) 658 659 requestData := encodeRequest(exampleGetRequest) 660 b := (&dataFrame{Length: 6}).Append(nil) // add a body 661 b = append(b, []byte("foobar")...) 662 responseBuf := &bytes.Buffer{} 663 setRequest(append(requestData, b...)) 664 done := make(chan struct{}) 665 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 666 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 667 668 s.handleConn(conn) 669 Eventually(done).Should(BeClosed()) 670 }) 671 672 It("handles a request for which the client immediately resets the stream", func() { 673 handlerCalled := make(chan struct{}) 674 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 675 close(handlerCalled) 676 }) 677 678 testErr := errors.New("stream reset") 679 done := make(chan struct{}) 680 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 681 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 682 683 s.handleConn(conn) 684 Consistently(handlerCalled).ShouldNot(BeClosed()) 685 }) 686 687 It("closes the connection when the first frame is not a HEADERS frame", func() { 688 handlerCalled := make(chan struct{}) 689 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 690 close(handlerCalled) 691 }) 692 693 b := (&dataFrame{}).Append(nil) 694 setRequest(b) 695 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 696 return len(p), nil 697 }).AnyTimes() 698 699 done := make(chan struct{}) 700 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 701 Expect(code).To(Equal(quic.ApplicationErrorCode(ErrCodeFrameUnexpected))) 702 close(done) 703 }) 704 s.handleConn(conn) 705 Eventually(done).Should(BeClosed()) 706 }) 707 708 It("closes the connection when the first frame is not a HEADERS frame", func() { 709 handlerCalled := make(chan struct{}) 710 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 711 close(handlerCalled) 712 }) 713 714 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 715 // but the request will still end up larger than DefaultMaxHeaderBytes. 716 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 717 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 718 Expect(err).ToNot(HaveOccurred()) 719 setRequest(encodeRequest(req)) 720 // str.EXPECT().Context().Return(reqContext) 721 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 722 return len(p), nil 723 }).AnyTimes() 724 done := make(chan struct{}) 725 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 726 727 s.handleConn(conn) 728 Eventually(done).Should(BeClosed()) 729 }) 730 }) 731 732 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 733 handlerCalled := make(chan struct{}) 734 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 735 r.Body = struct { 736 io.Reader 737 io.Closer 738 }{} 739 close(handlerCalled) 740 }) 741 742 setRequest(encodeRequest(examplePostRequest)) 743 str.EXPECT().Context().Return(reqContext) 744 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 745 return len(p), nil 746 }).AnyTimes() 747 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 748 749 serr := s.handleRequest(conn, str, qpackDecoder, nil) 750 Expect(serr.err).ToNot(HaveOccurred()) 751 Eventually(handlerCalled).Should(BeClosed()) 752 }) 753 754 It("cancels the request context when the stream is closed", func() { 755 handlerCalled := make(chan struct{}) 756 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 757 defer GinkgoRecover() 758 Expect(r.Context().Done()).To(BeClosed()) 759 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 760 close(handlerCalled) 761 }) 762 setRequest(encodeRequest(examplePostRequest)) 763 764 reqContext, cancel := context.WithCancel(context.Background()) 765 cancel() 766 str.EXPECT().Context().Return(reqContext) 767 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 768 return len(p), nil 769 }).AnyTimes() 770 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 771 772 serr := s.handleRequest(conn, str, qpackDecoder, nil) 773 Expect(serr.err).ToNot(HaveOccurred()) 774 Eventually(handlerCalled).Should(BeClosed()) 775 }) 776 }) 777 778 Context("setting http headers", func() { 779 BeforeEach(func() { 780 s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}} 781 }) 782 783 var ln1 QUICEarlyListener 784 var ln2 QUICEarlyListener 785 expected := http.Header{ 786 "Alt-Svc": {`h3=":443"; ma=2592000`}, 787 } 788 789 addListener := func(addr string, ln *QUICEarlyListener) { 790 mln := newMockAddrListener(addr) 791 mln.EXPECT().Addr() 792 *ln = mln 793 s.addListener(ln) 794 } 795 796 removeListener := func(ln *QUICEarlyListener) { 797 s.removeListener(ln) 798 } 799 800 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 801 hdr := http.Header{} 802 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 803 Expect(hdr).To(expected) 804 } 805 806 checkSetHeaderError := func() { 807 hdr := http.Header{} 808 Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 809 } 810 811 It("sets proper headers with numeric port", func() { 812 addListener(":443", &ln1) 813 checkSetHeaders(Equal(expected)) 814 removeListener(&ln1) 815 checkSetHeaderError() 816 }) 817 818 It("sets proper headers with full addr", func() { 819 addListener("127.0.0.1:443", &ln1) 820 checkSetHeaders(Equal(expected)) 821 removeListener(&ln1) 822 checkSetHeaderError() 823 }) 824 825 It("sets proper headers with string port", func() { 826 addListener(":https", &ln1) 827 checkSetHeaders(Equal(expected)) 828 removeListener(&ln1) 829 checkSetHeaderError() 830 }) 831 832 It("works multiple times", func() { 833 addListener(":https", &ln1) 834 checkSetHeaders(Equal(expected)) 835 checkSetHeaders(Equal(expected)) 836 removeListener(&ln1) 837 checkSetHeaderError() 838 }) 839 840 It("works if the quic.Config sets QUIC versions", func() { 841 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version2} 842 addListener(":443", &ln1) 843 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 844 removeListener(&ln1) 845 checkSetHeaderError() 846 }) 847 848 It("uses s.Port if set to a non-zero value", func() { 849 s.Port = 8443 850 addListener(":443", &ln1) 851 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 852 removeListener(&ln1) 853 checkSetHeaderError() 854 }) 855 856 It("uses s.Addr if listeners don't have ports available", func() { 857 s.Addr = ":443" 858 mln := &noPortListener{newMockAddrListener("")} 859 mln.EXPECT().Addr() 860 ln1 = mln 861 s.addListener(&ln1) 862 checkSetHeaders(Equal(expected)) 863 s.removeListener(&ln1) 864 checkSetHeaderError() 865 }) 866 867 It("properly announces multiple listeners", func() { 868 addListener(":443", &ln1) 869 addListener(":8443", &ln2) 870 checkSetHeaders(Or( 871 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 872 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 873 )) 874 removeListener(&ln1) 875 removeListener(&ln2) 876 checkSetHeaderError() 877 }) 878 879 It("doesn't duplicate Alt-Svc values", func() { 880 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version1} 881 addListener(":443", &ln1) 882 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 883 removeListener(&ln1) 884 checkSetHeaderError() 885 }) 886 }) 887 888 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 889 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 890 }) 891 892 It("should nop-Close() when s.server is nil", func() { 893 Expect((&Server{}).Close()).To(Succeed()) 894 }) 895 896 It("errors when ListenAndServeTLS is called after Close", func() { 897 serv := &Server{} 898 Expect(serv.Close()).To(Succeed()) 899 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 900 }) 901 902 It("handles concurrent Serve and Close", func() { 903 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 904 Expect(err).ToNot(HaveOccurred()) 905 c, err := net.ListenUDP("udp", addr) 906 Expect(err).ToNot(HaveOccurred()) 907 done := make(chan struct{}) 908 go func() { 909 defer GinkgoRecover() 910 defer close(done) 911 s.Serve(c) 912 }() 913 runtime.Gosched() 914 s.Close() 915 Eventually(done).Should(BeClosed()) 916 }) 917 918 Context("ConfigureTLSConfig", func() { 919 It("advertises v1 by default", func() { 920 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 921 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 922 Expect(err).ToNot(HaveOccurred()) 923 defer ln.Close() 924 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 925 Expect(err).ToNot(HaveOccurred()) 926 defer c.CloseWithError(0, "") 927 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 928 }) 929 930 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 931 var receivedConf *tls.Config 932 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 933 receivedConf = tlsConf 934 return nil, errors.New("listen err") 935 } 936 Expect(s.ListenAndServe()).To(HaveOccurred()) 937 Expect(receivedConf).ToNot(BeNil()) 938 }) 939 940 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 941 tlsConf := &tls.Config{ 942 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 943 c := testdata.GetTLSConfig() 944 c.NextProtos = []string{"foo", "bar"} 945 return c, nil 946 }, 947 } 948 949 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 950 Expect(err).ToNot(HaveOccurred()) 951 defer ln.Close() 952 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 953 Expect(err).ToNot(HaveOccurred()) 954 defer c.CloseWithError(0, "") 955 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 956 }) 957 958 It("works if GetConfigForClient returns a nil tls.Config", func() { 959 tlsConf := testdata.GetTLSConfig() 960 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 961 962 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 963 Expect(err).ToNot(HaveOccurred()) 964 defer ln.Close() 965 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 966 Expect(err).ToNot(HaveOccurred()) 967 defer c.CloseWithError(0, "") 968 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 969 }) 970 971 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 972 tlsClientConf := testdata.GetTLSConfig() 973 tlsClientConf.NextProtos = []string{"foo", "bar"} 974 tlsConf := &tls.Config{ 975 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 976 return tlsClientConf, nil 977 }, 978 } 979 980 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 981 Expect(err).ToNot(HaveOccurred()) 982 defer ln.Close() 983 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 984 Expect(err).ToNot(HaveOccurred()) 985 defer c.CloseWithError(0, "") 986 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 987 // check that the original config was not modified 988 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 989 }) 990 }) 991 992 Context("Serve", func() { 993 origQuicListen := quicListen 994 995 AfterEach(func() { 996 quicListen = origQuicListen 997 }) 998 999 It("serves a packet conn", func() { 1000 ln := newMockAddrListener(":443") 1001 conn := &net.UDPConn{} 1002 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1003 Expect(c).To(Equal(conn)) 1004 return ln, nil 1005 } 1006 1007 s := &Server{ 1008 TLSConfig: &tls.Config{}, 1009 } 1010 1011 stopAccept := make(chan struct{}) 1012 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1013 <-stopAccept 1014 return nil, errors.New("closed") 1015 }) 1016 ln.EXPECT().Addr() // generate alt-svc headers 1017 done := make(chan struct{}) 1018 go func() { 1019 defer GinkgoRecover() 1020 defer close(done) 1021 s.Serve(conn) 1022 }() 1023 1024 Consistently(done).ShouldNot(BeClosed()) 1025 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1026 Expect(s.Close()).To(Succeed()) 1027 Eventually(done).Should(BeClosed()) 1028 }) 1029 1030 It("serves two packet conns", func() { 1031 ln1 := newMockAddrListener(":443") 1032 ln2 := newMockAddrListener(":8443") 1033 lns := make(chan QUICEarlyListener, 2) 1034 lns <- ln1 1035 lns <- ln2 1036 conn1 := &net.UDPConn{} 1037 conn2 := &net.UDPConn{} 1038 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1039 return <-lns, nil 1040 } 1041 1042 s := &Server{ 1043 TLSConfig: &tls.Config{}, 1044 } 1045 1046 stopAccept1 := make(chan struct{}) 1047 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1048 <-stopAccept1 1049 return nil, errors.New("closed") 1050 }) 1051 ln1.EXPECT().Addr() // generate alt-svc headers 1052 stopAccept2 := make(chan struct{}) 1053 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1054 <-stopAccept2 1055 return nil, errors.New("closed") 1056 }) 1057 ln2.EXPECT().Addr() 1058 1059 done1 := make(chan struct{}) 1060 go func() { 1061 defer GinkgoRecover() 1062 defer close(done1) 1063 s.Serve(conn1) 1064 }() 1065 done2 := make(chan struct{}) 1066 go func() { 1067 defer GinkgoRecover() 1068 defer close(done2) 1069 s.Serve(conn2) 1070 }() 1071 1072 Consistently(done1).ShouldNot(BeClosed()) 1073 Expect(done2).ToNot(BeClosed()) 1074 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1075 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1076 Expect(s.Close()).To(Succeed()) 1077 Eventually(done1).Should(BeClosed()) 1078 Eventually(done2).Should(BeClosed()) 1079 }) 1080 }) 1081 1082 Context("ServeListener", func() { 1083 origQuicListen := quicListen 1084 1085 AfterEach(func() { 1086 quicListen = origQuicListen 1087 }) 1088 1089 It("serves a listener", func() { 1090 var called int32 1091 ln := newMockAddrListener(":443") 1092 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1093 atomic.StoreInt32(&called, 1) 1094 return ln, nil 1095 } 1096 1097 s := &Server{} 1098 1099 stopAccept := make(chan struct{}) 1100 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1101 <-stopAccept 1102 return nil, errors.New("closed") 1103 }) 1104 ln.EXPECT().Addr() // generate alt-svc headers 1105 done := make(chan struct{}) 1106 go func() { 1107 defer GinkgoRecover() 1108 defer close(done) 1109 s.ServeListener(ln) 1110 }() 1111 1112 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1113 Consistently(done).ShouldNot(BeClosed()) 1114 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1115 Expect(s.Close()).To(Succeed()) 1116 Eventually(done).Should(BeClosed()) 1117 }) 1118 1119 It("serves two listeners", func() { 1120 var called int32 1121 ln1 := newMockAddrListener(":443") 1122 ln2 := newMockAddrListener(":8443") 1123 lns := make(chan QUICEarlyListener, 2) 1124 lns <- ln1 1125 lns <- ln2 1126 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1127 atomic.StoreInt32(&called, 1) 1128 return <-lns, nil 1129 } 1130 1131 s := &Server{} 1132 1133 stopAccept1 := make(chan struct{}) 1134 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1135 <-stopAccept1 1136 return nil, errors.New("closed") 1137 }) 1138 ln1.EXPECT().Addr() // generate alt-svc headers 1139 stopAccept2 := make(chan struct{}) 1140 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1141 <-stopAccept2 1142 return nil, errors.New("closed") 1143 }) 1144 ln2.EXPECT().Addr() 1145 1146 done1 := make(chan struct{}) 1147 go func() { 1148 defer GinkgoRecover() 1149 defer close(done1) 1150 s.ServeListener(ln1) 1151 }() 1152 done2 := make(chan struct{}) 1153 go func() { 1154 defer GinkgoRecover() 1155 defer close(done2) 1156 s.ServeListener(ln2) 1157 }() 1158 1159 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1160 Consistently(done1).ShouldNot(BeClosed()) 1161 Expect(done2).ToNot(BeClosed()) 1162 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1163 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1164 Expect(s.Close()).To(Succeed()) 1165 Eventually(done1).Should(BeClosed()) 1166 Eventually(done2).Should(BeClosed()) 1167 }) 1168 }) 1169 1170 Context("ServeQUICConn", func() { 1171 It("serves a QUIC connection", func() { 1172 mux := http.NewServeMux() 1173 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1174 w.Write([]byte("foobar")) 1175 }) 1176 s.Handler = mux 1177 tlsConf := testdata.GetTLSConfig() 1178 tlsConf.NextProtos = []string{NextProtoH3} 1179 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1180 controlStr := mockquic.NewMockStream(mockCtrl) 1181 controlStr.EXPECT().Write(gomock.Any()) 1182 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1183 testDone := make(chan struct{}) 1184 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1185 <-testDone 1186 return nil, errors.New("test done") 1187 }).MaxTimes(1) 1188 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1189 s.ServeQUICConn(conn) 1190 close(testDone) 1191 }) 1192 }) 1193 1194 Context("ListenAndServe", func() { 1195 BeforeEach(func() { 1196 s.Addr = "localhost:0" 1197 }) 1198 1199 AfterEach(func() { 1200 Expect(s.Close()).To(Succeed()) 1201 }) 1202 1203 It("uses the quic.Config to start the QUIC server", func() { 1204 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1205 var receivedConf *quic.Config 1206 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1207 receivedConf = config 1208 return nil, errors.New("listen err") 1209 } 1210 s.QuicConfig = conf 1211 Expect(s.ListenAndServe()).To(HaveOccurred()) 1212 Expect(receivedConf).To(Equal(conf)) 1213 }) 1214 }) 1215 1216 It("closes gracefully", func() { 1217 Expect(s.CloseGracefully(0)).To(Succeed()) 1218 }) 1219 1220 It("errors when listening fails", func() { 1221 testErr := errors.New("listen error") 1222 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1223 return nil, testErr 1224 } 1225 fullpem, privkey := testdata.GetCertificatePaths() 1226 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1227 }) 1228 1229 It("supports H3_DATAGRAM", func() { 1230 s.EnableDatagrams = true 1231 var receivedConf *quic.Config 1232 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1233 receivedConf = config 1234 return nil, errors.New("listen err") 1235 } 1236 Expect(s.ListenAndServe()).To(HaveOccurred()) 1237 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1238 }) 1239 })