github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "io" 9 "log/slog" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/apernet/quic-go" 17 mockquic "github.com/apernet/quic-go/internal/mocks/quic" 18 "github.com/apernet/quic-go/internal/protocol" 19 "github.com/apernet/quic-go/internal/testdata" 20 "github.com/apernet/quic-go/quicvarint" 21 22 "github.com/quic-go/qpack" 23 "go.uber.org/mock/gomock" 24 25 . "github.com/onsi/ginkgo/v2" 26 . "github.com/onsi/gomega" 27 gmtypes "github.com/onsi/gomega/types" 28 ) 29 30 type mockAddr struct{ addr string } 31 32 func (ma *mockAddr) Network() string { return "udp" } 33 func (ma *mockAddr) String() string { return ma.addr } 34 35 type mockAddrListener struct { 36 *MockQUICEarlyListener 37 addr *mockAddr 38 } 39 40 func (m *mockAddrListener) Addr() net.Addr { 41 _ = m.MockQUICEarlyListener.Addr() 42 return m.addr 43 } 44 45 func newMockAddrListener(addr string) *mockAddrListener { 46 return &mockAddrListener{ 47 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 48 addr: &mockAddr{addr: addr}, 49 } 50 } 51 52 type noPortListener struct { 53 *mockAddrListener 54 } 55 56 func (m *noPortListener) Addr() net.Addr { 57 _ = m.mockAddrListener.Addr() 58 return &net.UnixAddr{ 59 Net: "unix", 60 Name: "/tmp/quic.sock", 61 } 62 } 63 64 var _ = Describe("Server", func() { 65 var ( 66 s *Server 67 origQuicListenAddr = quicListenAddr 68 ) 69 type testConnContextKey string 70 71 BeforeEach(func() { 72 s = &Server{ 73 TLSConfig: testdata.GetTLSConfig(), 74 ConnContext: func(ctx context.Context, c quic.Connection) context.Context { 75 return context.WithValue(ctx, testConnContextKey("test"), c) 76 }, 77 } 78 origQuicListenAddr = quicListenAddr 79 }) 80 81 AfterEach(func() { 82 quicListenAddr = origQuicListenAddr 83 }) 84 85 Context("handling requests", func() { 86 var ( 87 qpackDecoder *qpack.Decoder 88 str *mockquic.MockStream 89 conn *connection 90 exampleGetRequest *http.Request 91 examplePostRequest *http.Request 92 ) 93 reqContext := context.Background() 94 95 decodeHeader := func(str io.Reader) map[string][]string { 96 fields := make(map[string][]string) 97 decoder := qpack.NewDecoder(nil) 98 99 frame, err := parseNextFrame(str, nil) 100 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 101 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 102 headersFrame := frame.(*headersFrame) 103 data := make([]byte, headersFrame.Length) 104 _, err = io.ReadFull(str, data) 105 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 106 hfs, err := decoder.DecodeFull(data) 107 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 108 for _, p := range hfs { 109 fields[p.Name] = append(fields[p.Name], p.Value) 110 } 111 return fields 112 } 113 114 encodeRequest := func(req *http.Request) []byte { 115 buf := &bytes.Buffer{} 116 str := mockquic.NewMockStream(mockCtrl) 117 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 118 rw := newRequestWriter() 119 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 120 return buf.Bytes() 121 } 122 123 setRequest := func(data []byte) { 124 buf := bytes.NewBuffer(data) 125 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 126 if buf.Len() == 0 { 127 return 0, io.EOF 128 } 129 return buf.Read(p) 130 }).AnyTimes() 131 } 132 133 BeforeEach(func() { 134 var err error 135 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 136 Expect(err).ToNot(HaveOccurred()) 137 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 138 Expect(err).ToNot(HaveOccurred()) 139 140 qpackDecoder = qpack.NewDecoder(nil) 141 str = mockquic.NewMockStream(mockCtrl) 142 str.EXPECT().StreamID().AnyTimes() 143 qconn := mockquic.NewMockEarlyConnection(mockCtrl) 144 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 145 qconn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 146 qconn.EXPECT().LocalAddr().AnyTimes() 147 qconn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 148 qconn.EXPECT().Context().Return(context.Background()).AnyTimes() 149 conn = newConnection(qconn, false, protocol.PerspectiveServer, nil) 150 }) 151 152 It("calls the HTTP handler function", func() { 153 requestChan := make(chan *http.Request, 1) 154 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 155 requestChan <- r 156 }) 157 158 setRequest(encodeRequest(exampleGetRequest)) 159 str.EXPECT().Context().Return(reqContext) 160 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 161 return len(p), nil 162 }).AnyTimes() 163 str.EXPECT().CancelRead(gomock.Any()) 164 str.EXPECT().Close() 165 166 s.handleRequest(conn, str, nil, qpackDecoder) 167 var req *http.Request 168 Eventually(requestChan).Should(Receive(&req)) 169 Expect(req.Host).To(Equal("www.example.com")) 170 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 171 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 172 Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil)) 173 }) 174 175 It("returns 200 with an empty handler", func() { 176 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 177 178 responseBuf := &bytes.Buffer{} 179 setRequest(encodeRequest(exampleGetRequest)) 180 str.EXPECT().Context().Return(reqContext) 181 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 182 str.EXPECT().CancelRead(gomock.Any()) 183 str.EXPECT().Close() 184 185 s.handleRequest(conn, str, nil, qpackDecoder) 186 hfs := decodeHeader(responseBuf) 187 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 188 }) 189 190 It("sets Content-Length when the handler doesn't flush to the client", func() { 191 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 192 w.Write([]byte("foobar")) 193 }) 194 195 responseBuf := &bytes.Buffer{} 196 setRequest(encodeRequest(exampleGetRequest)) 197 str.EXPECT().Context().Return(reqContext) 198 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 199 str.EXPECT().CancelRead(gomock.Any()) 200 str.EXPECT().Close() 201 202 s.handleRequest(conn, str, nil, qpackDecoder) 203 hfs := decodeHeader(responseBuf) 204 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 205 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) 206 // status, content-length, date, content-type 207 Expect(hfs).To(HaveLen(4)) 208 }) 209 210 It("sets Content-Type when WriteHeader is called but response is not flushed", func() { 211 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 212 w.WriteHeader(http.StatusNotFound) 213 w.Write([]byte("<html></html>")) 214 }) 215 216 responseBuf := &bytes.Buffer{} 217 setRequest(encodeRequest(exampleGetRequest)) 218 str.EXPECT().Context().Return(reqContext) 219 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 220 str.EXPECT().CancelRead(gomock.Any()) 221 str.EXPECT().Close() 222 223 s.handleRequest(conn, str, nil, qpackDecoder) 224 hfs := decodeHeader(responseBuf) 225 Expect(hfs).To(HaveKeyWithValue(":status", []string{"404"})) 226 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 227 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 228 }) 229 230 It("not sets Content-Length when the handler flushes to the client", func() { 231 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 232 w.Write([]byte("foobar")) 233 // force flush 234 w.(http.Flusher).Flush() 235 }) 236 237 responseBuf := &bytes.Buffer{} 238 setRequest(encodeRequest(exampleGetRequest)) 239 str.EXPECT().Context().Return(reqContext) 240 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 241 str.EXPECT().CancelRead(gomock.Any()) 242 str.EXPECT().Close() 243 244 s.handleRequest(conn, str, nil, qpackDecoder) 245 hfs := decodeHeader(responseBuf) 246 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 247 // status, date, content-type 248 Expect(hfs).To(HaveLen(3)) 249 }) 250 251 It("ignores calls to Write for responses to HEAD requests", func() { 252 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 253 w.Write([]byte("foobar")) 254 }) 255 256 headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) 257 Expect(err).ToNot(HaveOccurred()) 258 responseBuf := &bytes.Buffer{} 259 setRequest(encodeRequest(headRequest)) 260 str.EXPECT().Context().Return(reqContext) 261 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 262 str.EXPECT().CancelRead(gomock.Any()) 263 str.EXPECT().Close() 264 265 s.handleRequest(conn, str, nil, qpackDecoder) 266 hfs := decodeHeader(responseBuf) 267 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 268 Expect(responseBuf.Bytes()).To(BeEmpty()) 269 }) 270 271 It("response to HEAD request should also do content sniffing", func() { 272 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 273 w.Write([]byte("<html></html>")) 274 }) 275 276 headRequest, err := http.NewRequest(http.MethodHead, "https://www.example.com", nil) 277 Expect(err).ToNot(HaveOccurred()) 278 responseBuf := &bytes.Buffer{} 279 setRequest(encodeRequest(headRequest)) 280 str.EXPECT().Context().Return(reqContext) 281 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 282 str.EXPECT().CancelRead(gomock.Any()) 283 str.EXPECT().Close() 284 285 s.handleRequest(conn, str, nil, qpackDecoder) 286 hfs := decodeHeader(responseBuf) 287 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 288 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 289 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 290 }) 291 292 It("handles an aborting handler", func() { 293 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 294 panic(http.ErrAbortHandler) 295 }) 296 297 responseBuf := &bytes.Buffer{} 298 setRequest(encodeRequest(exampleGetRequest)) 299 str.EXPECT().Context().Return(reqContext) 300 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 301 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) 302 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) 303 304 s.handleRequest(conn, str, nil, qpackDecoder) 305 Expect(responseBuf.Bytes()).To(HaveLen(0)) 306 }) 307 308 It("handles a panicking handler", func() { 309 var logBuf bytes.Buffer 310 s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil)) 311 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 312 panic("foobar") 313 }) 314 315 responseBuf := &bytes.Buffer{} 316 setRequest(encodeRequest(exampleGetRequest)) 317 str.EXPECT().Context().Return(reqContext) 318 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 319 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeInternalError)) 320 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeInternalError)) 321 322 s.handleRequest(conn, str, nil, qpackDecoder) 323 Expect(responseBuf.Bytes()).To(HaveLen(0)) 324 Expect(logBuf.String()).To(ContainSubstring("http: panic serving")) 325 Expect(logBuf.String()).To(ContainSubstring("foobar")) 326 }) 327 328 Context("hijacking bidirectional streams", func() { 329 var conn *mockquic.MockEarlyConnection 330 testDone := make(chan struct{}) 331 332 BeforeEach(func() { 333 testDone = make(chan struct{}) 334 conn = mockquic.NewMockEarlyConnection(mockCtrl) 335 controlStr := mockquic.NewMockStream(mockCtrl) 336 controlStr.EXPECT().Write(gomock.Any()) 337 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 338 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 339 conn.EXPECT().LocalAddr().AnyTimes() 340 }) 341 342 AfterEach(func() { testDone <- struct{}{} }) 343 344 It("hijacks a bidirectional stream of unknown frame type", func() { 345 id := quic.ConnectionTracingID(1337) 346 frameTypeChan := make(chan FrameType, 1) 347 s.StreamHijacker = func(ft FrameType, connTracingID quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 348 defer GinkgoRecover() 349 Expect(e).ToNot(HaveOccurred()) 350 Expect(connTracingID).To(Equal(id)) 351 frameTypeChan <- ft 352 return true, nil 353 } 354 355 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 356 unknownStr := mockquic.NewMockStream(mockCtrl) 357 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 358 unknownStr.EXPECT().StreamID().AnyTimes() 359 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 360 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 361 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 362 <-testDone 363 return nil, errors.New("test done") 364 }) 365 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 366 conn.EXPECT().Context().Return(ctx).AnyTimes() 367 s.handleConn(conn) 368 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 369 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 370 }) 371 372 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 373 frameTypeChan := make(chan FrameType, 1) 374 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 375 Expect(e).ToNot(HaveOccurred()) 376 frameTypeChan <- ft 377 return false, nil 378 } 379 380 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 381 unknownStr := mockquic.NewMockStream(mockCtrl) 382 unknownStr.EXPECT().StreamID().AnyTimes() 383 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 384 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 385 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 386 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 387 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 388 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 389 <-testDone 390 return nil, errors.New("test done") 391 }) 392 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 393 conn.EXPECT().Context().Return(ctx).AnyTimes() 394 s.handleConn(conn) 395 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 396 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 397 }) 398 399 It("cancels writing when hijacker returned error", func() { 400 frameTypeChan := make(chan FrameType, 1) 401 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, _ quic.Stream, e error) (hijacked bool, err error) { 402 Expect(e).ToNot(HaveOccurred()) 403 frameTypeChan <- ft 404 return false, errors.New("error in hijacker") 405 } 406 407 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 408 unknownStr := mockquic.NewMockStream(mockCtrl) 409 unknownStr.EXPECT().StreamID().AnyTimes() 410 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 411 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 412 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 413 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 414 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 415 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 416 <-testDone 417 return nil, errors.New("test done") 418 }) 419 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 420 conn.EXPECT().Context().Return(ctx).AnyTimes() 421 s.handleConn(conn) 422 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 423 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 424 }) 425 426 It("handles errors that occur when reading the stream type", func() { 427 const strID = protocol.StreamID(1234 * 4) 428 testErr := errors.New("test error") 429 done := make(chan struct{}) 430 unknownStr := mockquic.NewMockStream(mockCtrl) 431 s.StreamHijacker = func(ft FrameType, _ quic.ConnectionTracingID, str quic.Stream, err error) (bool, error) { 432 defer close(done) 433 Expect(ft).To(BeZero()) 434 Expect(str.StreamID()).To(Equal(strID)) 435 Expect(err).To(MatchError(testErr)) 436 return true, nil 437 } 438 unknownStr.EXPECT().StreamID().Return(strID).AnyTimes() 439 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 440 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 441 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 442 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 443 <-testDone 444 return nil, errors.New("test done") 445 }) 446 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 447 conn.EXPECT().Context().Return(ctx).AnyTimes() 448 s.handleConn(conn) 449 Eventually(done).Should(BeClosed()) 450 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 451 }) 452 }) 453 454 Context("hijacking unidirectional streams", func() { 455 var conn *mockquic.MockEarlyConnection 456 testDone := make(chan struct{}) 457 458 BeforeEach(func() { 459 testDone = make(chan struct{}) 460 conn = mockquic.NewMockEarlyConnection(mockCtrl) 461 controlStr := mockquic.NewMockStream(mockCtrl) 462 controlStr.EXPECT().Write(gomock.Any()) 463 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 464 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 465 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 466 conn.EXPECT().LocalAddr().AnyTimes() 467 }) 468 469 AfterEach(func() { testDone <- struct{}{} }) 470 471 It("hijacks an unidirectional stream of unknown stream type", func() { 472 id := quic.ConnectionTracingID(42) 473 streamTypeChan := make(chan StreamType, 1) 474 s.UniStreamHijacker = func(st StreamType, connTracingID quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 475 Expect(err).ToNot(HaveOccurred()) 476 Expect(connTracingID).To(Equal(id)) 477 streamTypeChan <- st 478 return true 479 } 480 481 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 482 unknownStr := mockquic.NewMockStream(mockCtrl) 483 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 484 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 485 return unknownStr, nil 486 }) 487 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 488 <-testDone 489 return nil, errors.New("test done") 490 }) 491 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, id) 492 conn.EXPECT().Context().Return(ctx).AnyTimes() 493 s.handleConn(conn) 494 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 495 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 496 }) 497 498 It("handles errors that occur when reading the stream type", func() { 499 testErr := errors.New("test error") 500 done := make(chan struct{}) 501 unknownStr := mockquic.NewMockStream(mockCtrl) 502 s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, str quic.ReceiveStream, err error) bool { 503 defer close(done) 504 Expect(st).To(BeZero()) 505 Expect(str).To(Equal(unknownStr)) 506 Expect(err).To(MatchError(testErr)) 507 return true 508 } 509 510 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 511 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 512 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 513 <-testDone 514 return nil, errors.New("test done") 515 }) 516 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 517 conn.EXPECT().Context().Return(ctx).AnyTimes() 518 s.handleConn(conn) 519 Eventually(done).Should(BeClosed()) 520 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 521 }) 522 523 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 524 streamTypeChan := make(chan StreamType, 1) 525 s.UniStreamHijacker = func(st StreamType, _ quic.ConnectionTracingID, _ quic.ReceiveStream, err error) bool { 526 Expect(err).ToNot(HaveOccurred()) 527 streamTypeChan <- st 528 return false 529 } 530 531 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 532 unknownStr := mockquic.NewMockStream(mockCtrl) 533 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 534 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 535 536 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 537 return unknownStr, nil 538 }) 539 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 540 <-testDone 541 return nil, errors.New("test done") 542 }) 543 ctx := context.WithValue(context.Background(), quic.ConnectionTracingKey, quic.ConnectionTracingID(1234)) 544 conn.EXPECT().Context().Return(ctx).AnyTimes() 545 s.handleConn(conn) 546 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 547 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 548 }) 549 }) 550 551 Context("stream- and connection-level errors", func() { 552 var conn *mockquic.MockEarlyConnection 553 testDone := make(chan struct{}) 554 555 BeforeEach(func() { 556 testDone = make(chan struct{}) 557 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 558 conn = mockquic.NewMockEarlyConnection(mockCtrl) 559 controlStr := mockquic.NewMockStream(mockCtrl) 560 controlStr.EXPECT().Write(gomock.Any()) 561 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 562 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 563 <-testDone 564 return nil, errors.New("test done") 565 }) 566 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 567 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 568 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 569 conn.EXPECT().LocalAddr().AnyTimes() 570 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 571 conn.EXPECT().Context().Return(context.Background()).AnyTimes() 572 }) 573 574 AfterEach(func() { testDone <- struct{}{} }) 575 576 It("cancels reading when client sends a body in GET request", func() { 577 var handlerCalled bool 578 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 579 handlerCalled = true 580 }) 581 582 requestData := encodeRequest(exampleGetRequest) 583 b := (&dataFrame{Length: 6}).Append(nil) // add a body 584 b = append(b, []byte("foobar")...) 585 responseBuf := &bytes.Buffer{} 586 setRequest(append(requestData, b...)) 587 done := make(chan struct{}) 588 str.EXPECT().Context().Return(reqContext) 589 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 590 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 591 str.EXPECT().Close().Do(func() error { close(done); return nil }) 592 593 s.handleConn(conn) 594 Eventually(done).Should(BeClosed()) 595 hfs := decodeHeader(responseBuf) 596 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 597 Expect(handlerCalled).To(BeTrue()) 598 }) 599 600 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 601 handlerCalled := make(chan struct{}) 602 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 603 defer close(handlerCalled) 604 w.(HTTPStreamer).HTTPStream() 605 str.Write([]byte("foobar")) 606 }) 607 608 requestData := encodeRequest(exampleGetRequest) 609 b := (&dataFrame{Length: 6}).Append(nil) // add a body 610 b = append(b, []byte("foobar")...) 611 setRequest(append(requestData, b...)) 612 str.EXPECT().Context().Return(reqContext) 613 var buf bytes.Buffer 614 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 615 616 s.handleConn(conn) 617 Eventually(handlerCalled).Should(BeClosed()) 618 619 // The buffer is expected to contain: 620 // 1. The response header (in a HEADERS frame) 621 // 2. the "foobar" (unframed) 622 frame, err := parseNextFrame(&buf, nil) 623 Expect(err).ToNot(HaveOccurred()) 624 Expect(frame).To(BeAssignableToTypeOf(&headersFrame{})) 625 df := frame.(*headersFrame) 626 data := make([]byte, df.Length) 627 _, err = io.ReadFull(&buf, data) 628 Expect(err).ToNot(HaveOccurred()) 629 hdrs, err := qpackDecoder.DecodeFull(data) 630 Expect(err).ToNot(HaveOccurred()) 631 Expect(hdrs).To(ContainElement(qpack.HeaderField{Name: ":status", Value: "200"})) 632 Expect(buf.Bytes()).To(Equal([]byte("foobar"))) 633 }) 634 635 It("errors when the client sends a too large header frame", func() { 636 s.MaxHeaderBytes = 20 637 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 638 Fail("Handler should not be called.") 639 }) 640 641 requestData := encodeRequest(exampleGetRequest) 642 b := (&dataFrame{Length: 6}).Append(nil) // add a body 643 b = append(b, []byte("foobar")...) 644 responseBuf := &bytes.Buffer{} 645 setRequest(append(requestData, b...)) 646 done := make(chan struct{}) 647 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 648 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 649 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 650 651 s.handleConn(conn) 652 Eventually(done).Should(BeClosed()) 653 }) 654 655 It("handles a request for which the client immediately resets the stream", func() { 656 handlerCalled := make(chan struct{}) 657 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 658 close(handlerCalled) 659 }) 660 661 testErr := errors.New("stream reset") 662 done := make(chan struct{}) 663 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 664 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 665 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 666 667 s.handleConn(conn) 668 Consistently(handlerCalled).ShouldNot(BeClosed()) 669 }) 670 671 It("closes the connection when the first frame is not a HEADERS frame", func() { 672 handlerCalled := make(chan struct{}) 673 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 674 close(handlerCalled) 675 }) 676 677 b := (&dataFrame{}).Append(nil) 678 setRequest(b) 679 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 680 return len(p), nil 681 }).AnyTimes() 682 683 done := make(chan struct{}) 684 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 685 close(done) 686 return nil 687 }) 688 s.handleConn(conn) 689 Eventually(done).Should(BeClosed()) 690 }) 691 692 It("rejects a request that has too large request headers", func() { 693 handlerCalled := make(chan struct{}) 694 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 695 close(handlerCalled) 696 }) 697 698 // use 2*DefaultMaxHeaderBytes here. qpack will compress the request, 699 // but the request will still end up larger than DefaultMaxHeaderBytes. 700 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 701 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 702 Expect(err).ToNot(HaveOccurred()) 703 setRequest(encodeRequest(req)) 704 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 705 return len(p), nil 706 }).AnyTimes() 707 done := make(chan struct{}) 708 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeFrameError)) 709 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 710 711 s.handleConn(conn) 712 Eventually(done).Should(BeClosed()) 713 }) 714 }) 715 716 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 717 handlerCalled := make(chan struct{}) 718 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 719 r.Body = struct { 720 io.Reader 721 io.Closer 722 }{} 723 close(handlerCalled) 724 }) 725 726 setRequest(encodeRequest(examplePostRequest)) 727 str.EXPECT().Context().Return(reqContext) 728 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 729 return len(p), nil 730 }).AnyTimes() 731 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 732 str.EXPECT().Close() 733 734 s.handleRequest(conn, str, nil, qpackDecoder) 735 Eventually(handlerCalled).Should(BeClosed()) 736 }) 737 738 It("cancels the request context when the stream is closed", func() { 739 handlerCalled := make(chan struct{}) 740 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 741 defer GinkgoRecover() 742 Expect(r.Context().Done()).To(BeClosed()) 743 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 744 close(handlerCalled) 745 }) 746 setRequest(encodeRequest(examplePostRequest)) 747 748 reqContext, cancel := context.WithCancel(context.Background()) 749 cancel() 750 str.EXPECT().Context().Return(reqContext) 751 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 752 return len(p), nil 753 }).AnyTimes() 754 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 755 str.EXPECT().Close() 756 757 s.handleRequest(conn, str, nil, qpackDecoder) 758 Eventually(handlerCalled).Should(BeClosed()) 759 }) 760 }) 761 762 Context("setting http headers", func() { 763 BeforeEach(func() { 764 s.QUICConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}} 765 }) 766 767 var ln1 QUICEarlyListener 768 var ln2 QUICEarlyListener 769 expected := http.Header{ 770 "Alt-Svc": {`h3=":443"; ma=2592000`}, 771 } 772 773 addListener := func(addr string, ln *QUICEarlyListener) { 774 mln := newMockAddrListener(addr) 775 mln.EXPECT().Addr() 776 *ln = mln 777 s.addListener(ln) 778 } 779 780 removeListener := func(ln *QUICEarlyListener) { 781 s.removeListener(ln) 782 } 783 784 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 785 hdr := http.Header{} 786 Expect(s.SetQUICHeaders(hdr)).To(Succeed()) 787 Expect(hdr).To(expected) 788 } 789 790 checkSetHeaderError := func() { 791 hdr := http.Header{} 792 Expect(s.SetQUICHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 793 } 794 795 It("sets proper headers with numeric port", func() { 796 addListener(":443", &ln1) 797 checkSetHeaders(Equal(expected)) 798 removeListener(&ln1) 799 checkSetHeaderError() 800 }) 801 802 It("sets proper headers with full addr", func() { 803 addListener("127.0.0.1:443", &ln1) 804 checkSetHeaders(Equal(expected)) 805 removeListener(&ln1) 806 checkSetHeaderError() 807 }) 808 809 It("sets proper headers with string port", func() { 810 addListener(":https", &ln1) 811 checkSetHeaders(Equal(expected)) 812 removeListener(&ln1) 813 checkSetHeaderError() 814 }) 815 816 It("works multiple times", func() { 817 addListener(":https", &ln1) 818 checkSetHeaders(Equal(expected)) 819 checkSetHeaders(Equal(expected)) 820 removeListener(&ln1) 821 checkSetHeaderError() 822 }) 823 824 It("works if the quic.Config sets QUIC versions", func() { 825 s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version2} 826 addListener(":443", &ln1) 827 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 828 removeListener(&ln1) 829 checkSetHeaderError() 830 }) 831 832 It("uses s.Port if set to a non-zero value", func() { 833 s.Port = 8443 834 addListener(":443", &ln1) 835 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 836 removeListener(&ln1) 837 checkSetHeaderError() 838 }) 839 840 It("uses s.Addr if listeners don't have ports available", func() { 841 s.Addr = ":443" 842 var logBuf bytes.Buffer 843 s.Logger = slog.New(slog.NewTextHandler(&logBuf, nil)) 844 mln := &noPortListener{newMockAddrListener("")} 845 mln.EXPECT().Addr() 846 ln1 = mln 847 s.addListener(&ln1) 848 checkSetHeaders(Equal(expected)) 849 s.removeListener(&ln1) 850 checkSetHeaderError() 851 Expect(logBuf.String()).To(ContainSubstring("Unable to extract port from listener, will not be announced using SetQUICHeaders")) 852 }) 853 854 It("properly announces multiple listeners", func() { 855 addListener(":443", &ln1) 856 addListener(":8443", &ln2) 857 checkSetHeaders(Or( 858 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 859 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 860 )) 861 removeListener(&ln1) 862 removeListener(&ln2) 863 checkSetHeaderError() 864 }) 865 866 It("doesn't duplicate Alt-Svc values", func() { 867 s.QUICConfig.Versions = []quic.Version{quic.Version1, quic.Version1} 868 addListener(":443", &ln1) 869 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 870 removeListener(&ln1) 871 checkSetHeaderError() 872 }) 873 }) 874 875 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 876 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 877 }) 878 879 It("should nop-Close() when s.server is nil", func() { 880 Expect((&Server{}).Close()).To(Succeed()) 881 }) 882 883 It("errors when ListenAndServeTLS is called after Close", func() { 884 serv := &Server{} 885 Expect(serv.Close()).To(Succeed()) 886 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 887 }) 888 889 It("handles concurrent Serve and Close", func() { 890 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 891 Expect(err).ToNot(HaveOccurred()) 892 c, err := net.ListenUDP("udp", addr) 893 Expect(err).ToNot(HaveOccurred()) 894 done := make(chan struct{}) 895 go func() { 896 defer GinkgoRecover() 897 defer close(done) 898 s.Serve(c) 899 }() 900 runtime.Gosched() 901 s.Close() 902 Eventually(done).Should(BeClosed()) 903 }) 904 905 Context("ConfigureTLSConfig", func() { 906 It("advertises v1 by default", func() { 907 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 908 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}}) 909 Expect(err).ToNot(HaveOccurred()) 910 defer ln.Close() 911 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 912 Expect(err).ToNot(HaveOccurred()) 913 defer c.CloseWithError(0, "") 914 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 915 }) 916 917 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 918 var receivedConf *tls.Config 919 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 920 receivedConf = tlsConf 921 return nil, errors.New("listen err") 922 } 923 Expect(s.ListenAndServe()).To(HaveOccurred()) 924 Expect(receivedConf).ToNot(BeNil()) 925 }) 926 927 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 928 tlsConf := &tls.Config{ 929 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 930 c := testdata.GetTLSConfig() 931 c.NextProtos = []string{"foo", "bar"} 932 return c, nil 933 }, 934 } 935 936 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 937 Expect(err).ToNot(HaveOccurred()) 938 defer ln.Close() 939 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 940 Expect(err).ToNot(HaveOccurred()) 941 defer c.CloseWithError(0, "") 942 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 943 }) 944 945 It("works if GetConfigForClient returns a nil tls.Config", func() { 946 tlsConf := testdata.GetTLSConfig() 947 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 948 949 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 950 Expect(err).ToNot(HaveOccurred()) 951 defer ln.Close() 952 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 953 Expect(err).ToNot(HaveOccurred()) 954 defer c.CloseWithError(0, "") 955 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 956 }) 957 958 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 959 tlsClientConf := testdata.GetTLSConfig() 960 tlsClientConf.NextProtos = []string{"foo", "bar"} 961 tlsConf := &tls.Config{ 962 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 963 return tlsClientConf, nil 964 }, 965 } 966 967 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 968 Expect(err).ToNot(HaveOccurred()) 969 defer ln.Close() 970 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 971 Expect(err).ToNot(HaveOccurred()) 972 defer c.CloseWithError(0, "") 973 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 974 // check that the original config was not modified 975 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 976 }) 977 }) 978 979 Context("Serve", func() { 980 origQuicListen := quicListen 981 982 AfterEach(func() { 983 quicListen = origQuicListen 984 }) 985 986 It("serves a packet conn", func() { 987 ln := newMockAddrListener(":443") 988 conn := &net.UDPConn{} 989 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 990 Expect(c).To(Equal(conn)) 991 return ln, nil 992 } 993 994 s := &Server{ 995 TLSConfig: &tls.Config{}, 996 } 997 998 stopAccept := make(chan struct{}) 999 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1000 <-stopAccept 1001 return nil, errors.New("closed") 1002 }) 1003 ln.EXPECT().Addr() // generate alt-svc headers 1004 done := make(chan struct{}) 1005 go func() { 1006 defer GinkgoRecover() 1007 defer close(done) 1008 s.Serve(conn) 1009 }() 1010 1011 Consistently(done).ShouldNot(BeClosed()) 1012 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1013 Expect(s.Close()).To(Succeed()) 1014 Eventually(done).Should(BeClosed()) 1015 }) 1016 1017 It("serves two packet conns", func() { 1018 ln1 := newMockAddrListener(":443") 1019 ln2 := newMockAddrListener(":8443") 1020 lns := make(chan QUICEarlyListener, 2) 1021 lns <- ln1 1022 lns <- ln2 1023 conn1 := &net.UDPConn{} 1024 conn2 := &net.UDPConn{} 1025 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1026 return <-lns, nil 1027 } 1028 1029 s := &Server{ 1030 TLSConfig: &tls.Config{}, 1031 } 1032 1033 stopAccept1 := make(chan struct{}) 1034 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1035 <-stopAccept1 1036 return nil, errors.New("closed") 1037 }) 1038 ln1.EXPECT().Addr() // generate alt-svc headers 1039 stopAccept2 := make(chan struct{}) 1040 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1041 <-stopAccept2 1042 return nil, errors.New("closed") 1043 }) 1044 ln2.EXPECT().Addr() 1045 1046 done1 := make(chan struct{}) 1047 go func() { 1048 defer GinkgoRecover() 1049 defer close(done1) 1050 s.Serve(conn1) 1051 }() 1052 done2 := make(chan struct{}) 1053 go func() { 1054 defer GinkgoRecover() 1055 defer close(done2) 1056 s.Serve(conn2) 1057 }() 1058 1059 Consistently(done1).ShouldNot(BeClosed()) 1060 Expect(done2).ToNot(BeClosed()) 1061 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1062 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1063 Expect(s.Close()).To(Succeed()) 1064 Eventually(done1).Should(BeClosed()) 1065 Eventually(done2).Should(BeClosed()) 1066 }) 1067 }) 1068 1069 Context("ServeListener", func() { 1070 origQuicListen := quicListen 1071 1072 AfterEach(func() { 1073 quicListen = origQuicListen 1074 }) 1075 1076 It("serves a listener", func() { 1077 var called int32 1078 ln := newMockAddrListener(":443") 1079 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1080 atomic.StoreInt32(&called, 1) 1081 return ln, nil 1082 } 1083 1084 s := &Server{} 1085 1086 stopAccept := make(chan struct{}) 1087 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1088 <-stopAccept 1089 return nil, errors.New("closed") 1090 }) 1091 ln.EXPECT().Addr() // generate alt-svc headers 1092 done := make(chan struct{}) 1093 go func() { 1094 defer GinkgoRecover() 1095 defer close(done) 1096 s.ServeListener(ln) 1097 }() 1098 1099 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1100 Consistently(done).ShouldNot(BeClosed()) 1101 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1102 Expect(s.Close()).To(Succeed()) 1103 Eventually(done).Should(BeClosed()) 1104 }) 1105 1106 It("serves two listeners", func() { 1107 var called int32 1108 ln1 := newMockAddrListener(":443") 1109 ln2 := newMockAddrListener(":8443") 1110 lns := make(chan QUICEarlyListener, 2) 1111 lns <- ln1 1112 lns <- ln2 1113 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1114 atomic.StoreInt32(&called, 1) 1115 return <-lns, nil 1116 } 1117 1118 s := &Server{} 1119 1120 stopAccept1 := make(chan struct{}) 1121 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1122 <-stopAccept1 1123 return nil, errors.New("closed") 1124 }) 1125 ln1.EXPECT().Addr() // generate alt-svc headers 1126 stopAccept2 := make(chan struct{}) 1127 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1128 <-stopAccept2 1129 return nil, errors.New("closed") 1130 }) 1131 ln2.EXPECT().Addr() 1132 1133 done1 := make(chan struct{}) 1134 go func() { 1135 defer GinkgoRecover() 1136 defer close(done1) 1137 s.ServeListener(ln1) 1138 }() 1139 done2 := make(chan struct{}) 1140 go func() { 1141 defer GinkgoRecover() 1142 defer close(done2) 1143 s.ServeListener(ln2) 1144 }() 1145 1146 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1147 Consistently(done1).ShouldNot(BeClosed()) 1148 Expect(done2).ToNot(BeClosed()) 1149 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1150 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1151 Expect(s.Close()).To(Succeed()) 1152 Eventually(done1).Should(BeClosed()) 1153 Eventually(done2).Should(BeClosed()) 1154 }) 1155 }) 1156 1157 Context("ServeQUICConn", func() { 1158 It("serves a QUIC connection", func() { 1159 mux := http.NewServeMux() 1160 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1161 w.Write([]byte("foobar")) 1162 }) 1163 s.Handler = mux 1164 tlsConf := testdata.GetTLSConfig() 1165 tlsConf.NextProtos = []string{NextProtoH3} 1166 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1167 controlStr := mockquic.NewMockStream(mockCtrl) 1168 controlStr.EXPECT().Write(gomock.Any()) 1169 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1170 testDone := make(chan struct{}) 1171 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1172 <-testDone 1173 return nil, errors.New("test done") 1174 }).MaxTimes(1) 1175 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1176 s.ServeQUICConn(conn) 1177 close(testDone) 1178 }) 1179 }) 1180 1181 Context("ListenAndServe", func() { 1182 BeforeEach(func() { 1183 s.Addr = "localhost:0" 1184 }) 1185 1186 AfterEach(func() { 1187 Expect(s.Close()).To(Succeed()) 1188 }) 1189 1190 It("uses the quic.Config to start the QUIC server", func() { 1191 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1192 var receivedConf *quic.Config 1193 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1194 receivedConf = config 1195 return nil, errors.New("listen err") 1196 } 1197 s.QUICConfig = conf 1198 Expect(s.ListenAndServe()).To(HaveOccurred()) 1199 Expect(receivedConf).To(Equal(conf)) 1200 }) 1201 }) 1202 1203 It("closes gracefully", func() { 1204 Expect(s.CloseGracefully(0)).To(Succeed()) 1205 }) 1206 1207 It("errors when listening fails", func() { 1208 testErr := errors.New("listen error") 1209 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1210 return nil, testErr 1211 } 1212 fullpem, privkey := testdata.GetCertificatePaths() 1213 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1214 }) 1215 1216 It("supports H3_DATAGRAM", func() { 1217 s.EnableDatagrams = true 1218 var receivedConf *quic.Config 1219 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1220 receivedConf = config 1221 return nil, errors.New("listen err") 1222 } 1223 Expect(s.ListenAndServe()).To(HaveOccurred()) 1224 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1225 }) 1226 })