github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "golang.org/x/exp/slog" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/metacubex/quic-go" 17 mockquic "github.com/metacubex/quic-go/internal/mocks/quic" 18 "github.com/metacubex/quic-go/internal/protocol" 19 "github.com/metacubex/quic-go/internal/testdata" 20 "github.com/metacubex/quic-go/quicvarint" 21 22 "github.com/quic-go/qpack" 23 "go.uber.org/mock/gomock" 24 25 . "github.com/onsi/ginkgo/v2" 26 . "github.com/onsi/gomega" 27 gmtypes "github.com/onsi/gomega/types" 28 ) 29 30 type mockAddr struct{ addr string } 31 32 func (ma *mockAddr) Network() string { return "udp" } 33 func (ma *mockAddr) String() string { return ma.addr } 34 35 type mockAddrListener struct { 36 *MockQUICEarlyListener 37 addr *mockAddr 38 } 39 40 func (m *mockAddrListener) Addr() net.Addr { 41 _ = m.MockQUICEarlyListener.Addr() 42 return m.addr 43 } 44 45 func newMockAddrListener(addr string) *mockAddrListener { 46 return &mockAddrListener{ 47 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 48 addr: &mockAddr{addr: addr}, 49 } 50 } 51 52 type noPortListener struct { 53 *mockAddrListener 54 } 55 56 func (m *noPortListener) Addr() net.Addr { 57 _ = m.mockAddrListener.Addr() 58 return &net.UnixAddr{ 59 Net: "unix", 60 Name: "/tmp/quic.sock", 61 } 62 } 63 64 var _ = Describe("Server", func() { 65 var ( 66 s *Server 67 origQuicListenAddr = quicListenAddr 68 ) 69 type testConnContextKey string 70 71 BeforeEach(func() { 72 s = &Server{ 73 TLSConfig: testdata.GetTLSConfig(), 74 ConnContext: func(ctx context.Context, c quic.Connection) context.Context { 75 return context.WithValue(ctx, testConnContextKey("test"), c) 76 }, 77 } 78 origQuicListenAddr = quicListenAddr 79 }) 80 81 AfterEach(func() { 82 quicListenAddr = origQuicListenAddr 83 }) 84 85 Context("handling requests", func() { 86 var ( 87 qpackDecoder *qpack.Decoder 88 str *mockquic.MockStream 89 conn *connection 90 exampleGetRequest *http.Request 91 examplePostRequest *http.Request 92 ) 93 reqContext, reqContextCancel := context.WithCancel(context.Background()) 94 95 decodeHeader := func(str io.Reader) map[string][]string { 96 fields := make(map[string][]string) 97 decoder := qpack.NewDecoder(nil) 98 99 fp := frameParser{r: str} 100 frame, err := fp.ParseNext() 101 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 102 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 103 headersFrame := frame.(*headersFrame) 104 data := make([]byte, headersFrame.Length) 105 _, err = io.ReadFull(str, data) 106 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 107 hfs, err := decoder.DecodeFull(data) 108 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 109 for _, p := range hfs { 110 fields[p.Name] = append(fields[p.Name], p.Value) 111 } 112 return fields 113 } 114 115 encodeRequest := func(req *http.Request) []byte { 116 buf := &bytes.Buffer{} 117 str := mockquic.NewMockStream(mockCtrl) 118 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 119 rw := newRequestWriter() 120 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 121 return buf.Bytes() 122 } 123 124 setRequest := func(data []byte) { 125 buf := bytes.NewBuffer(data) 126 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 127 if buf.Len() == 0 { 128 return 0, io.EOF 129 } 130 return buf.Read(p) 131 }).AnyTimes() 132 } 133 134 BeforeEach(func() { 135 var err error 136 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 137 Expect(err).ToNot(HaveOccurred()) 138 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 139 Expect(err).ToNot(HaveOccurred()) 140 141 qpackDecoder = qpack.NewDecoder(nil) 142 str = mockquic.NewMockStream(mockCtrl) 143 str.EXPECT().Context().Return(reqContext).AnyTimes() 144 str.EXPECT().StreamID().AnyTimes() 145 qconn := mockquic.NewMockEarlyConnection(mockCtrl) 146 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 147 qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 148 qconn.EXPECT().LocalAddr().AnyTimes() 149 qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 150 qconn.EXPECT().Context().Return(context.Background()).AnyTimes() 151 conn = newConnection(qconn, false, protocol.PerspectiveServer, nil) 152 }) 153 154 It("calls the HTTP handler function", func() { 155 requestChan := make(chan *http.Request, 1) 156 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 157 requestChan <- r 158 }) 159 160 setRequest(encodeRequest(exampleGetRequest)) 161 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 162 return len(p), nil 163 }).AnyTimes() 164 str.EXPECT().CancelRead(gomock.Any()) 165 str.EXPECT().Close() 166 167 s.handleRequest(conn, str, nil, qpackDecoder) 168 var req *http.Request 169 Eventually(requestChan).Should(Receive(&req)) 170 Expect(req.Host).To(Equal("www.example.com")) 171 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 172 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 173 Expect(req.Context().Value(testConnContextKey("test"))).To(Equal(conn.Connection)) 174 }) 175 176 It("returns 200 with an empty handler", func() { 177 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 178 179 responseBuf := &bytes.Buffer{} 180 setRequest(encodeRequest(exampleGetRequest)) 181 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 182 str.EXPECT().CancelRead(gomock.Any()) 183 str.EXPECT().Close() 184 185 s.handleRequest(conn, str, nil, qpackDecoder) 186 hfs := decodeHeader(responseBuf) 187 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 188 }) 189 190 It("sets Content-Length when the handler doesn't flush to the client", func() { 191 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 192 w.Write([]byte("foobar")) 193 }) 194 195 responseBuf := &bytes.Buffer{} 196 setRequest(encodeRequest(exampleGetRequest)) 197 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 198 str.EXPECT().CancelRead(gomock.Any()) 199 str.EXPECT().Close() 200 201 s.handleRequest(conn, str, nil, qpackDecoder) 202 hfs := decodeHeader(responseBuf) 203 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 204 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) 205 // status, content-length, date, content-type 206 Expect(hfs).To(HaveLen(4)) 207 }) 208 209 It("sets Content-Type when WriteHeader is called but response is not flushed", func() { 210 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 211 w.WriteHeader(http.StatusNotFound) 212 w.Write([]byte("<html></html>")) 213 }) 214 215 responseBuf := &bytes.Buffer{} 216 setRequest(encodeRequest(exampleGetRequest)) 217 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 218 str.EXPECT().CancelRead(gomock.Any()) 219 str.EXPECT().Close() 220 221 s.handleRequest(conn, str, nil, qpackDecoder) 222 hfs := decodeHeader(responseBuf) 223 Expect(hfs).To(HaveKeyWithValue(":status", []string{"404"})) 224 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 225 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 226 }) 227 228 It("not sets Content-Length when the handler flushes to the client", func() { 229 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 230 w.Write([]byte("foobar")) 231 // force flush 232 w.(http.Flusher).Flush() 233 }) 234 235 responseBuf := &bytes.Buffer{} 236 setRequest(encodeRequest(exampleGetRequest)) 237 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 238 str.EXPECT().CancelRead(gomock.Any()) 239 str.EXPECT().Close() 240 241 s.handleRequest(conn, str, nil, qpackDecoder) 242 hfs := decodeHeader(responseBuf) 243 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 244 // status, date, content-type 245 Expect(hfs).To(HaveLen(3)) 246 }) 247 248 It("ignores calls to Write for responses to HEAD requests", func() { 249 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 250 w.Write([]byte("foobar")) 251 }) 252 253 headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) 254 Expect(err).ToNot(HaveOccurred()) 255 responseBuf := &bytes.Buffer{} 256 setRequest(encodeRequest(headRequest)) 257 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 258 str.EXPECT().CancelRead(gomock.Any()) 259 str.EXPECT().Close() 260 261 s.handleRequest(conn, str, nil, qpackDecoder) 262 hfs := decodeHeader(responseBuf) 263 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 264 Expect(responseBuf.Bytes()).To(BeEmpty()) 265 }) 266 267 It("response to HEAD request should also do content sniffing", func() { 268 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 269 w.Write([]byte("<html></html>")) 270 }) 271 272 headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) 273 Expect(err).ToNot(HaveOccurred()) 274 responseBuf := &bytes.Buffer{} 275 setRequest(encodeRequest(headRequest)) 276 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 277 str.EXPECT().CancelRead(gomock.Any()) 278 str.EXPECT().Close() 279 280 s.handleRequest(conn, str, nil, qpackDecoder) 281 hfs := decodeHeader(responseBuf) 282 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 283 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 284 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 285 }) 286 287 It("handles an aborting handler", func() { 288 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 289 panic(http.ErrAbortHandler) 290 }) 291 292 responseBuf := &bytes.Buffer{} 293 setRequest(encodeRequest(exampleGetRequest)) 294 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 295 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) 296 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) 297 298 s.handleRequest(conn, str, nil, qpackDecoder) 299 Expect(responseBuf.Bytes()).To(HaveLen(0)) 300 }) 301 302 It("handles a panicking handler", func() { 303 var logBuf bytes.Buffer 304 s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil)) 305 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 306 panic("foobar") 307 }) 308 309 responseBuf := &bytes.Buffer{} 310 setRequest(encodeRequest(exampleGetRequest)) 311 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 312 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) 313 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) 314 315 s.handleRequest(conn, str, nil, qpackDecoder) 316 Expect(responseBuf.Bytes()).To(HaveLen(0)) 317 Expect(logBuf.String()).To(ContainSubstring("http: panic serving")) 318 Expect(logBuf.String()).To(ContainSubstring("foobar")) 319 }) 320 321 Context("hijacking bidirectional streams", func() { 322 var conn *mockquic.MockEarlyConnection 323 testDone := make(chan struct{}) 324 325 BeforeEach(func() { 326 testDone = make(chan struct{}) 327 conn = mockquic.NewMockEarlyConnection(mockCtrl) 328 controlStr := mockquic.NewMockStream(mockCtrl) 329 controlStr.EXPECT().Write(gomock.Any()) 330 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 331 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 332 conn.EXPECT().LocalAddr().AnyTimes() 333 }) 334 335 AfterEach(func() { testDone <- struct{}{} }) 336 337 It("hijacks a bidirectional stream of unknown frame type", func() { 338 id := quic.ConnectionTracingID(1337) 339 frameTypeChan := make(chan FrameType, 1) 340 s.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 341 defer GinkgoRecover() 342 Expect(e).ToNot(HaveOccurred()) 343 Expect(connTracingID).To(Equal(id)) 344 frameTypeChan <- ft 345 return true, nil 346 } 347 348 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 349 unknownStr := mockquic.NewMockStream(mockCtrl) 350 unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() 351 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 352 unknownStr.EXPECT().StreamID().AnyTimes() 353 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 354 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 355 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 356 <-testDone 357 return nil, errors.New("test done") 358 }) 359 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 360 conn.EXPECT().Context().Return(ctx).AnyTimes() 361 s.handleConn(conn) 362 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 363 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 364 }) 365 366 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 367 frameTypeChan := make(chan FrameType, 1) 368 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 369 Expect(e).ToNot(HaveOccurred()) 370 frameTypeChan <- ft 371 return false, nil 372 } 373 374 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 375 unknownStr := mockquic.NewMockStream(mockCtrl) 376 unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() 377 unknownStr.EXPECT().StreamID().AnyTimes() 378 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 379 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 380 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 381 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 382 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 383 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 384 <-testDone 385 return nil, errors.New("test done") 386 }) 387 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 388 conn.EXPECT().Context().Return(ctx).AnyTimes() 389 s.handleConn(conn) 390 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 391 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 392 }) 393 394 It("cancels writing when hijacker returned error", func() { 395 frameTypeChan := make(chan FrameType, 1) 396 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 397 Expect(e).ToNot(HaveOccurred()) 398 frameTypeChan <- ft 399 return false, errors.New("error in hijacker") 400 } 401 402 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 403 unknownStr := mockquic.NewMockStream(mockCtrl) 404 unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() 405 unknownStr.EXPECT().StreamID().AnyTimes() 406 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 407 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 408 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 409 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 410 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 411 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 412 <-testDone 413 return nil, errors.New("test done") 414 }) 415 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 416 conn.EXPECT().Context().Return(ctx).AnyTimes() 417 s.handleConn(conn) 418 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 419 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 420 }) 421 422 It("handles errors that occur when reading the stream type", func() { 423 const strID = protocol.StreamID(1234 * 4) 424 testErr := errors.New("test error") 425 done := make(chan struct{}) 426 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) { 427 defer close(done) 428 Expect(ft).To(BeZero()) 429 Expect(str.StreamID()).To(Equal(strID)) 430 Expect(err).To(MatchError(testErr)) 431 return true, nil 432 } 433 unknownStr := mockquic.NewMockStream(mockCtrl) 434 unknownStr.EXPECT().Context().Return(context.Background()).AnyTimes() 435 unknownStr.EXPECT().StreamID().Return(strID).AnyTimes() 436 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 437 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 438 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 439 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 440 <-testDone 441 return nil, errors.New("test done") 442 }) 443 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 444 conn.EXPECT().Context().Return(ctx).AnyTimes() 445 s.handleConn(conn) 446 Eventually(done).Should(BeClosed()) 447 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 448 }) 449 }) 450 451 Context("hijacking unidirectional streams", func() { 452 var conn *mockquic.MockEarlyConnection 453 testDone := make(chan struct{}) 454 455 BeforeEach(func() { 456 testDone = make(chan struct{}) 457 conn = mockquic.NewMockEarlyConnection(mockCtrl) 458 controlStr := mockquic.NewMockStream(mockCtrl) 459 controlStr.EXPECT().Write(gomock.Any()) 460 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 461 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 462 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 463 conn.EXPECT().LocalAddr().AnyTimes() 464 }) 465 466 AfterEach(func() { testDone <- struct{}{} }) 467 468 It("hijacks an unidirectional stream of unknown stream type", func() { 469 id := quic.ConnectionTracingID(42) 470 streamTypeChan := make(chan StreamType, 1) 471 s.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 472 Expect(err).ToNot(HaveOccurred()) 473 Expect(connTracingID).To(Equal(id)) 474 streamTypeChan <- st 475 return true 476 } 477 478 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 479 unknownStr := mockquic.NewMockStream(mockCtrl) 480 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 481 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 482 return unknownStr, nil 483 }) 484 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 485 <-testDone 486 return nil, errors.New("test done") 487 }) 488 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 489 conn.EXPECT().Context().Return(ctx).AnyTimes() 490 s.handleConn(conn) 491 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 492 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 493 }) 494 495 It("handles errors that occur when reading the stream type", func() { 496 testErr := errors.New("test error") 497 done := make(chan struct{}) 498 unknownStr := mockquic.NewMockStream(mockCtrl) 499 s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { 500 defer close(done) 501 Expect(st).To(BeZero()) 502 Expect(str).To(Equal(unknownStr)) 503 Expect(err).To(MatchError(testErr)) 504 return true 505 } 506 507 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 508 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 509 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 510 <-testDone 511 return nil, errors.New("test done") 512 }) 513 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 514 conn.EXPECT().Context().Return(ctx).AnyTimes() 515 s.handleConn(conn) 516 Eventually(done).Should(BeClosed()) 517 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 518 }) 519 520 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 521 streamTypeChan := make(chan StreamType, 1) 522 s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 523 Expect(err).ToNot(HaveOccurred()) 524 streamTypeChan <- st 525 return false 526 } 527 528 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 529 unknownStr := mockquic.NewMockStream(mockCtrl) 530 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 531 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 532 533 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 534 return unknownStr, nil 535 }) 536 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 537 <-testDone 538 return nil, errors.New("test done") 539 }) 540 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 541 conn.EXPECT().Context().Return(ctx).AnyTimes() 542 s.handleConn(conn) 543 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 544 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 545 }) 546 }) 547 548 Context("stream- and connection-level errors", func() { 549 var conn *mockquic.MockEarlyConnection 550 testDone := make(chan struct{}) 551 552 BeforeEach(func() { 553 testDone = make(chan struct{}) 554 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 555 conn = mockquic.NewMockEarlyConnection(mockCtrl) 556 controlStr := mockquic.NewMockStream(mockCtrl) 557 controlStr.EXPECT().Write(gomock.Any()) 558 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 559 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 560 <-testDone 561 return nil, errors.New("test done") 562 }) 563 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 564 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 565 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 566 conn.EXPECT().LocalAddr().AnyTimes() 567 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 568 conn.EXPECT().Context().Return(context.Background()).AnyTimes() 569 }) 570 571 AfterEach(func() { testDone <- struct{}{} }) 572 573 It("cancels reading when client sends a body in GET request", func() { 574 var handlerCalled bool 575 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 576 handlerCalled = true 577 }) 578 579 requestData := encodeRequest(exampleGetRequest) 580 b := (&dataFrame{Length: 6}).Append(nil) // add a body 581 b = append(b, []byte("foobar")...) 582 responseBuf := &bytes.Buffer{} 583 setRequest(append(requestData, b...)) 584 done := make(chan struct{}) 585 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 586 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 587 str.EXPECT().Close().Do(func() error { close(done); return nil }) 588 589 s.handleConn(conn) 590 Eventually(done).Should(BeClosed()) 591 hfs := decodeHeader(responseBuf) 592 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 593 Expect(handlerCalled).To(BeTrue()) 594 }) 595 596 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 597 handlerCalled := make(chan struct{}) 598 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 599 defer close(handlerCalled) 600 w.(HTTPStreamer).HTTPStream() 601 str.Write([]byte("foobar")) 602 }) 603 604 requestData := encodeRequest(exampleGetRequest) 605 b := (&dataFrame{Length: 6}).Append(nil) // add a body 606 b = append(b, []byte("foobar")...) 607 setRequest(append(requestData, b...)) 608 var buf bytes.Buffer 609 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 610 611 s.handleConn(conn) 612 Eventually(handlerCalled).Should(BeClosed()) 613 614 // The buffer is expected to contain: 615 // 1. The response header (in a HEADERS frame) 616 // 2. the "foobar" (unframed) 617 fp := frameParser{r: &buf} 618 frame, err := fp.ParseNext() 619 Expect(err).ToNot(HaveOccurred()) 620 Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) 621 df := frame.(*headersFrame) 622 data := make([]byte, df.Length) 623 _, err = io.ReadFull(&buf, data) 624 Expect(err).ToNot(HaveOccurred()) 625 hdrs, err := qpackDecoder.DecodeFull(data) 626 Expect(err).ToNot(HaveOccurred()) 627 Expect(hdrs).To(ContainElement(qpack.HeaderField{Name: ":status", Value: "200"})) 628 Expect(buf.Bytes()).To(Equal([]byte("foobar"))) 629 }) 630 631 It("errors when the client sends a too large header frame", func() { 632 s.MaxHeaderBytes = 20 633 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 634 Fail("Handler should not be called.") 635 }) 636 637 requestData := encodeRequest(exampleGetRequest) 638 b := (&dataFrame{Length: 6}).Append(nil) // add a body 639 b = append(b, []byte("foobar")...) 640 responseBuf := &bytes.Buffer{} 641 setRequest(append(requestData, b...)) 642 done := make(chan struct{}) 643 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 644 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 645 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 646 647 s.handleConn(conn) 648 Eventually(done).Should(BeClosed()) 649 }) 650 651 It("handles a request for which the client immediately resets the stream", func() { 652 handlerCalled := make(chan struct{}) 653 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 654 close(handlerCalled) 655 }) 656 657 testErr := errors.New("stream reset") 658 done := make(chan struct{}) 659 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 660 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 661 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 662 663 s.handleConn(conn) 664 Consistently(handlerCalled).ShouldNot(BeClosed()) 665 }) 666 667 It("closes the connection when the first frame is not a HEADERS frame", func() { 668 handlerCalled := make(chan struct{}) 669 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 670 close(handlerCalled) 671 }) 672 673 b := (&dataFrame{}).Append(nil) 674 setRequest(b) 675 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 676 return len(p), nil 677 }).AnyTimes() 678 679 done := make(chan struct{}) 680 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 681 close(done) 682 return nil 683 }) 684 s.handleConn(conn) 685 Eventually(done).Should(BeClosed()) 686 }) 687 688 It("rejects a request that has too large request headers", func() { 689 handlerCalled := make(chan struct{}) 690 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 691 close(handlerCalled) 692 }) 693 694 // use 2*DefaultMaxHeaderBytes here. qpack will compress the request, 695 // but the request will still end up larger than DefaultMaxHeaderBytes. 696 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 697 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 698 Expect(err).ToNot(HaveOccurred()) 699 setRequest(encodeRequest(req)) 700 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 701 return len(p), nil 702 }).AnyTimes() 703 done := make(chan struct{}) 704 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 705 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 706 707 s.handleConn(conn) 708 Eventually(done).Should(BeClosed()) 709 }) 710 }) 711 712 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 713 handlerCalled := make(chan struct{}) 714 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 715 r.Body = struct { 716 io.Reader 717 io.Closer 718 }{} 719 close(handlerCalled) 720 }) 721 722 setRequest(encodeRequest(examplePostRequest)) 723 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 724 return len(p), nil 725 }).AnyTimes() 726 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 727 str.EXPECT().Close() 728 729 s.handleRequest(conn, str, nil, qpackDecoder) 730 Eventually(handlerCalled).Should(BeClosed()) 731 }) 732 733 It("cancels the request context when the stream is closed", func() { 734 handlerCalled := make(chan struct{}) 735 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 736 defer GinkgoRecover() 737 Expect(r.Context().Done()).To(BeClosed()) 738 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 739 close(handlerCalled) 740 }) 741 setRequest(encodeRequest(examplePostRequest)) 742 743 reqContextCancel() 744 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 745 return len(p), nil 746 }).AnyTimes() 747 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 748 str.EXPECT().Close() 749 750 s.handleRequest(conn, str, nil, qpackDecoder) 751 Eventually(handlerCalled).Should(BeClosed()) 752 }) 753 }) 754 755 Context("setting http headers", func() { 756 BeforeEach(func() { 757 s.QUICConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}} 758 }) 759 760 var ln1 QUICEarlyListener 761 var ln2 QUICEarlyListener 762 expected := http.Header{ 763 "Alt-Svc": {`h3=":443"; ma=2592000`}, 764 } 765 766 addListener := func(addr string, ln *QUICEarlyListener) { 767 mln := newMockAddrListener(addr) 768 mln.EXPECT().Addr() 769 *ln = mln 770 s.addListener(ln) 771 } 772 773 removeListener := func(ln *QUICEarlyListener) { 774 s.removeListener(ln) 775 } 776 777 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 778 hdr := http.Header{} 779 Expect(s.SetQUICHeaders(hdr)).To(Succeed()) 780 Expect(hdr).To(expected) 781 } 782 783 checkSetHeaderError := func() { 784 hdr := http.Header{} 785 Expect(s.SetQUICHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 786 } 787 788 It("sets proper headers with numeric port", func() { 789 addListener(":443", &ln1) 790 checkSetHeaders(Equal(expected)) 791 removeListener(&ln1) 792 checkSetHeaderError() 793 }) 794 795 It("sets proper headers with full addr", func() { 796 addListener("127.0.0.1:443", &ln1) 797 checkSetHeaders(Equal(expected)) 798 removeListener(&ln1) 799 checkSetHeaderError() 800 }) 801 802 It("sets proper headers with string port", func() { 803 addListener(":https", &ln1) 804 checkSetHeaders(Equal(expected)) 805 removeListener(&ln1) 806 checkSetHeaderError() 807 }) 808 809 It("works multiple times", func() { 810 addListener(":https", &ln1) 811 checkSetHeaders(Equal(expected)) 812 checkSetHeaders(Equal(expected)) 813 removeListener(&ln1) 814 checkSetHeaderError() 815 }) 816 817 It("works if the quic.Config sets QUIC versions", func() { 818 s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version2} 819 addListener(":443", &ln1) 820 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 821 removeListener(&ln1) 822 checkSetHeaderError() 823 }) 824 825 It("uses s.Port if set to a non-zero value", func() { 826 s.Port = 8443 827 addListener(":443", &ln1) 828 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 829 removeListener(&ln1) 830 checkSetHeaderError() 831 }) 832 833 It("uses s.Addr if listeners don't have ports available", func() { 834 s.Addr = ":443" 835 var logBuf bytes.Buffer 836 s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil)) 837 mln := &noPortListener{newMockAddrListener("")} 838 mln.EXPECT().Addr() 839 ln1 = mln 840 s.addListener(&ln1) 841 checkSetHeaders(Equal(expected)) 842 s.removeListener(&ln1) 843 checkSetHeaderError() 844 Expect(logBuf.String()).To(ContainSubstring("Unable to extract port from listener, will not be announced using SetQUICHeaders")) 845 }) 846 847 It("properly announces multiple listeners", func() { 848 addListener(":443", &ln1) 849 addListener(":8443", &ln2) 850 checkSetHeaders(Or( 851 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 852 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 853 )) 854 removeListener(&ln1) 855 removeListener(&ln2) 856 checkSetHeaderError() 857 }) 858 859 It("doesn't duplicate Alt-Svc values", func() { 860 s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version1} 861 addListener(":443", &ln1) 862 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 863 removeListener(&ln1) 864 checkSetHeaderError() 865 }) 866 }) 867 868 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 869 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 870 }) 871 872 It("should nop-Close() when s.server is nil", func() { 873 Expect((&Server{}).Close()).To(Succeed()) 874 }) 875 876 It("errors when ListenAndServeTLS is called after Close", func() { 877 serv := &Server{} 878 Expect(serv.Close()).To(Succeed()) 879 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 880 }) 881 882 It("handles concurrent Serve and Close", func() { 883 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 884 Expect(err).ToNot(HaveOccurred()) 885 c, err := net.ListenUDP("udp", addr) 886 Expect(err).ToNot(HaveOccurred()) 887 done := make(chan struct{}) 888 go func() { 889 defer GinkgoRecover() 890 defer close(done) 891 s.Serve(c) 892 }() 893 runtime.Gosched() 894 s.Close() 895 Eventually(done).Should(BeClosed()) 896 }) 897 898 Context("ConfigureTLSConfig", func() { 899 It("advertises v1 by default", func() { 900 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 901 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}}) 902 Expect(err).ToNot(HaveOccurred()) 903 defer ln.Close() 904 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 905 Expect(err).ToNot(HaveOccurred()) 906 defer c.CloseWithError(0, "") 907 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 908 }) 909 910 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 911 var receivedConf *tls.Config 912 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 913 receivedConf = tlsConf 914 return nil, errors.New("listen err") 915 } 916 Expect(s.ListenAndServe()).To(HaveOccurred()) 917 Expect(receivedConf).ToNot(BeNil()) 918 }) 919 920 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 921 tlsConf := &tls.Config{ 922 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 923 c := testdata.GetTLSConfig() 924 c.NextProtos = []string{"foo", "bar"} 925 return c, nil 926 }, 927 } 928 929 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 930 Expect(err).ToNot(HaveOccurred()) 931 defer ln.Close() 932 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 933 Expect(err).ToNot(HaveOccurred()) 934 defer c.CloseWithError(0, "") 935 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 936 }) 937 938 It("works if GetConfigForClient returns a nil tls.Config", func() { 939 tlsConf := testdata.GetTLSConfig() 940 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 941 942 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 943 Expect(err).ToNot(HaveOccurred()) 944 defer ln.Close() 945 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 946 Expect(err).ToNot(HaveOccurred()) 947 defer c.CloseWithError(0, "") 948 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 949 }) 950 951 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 952 tlsClientConf := testdata.GetTLSConfig() 953 tlsClientConf.NextProtos = []string{"foo", "bar"} 954 tlsConf := &tls.Config{ 955 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 956 return tlsClientConf, nil 957 }, 958 } 959 960 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 961 Expect(err).ToNot(HaveOccurred()) 962 defer ln.Close() 963 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 964 Expect(err).ToNot(HaveOccurred()) 965 defer c.CloseWithError(0, "") 966 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 967 // check that the original config was not modified 968 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 969 }) 970 }) 971 972 Context("Serve", func() { 973 origQuicListen := quicListen 974 975 AfterEach(func() { 976 quicListen = origQuicListen 977 }) 978 979 It("serves a packet conn", func() { 980 ln := newMockAddrListener(":443") 981 conn := &net.UDPConn{} 982 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 983 Expect(c).To(Equal(conn)) 984 return ln, nil 985 } 986 987 s := &Server{ 988 TLSConfig: &tls.Config{}, 989 } 990 991 stopAccept := make(chan struct{}) 992 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 993 <-stopAccept 994 return nil, errors.New("closed") 995 }) 996 ln.EXPECT().Addr() // generate alt-svc headers 997 done := make(chan struct{}) 998 go func() { 999 defer GinkgoRecover() 1000 defer close(done) 1001 s.Serve(conn) 1002 }() 1003 1004 Consistently(done).ShouldNot(BeClosed()) 1005 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1006 Expect(s.Close()).To(Succeed()) 1007 Eventually(done).Should(BeClosed()) 1008 }) 1009 1010 It("serves two packet conns", func() { 1011 ln1 := newMockAddrListener(":443") 1012 ln2 := newMockAddrListener(":8443") 1013 lns := make(chan QUICEarlyListener, 2) 1014 lns <- ln1 1015 lns <- ln2 1016 conn1 := &net.UDPConn{} 1017 conn2 := &net.UDPConn{} 1018 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1019 return <-lns, nil 1020 } 1021 1022 s := &Server{ 1023 TLSConfig: &tls.Config{}, 1024 } 1025 1026 stopAccept1 := make(chan struct{}) 1027 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1028 <-stopAccept1 1029 return nil, errors.New("closed") 1030 }) 1031 ln1.EXPECT().Addr() // generate alt-svc headers 1032 stopAccept2 := make(chan struct{}) 1033 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1034 <-stopAccept2 1035 return nil, errors.New("closed") 1036 }) 1037 ln2.EXPECT().Addr() 1038 1039 done1 := make(chan struct{}) 1040 go func() { 1041 defer GinkgoRecover() 1042 defer close(done1) 1043 s.Serve(conn1) 1044 }() 1045 done2 := make(chan struct{}) 1046 go func() { 1047 defer GinkgoRecover() 1048 defer close(done2) 1049 s.Serve(conn2) 1050 }() 1051 1052 Consistently(done1).ShouldNot(BeClosed()) 1053 Expect(done2).ToNot(BeClosed()) 1054 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1055 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1056 Expect(s.Close()).To(Succeed()) 1057 Eventually(done1).Should(BeClosed()) 1058 Eventually(done2).Should(BeClosed()) 1059 }) 1060 }) 1061 1062 Context("ServeListener", func() { 1063 origQuicListen := quicListen 1064 1065 AfterEach(func() { 1066 quicListen = origQuicListen 1067 }) 1068 1069 It("serves a listener", func() { 1070 var called int32 1071 ln := newMockAddrListener(":443") 1072 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1073 atomic.StoreInt32(&called, 1) 1074 return ln, nil 1075 } 1076 1077 s := &Server{} 1078 1079 stopAccept := make(chan struct{}) 1080 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1081 <-stopAccept 1082 return nil, errors.New("closed") 1083 }) 1084 ln.EXPECT().Addr() // generate alt-svc headers 1085 done := make(chan struct{}) 1086 go func() { 1087 defer GinkgoRecover() 1088 defer close(done) 1089 s.ServeListener(ln) 1090 }() 1091 1092 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1093 Consistently(done).ShouldNot(BeClosed()) 1094 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1095 Expect(s.Close()).To(Succeed()) 1096 Eventually(done).Should(BeClosed()) 1097 }) 1098 1099 It("serves two listeners", func() { 1100 var called int32 1101 ln1 := newMockAddrListener(":443") 1102 ln2 := newMockAddrListener(":8443") 1103 lns := make(chan QUICEarlyListener, 2) 1104 lns <- ln1 1105 lns <- ln2 1106 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1107 atomic.StoreInt32(&called, 1) 1108 return <-lns, nil 1109 } 1110 1111 s := &Server{} 1112 1113 stopAccept1 := make(chan struct{}) 1114 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1115 <-stopAccept1 1116 return nil, errors.New("closed") 1117 }) 1118 ln1.EXPECT().Addr() // generate alt-svc headers 1119 stopAccept2 := make(chan struct{}) 1120 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1121 <-stopAccept2 1122 return nil, errors.New("closed") 1123 }) 1124 ln2.EXPECT().Addr() 1125 1126 done1 := make(chan struct{}) 1127 go func() { 1128 defer GinkgoRecover() 1129 defer close(done1) 1130 s.ServeListener(ln1) 1131 }() 1132 done2 := make(chan struct{}) 1133 go func() { 1134 defer GinkgoRecover() 1135 defer close(done2) 1136 s.ServeListener(ln2) 1137 }() 1138 1139 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1140 Consistently(done1).ShouldNot(BeClosed()) 1141 Expect(done2).ToNot(BeClosed()) 1142 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1143 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1144 Expect(s.Close()).To(Succeed()) 1145 Eventually(done1).Should(BeClosed()) 1146 Eventually(done2).Should(BeClosed()) 1147 }) 1148 }) 1149 1150 Context("ServeQUICConn", func() { 1151 It("serves a QUIC connection", func() { 1152 mux := http.NewServeMux() 1153 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1154 w.Write([]byte("foobar")) 1155 }) 1156 s.Handler = mux 1157 tlsConf := testdata.GetTLSConfig() 1158 tlsConf.NextProtos = []string{NextProtoH3} 1159 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1160 controlStr := mockquic.NewMockStream(mockCtrl) 1161 controlStr.EXPECT().Write(gomock.Any()) 1162 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1163 testDone := make(chan struct{}) 1164 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1165 <-testDone 1166 return nil, errors.New("test done") 1167 }).MaxTimes(1) 1168 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1169 s.ServeQUICConn(conn) 1170 close(testDone) 1171 }) 1172 }) 1173 1174 Context("ListenAndServe", func() { 1175 BeforeEach(func() { 1176 s.Addr = "localhost:0" 1177 }) 1178 1179 AfterEach(func() { 1180 Expect(s.Close()).To(Succeed()) 1181 }) 1182 1183 It("uses the quic.Config to start the QUIC server", func() { 1184 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1185 var receivedConf *quic.Config 1186 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1187 receivedConf = config 1188 return nil, errors.New("listen err") 1189 } 1190 s.QUICConfig = conf 1191 Expect(s.ListenAndServe()).To(HaveOccurred()) 1192 Expect(receivedConf).To(Equal(conf)) 1193 }) 1194 }) 1195 1196 It("closes gracefully", func() { 1197 Expect(s.CloseGracefully(0)).To(Succeed()) 1198 }) 1199 1200 It("errors when listening fails", func() { 1201 testErr := errors.New("listen error") 1202 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1203 return nil, testErr 1204 } 1205 fullpem, privkey := testdata.GetCertificatePaths() 1206 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1207 }) 1208 1209 It("supports H3_DATAGRAM", func() { 1210 s.EnableDatagrams = true 1211 var receivedConf *quic.Config 1212 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1213 receivedConf = config 1214 return nil, errors.New("listen err") 1215 } 1216 Expect(s.ListenAndServe()).To(HaveOccurred()) 1217 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1218 }) 1219 })