github.com/MerlinKodo/quic-go@v0.39.2/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/MerlinKodo/quic-go" 17 mockquic "github.com/MerlinKodo/quic-go/internal/mocks/quic" 18 "github.com/MerlinKodo/quic-go/internal/protocol" 19 "github.com/MerlinKodo/quic-go/internal/testdata" 20 "github.com/MerlinKodo/quic-go/internal/utils" 21 "github.com/MerlinKodo/quic-go/quicvarint" 22 23 "github.com/quic-go/qpack" 24 "go.uber.org/mock/gomock" 25 26 . "github.com/onsi/ginkgo/v2" 27 . "github.com/onsi/gomega" 28 gmtypes "github.com/onsi/gomega/types" 29 ) 30 31 type mockAddr struct{ addr string } 32 33 func (ma *mockAddr) Network() string { return "udp" } 34 func (ma *mockAddr) String() string { return ma.addr } 35 36 type mockAddrListener struct { 37 *MockQUICEarlyListener 38 addr *mockAddr 39 } 40 41 func (m *mockAddrListener) Addr() net.Addr { 42 _ = m.MockQUICEarlyListener.Addr() 43 return m.addr 44 } 45 46 func newMockAddrListener(addr string) *mockAddrListener { 47 return &mockAddrListener{ 48 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 49 addr: &mockAddr{addr: addr}, 50 } 51 } 52 53 type noPortListener struct { 54 *mockAddrListener 55 } 56 57 func (m *noPortListener) Addr() net.Addr { 58 _ = m.mockAddrListener.Addr() 59 return &net.UnixAddr{ 60 Net: "unix", 61 Name: "/tmp/quic.sock", 62 } 63 } 64 65 var _ = Describe("Server", func() { 66 var ( 67 s *Server 68 origQuicListenAddr = quicListenAddr 69 ) 70 71 BeforeEach(func() { 72 s = &Server{ 73 TLSConfig: testdata.GetTLSConfig(), 74 logger: utils.DefaultLogger, 75 } 76 origQuicListenAddr = quicListenAddr 77 }) 78 79 AfterEach(func() { 80 quicListenAddr = origQuicListenAddr 81 }) 82 83 Context("handling requests", func() { 84 var ( 85 qpackDecoder *qpack.Decoder 86 str *mockquic.MockStream 87 conn *mockquic.MockEarlyConnection 88 exampleGetRequest *http.Request 89 examplePostRequest *http.Request 90 ) 91 reqContext := context.Background() 92 93 decodeHeader := func(str io.Reader) map[string][]string { 94 fields := make(map[string][]string) 95 decoder := qpack.NewDecoder(nil) 96 97 frame, err := parseNextFrame(str, nil) 98 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 99 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 100 headersFrame := frame.(*headersFrame) 101 data := make([]byte, headersFrame.Length) 102 _, err = io.ReadFull(str, data) 103 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 104 hfs, err := decoder.DecodeFull(data) 105 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 106 for _, p := range hfs { 107 fields[p.Name] = append(fields[p.Name], p.Value) 108 } 109 return fields 110 } 111 112 encodeRequest := func(req *http.Request) []byte { 113 buf := &bytes.Buffer{} 114 str := mockquic.NewMockStream(mockCtrl) 115 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 116 rw := newRequestWriter(utils.DefaultLogger) 117 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 118 return buf.Bytes() 119 } 120 121 setRequest := func(data []byte) { 122 buf := bytes.NewBuffer(data) 123 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 124 if buf.Len() == 0 { 125 return 0, io.EOF 126 } 127 return buf.Read(p) 128 }).AnyTimes() 129 } 130 131 BeforeEach(func() { 132 var err error 133 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 134 Expect(err).ToNot(HaveOccurred()) 135 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 136 Expect(err).ToNot(HaveOccurred()) 137 138 qpackDecoder = qpack.NewDecoder(nil) 139 str = mockquic.NewMockStream(mockCtrl) 140 conn = mockquic.NewMockEarlyConnection(mockCtrl) 141 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 142 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 143 conn.EXPECT().LocalAddr().AnyTimes() 144 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 145 }) 146 147 It("calls the HTTP handler function", func() { 148 requestChan := make(chan *http.Request, 1) 149 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 150 requestChan <- r 151 }) 152 153 setRequest(encodeRequest(exampleGetRequest)) 154 str.EXPECT().Context().Return(reqContext) 155 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 156 return len(p), nil 157 }).AnyTimes() 158 str.EXPECT().CancelRead(gomock.Any()) 159 160 Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) 161 var req *http.Request 162 Eventually(requestChan).Should(Receive(&req)) 163 Expect(req.Host).To(Equal("www.example.com")) 164 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 165 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 166 }) 167 168 It("returns 200 with an empty handler", func() { 169 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 170 171 responseBuf := &bytes.Buffer{} 172 setRequest(encodeRequest(exampleGetRequest)) 173 str.EXPECT().Context().Return(reqContext) 174 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 175 str.EXPECT().CancelRead(gomock.Any()) 176 177 serr := s.handleRequest(conn, str, qpackDecoder, nil) 178 Expect(serr.err).ToNot(HaveOccurred()) 179 hfs := decodeHeader(responseBuf) 180 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 181 }) 182 183 It("sets Content-Length when the handler doesn't flush to the client", func() { 184 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 185 w.Write([]byte("foobar")) 186 }) 187 188 responseBuf := &bytes.Buffer{} 189 setRequest(encodeRequest(exampleGetRequest)) 190 str.EXPECT().Context().Return(reqContext) 191 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 192 str.EXPECT().CancelRead(gomock.Any()) 193 194 serr := s.handleRequest(conn, str, qpackDecoder, nil) 195 Expect(serr.err).ToNot(HaveOccurred()) 196 hfs := decodeHeader(responseBuf) 197 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 198 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) 199 // status, content-length, date, content-type 200 Expect(hfs).To(HaveLen(4)) 201 }) 202 203 It("not sets Content-Length when the handler flushes to the client", func() { 204 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 205 w.Write([]byte("foobar")) 206 // force flush 207 w.(http.Flusher).Flush() 208 }) 209 210 responseBuf := &bytes.Buffer{} 211 setRequest(encodeRequest(exampleGetRequest)) 212 str.EXPECT().Context().Return(reqContext) 213 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 214 str.EXPECT().CancelRead(gomock.Any()) 215 216 serr := s.handleRequest(conn, str, qpackDecoder, nil) 217 Expect(serr.err).ToNot(HaveOccurred()) 218 hfs := decodeHeader(responseBuf) 219 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 220 // status, date, content-type 221 Expect(hfs).To(HaveLen(3)) 222 }) 223 224 It("handles a aborting handler", func() { 225 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 226 panic(http.ErrAbortHandler) 227 }) 228 229 responseBuf := &bytes.Buffer{} 230 setRequest(encodeRequest(exampleGetRequest)) 231 str.EXPECT().Context().Return(reqContext) 232 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 233 str.EXPECT().CancelRead(gomock.Any()) 234 235 serr := s.handleRequest(conn, str, qpackDecoder, nil) 236 Expect(serr.err).ToNot(HaveOccurred()) 237 Expect(responseBuf.Bytes()).To(HaveLen(0)) 238 }) 239 240 It("handles a panicking handler", func() { 241 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 242 panic("foobar") 243 }) 244 245 responseBuf := &bytes.Buffer{} 246 setRequest(encodeRequest(exampleGetRequest)) 247 str.EXPECT().Context().Return(reqContext) 248 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 249 str.EXPECT().CancelRead(gomock.Any()) 250 251 serr := s.handleRequest(conn, str, qpackDecoder, nil) 252 Expect(serr.err).ToNot(HaveOccurred()) 253 Expect(responseBuf.Bytes()).To(HaveLen(0)) 254 }) 255 256 Context("hijacking bidirectional streams", func() { 257 var conn *mockquic.MockEarlyConnection 258 testDone := make(chan struct{}) 259 260 BeforeEach(func() { 261 testDone = make(chan struct{}) 262 conn = mockquic.NewMockEarlyConnection(mockCtrl) 263 controlStr := mockquic.NewMockStream(mockCtrl) 264 controlStr.EXPECT().Write(gomock.Any()) 265 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 266 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 267 conn.EXPECT().LocalAddr().AnyTimes() 268 }) 269 270 AfterEach(func() { testDone <- struct{}{} }) 271 272 It("hijacks a bidirectional stream of unknown frame type", func() { 273 frameTypeChan := make(chan FrameType, 1) 274 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 275 Expect(e).ToNot(HaveOccurred()) 276 frameTypeChan <- ft 277 return true, nil 278 } 279 280 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 281 unknownStr := mockquic.NewMockStream(mockCtrl) 282 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 283 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 284 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 285 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 286 <-testDone 287 return nil, errors.New("test done") 288 }) 289 s.handleConn(conn) 290 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 291 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 292 }) 293 294 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 295 frameTypeChan := make(chan FrameType, 1) 296 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 297 Expect(e).ToNot(HaveOccurred()) 298 frameTypeChan <- ft 299 return false, nil 300 } 301 302 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 303 unknownStr := mockquic.NewMockStream(mockCtrl) 304 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 305 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 306 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 307 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 308 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 309 <-testDone 310 return nil, errors.New("test done") 311 }) 312 s.handleConn(conn) 313 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 314 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 315 }) 316 317 It("cancels writing when hijacker returned error", func() { 318 frameTypeChan := make(chan FrameType, 1) 319 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 320 Expect(e).ToNot(HaveOccurred()) 321 frameTypeChan <- ft 322 return false, errors.New("error in hijacker") 323 } 324 325 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 326 unknownStr := mockquic.NewMockStream(mockCtrl) 327 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 328 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 329 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 330 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 331 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 332 <-testDone 333 return nil, errors.New("test done") 334 }) 335 s.handleConn(conn) 336 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 337 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 338 }) 339 340 It("handles errors that occur when reading the stream type", func() { 341 testErr := errors.New("test error") 342 done := make(chan struct{}) 343 unknownStr := mockquic.NewMockStream(mockCtrl) 344 s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { 345 defer close(done) 346 Expect(ft).To(BeZero()) 347 Expect(str).To(Equal(unknownStr)) 348 Expect(err).To(MatchError(testErr)) 349 return true, nil 350 } 351 352 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 353 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 354 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 355 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 356 <-testDone 357 return nil, errors.New("test done") 358 }) 359 s.handleConn(conn) 360 Eventually(done).Should(BeClosed()) 361 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 362 }) 363 }) 364 365 Context("hijacking unidirectional streams", func() { 366 var conn *mockquic.MockEarlyConnection 367 testDone := make(chan struct{}) 368 369 BeforeEach(func() { 370 testDone = make(chan struct{}) 371 conn = mockquic.NewMockEarlyConnection(mockCtrl) 372 controlStr := mockquic.NewMockStream(mockCtrl) 373 controlStr.EXPECT().Write(gomock.Any()) 374 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 375 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 376 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 377 conn.EXPECT().LocalAddr().AnyTimes() 378 }) 379 380 AfterEach(func() { testDone <- struct{}{} }) 381 382 It("hijacks an unidirectional stream of unknown stream type", func() { 383 streamTypeChan := make(chan StreamType, 1) 384 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 385 Expect(err).ToNot(HaveOccurred()) 386 streamTypeChan <- st 387 return true 388 } 389 390 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 391 unknownStr := mockquic.NewMockStream(mockCtrl) 392 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 393 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 394 return unknownStr, nil 395 }) 396 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 397 <-testDone 398 return nil, errors.New("test done") 399 }) 400 s.handleConn(conn) 401 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 402 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 403 }) 404 405 It("handles errors that occur when reading the stream type", func() { 406 testErr := errors.New("test error") 407 done := make(chan struct{}) 408 unknownStr := mockquic.NewMockStream(mockCtrl) 409 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 410 defer close(done) 411 Expect(st).To(BeZero()) 412 Expect(str).To(Equal(unknownStr)) 413 Expect(err).To(MatchError(testErr)) 414 return true 415 } 416 417 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 418 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 419 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 420 <-testDone 421 return nil, errors.New("test done") 422 }) 423 s.handleConn(conn) 424 Eventually(done).Should(BeClosed()) 425 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 426 }) 427 428 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 429 streamTypeChan := make(chan StreamType, 1) 430 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 431 Expect(err).ToNot(HaveOccurred()) 432 streamTypeChan <- st 433 return false 434 } 435 436 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 437 unknownStr := mockquic.NewMockStream(mockCtrl) 438 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 439 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 440 441 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 442 return unknownStr, nil 443 }) 444 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 445 <-testDone 446 return nil, errors.New("test done") 447 }) 448 s.handleConn(conn) 449 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 450 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 451 }) 452 }) 453 454 Context("control stream handling", func() { 455 var conn *mockquic.MockEarlyConnection 456 testDone := make(chan struct{}) 457 458 BeforeEach(func() { 459 conn = mockquic.NewMockEarlyConnection(mockCtrl) 460 controlStr := mockquic.NewMockStream(mockCtrl) 461 controlStr.EXPECT().Write(gomock.Any()) 462 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 463 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 464 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 465 conn.EXPECT().LocalAddr().AnyTimes() 466 }) 467 468 AfterEach(func() { testDone <- struct{}{} }) 469 470 It("parses the SETTINGS frame", func() { 471 b := quicvarint.Append(nil, streamTypeControlStream) 472 b = (&settingsFrame{}).Append(b) 473 controlStr := mockquic.NewMockStream(mockCtrl) 474 r := bytes.NewReader(b) 475 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 476 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 477 return controlStr, nil 478 }) 479 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 480 <-testDone 481 return nil, errors.New("test done") 482 }) 483 s.handleConn(conn) 484 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 485 }) 486 487 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 488 streamType := t 489 name := "encoder" 490 if streamType == streamTypeQPACKDecoderStream { 491 name = "decoder" 492 } 493 494 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 495 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 496 str := mockquic.NewMockStream(mockCtrl) 497 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 498 499 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 500 return str, nil 501 }) 502 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 503 <-testDone 504 return nil, errors.New("test done") 505 }) 506 s.handleConn(conn) 507 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 508 }) 509 } 510 511 It("reset streams other than the control stream and the QPACK streams", func() { 512 buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337)) 513 str := mockquic.NewMockStream(mockCtrl) 514 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 515 done := make(chan struct{}) 516 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(code quic.StreamErrorCode) { 517 close(done) 518 }) 519 520 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 521 return str, nil 522 }) 523 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 524 <-testDone 525 return nil, errors.New("test done") 526 }) 527 s.handleConn(conn) 528 Eventually(done).Should(BeClosed()) 529 }) 530 531 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 532 b := quicvarint.Append(nil, streamTypeControlStream) 533 b = (&dataFrame{}).Append(b) 534 controlStr := mockquic.NewMockStream(mockCtrl) 535 r := bytes.NewReader(b) 536 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 537 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 538 return controlStr, nil 539 }) 540 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 541 <-testDone 542 return nil, errors.New("test done") 543 }) 544 done := make(chan struct{}) 545 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 546 defer GinkgoRecover() 547 Expect(code).To(BeEquivalentTo(ErrCodeMissingSettings)) 548 close(done) 549 }) 550 s.handleConn(conn) 551 Eventually(done).Should(BeClosed()) 552 }) 553 554 It("errors when parsing the frame on the control stream fails", func() { 555 b := quicvarint.Append(nil, streamTypeControlStream) 556 b = (&settingsFrame{}).Append(b) 557 r := bytes.NewReader(b[:len(b)-1]) 558 controlStr := mockquic.NewMockStream(mockCtrl) 559 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 560 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 561 return controlStr, nil 562 }) 563 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 564 <-testDone 565 return nil, errors.New("test done") 566 }) 567 done := make(chan struct{}) 568 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 569 defer GinkgoRecover() 570 Expect(code).To(BeEquivalentTo(ErrCodeFrameError)) 571 close(done) 572 }) 573 s.handleConn(conn) 574 Eventually(done).Should(BeClosed()) 575 }) 576 577 It("errors when the client opens a push stream", func() { 578 b := quicvarint.Append(nil, streamTypePushStream) 579 b = (&dataFrame{}).Append(b) 580 r := bytes.NewReader(b) 581 controlStr := mockquic.NewMockStream(mockCtrl) 582 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 583 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 584 return controlStr, nil 585 }) 586 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 587 <-testDone 588 return nil, errors.New("test done") 589 }) 590 done := make(chan struct{}) 591 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 592 defer GinkgoRecover() 593 Expect(code).To(BeEquivalentTo(ErrCodeStreamCreationError)) 594 close(done) 595 }) 596 s.handleConn(conn) 597 Eventually(done).Should(BeClosed()) 598 }) 599 600 It("errors when the client advertises datagram support (and we enabled support for it)", func() { 601 s.EnableDatagrams = true 602 b := quicvarint.Append(nil, streamTypeControlStream) 603 b = (&settingsFrame{Datagram: true}).Append(b) 604 r := bytes.NewReader(b) 605 controlStr := mockquic.NewMockStream(mockCtrl) 606 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 607 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 608 return controlStr, nil 609 }) 610 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 611 <-testDone 612 return nil, errors.New("test done") 613 }) 614 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 615 done := make(chan struct{}) 616 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, reason string) { 617 defer GinkgoRecover() 618 Expect(code).To(BeEquivalentTo(ErrCodeSettingsError)) 619 Expect(reason).To(Equal("missing QUIC Datagram support")) 620 close(done) 621 }) 622 s.handleConn(conn) 623 Eventually(done).Should(BeClosed()) 624 }) 625 }) 626 627 Context("stream- and connection-level errors", func() { 628 var conn *mockquic.MockEarlyConnection 629 testDone := make(chan struct{}) 630 631 BeforeEach(func() { 632 testDone = make(chan struct{}) 633 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 634 conn = mockquic.NewMockEarlyConnection(mockCtrl) 635 controlStr := mockquic.NewMockStream(mockCtrl) 636 controlStr.EXPECT().Write(gomock.Any()) 637 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 638 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 639 <-testDone 640 return nil, errors.New("test done") 641 }) 642 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 643 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 644 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 645 conn.EXPECT().LocalAddr().AnyTimes() 646 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 647 }) 648 649 AfterEach(func() { testDone <- struct{}{} }) 650 651 It("cancels reading when client sends a body in GET request", func() { 652 var handlerCalled bool 653 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 654 handlerCalled = true 655 }) 656 657 requestData := encodeRequest(exampleGetRequest) 658 b := (&dataFrame{Length: 6}).Append(nil) // add a body 659 b = append(b, []byte("foobar")...) 660 responseBuf := &bytes.Buffer{} 661 setRequest(append(requestData, b...)) 662 done := make(chan struct{}) 663 str.EXPECT().Context().Return(reqContext) 664 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 665 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 666 str.EXPECT().Close().Do(func() { close(done) }) 667 668 s.handleConn(conn) 669 Eventually(done).Should(BeClosed()) 670 hfs := decodeHeader(responseBuf) 671 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 672 Expect(handlerCalled).To(BeTrue()) 673 }) 674 675 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 676 handlerCalled := make(chan struct{}) 677 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 678 defer close(handlerCalled) 679 r.Body.(HTTPStreamer).HTTPStream() 680 str.Write([]byte("foobar")) 681 }) 682 683 requestData := encodeRequest(exampleGetRequest) 684 b := (&dataFrame{Length: 6}).Append(nil) // add a body 685 b = append(b, []byte("foobar")...) 686 setRequest(append(requestData, b...)) 687 str.EXPECT().Context().Return(reqContext) 688 str.EXPECT().Write([]byte("foobar")).Return(6, nil) 689 690 s.handleConn(conn) 691 Eventually(handlerCalled).Should(BeClosed()) 692 }) 693 694 It("errors when the client sends a too large header frame", func() { 695 s.MaxHeaderBytes = 20 696 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 697 Fail("Handler should not be called.") 698 }) 699 700 requestData := encodeRequest(exampleGetRequest) 701 b := (&dataFrame{Length: 6}).Append(nil) // add a body 702 b = append(b, []byte("foobar")...) 703 responseBuf := &bytes.Buffer{} 704 setRequest(append(requestData, b...)) 705 done := make(chan struct{}) 706 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 707 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 708 709 s.handleConn(conn) 710 Eventually(done).Should(BeClosed()) 711 }) 712 713 It("handles a request for which the client immediately resets the stream", func() { 714 handlerCalled := make(chan struct{}) 715 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 716 close(handlerCalled) 717 }) 718 719 testErr := errors.New("stream reset") 720 done := make(chan struct{}) 721 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 722 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 723 724 s.handleConn(conn) 725 Consistently(handlerCalled).ShouldNot(BeClosed()) 726 }) 727 728 It("closes the connection when the first frame is not a HEADERS frame", func() { 729 handlerCalled := make(chan struct{}) 730 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 731 close(handlerCalled) 732 }) 733 734 b := (&dataFrame{}).Append(nil) 735 setRequest(b) 736 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 737 return len(p), nil 738 }).AnyTimes() 739 740 done := make(chan struct{}) 741 conn.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(code quic.ApplicationErrorCode, _ string) { 742 Expect(code).To(Equal(quic.ApplicationErrorCode(ErrCodeFrameUnexpected))) 743 close(done) 744 }) 745 s.handleConn(conn) 746 Eventually(done).Should(BeClosed()) 747 }) 748 749 It("closes the connection when the first frame is not a HEADERS frame", func() { 750 handlerCalled := make(chan struct{}) 751 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 752 close(handlerCalled) 753 }) 754 755 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 756 // but the request will still end up larger than DefaultMaxHeaderBytes. 757 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 758 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 759 Expect(err).ToNot(HaveOccurred()) 760 setRequest(encodeRequest(req)) 761 // str.EXPECT().Context().Return(reqContext) 762 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 763 return len(p), nil 764 }).AnyTimes() 765 done := make(chan struct{}) 766 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 767 768 s.handleConn(conn) 769 Eventually(done).Should(BeClosed()) 770 }) 771 }) 772 773 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 774 handlerCalled := make(chan struct{}) 775 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 776 r.Body = struct { 777 io.Reader 778 io.Closer 779 }{} 780 close(handlerCalled) 781 }) 782 783 setRequest(encodeRequest(examplePostRequest)) 784 str.EXPECT().Context().Return(reqContext) 785 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 786 return len(p), nil 787 }).AnyTimes() 788 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 789 790 serr := s.handleRequest(conn, str, qpackDecoder, nil) 791 Expect(serr.err).ToNot(HaveOccurred()) 792 Eventually(handlerCalled).Should(BeClosed()) 793 }) 794 795 It("cancels the request context when the stream is closed", func() { 796 handlerCalled := make(chan struct{}) 797 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 798 defer GinkgoRecover() 799 Expect(r.Context().Done()).To(BeClosed()) 800 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 801 close(handlerCalled) 802 }) 803 setRequest(encodeRequest(examplePostRequest)) 804 805 reqContext, cancel := context.WithCancel(context.Background()) 806 cancel() 807 str.EXPECT().Context().Return(reqContext) 808 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 809 return len(p), nil 810 }).AnyTimes() 811 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 812 813 serr := s.handleRequest(conn, str, qpackDecoder, nil) 814 Expect(serr.err).ToNot(HaveOccurred()) 815 Eventually(handlerCalled).Should(BeClosed()) 816 }) 817 }) 818 819 Context("setting http headers", func() { 820 BeforeEach(func() { 821 s.QuicConfig = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}} 822 }) 823 824 var ln1 QUICEarlyListener 825 var ln2 QUICEarlyListener 826 expected := http.Header{ 827 "Alt-Svc": {`h3=":443"; ma=2592000`}, 828 } 829 830 addListener := func(addr string, ln *QUICEarlyListener) { 831 mln := newMockAddrListener(addr) 832 mln.EXPECT().Addr() 833 *ln = mln 834 s.addListener(ln) 835 } 836 837 removeListener := func(ln *QUICEarlyListener) { 838 s.removeListener(ln) 839 } 840 841 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 842 hdr := http.Header{} 843 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 844 Expect(hdr).To(expected) 845 } 846 847 checkSetHeaderError := func() { 848 hdr := http.Header{} 849 Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 850 } 851 852 It("sets proper headers with numeric port", func() { 853 addListener(":443", &ln1) 854 checkSetHeaders(Equal(expected)) 855 removeListener(&ln1) 856 checkSetHeaderError() 857 }) 858 859 It("sets proper headers with full addr", func() { 860 addListener("127.0.0.1:443", &ln1) 861 checkSetHeaders(Equal(expected)) 862 removeListener(&ln1) 863 checkSetHeaderError() 864 }) 865 866 It("sets proper headers with string port", func() { 867 addListener(":https", &ln1) 868 checkSetHeaders(Equal(expected)) 869 removeListener(&ln1) 870 checkSetHeaderError() 871 }) 872 873 It("works multiple times", func() { 874 addListener(":https", &ln1) 875 checkSetHeaders(Equal(expected)) 876 checkSetHeaders(Equal(expected)) 877 removeListener(&ln1) 878 checkSetHeaderError() 879 }) 880 881 It("works if the quic.Config sets QUIC versions", func() { 882 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version2} 883 addListener(":443", &ln1) 884 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 885 removeListener(&ln1) 886 checkSetHeaderError() 887 }) 888 889 It("uses s.Port if set to a non-zero value", func() { 890 s.Port = 8443 891 addListener(":443", &ln1) 892 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 893 removeListener(&ln1) 894 checkSetHeaderError() 895 }) 896 897 It("uses s.Addr if listeners don't have ports available", func() { 898 s.Addr = ":443" 899 mln := &noPortListener{newMockAddrListener("")} 900 mln.EXPECT().Addr() 901 ln1 = mln 902 s.addListener(&ln1) 903 checkSetHeaders(Equal(expected)) 904 s.removeListener(&ln1) 905 checkSetHeaderError() 906 }) 907 908 It("properly announces multiple listeners", func() { 909 addListener(":443", &ln1) 910 addListener(":8443", &ln2) 911 checkSetHeaders(Or( 912 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 913 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 914 )) 915 removeListener(&ln1) 916 removeListener(&ln2) 917 checkSetHeaderError() 918 }) 919 920 It("doesn't duplicate Alt-Svc values", func() { 921 s.QuicConfig.Versions = []quic.VersionNumber{quic.Version1, quic.Version1} 922 addListener(":443", &ln1) 923 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 924 removeListener(&ln1) 925 checkSetHeaderError() 926 }) 927 }) 928 929 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 930 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 931 }) 932 933 It("should nop-Close() when s.server is nil", func() { 934 Expect((&Server{}).Close()).To(Succeed()) 935 }) 936 937 It("errors when ListenAndServeTLS is called after Close", func() { 938 serv := &Server{} 939 Expect(serv.Close()).To(Succeed()) 940 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 941 }) 942 943 It("handles concurrent Serve and Close", func() { 944 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 945 Expect(err).ToNot(HaveOccurred()) 946 c, err := net.ListenUDP("udp", addr) 947 Expect(err).ToNot(HaveOccurred()) 948 done := make(chan struct{}) 949 go func() { 950 defer GinkgoRecover() 951 defer close(done) 952 s.Serve(c) 953 }() 954 runtime.Gosched() 955 s.Close() 956 Eventually(done).Should(BeClosed()) 957 }) 958 959 Context("ConfigureTLSConfig", func() { 960 It("advertises v1 by default", func() { 961 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 962 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 963 Expect(err).ToNot(HaveOccurred()) 964 defer ln.Close() 965 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 966 Expect(err).ToNot(HaveOccurred()) 967 defer c.CloseWithError(0, "") 968 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 969 }) 970 971 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 972 var receivedConf *tls.Config 973 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 974 receivedConf = tlsConf 975 return nil, errors.New("listen err") 976 } 977 Expect(s.ListenAndServe()).To(HaveOccurred()) 978 Expect(receivedConf).ToNot(BeNil()) 979 }) 980 981 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 982 tlsConf := &tls.Config{ 983 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 984 c := testdata.GetTLSConfig() 985 c.NextProtos = []string{"foo", "bar"} 986 return c, nil 987 }, 988 } 989 990 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 991 Expect(err).ToNot(HaveOccurred()) 992 defer ln.Close() 993 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 994 Expect(err).ToNot(HaveOccurred()) 995 defer c.CloseWithError(0, "") 996 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 997 }) 998 999 It("works if GetConfigForClient returns a nil tls.Config", func() { 1000 tlsConf := testdata.GetTLSConfig() 1001 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 1002 1003 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 1004 Expect(err).ToNot(HaveOccurred()) 1005 defer ln.Close() 1006 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1007 Expect(err).ToNot(HaveOccurred()) 1008 defer c.CloseWithError(0, "") 1009 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1010 }) 1011 1012 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 1013 tlsClientConf := testdata.GetTLSConfig() 1014 tlsClientConf.NextProtos = []string{"foo", "bar"} 1015 tlsConf := &tls.Config{ 1016 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 1017 return tlsClientConf, nil 1018 }, 1019 } 1020 1021 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) 1022 Expect(err).ToNot(HaveOccurred()) 1023 defer ln.Close() 1024 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1025 Expect(err).ToNot(HaveOccurred()) 1026 defer c.CloseWithError(0, "") 1027 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1028 // check that the original config was not modified 1029 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 1030 }) 1031 }) 1032 1033 Context("Serve", func() { 1034 origQuicListen := quicListen 1035 1036 AfterEach(func() { 1037 quicListen = origQuicListen 1038 }) 1039 1040 It("serves a packet conn", func() { 1041 ln := newMockAddrListener(":443") 1042 conn := &net.UDPConn{} 1043 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1044 Expect(c).To(Equal(conn)) 1045 return ln, nil 1046 } 1047 1048 s := &Server{ 1049 TLSConfig: &tls.Config{}, 1050 } 1051 1052 stopAccept := make(chan struct{}) 1053 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1054 <-stopAccept 1055 return nil, errors.New("closed") 1056 }) 1057 ln.EXPECT().Addr() // generate alt-svc headers 1058 done := make(chan struct{}) 1059 go func() { 1060 defer GinkgoRecover() 1061 defer close(done) 1062 s.Serve(conn) 1063 }() 1064 1065 Consistently(done).ShouldNot(BeClosed()) 1066 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1067 Expect(s.Close()).To(Succeed()) 1068 Eventually(done).Should(BeClosed()) 1069 }) 1070 1071 It("serves two packet conns", func() { 1072 ln1 := newMockAddrListener(":443") 1073 ln2 := newMockAddrListener(":8443") 1074 lns := make(chan QUICEarlyListener, 2) 1075 lns <- ln1 1076 lns <- ln2 1077 conn1 := &net.UDPConn{} 1078 conn2 := &net.UDPConn{} 1079 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1080 return <-lns, nil 1081 } 1082 1083 s := &Server{ 1084 TLSConfig: &tls.Config{}, 1085 } 1086 1087 stopAccept1 := make(chan struct{}) 1088 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1089 <-stopAccept1 1090 return nil, errors.New("closed") 1091 }) 1092 ln1.EXPECT().Addr() // generate alt-svc headers 1093 stopAccept2 := make(chan struct{}) 1094 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1095 <-stopAccept2 1096 return nil, errors.New("closed") 1097 }) 1098 ln2.EXPECT().Addr() 1099 1100 done1 := make(chan struct{}) 1101 go func() { 1102 defer GinkgoRecover() 1103 defer close(done1) 1104 s.Serve(conn1) 1105 }() 1106 done2 := make(chan struct{}) 1107 go func() { 1108 defer GinkgoRecover() 1109 defer close(done2) 1110 s.Serve(conn2) 1111 }() 1112 1113 Consistently(done1).ShouldNot(BeClosed()) 1114 Expect(done2).ToNot(BeClosed()) 1115 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1116 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1117 Expect(s.Close()).To(Succeed()) 1118 Eventually(done1).Should(BeClosed()) 1119 Eventually(done2).Should(BeClosed()) 1120 }) 1121 }) 1122 1123 Context("ServeListener", func() { 1124 origQuicListen := quicListen 1125 1126 AfterEach(func() { 1127 quicListen = origQuicListen 1128 }) 1129 1130 It("serves a listener", func() { 1131 var called int32 1132 ln := newMockAddrListener(":443") 1133 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1134 atomic.StoreInt32(&called, 1) 1135 return ln, nil 1136 } 1137 1138 s := &Server{} 1139 1140 stopAccept := make(chan struct{}) 1141 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1142 <-stopAccept 1143 return nil, errors.New("closed") 1144 }) 1145 ln.EXPECT().Addr() // generate alt-svc headers 1146 done := make(chan struct{}) 1147 go func() { 1148 defer GinkgoRecover() 1149 defer close(done) 1150 s.ServeListener(ln) 1151 }() 1152 1153 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1154 Consistently(done).ShouldNot(BeClosed()) 1155 ln.EXPECT().Close().Do(func() { close(stopAccept) }) 1156 Expect(s.Close()).To(Succeed()) 1157 Eventually(done).Should(BeClosed()) 1158 }) 1159 1160 It("serves two listeners", func() { 1161 var called int32 1162 ln1 := newMockAddrListener(":443") 1163 ln2 := newMockAddrListener(":8443") 1164 lns := make(chan QUICEarlyListener, 2) 1165 lns <- ln1 1166 lns <- ln2 1167 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1168 atomic.StoreInt32(&called, 1) 1169 return <-lns, nil 1170 } 1171 1172 s := &Server{} 1173 1174 stopAccept1 := make(chan struct{}) 1175 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1176 <-stopAccept1 1177 return nil, errors.New("closed") 1178 }) 1179 ln1.EXPECT().Addr() // generate alt-svc headers 1180 stopAccept2 := make(chan struct{}) 1181 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Connection, error) { 1182 <-stopAccept2 1183 return nil, errors.New("closed") 1184 }) 1185 ln2.EXPECT().Addr() 1186 1187 done1 := make(chan struct{}) 1188 go func() { 1189 defer GinkgoRecover() 1190 defer close(done1) 1191 s.ServeListener(ln1) 1192 }() 1193 done2 := make(chan struct{}) 1194 go func() { 1195 defer GinkgoRecover() 1196 defer close(done2) 1197 s.ServeListener(ln2) 1198 }() 1199 1200 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1201 Consistently(done1).ShouldNot(BeClosed()) 1202 Expect(done2).ToNot(BeClosed()) 1203 ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) 1204 ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) 1205 Expect(s.Close()).To(Succeed()) 1206 Eventually(done1).Should(BeClosed()) 1207 Eventually(done2).Should(BeClosed()) 1208 }) 1209 }) 1210 1211 Context("ServeQUICConn", func() { 1212 It("serves a QUIC connection", func() { 1213 mux := http.NewServeMux() 1214 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1215 w.Write([]byte("foobar")) 1216 }) 1217 s.Handler = mux 1218 tlsConf := testdata.GetTLSConfig() 1219 tlsConf.NextProtos = []string{NextProtoH3} 1220 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1221 controlStr := mockquic.NewMockStream(mockCtrl) 1222 controlStr.EXPECT().Write(gomock.Any()) 1223 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1224 testDone := make(chan struct{}) 1225 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1226 <-testDone 1227 return nil, errors.New("test done") 1228 }).MaxTimes(1) 1229 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1230 s.ServeQUICConn(conn) 1231 close(testDone) 1232 }) 1233 }) 1234 1235 Context("ListenAndServe", func() { 1236 BeforeEach(func() { 1237 s.Addr = "localhost:0" 1238 }) 1239 1240 AfterEach(func() { 1241 Expect(s.Close()).To(Succeed()) 1242 }) 1243 1244 It("uses the quic.Config to start the QUIC server", func() { 1245 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1246 var receivedConf *quic.Config 1247 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1248 receivedConf = config 1249 return nil, errors.New("listen err") 1250 } 1251 s.QuicConfig = conf 1252 Expect(s.ListenAndServe()).To(HaveOccurred()) 1253 Expect(receivedConf).To(Equal(conf)) 1254 }) 1255 }) 1256 1257 It("closes gracefully", func() { 1258 Expect(s.CloseGracefully(0)).To(Succeed()) 1259 }) 1260 1261 It("errors when listening fails", func() { 1262 testErr := errors.New("listen error") 1263 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1264 return nil, testErr 1265 } 1266 fullpem, privkey := testdata.GetCertificatePaths() 1267 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1268 }) 1269 1270 It("supports H3_DATAGRAM", func() { 1271 s.EnableDatagrams = true 1272 var receivedConf *quic.Config 1273 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1274 receivedConf = config 1275 return nil, errors.New("listen err") 1276 } 1277 Expect(s.ListenAndServe()).To(HaveOccurred()) 1278 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1279 }) 1280 })