github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/mikelsr/quic-go" 17 mockquic "github.com/mikelsr/quic-go/internal/mocks/quic" 18 "github.com/mikelsr/quic-go/internal/protocol" 19 "github.com/mikelsr/quic-go/internal/testdata" 20 "github.com/mikelsr/quic-go/internal/utils" 21 "github.com/mikelsr/quic-go/quicvarint" 22 23 "github.com/golang/mock/gomock" 24 "github.com/quic-go/qpack" 25 26 . "github.com/onsi/ginkgo/v2" 27 . "github.com/onsi/gomega" 28 gmtypes "github.com/onsi/gomega/types" 29 ) 30 31 type mockAddr struct{ addr string } 32 33 func (ma *mockAddr) Network() string { return "udp" } 34 func (ma *mockAddr) String() string { return ma.addr } 35 36 type mockAddrListener struct { 37 *MockQUICEarlyListener 38 addr *mockAddr 39 } 40 41 func (m *mockAddrListener) Addr() net.Addr { 42 _ = m.MockQUICEarlyListener.Addr() 43 return m.addr 44 } 45 46 func newMockAddrListener(addr string) *mockAddrListener { 47 return &mockAddrListener{ 48 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 49 addr: &mockAddr{addr: addr}, 50 } 51 } 52 53 type noPortListener struct { 54 *mockAddrListener 55 } 56 57 func (m *noPortListener) Addr() net.Addr { 58 _ = m.mockAddrListener.Addr() 59 return &net.UnixAddr{ 60 Net: "unix", 61 Name: "/tmp/quic.sock", 62 } 63 } 64 65 var _ = Describe("Server", func() { 66 var ( 67 s *Server 68 origQuicListenAddr = quicListenAddr 69 ) 70 71 BeforeEach(func() { 72 s = &Server{ 73 TLSConfig: testdata.GetTLSConfig(), 74 logger: utils.DefaultLogger, 75 } 76 origQuicListenAddr = quicListenAddr 77 }) 78 79 AfterEach(func() { 80 quicListenAddr = origQuicListenAddr 81 }) 82 83 Context("handling requests", func() { 84 var ( 85 qpackDecoder *qpack.Decoder 86 str *mockquic.MockStream 87 conn *mockquic.MockEarlyConnection 88 exampleGetRequest *http.Request 89 examplePostRequest *http.Request 90 ) 91 reqContext := context.Background() 92 93 decodeHeader := func(str io.Reader) map[string][]string { 94 fields := make(map[string][]string) 95 decoder := qpack.NewDecoder(nil) 96 97 frame, err := parseNextFrame(str, nil) 98 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 99 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 100 headersFrame := frame.(*headersFrame) 101 data := make([]byte, headersFrame.Length) 102 _, err = io.ReadFull(str, data) 103 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 104 hfs, err := decoder.DecodeFull(data) 105 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 106 for _, p := range hfs { 107 fields[p.Name] = append(fields[p.Name], p.Value) 108 } 109 return fields 110 } 111 112 encodeRequest := func(req *http.Request) []byte { 113 buf := &bytes.Buffer{} 114 str := mockquic.NewMockStream(mockCtrl) 115 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 116 rw := newRequestWriter(utils.DefaultLogger) 117 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 118 return buf.Bytes() 119 } 120 121 setRequest := func(data []byte) { 122 buf := bytes.NewBuffer(data) 123 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 124 if buf.Len() == 0 { 125 return 0, io.EOF 126 } 127 return buf.Read(p) 128 }).AnyTimes() 129 } 130 131 BeforeEach(func() { 132 var err error 133 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 134 Expect(err).ToNot(HaveOccurred()) 135 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 136 Expect(err).ToNot(HaveOccurred()) 137 138 qpackDecoder = qpack.NewDecoder(nil) 139 str = mockquic.NewMockStream(mockCtrl) 140 conn = mockquic.NewMockEarlyConnection(mockCtrl) 141 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 142 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 143 conn.EXPECT().LocalAddr().AnyTimes() 144 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 145 }) 146 147 It("calls the HTTP handler function", func() { 148 requestChan := make(chan *http.Request, 1) 149 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 150 requestChan <- r 151 }) 152 153 setRequest(encodeRequest(exampleGetRequest)) 154 str.EXPECT().Context().Return(reqContext) 155 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 156 return len(p), nil 157 }).AnyTimes() 158 str.EXPECT().CancelRead(gomock.Any()) 159 160 Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) 161 var req *http.Request 162 Eventually(requestChan).Should(Receive(&req)) 163 Expect(req.Host).To(Equal("www.example.com")) 164 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 165 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 166 }) 167 168 It("returns 200 with an empty handler", func() { 169 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 170 171 responseBuf := &bytes.Buffer{} 172 setRequest(encodeRequest(exampleGetRequest)) 173 str.EXPECT().Context().Return(reqContext) 174 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 175 str.EXPECT().CancelRead(gomock.Any()) 176 177 serr := s.handleRequest(conn, str, qpackDecoder, nil) 178 Expect(serr.err).ToNot(HaveOccurred()) 179 hfs := decodeHeader(responseBuf) 180 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 181 }) 182 183 It("handles a aborting handler", func() { 184 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 185 panic(http.ErrAbortHandler) 186 }) 187 188 responseBuf := &bytes.Buffer{} 189 setRequest(encodeRequest(exampleGetRequest)) 190 str.EXPECT().Context().Return(reqContext) 191 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 192 str.EXPECT().CancelRead(gomock.Any()) 193 194 serr := s.handleRequest(conn, str, qpackDecoder, nil) 195 Expect(serr.err).ToNot(HaveOccurred()) 196 hfs := decodeHeader(responseBuf) 197 Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) 198 }) 199 200 It("handles a panicking handler", func() { 201 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 panic("foobar") 203 }) 204 205 responseBuf := &bytes.Buffer{} 206 setRequest(encodeRequest(exampleGetRequest)) 207 str.EXPECT().Context().Return(reqContext) 208 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 209 str.EXPECT().CancelRead(gomock.Any()) 210 211 serr := s.handleRequest(conn, str, qpackDecoder, nil) 212 Expect(serr.err).ToNot(HaveOccurred()) 213 hfs := decodeHeader(responseBuf) 214 Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) 215 }) 216 217 Context("hijacking bidirectional streams", func() { 218 var conn *mockquic.MockEarlyConnection 219 testDone := make(chan struct{}) 220 221 BeforeEach(func() { 222 testDone = make(chan struct{}) 223 conn = mockquic.NewMockEarlyConnection(mockCtrl) 224 controlStr := mockquic.NewMockStream(mockCtrl) 225 controlStr.EXPECT().Write(gomock.Any()) 226 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 227 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 228 conn.EXPECT().LocalAddr().AnyTimes() 229 }) 230 231 AfterEach(func() { testDone <- struct{}{} }) 232 233 It("hijacks a bidirectional stream of unknown frame type", func() { 234 frameTypeChan := make(chan FrameType, 1) 235 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 236 Expect(e).ToNot(HaveOccurred()) 237 frameTypeChan <- ft 238 return true, nil 239 } 240 241 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 242 unknownStr := mockquic.NewMockStream(mockCtrl) 243 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 244 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 245 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 246 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 247 <-testDone 248 return nil, errors.New("test done") 249 }) 250 s.handleConn(conn) 251 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 252 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 253 }) 254 255 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 256 frameTypeChan := make(chan FrameType, 1) 257 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 258 Expect(e).ToNot(HaveOccurred()) 259 frameTypeChan <- ft 260 return false, nil 261 } 262 263 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 264 unknownStr := mockquic.NewMockStream(mockCtrl) 265 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 266 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 267 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 268 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 269 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 270 <-testDone 271 return nil, errors.New("test done") 272 }) 273 s.handleConn(conn) 274 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 275 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 276 }) 277 278 It("cancels writing when hijacker returned error", func() { 279 frameTypeChan := make(chan FrameType, 1) 280 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 281 Expect(e).ToNot(HaveOccurred()) 282 frameTypeChan <- ft 283 return false, errors.New("error in hijacker") 284 } 285 286 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 287 unknownStr := mockquic.NewMockStream(mockCtrl) 288 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 289 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 290 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 291 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 292 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 293 <-testDone 294 return nil, errors.New("test done") 295 }) 296 s.handleConn(conn) 297 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 298 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 299 }) 300 301 It("handles errors that occur when reading the stream type", func() { 302 testErr := errors.New("test error") 303 done := make(chan struct{}) 304 unknownStr := mockquic.NewMockStream(mockCtrl) 305 s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { 306 defer close(done) 307 Expect(ft).To(BeZero()) 308 Expect(str).To(Equal(unknownStr)) 309 Expect(err).To(MatchError(testErr)) 310 return true, nil 311 } 312 313 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 314 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 315 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 316 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 317 <-testDone 318 return nil, errors.New("test done") 319 }) 320 s.handleConn(conn) 321 Eventually(done).Should(BeClosed()) 322 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 323 }) 324 }) 325 326 Context("hijacking unidirectional streams", func() { 327 var conn *mockquic.MockEarlyConnection 328 testDone := make(chan struct{}) 329 330 BeforeEach(func() { 331 testDone = make(chan struct{}) 332 conn = mockquic.NewMockEarlyConnection(mockCtrl) 333 controlStr := mockquic.NewMockStream(mockCtrl) 334 controlStr.EXPECT().Write(gomock.Any()) 335 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 336 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 337 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 338 conn.EXPECT().LocalAddr().AnyTimes() 339 }) 340 341 AfterEach(func() { testDone <- struct{}{} }) 342 343 It("hijacks an unidirectional stream of unknown stream type", func() { 344 streamTypeChan := make(chan StreamType, 1) 345 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 346 Expect(err).ToNot(HaveOccurred()) 347 streamTypeChan <- st 348 return true 349 } 350 351 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 352 unknownStr := mockquic.NewMockStream(mockCtrl) 353 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 354 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 355 return unknownStr, nil 356 }) 357 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 358 <-testDone 359 return nil, errors.New("test done") 360 }) 361 s.handleConn(conn) 362 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 363 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 364 }) 365 366 It("handles errors that occur when reading the stream type", func() { 367 testErr := errors.New("test error") 368 done := make(chan struct{}) 369 unknownStr := mockquic.NewMockStream(mockCtrl) 370 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 371 defer close(done) 372 Expect(st).To(BeZero()) 373 Expect(str).To(Equal(unknownStr)) 374 Expect(err).To(MatchError(testErr)) 375 return true 376 } 377 378 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 379 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 380 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 381 <-testDone 382 return nil, errors.New("test done") 383 }) 384 s.handleConn(conn) 385 Eventually(done).Should(BeClosed()) 386 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 387 }) 388 389 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 390 streamTypeChan := make(chan StreamType, 1) 391 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 392 Expect(err).ToNot(HaveOccurred()) 393 streamTypeChan <- st 394 return false 395 } 396 397 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 398 unknownStr := mockquic.NewMockStream(mockCtrl) 399 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 400 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 401 402 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 403 return unknownStr, nil 404 }) 405 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 406 <-testDone 407 return nil, errors.New("test done") 408 }) 409 s.handleConn(conn) 410 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 411 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 412 }) 413 }) 414 415 Context("control stream handling", func() { 416 var conn *mockquic.MockEarlyConnection 417 testDone := make(chan struct{}) 418 419 BeforeEach(func() { 420 conn = mockquic.NewMockEarlyConnection(mockCtrl) 421 controlStr := mockquic.NewMockStream(mockCtrl) 422 controlStr.EXPECT().Write(gomock.Any()) 423 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 424 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 425 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 426 conn.EXPECT().LocalAddr().AnyTimes() 427 }) 428 429 AfterEach(func() { testDone <- struct{}{} }) 430 431 It("parses the SETTINGS frame", func() { 432 b := quicvarint.Append(nil, streamTypeControlStream) 433 b = (&settingsFrame{}).Append(b) 434 controlStr := mockquic.NewMockStream(mockCtrl) 435 r := bytes.NewReader(b) 436 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 437 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 438 return controlStr, nil 439 }) 440 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 441 <-testDone 442 return nil, errors.New("test done") 443 }) 444 s.handleConn(conn) 445 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 446 }) 447 448 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 449 streamType := t 450 name := "encoder" 451 if streamType == streamTypeQPACKDecoderStream { 452 name = "decoder" 453 } 454 455 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 456 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 457 str := mockquic.NewMockStream(mockCtrl) 458 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 459 460 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 461 return str, nil 462 }) 463 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 464 <-testDone 465 return nil, errors.New("test done") 466 }) 467 s.handleConn(conn) 468 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 469 }) 470 } 471 472 It("reset streams other than the control stream and the QPACK streams", func() { 473 buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337)) 474 str := mockquic.NewMockStream(mockCtrl) 475 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 476 done := make(chan struct{}) 477 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { 478 close(done) 479 }) 480 481 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 482 return str, nil 483 }) 484 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 485 <-testDone 486 return nil, errors.New("test done") 487 }) 488 s.handleConn(conn) 489 Eventually(done).Should(BeClosed()) 490 }) 491 492 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 493 b := quicvarint.Append(nil, streamTypeControlStream) 494 b = (&dataFrame{}).Append(b) 495 controlStr := mockquic.NewMockStream(mockCtrl) 496 r := bytes.NewReader(b) 497 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 498 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 499 return controlStr, nil 500 }) 501 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 502 <-testDone 503 return nil, errors.New("test done") 504 }) 505 done := make(chan struct{}) 506 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 507 defer GinkgoRecover() 508 Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) 509 close(done) 510 }) 511 s.handleConn(conn) 512 Eventually(done).Should(BeClosed()) 513 }) 514 515 It("errors when parsing the frame on the control stream fails", func() { 516 b := quicvarint.Append(nil, streamTypeControlStream) 517 b = (&settingsFrame{}).Append(b) 518 r := bytes.NewReader(b[:len(b)-1]) 519 controlStr := mockquic.NewMockStream(mockCtrl) 520 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 521 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 522 return controlStr, nil 523 }) 524 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 525 <-testDone 526 return nil, errors.New("test done") 527 }) 528 done := make(chan struct{}) 529 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 530 defer GinkgoRecover() 531 Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) 532 close(done) 533 }) 534 s.handleConn(conn) 535 Eventually(done).Should(BeClosed()) 536 }) 537 538 It("errors when the client opens a push stream", func() { 539 b := quicvarint.Append(nil, streamTypePushStream) 540 b = (&dataFrame{}).Append(b) 541 r := bytes.NewReader(b) 542 controlStr := mockquic.NewMockStream(mockCtrl) 543 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 544 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 545 return controlStr, nil 546 }) 547 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 548 <-testDone 549 return nil, errors.New("test done") 550 }) 551 done := make(chan struct{}) 552 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 553 defer GinkgoRecover() 554 Expect(code).To(BeEquivalentTo(ErrCodeStreamCreationError)) 555 close(done) 556 }) 557 s.handleConn(conn) 558 Eventually(done).Should(BeClosed()) 559 }) 560 561 It("errors when the client advertises datagram support (and we enabled support for it)", func() { 562 s.EnableDatagrams = true 563 b := quicvarint.Append(nil, streamTypeControlStream) 564 b = (&settingsFrame{Datagram: true}).Append(b) 565 r := bytes.NewReader(b) 566 controlStr := mockquic.NewMockStream(mockCtrl) 567 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 568 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 569 return controlStr, nil 570 }) 571 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 572 <-testDone 573 return nil, errors.New("test done") 574 }) 575 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 576 done := make(chan struct{}) 577 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { 578 defer GinkgoRecover() 579 Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) 580 Expect(reason).To(Equal("missing QUIC Datagram support")) 581 close(done) 582 }) 583 s.handleConn(conn) 584 Eventually(done).Should(BeClosed()) 585 }) 586 }) 587 588 Context("stream- and connection-level errors", func() { 589 var conn *mockquic.MockEarlyConnection 590 testDone := make(chan struct{}) 591 592 BeforeEach(func() { 593 testDone = make(chan struct{}) 594 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 595 conn = mockquic.NewMockEarlyConnection(mockCtrl) 596 controlStr := mockquic.NewMockStream(mockCtrl) 597 controlStr.EXPECT().Write(gomock.Any()) 598 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 599 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 600 <-testDone 601 return nil, errors.New("test done") 602 }) 603 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 604 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 605 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 606 conn.EXPECT().LocalAddr().AnyTimes() 607 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 608 }) 609 610 AfterEach(func() { testDone <- struct{}{} }) 611 612 It("cancels reading when client sends a body in GET request", func() { 613 var handlerCalled bool 614 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 615 handlerCalled = true 616 }) 617 618 requestData := encodeRequest(exampleGetRequest) 619 b := (&dataFrame{Length: 6}).Append(nil) // add a body 620 b = append(b, []byte("foobar")...) 621 responseBuf := &bytes.Buffer{} 622 setRequest(append(requestData, b...)) 623 done := make(chan struct{}) 624 str.EXPECT().Context().Return(reqContext) 625 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 626 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 627 str.EXPECT().Close().Do(func() { close(done) }) 628 629 s.handleConn(conn) 630 Eventually(done).Should(BeClosed()) 631 hfs := decodeHeader(responseBuf) 632 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 633 Expect(handlerCalled).To(BeTrue()) 634 }) 635 636 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 637 handlerCalled := make(chan struct{}) 638 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 639 defer close(handlerCalled) 640 r.Body.(HTTPStreamer).HTTPStream() 641 str.Write([]byte("foobar")) 642 }) 643 644 requestData := encodeRequest(exampleGetRequest) 645 b := (&dataFrame{Length: 6}).Append(nil) // add a body 646 b = append(b, []byte("foobar")...) 647 setRequest(append(requestData, b...)) 648 str.EXPECT().Context().Return(reqContext) 649 str.EXPECT().Write([]byte("foobar")).Return(6, nil) 650 651 s.handleConn(conn) 652 Eventually(handlerCalled).Should(BeClosed()) 653 }) 654 655 It("errors when the client sends a too large header frame", func() { 656 s.MaxHeaderBytes = 20 657 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 658 Fail("Handler should not be called.") 659 }) 660 661 requestData := encodeRequest(exampleGetRequest) 662 b := (&dataFrame{Length: 6}).Append(nil) // add a body 663 b = append(b, []byte("foobar")...) 664 responseBuf := &bytes.Buffer{} 665 setRequest(append(requestData, b...)) 666 done := make(chan struct{}) 667 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 668 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 669 670 s.handleConn(conn) 671 Eventually(done).Should(BeClosed()) 672 }) 673 674 It("handles a request for which the client immediately resets the stream", func() { 675 handlerCalled := make(chan struct{}) 676 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 677 close(handlerCalled) 678 }) 679 680 testErr := errors.New("stream reset") 681 done := make(chan struct{}) 682 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 683 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 684 685 s.handleConn(conn) 686 Consistently(handlerCalled).ShouldNot(BeClosed()) 687 }) 688 689 It("closes the connection when the first frame is not a HEADERS frame", func() { 690 handlerCalled := make(chan struct{}) 691 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 692 close(handlerCalled) 693 }) 694 695 b := (&dataFrame{}).Append(nil) 696 setRequest(b) 697 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 698 return len(p), nil 699 }).AnyTimes() 700 701 done := make(chan struct{}) 702 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 703 Expect(code).To(Equal(quic.ApplicationErrorCode(ErrCodeFrameUnexpected))) 704 close(done) 705 }) 706 s.handleConn(conn) 707 Eventually(done).Should(BeClosed()) 708 }) 709 710 It("closes the connection when the first frame is not a HEADERS frame", func() { 711 handlerCalled := make(chan struct{}) 712 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 713 close(handlerCalled) 714 }) 715 716 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 717 // but the request will still end up larger than DefaultMaxHeaderBytes. 718 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 719 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 720 Expect(err).ToNot(HaveOccurred()) 721 setRequest(encodeRequest(req)) 722 // str.EXPECT().Context().Return(reqContext) 723 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 724 return len(p), nil 725 }).AnyTimes() 726 done := make(chan struct{}) 727 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 728 729 s.handleConn(conn) 730 Eventually(done).Should(BeClosed()) 731 }) 732 }) 733 734 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 735 handlerCalled := make(chan struct{}) 736 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 737 r.Body = struct { 738 io.Reader 739 io.Closer 740 }{} 741 close(handlerCalled) 742 }) 743 744 setRequest(encodeRequest(examplePostRequest)) 745 str.EXPECT().Context().Return(reqContext) 746 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 747 return len(p), nil 748 }).AnyTimes() 749 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 750 751 serr := s.handleRequest(conn, str, qpackDecoder, nil) 752 Expect(serr.err).ToNot(HaveOccurred()) 753 Eventually(handlerCalled).Should(BeClosed()) 754 }) 755 756 It("cancels the request context when the stream is closed", func() { 757 handlerCalled := make(chan struct{}) 758 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 759 defer GinkgoRecover() 760 Expect(r.Context().Done()).To(BeClosed()) 761 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 762 close(handlerCalled) 763 }) 764 setRequest(encodeRequest(examplePostRequest)) 765 766 reqContext, cancel := context.WithCancel(context.Background()) 767 cancel() 768 str.EXPECT().Context().Return(reqContext) 769 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 770 return len(p), nil 771 }).AnyTimes() 772 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 773 774 serr := s.handleRequest(conn, str, qpackDecoder, nil) 775 Expect(serr.err).ToNot(HaveOccurred()) 776 Eventually(handlerCalled).Should(BeClosed()) 777 }) 778 }) 779 780 Context("setting http headers", func() { 781 BeforeEach(func() { 782 s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}} 783 }) 784 785 var ln1 QUICEarlyListener 786 var ln2 QUICEarlyListener 787 expected := http.Header{ 788 "Alt-Svc": {`h3=":443"; ma=2592000`}, 789 } 790 791 addListener := func(addr string, ln *QUICEarlyListener) { 792 mln := newMockAddrListener(addr) 793 mln.EXPECT().Addr() 794 *ln = mln 795 s.addListener(ln) 796 } 797 798 removeListener := func(ln *QUICEarlyListener) { 799 s.removeListener(ln) 800 } 801 802 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 803 hdr := http.Header{} 804 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 805 Expect(hdr).To(expected) 806 } 807 808 checkSetHeaderError := func() { 809 hdr := http.Header{} 810 Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 811 } 812 813 It("sets proper headers with numeric port", func() { 814 addListener(":443", &ln1) 815 checkSetHeaders(Equal(expected)) 816 removeListener(&ln1) 817 checkSetHeaderError() 818 }) 819 820 It("sets proper headers with full addr", func() { 821 addListener("127.0.0.1:443", &ln1) 822 checkSetHeaders(Equal(expected)) 823 removeListener(&ln1) 824 checkSetHeaderError() 825 }) 826 827 It("sets proper headers with string port", func() { 828 addListener(":https", &ln1) 829 checkSetHeaders(Equal(expected)) 830 removeListener(&ln1) 831 checkSetHeaderError() 832 }) 833 834 It("works multiple times", func() { 835 addListener(":https", &ln1) 836 checkSetHeaders(Equal(expected)) 837 checkSetHeaders(Equal(expected)) 838 removeListener(&ln1) 839 checkSetHeaderError() 840 }) 841 842 It("works if the quic.Config sets QUIC versions", func() { 843 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version2} 844 addListener(":443", &ln1) 845 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 846 removeListener(&ln1) 847 checkSetHeaderError() 848 }) 849 850 It("uses s.Port if set to a non-zero value", func() { 851 s.Port = 8443 852 addListener(":443", &ln1) 853 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 854 removeListener(&ln1) 855 checkSetHeaderError() 856 }) 857 858 It("uses s.Addr if listeners don't have ports available", func() { 859 s.Addr = ":443" 860 mln := &noPortListener{newMockAddrListener("")} 861 mln.EXPECT().Addr() 862 ln1 = mln 863 s.addListener(&ln1) 864 checkSetHeaders(Equal(expected)) 865 s.removeListener(&ln1) 866 checkSetHeaderError() 867 }) 868 869 It("properly announces multiple listeners", func() { 870 addListener(":443", &ln1) 871 addListener(":8443", &ln2) 872 checkSetHeaders(Or( 873 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 874 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 875 )) 876 removeListener(&ln1) 877 removeListener(&ln2) 878 checkSetHeaderError() 879 }) 880 881 It("doesn't duplicate Alt-Svc values", func() { 882 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version1} 883 addListener(":443", &ln1) 884 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 885 removeListener(&ln1) 886 checkSetHeaderError() 887 }) 888 }) 889 890 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 891 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 892 }) 893 894 It("should nop-Close() when s.server is nil", func() { 895 Expect((&Server{}).Close()).To(Succeed()) 896 }) 897 898 It("errors when ListenAndServeTLS is called after Close", func() { 899 serv := &Server{} 900 Expect(serv.Close()).To(Succeed()) 901 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 902 }) 903 904 It("handles concurrent Serve and Close", func() { 905 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 906 Expect(err).ToNot(HaveOccurred()) 907 c, err := net.ListenUDP("udp", addr) 908 Expect(err).ToNot(HaveOccurred()) 909 done := make(chan struct{}) 910 go func() { 911 defer GinkgoRecover() 912 defer close(done) 913 s.Serve(c) 914 }() 915 runtime.Gosched() 916 s.Close() 917 Eventually(done).Should(BeClosed()) 918 }) 919 920 Context("ConfigureTLSConfig", func() { 921 It("advertises v1 by default", func() { 922 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 923 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 924 Expect(err).ToNot(HaveOccurred()) 925 defer ln.Close() 926 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 927 Expect(err).ToNot(HaveOccurred()) 928 defer c.CloseWithError(0, "") 929 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 930 }) 931 932 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 933 var receivedConf *tls.Config 934 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 935 receivedConf = tlsConf 936 return nil, errors.New("listen err") 937 } 938 Expect(s.ListenAndServe()).To(HaveOccurred()) 939 Expect(receivedConf).ToNot(BeNil()) 940 }) 941 942 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 943 tlsConf := &tls.Config{ 944 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 945 c := testdata.GetTLSConfig() 946 c.NextProtos = []string{"foo", "bar"} 947 return c, nil 948 }, 949 } 950 951 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 952 Expect(err).ToNot(HaveOccurred()) 953 defer ln.Close() 954 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 955 Expect(err).ToNot(HaveOccurred()) 956 defer c.CloseWithError(0, "") 957 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 958 }) 959 960 It("works if GetConfigForClient returns a nil tls.Config", func() { 961 tlsConf := testdata.GetTLSConfig() 962 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 963 964 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 965 Expect(err).ToNot(HaveOccurred()) 966 defer ln.Close() 967 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 968 Expect(err).ToNot(HaveOccurred()) 969 defer c.CloseWithError(0, "") 970 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 971 }) 972 973 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 974 tlsClientConf := testdata.GetTLSConfig() 975 tlsClientConf.NextProtos = []string{"foo", "bar"} 976 tlsConf := &tls.Config{ 977 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 978 return tlsClientConf, nil 979 }, 980 } 981 982 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 983 Expect(err).ToNot(HaveOccurred()) 984 defer ln.Close() 985 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 986 Expect(err).ToNot(HaveOccurred()) 987 defer c.CloseWithError(0, "") 988 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 989 // check that the original config was not modified 990 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 991 }) 992 }) 993 994 Context("Serve", func() { 995 origQuicListen := quicListen 996 997 AfterEach(func() { 998 quicListen = origQuicListen 999 }) 1000 1001 It("serves a packet conn", func() { 1002 ln := newMockAddrListener(":443") 1003 conn := &net.UDPConn{} 1004 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1005 Expect(c).To(Equal(conn)) 1006 return ln, nil 1007 } 1008 1009 s := &Server{ 1010 TLSConfig: &tls.Config{}, 1011 } 1012 1013 stopAccept := make(chan struct{}) 1014 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1015 <-stopAccept 1016 return nil, errors.New("closed") 1017 }) 1018 ln.EXPECT().Addr() // generate alt-svc headers 1019 done := make(chan struct{}) 1020 go func() { 1021 defer GinkgoRecover() 1022 defer close(done) 1023 s.Serve(conn) 1024 }() 1025 1026 Consistently(done).ShouldNot(BeClosed()) 1027 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1028 Expect(s.Close()).To(Succeed()) 1029 Eventually(done).Should(BeClosed()) 1030 }) 1031 1032 It("serves two packet conns", func() { 1033 ln1 := newMockAddrListener(":443") 1034 ln2 := newMockAddrListener(":8443") 1035 lns := make(chan QUICEarlyListener, 2) 1036 lns <- ln1 1037 lns <- ln2 1038 conn1 := &net.UDPConn{} 1039 conn2 := &net.UDPConn{} 1040 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1041 return <-lns, nil 1042 } 1043 1044 s := &Server{ 1045 TLSConfig: &tls.Config{}, 1046 } 1047 1048 stopAccept1 := make(chan struct{}) 1049 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1050 <-stopAccept1 1051 return nil, errors.New("closed") 1052 }) 1053 ln1.EXPECT().Addr() // generate alt-svc headers 1054 stopAccept2 := make(chan struct{}) 1055 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1056 <-stopAccept2 1057 return nil, errors.New("closed") 1058 }) 1059 ln2.EXPECT().Addr() 1060 1061 done1 := make(chan struct{}) 1062 go func() { 1063 defer GinkgoRecover() 1064 defer close(done1) 1065 s.Serve(conn1) 1066 }() 1067 done2 := make(chan struct{}) 1068 go func() { 1069 defer GinkgoRecover() 1070 defer close(done2) 1071 s.Serve(conn2) 1072 }() 1073 1074 Consistently(done1).ShouldNot(BeClosed()) 1075 Expect(done2).ToNot(BeClosed()) 1076 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1077 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1078 Expect(s.Close()).To(Succeed()) 1079 Eventually(done1).Should(BeClosed()) 1080 Eventually(done2).Should(BeClosed()) 1081 }) 1082 }) 1083 1084 Context("ServeListener", func() { 1085 origQuicListen := quicListen 1086 1087 AfterEach(func() { 1088 quicListen = origQuicListen 1089 }) 1090 1091 It("serves a listener", func() { 1092 var called int32 1093 ln := newMockAddrListener(":443") 1094 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1095 atomic.StoreInt32(&called, 1) 1096 return ln, nil 1097 } 1098 1099 s := &Server{} 1100 1101 stopAccept := make(chan struct{}) 1102 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1103 <-stopAccept 1104 return nil, errors.New("closed") 1105 }) 1106 ln.EXPECT().Addr() // generate alt-svc headers 1107 done := make(chan struct{}) 1108 go func() { 1109 defer GinkgoRecover() 1110 defer close(done) 1111 s.ServeListener(ln) 1112 }() 1113 1114 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1115 Consistently(done).ShouldNot(BeClosed()) 1116 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1117 Expect(s.Close()).To(Succeed()) 1118 Eventually(done).Should(BeClosed()) 1119 }) 1120 1121 It("serves two listeners", func() { 1122 var called int32 1123 ln1 := newMockAddrListener(":443") 1124 ln2 := newMockAddrListener(":8443") 1125 lns := make(chan QUICEarlyListener, 2) 1126 lns <- ln1 1127 lns <- ln2 1128 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1129 atomic.StoreInt32(&called, 1) 1130 return <-lns, nil 1131 } 1132 1133 s := &Server{} 1134 1135 stopAccept1 := make(chan struct{}) 1136 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1137 <-stopAccept1 1138 return nil, errors.New("closed") 1139 }) 1140 ln1.EXPECT().Addr() // generate alt-svc headers 1141 stopAccept2 := make(chan struct{}) 1142 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1143 <-stopAccept2 1144 return nil, errors.New("closed") 1145 }) 1146 ln2.EXPECT().Addr() 1147 1148 done1 := make(chan struct{}) 1149 go func() { 1150 defer GinkgoRecover() 1151 defer close(done1) 1152 s.ServeListener(ln1) 1153 }() 1154 done2 := make(chan struct{}) 1155 go func() { 1156 defer GinkgoRecover() 1157 defer close(done2) 1158 s.ServeListener(ln2) 1159 }() 1160 1161 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1162 Consistently(done1).ShouldNot(BeClosed()) 1163 Expect(done2).ToNot(BeClosed()) 1164 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1165 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1166 Expect(s.Close()).To(Succeed()) 1167 Eventually(done1).Should(BeClosed()) 1168 Eventually(done2).Should(BeClosed()) 1169 }) 1170 }) 1171 1172 Context("ServeQUICConn", func() { 1173 It("serves a QUIC connection", func() { 1174 mux := http.NewServeMux() 1175 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1176 w.Write([]byte("foobar")) 1177 }) 1178 s.Handler = mux 1179 tlsConf := testdata.GetTLSConfig() 1180 tlsConf.NextProtos = []string{NextProtoH3} 1181 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1182 controlStr := mockquic.NewMockStream(mockCtrl) 1183 controlStr.EXPECT().Write(gomock.Any()) 1184 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1185 testDone := make(chan struct{}) 1186 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1187 <-testDone 1188 return nil, errors.New("test done") 1189 }).MaxTimes(1) 1190 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1191 s.ServeQUICConn(conn) 1192 close(testDone) 1193 }) 1194 }) 1195 1196 Context("ListenAndServe", func() { 1197 BeforeEach(func() { 1198 s.Addr = "localhost:0" 1199 }) 1200 1201 AfterEach(func() { 1202 Expect(s.Close()).To(Succeed()) 1203 }) 1204 1205 It("uses the quic.Config to start the QUIC server", func() { 1206 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1207 var receivedConf *quic.Config 1208 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1209 receivedConf = config 1210 return nil, errors.New("listen err") 1211 } 1212 s.QuicConfig = conf 1213 Expect(s.ListenAndServe()).To(HaveOccurred()) 1214 Expect(receivedConf).To(Equal(conf)) 1215 }) 1216 }) 1217 1218 It("closes gracefully", func() { 1219 Expect(s.CloseGracefully(0)).To(Succeed()) 1220 }) 1221 1222 It("errors when listening fails", func() { 1223 testErr := errors.New("listen error") 1224 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1225 return nil, testErr 1226 } 1227 fullpem, privkey := testdata.GetCertificatePaths() 1228 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1229 }) 1230 1231 It("supports H3_DATAGRAM", func() { 1232 s.EnableDatagrams = true 1233 var receivedConf *quic.Config 1234 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1235 receivedConf = config 1236 return nil, errors.New("listen err") 1237 } 1238 Expect(s.ListenAndServe()).To(HaveOccurred()) 1239 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1240 }) 1241 })