github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/danielpfeifer02/quic-go-prio-packs" 17 mockquic "github.com/danielpfeifer02/quic-go-prio-packs/internal/mocks/quic" 18 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 19 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr" 20 "github.com/danielpfeifer02/quic-go-prio-packs/internal/testdata" 21 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 22 "github.com/danielpfeifer02/quic-go-prio-packs/quicvarint" 23 24 "github.com/quic-go/qpack" 25 "go.uber.org/mock/gomock" 26 27 . "github.com/onsi/ginkgo/v2" 28 . "github.com/onsi/gomega" 29 gmtypes "github.com/onsi/gomega/types" 30 ) 31 32 type mockAddr struct{ addr string } 33 34 func (ma *mockAddr) Network() string { return "udp" } 35 func (ma *mockAddr) String() string { return ma.addr } 36 37 type mockAddrListener struct { 38 *MockQUICEarlyListener 39 addr *mockAddr 40 } 41 42 func (m *mockAddrListener) Addr() net.Addr { 43 _ = m.MockQUICEarlyListener.Addr() 44 return m.addr 45 } 46 47 func newMockAddrListener(addr string) *mockAddrListener { 48 return &mockAddrListener{ 49 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 50 addr: &mockAddr{addr: addr}, 51 } 52 } 53 54 type noPortListener struct { 55 *mockAddrListener 56 } 57 58 func (m *noPortListener) Addr() net.Addr { 59 _ = m.mockAddrListener.Addr() 60 return &net.UnixAddr{ 61 Net: "unix", 62 Name: "/tmp/quic.sock", 63 } 64 } 65 66 var _ = Describe("Server", func() { 67 var ( 68 s *Server 69 origQuicListenAddr = quicListenAddr 70 ) 71 type testConnContextKey string 72 73 BeforeEach(func() { 74 s = &Server{ 75 TLSConfig: testdata.GetTLSConfig(), 76 logger: utils.DefaultLogger, 77 ConnContext: func(ctx context.Context, c quic.Connection) context.Context { 78 return context.WithValue(ctx, testConnContextKey("test"), c) 79 }, 80 } 81 origQuicListenAddr = quicListenAddr 82 }) 83 84 AfterEach(func() { 85 quicListenAddr = origQuicListenAddr 86 }) 87 88 Context("handling requests", func() { 89 var ( 90 qpackDecoder *qpack.Decoder 91 str *mockquic.MockStream 92 conn *mockquic.MockEarlyConnection 93 exampleGetRequest *http.Request 94 examplePostRequest *http.Request 95 ) 96 reqContext := context.Background() 97 98 decodeHeader := func(str io.Reader) map[string][]string { 99 fields := make(map[string][]string) 100 decoder := qpack.NewDecoder(nil) 101 102 frame, err := parseNextFrame(str, nil) 103 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 104 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 105 headersFrame := frame.(*headersFrame) 106 data := make([]byte, headersFrame.Length) 107 _, err = io.ReadFull(str, data) 108 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 109 hfs, err := decoder.DecodeFull(data) 110 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 111 for _, p := range hfs { 112 fields[p.Name] = append(fields[p.Name], p.Value) 113 } 114 return fields 115 } 116 117 encodeRequest := func(req *http.Request) []byte { 118 buf := &bytes.Buffer{} 119 str := mockquic.NewMockStream(mockCtrl) 120 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 121 rw := newRequestWriter(utils.DefaultLogger) 122 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 123 return buf.Bytes() 124 } 125 126 setRequest := func(data []byte) { 127 buf := bytes.NewBuffer(data) 128 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 129 if buf.Len() == 0 { 130 return 0, io.EOF 131 } 132 return buf.Read(p) 133 }).AnyTimes() 134 } 135 136 BeforeEach(func() { 137 var err error 138 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 139 Expect(err).ToNot(HaveOccurred()) 140 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 141 Expect(err).ToNot(HaveOccurred()) 142 143 qpackDecoder = qpack.NewDecoder(nil) 144 str = mockquic.NewMockStream(mockCtrl) 145 conn = mockquic.NewMockEarlyConnection(mockCtrl) 146 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 147 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 148 conn.EXPECT().LocalAddr().AnyTimes() 149 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 150 }) 151 152 It("calls the HTTP handler function", func() { 153 requestChan := make(chan *http.Request, 1) 154 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 155 requestChan <- r 156 }) 157 158 setRequest(encodeRequest(exampleGetRequest)) 159 str.EXPECT().Context().Return(reqContext) 160 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 161 return len(p), nil 162 }).AnyTimes() 163 str.EXPECT().CancelRead(gomock.Any()) 164 165 Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) 166 var req *http.Request 167 Eventually(requestChan).Should(Receive(&req)) 168 Expect(req.Host).To(Equal("www.example.com")) 169 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 170 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 171 Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil)) 172 }) 173 174 It("returns 200 with an empty handler", func() { 175 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 176 177 responseBuf := &bytes.Buffer{} 178 setRequest(encodeRequest(exampleGetRequest)) 179 str.EXPECT().Context().Return(reqContext) 180 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 181 str.EXPECT().CancelRead(gomock.Any()) 182 183 serr := s.handleRequest(conn, str, qpackDecoder, nil) 184 Expect(serr.err).ToNot(HaveOccurred()) 185 hfs := decodeHeader(responseBuf) 186 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 187 }) 188 189 It("sets Content-Length when the handler doesn't flush to the client", func() { 190 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 191 w.Write([]byte("foobar")) 192 }) 193 194 responseBuf := &bytes.Buffer{} 195 setRequest(encodeRequest(exampleGetRequest)) 196 str.EXPECT().Context().Return(reqContext) 197 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 198 str.EXPECT().CancelRead(gomock.Any()) 199 200 serr := s.handleRequest(conn, str, qpackDecoder, nil) 201 Expect(serr.err).ToNot(HaveOccurred()) 202 hfs := decodeHeader(responseBuf) 203 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 204 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) 205 // status, content-length, date, content-type 206 Expect(hfs).To(HaveLen(4)) 207 }) 208 209 It("not sets Content-Length when the handler flushes to the client", func() { 210 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 211 w.Write([]byte("foobar")) 212 // force flush 213 w.(http.Flusher).Flush() 214 }) 215 216 responseBuf := &bytes.Buffer{} 217 setRequest(encodeRequest(exampleGetRequest)) 218 str.EXPECT().Context().Return(reqContext) 219 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 220 str.EXPECT().CancelRead(gomock.Any()) 221 222 serr := s.handleRequest(conn, str, qpackDecoder, nil) 223 Expect(serr.err).ToNot(HaveOccurred()) 224 hfs := decodeHeader(responseBuf) 225 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 226 // status, date, content-type 227 Expect(hfs).To(HaveLen(3)) 228 }) 229 230 It("response to HEAD request should not have body", func() { 231 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 232 w.Write([]byte("foobar")) 233 }) 234 235 headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) 236 Expect(err).ToNot(HaveOccurred()) 237 responseBuf := &bytes.Buffer{} 238 setRequest(encodeRequest(headRequest)) 239 str.EXPECT().Context().Return(reqContext) 240 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 241 str.EXPECT().CancelRead(gomock.Any()) 242 serr := s.handleRequest(conn, str, qpackDecoder, nil) 243 Expect(serr.err).ToNot(HaveOccurred()) 244 hfs := decodeHeader(responseBuf) 245 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 246 Expect(responseBuf.Bytes()).To(HaveLen(0)) 247 }) 248 249 It("response to HEAD request should also do content sniffing", func() { 250 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 251 w.Write([]byte("<html></html>")) 252 }) 253 254 headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) 255 Expect(err).ToNot(HaveOccurred()) 256 responseBuf := &bytes.Buffer{} 257 setRequest(encodeRequest(headRequest)) 258 str.EXPECT().Context().Return(reqContext) 259 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 260 str.EXPECT().CancelRead(gomock.Any()) 261 serr := s.handleRequest(conn, str, qpackDecoder, nil) 262 Expect(serr.err).ToNot(HaveOccurred()) 263 hfs := decodeHeader(responseBuf) 264 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 265 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 266 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 267 }) 268 269 It("handles a aborting handler", func() { 270 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 271 panic(http.ErrAbortHandler) 272 }) 273 274 responseBuf := &bytes.Buffer{} 275 setRequest(encodeRequest(exampleGetRequest)) 276 str.EXPECT().Context().Return(reqContext) 277 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 278 str.EXPECT().CancelRead(gomock.Any()) 279 280 serr := s.handleRequest(conn, str, qpackDecoder, nil) 281 Expect(serr.err).To(MatchError(errPanicked)) 282 Expect(responseBuf.Bytes()).To(HaveLen(0)) 283 }) 284 285 It("handles a panicking handler", func() { 286 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 panic("foobar") 288 }) 289 290 responseBuf := &bytes.Buffer{} 291 setRequest(encodeRequest(exampleGetRequest)) 292 str.EXPECT().Context().Return(reqContext) 293 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 294 str.EXPECT().CancelRead(gomock.Any()) 295 296 serr := s.handleRequest(conn, str, qpackDecoder, nil) 297 Expect(serr.err).To(MatchError(errPanicked)) 298 Expect(responseBuf.Bytes()).To(HaveLen(0)) 299 }) 300 301 Context("hijacking bidirectional streams", func() { 302 var conn *mockquic.MockEarlyConnection 303 testDone := make(chan struct{}) 304 305 BeforeEach(func() { 306 testDone = make(chan struct{}) 307 conn = mockquic.NewMockEarlyConnection(mockCtrl) 308 controlStr := mockquic.NewMockStream(mockCtrl) 309 controlStr.EXPECT().Write(gomock.Any()) 310 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 311 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 312 conn.EXPECT().LocalAddr().AnyTimes() 313 }) 314 315 AfterEach(func() { testDone <- struct{}{} }) 316 317 It("hijacks a bidirectional stream of unknown frame type", func() { 318 frameTypeChan := make(chan FrameType, 1) 319 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 320 Expect(e).ToNot(HaveOccurred()) 321 frameTypeChan <- ft 322 return true, nil 323 } 324 325 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 326 unknownStr := mockquic.NewMockStream(mockCtrl) 327 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 328 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 329 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 330 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 331 <-testDone 332 return nil, errors.New("test done") 333 }) 334 s.handleConn(conn) 335 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 336 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 337 }) 338 339 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 340 frameTypeChan := make(chan FrameType, 1) 341 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 342 Expect(e).ToNot(HaveOccurred()) 343 frameTypeChan <- ft 344 return false, nil 345 } 346 347 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 348 unknownStr := mockquic.NewMockStream(mockCtrl) 349 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 350 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 351 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 352 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 353 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 354 <-testDone 355 return nil, errors.New("test done") 356 }) 357 s.handleConn(conn) 358 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 359 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 360 }) 361 362 It("cancels writing when hijacker returned error", func() { 363 frameTypeChan := make(chan FrameType, 1) 364 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 365 Expect(e).ToNot(HaveOccurred()) 366 frameTypeChan <- ft 367 return false, errors.New("error in hijacker") 368 } 369 370 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 371 unknownStr := mockquic.NewMockStream(mockCtrl) 372 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 373 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 374 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 375 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 376 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 377 <-testDone 378 return nil, errors.New("test done") 379 }) 380 s.handleConn(conn) 381 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 382 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 383 }) 384 385 It("handles errors that occur when reading the stream type", func() { 386 testErr := errors.New("test error") 387 done := make(chan struct{}) 388 unknownStr := mockquic.NewMockStream(mockCtrl) 389 s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { 390 defer close(done) 391 Expect(ft).To(BeZero()) 392 Expect(str).To(Equal(unknownStr)) 393 Expect(err).To(MatchError(testErr)) 394 return true, nil 395 } 396 397 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 398 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 399 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 400 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 401 <-testDone 402 return nil, errors.New("test done") 403 }) 404 s.handleConn(conn) 405 Eventually(done).Should(BeClosed()) 406 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 407 }) 408 }) 409 410 Context("hijacking unidirectional streams", func() { 411 var conn *mockquic.MockEarlyConnection 412 testDone := make(chan struct{}) 413 414 BeforeEach(func() { 415 testDone = make(chan struct{}) 416 conn = mockquic.NewMockEarlyConnection(mockCtrl) 417 controlStr := mockquic.NewMockStream(mockCtrl) 418 controlStr.EXPECT().Write(gomock.Any()) 419 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 420 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 421 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 422 conn.EXPECT().LocalAddr().AnyTimes() 423 }) 424 425 AfterEach(func() { testDone <- struct{}{} }) 426 427 It("hijacks an unidirectional stream of unknown stream type", func() { 428 streamTypeChan := make(chan StreamType, 1) 429 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 430 Expect(err).ToNot(HaveOccurred()) 431 streamTypeChan <- st 432 return true 433 } 434 435 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 436 unknownStr := mockquic.NewMockStream(mockCtrl) 437 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 438 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 439 return unknownStr, nil 440 }) 441 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 442 <-testDone 443 return nil, errors.New("test done") 444 }) 445 s.handleConn(conn) 446 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 447 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 448 }) 449 450 It("handles errors that occur when reading the stream type", func() { 451 testErr := errors.New("test error") 452 done := make(chan struct{}) 453 unknownStr := mockquic.NewMockStream(mockCtrl) 454 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 455 defer close(done) 456 Expect(st).To(BeZero()) 457 Expect(str).To(Equal(unknownStr)) 458 Expect(err).To(MatchError(testErr)) 459 return true 460 } 461 462 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 463 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 464 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 465 <-testDone 466 return nil, errors.New("test done") 467 }) 468 s.handleConn(conn) 469 Eventually(done).Should(BeClosed()) 470 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 471 }) 472 473 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 474 streamTypeChan := make(chan StreamType, 1) 475 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 476 Expect(err).ToNot(HaveOccurred()) 477 streamTypeChan <- st 478 return false 479 } 480 481 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 482 unknownStr := mockquic.NewMockStream(mockCtrl) 483 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 484 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 485 486 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 487 return unknownStr, nil 488 }) 489 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 490 <-testDone 491 return nil, errors.New("test done") 492 }) 493 s.handleConn(conn) 494 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 495 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 496 }) 497 }) 498 499 Context("control stream handling", func() { 500 var conn *mockquic.MockEarlyConnection 501 testDone := make(chan struct{}, 1) 502 503 BeforeEach(func() { 504 conn = mockquic.NewMockEarlyConnection(mockCtrl) 505 controlStr := mockquic.NewMockStream(mockCtrl) 506 controlStr.EXPECT().Write(gomock.Any()) 507 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 508 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 509 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 510 conn.EXPECT().LocalAddr().AnyTimes() 511 }) 512 513 AfterEach(func() { testDone <- struct{}{} }) 514 515 It("parses the SETTINGS frame", func() { 516 b := quicvarint.Append(nil, streamTypeControlStream) 517 b = (&settingsFrame{}).Append(b) 518 controlStr := mockquic.NewMockStream(mockCtrl) 519 r := bytes.NewReader(b) 520 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 521 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 522 return controlStr, nil 523 }) 524 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 525 <-testDone 526 return nil, errors.New("test done") 527 }) 528 s.handleConn(conn) 529 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 530 }) 531 532 It("rejects duplicate control streams", func() { 533 b := quicvarint.Append(nil, streamTypeControlStream) 534 b = (&settingsFrame{}).Append(b) 535 r1 := bytes.NewReader(b) 536 controlStr1 := mockquic.NewMockStream(mockCtrl) 537 controlStr1.EXPECT().Read(gomock.Any()).DoAndReturn(r1.Read).AnyTimes() 538 r2 := bytes.NewReader(b) 539 controlStr2 := mockquic.NewMockStream(mockCtrl) 540 controlStr2.EXPECT().Read(gomock.Any()).DoAndReturn(r2.Read).AnyTimes() 541 done := make(chan struct{}) 542 conn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream").Do(func(qerr.ApplicationErrorCode, string) error { 543 close(done) 544 return nil 545 }) 546 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 547 return controlStr1, nil 548 }) 549 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 550 return controlStr2, nil 551 }) 552 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 553 <-done 554 return nil, errors.New("test done") 555 }) 556 s.handleConn(conn) 557 Eventually(done).Should(BeClosed()) 558 }) 559 560 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 561 streamType := t 562 name := "encoder" 563 if streamType == streamTypeQPACKDecoderStream { 564 name = "decoder" 565 } 566 567 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 568 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 569 str := mockquic.NewMockStream(mockCtrl) 570 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 571 572 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 573 return str, nil 574 }) 575 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 576 <-testDone 577 return nil, errors.New("test done") 578 }) 579 s.handleConn(conn) 580 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 581 }) 582 } 583 584 It("reset streams other than the control stream and the QPACK streams", func() { 585 buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337)) 586 str := mockquic.NewMockStream(mockCtrl) 587 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 588 done := make(chan struct{}) 589 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) 590 591 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 592 return str, nil 593 }) 594 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 595 <-testDone 596 return nil, errors.New("test done") 597 }) 598 s.handleConn(conn) 599 Eventually(done).Should(BeClosed()) 600 }) 601 602 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 603 b := quicvarint.Append(nil, streamTypeControlStream) 604 b = (&dataFrame{}).Append(b) 605 controlStr := mockquic.NewMockStream(mockCtrl) 606 r := bytes.NewReader(b) 607 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 608 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 609 return controlStr, nil 610 }) 611 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 612 <-testDone 613 return nil, errors.New("test done") 614 }) 615 done := make(chan struct{}) 616 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 617 close(done) 618 return nil 619 }) 620 s.handleConn(conn) 621 Eventually(done).Should(BeClosed()) 622 }) 623 624 It("errors when parsing the frame on the control stream fails", func() { 625 b := quicvarint.Append(nil, streamTypeControlStream) 626 b = (&settingsFrame{}).Append(b) 627 r := bytes.NewReader(b[:len(b)-1]) 628 controlStr := mockquic.NewMockStream(mockCtrl) 629 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 630 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 631 return controlStr, nil 632 }) 633 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 634 <-testDone 635 return nil, errors.New("test done") 636 }) 637 done := make(chan struct{}) 638 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 639 close(done) 640 return nil 641 }) 642 s.handleConn(conn) 643 Eventually(done).Should(BeClosed()) 644 }) 645 646 It("errors when the client opens a push stream", func() { 647 b := quicvarint.Append(nil, streamTypePushStream) 648 b = (&dataFrame{}).Append(b) 649 r := bytes.NewReader(b) 650 controlStr := mockquic.NewMockStream(mockCtrl) 651 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 652 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 653 return controlStr, nil 654 }) 655 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 656 <-testDone 657 return nil, errors.New("test done") 658 }) 659 done := make(chan struct{}) 660 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 661 close(done) 662 return nil 663 }) 664 s.handleConn(conn) 665 Eventually(done).Should(BeClosed()) 666 }) 667 668 It("errors when the client advertises datagram support (and we enabled support for it)", func() { 669 s.EnableDatagrams = true 670 b := quicvarint.Append(nil, streamTypeControlStream) 671 b = (&settingsFrame{Datagram: true}).Append(b) 672 r := bytes.NewReader(b) 673 controlStr := mockquic.NewMockStream(mockCtrl) 674 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 675 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 676 return controlStr, nil 677 }) 678 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 679 <-testDone 680 return nil, errors.New("test done") 681 }) 682 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 683 done := make(chan struct{}) 684 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { 685 close(done) 686 return nil 687 }) 688 s.handleConn(conn) 689 Eventually(done).Should(BeClosed()) 690 }) 691 }) 692 693 Context("stream- and connection-level errors", func() { 694 var conn *mockquic.MockEarlyConnection 695 testDone := make(chan struct{}) 696 697 BeforeEach(func() { 698 testDone = make(chan struct{}) 699 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 700 conn = mockquic.NewMockEarlyConnection(mockCtrl) 701 controlStr := mockquic.NewMockStream(mockCtrl) 702 controlStr.EXPECT().Write(gomock.Any()) 703 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 704 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 705 <-testDone 706 return nil, errors.New("test done") 707 }) 708 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 709 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 710 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 711 conn.EXPECT().LocalAddr().AnyTimes() 712 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 713 }) 714 715 AfterEach(func() { testDone <- struct{}{} }) 716 717 It("cancels reading when client sends a body in GET request", func() { 718 var handlerCalled bool 719 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 720 handlerCalled = true 721 }) 722 723 requestData := encodeRequest(exampleGetRequest) 724 b := (&dataFrame{Length: 6}).Append(nil) // add a body 725 b = append(b, []byte("foobar")...) 726 responseBuf := &bytes.Buffer{} 727 setRequest(append(requestData, b...)) 728 done := make(chan struct{}) 729 str.EXPECT().Context().Return(reqContext) 730 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 731 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 732 str.EXPECT().Close().Do(func() error { close(done); return nil }) 733 734 s.handleConn(conn) 735 Eventually(done).Should(BeClosed()) 736 hfs := decodeHeader(responseBuf) 737 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 738 Expect(handlerCalled).To(BeTrue()) 739 }) 740 741 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 742 handlerCalled := make(chan struct{}) 743 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 744 defer close(handlerCalled) 745 r.Body.(HTTPStreamer).HTTPStream() 746 str.Write([]byte("foobar")) 747 }) 748 749 requestData := encodeRequest(exampleGetRequest) 750 b := (&dataFrame{Length: 6}).Append(nil) // add a body 751 b = append(b, []byte("foobar")...) 752 setRequest(append(requestData, b...)) 753 str.EXPECT().Context().Return(reqContext) 754 str.EXPECT().Write([]byte("foobar")).Return(6, nil) 755 756 s.handleConn(conn) 757 Eventually(handlerCalled).Should(BeClosed()) 758 }) 759 760 It("errors when the client sends a too large header frame", func() { 761 s.MaxHeaderBytes = 20 762 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 763 Fail("Handler should not be called.") 764 }) 765 766 requestData := encodeRequest(exampleGetRequest) 767 b := (&dataFrame{Length: 6}).Append(nil) // add a body 768 b = append(b, []byte("foobar")...) 769 responseBuf := &bytes.Buffer{} 770 setRequest(append(requestData, b...)) 771 done := make(chan struct{}) 772 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 773 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 774 775 s.handleConn(conn) 776 Eventually(done).Should(BeClosed()) 777 }) 778 779 It("handles a request for which the client immediately resets the stream", func() { 780 handlerCalled := make(chan struct{}) 781 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 782 close(handlerCalled) 783 }) 784 785 testErr := errors.New("stream reset") 786 done := make(chan struct{}) 787 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 788 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 789 790 s.handleConn(conn) 791 Consistently(handlerCalled).ShouldNot(BeClosed()) 792 }) 793 794 It("closes the connection when the first frame is not a HEADERS frame", func() { 795 handlerCalled := make(chan struct{}) 796 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 797 close(handlerCalled) 798 }) 799 800 b := (&dataFrame{}).Append(nil) 801 setRequest(b) 802 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 803 return len(p), nil 804 }).AnyTimes() 805 806 done := make(chan struct{}) 807 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 808 close(done) 809 return nil 810 }) 811 s.handleConn(conn) 812 Eventually(done).Should(BeClosed()) 813 }) 814 815 It("closes the connection when the first frame is not a HEADERS frame", func() { 816 handlerCalled := make(chan struct{}) 817 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 818 close(handlerCalled) 819 }) 820 821 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 822 // but the request will still end up larger than DefaultMaxHeaderBytes. 823 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 824 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 825 Expect(err).ToNot(HaveOccurred()) 826 setRequest(encodeRequest(req)) 827 // str.EXPECT().Context().Return(reqContext) 828 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 829 return len(p), nil 830 }).AnyTimes() 831 done := make(chan struct{}) 832 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 833 834 s.handleConn(conn) 835 Eventually(done).Should(BeClosed()) 836 }) 837 }) 838 839 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 840 handlerCalled := make(chan struct{}) 841 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 842 r.Body = struct { 843 io.Reader 844 io.Closer 845 }{} 846 close(handlerCalled) 847 }) 848 849 setRequest(encodeRequest(examplePostRequest)) 850 str.EXPECT().Context().Return(reqContext) 851 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 852 return len(p), nil 853 }).AnyTimes() 854 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 855 856 serr := s.handleRequest(conn, str, qpackDecoder, nil) 857 Expect(serr.err).ToNot(HaveOccurred()) 858 Eventually(handlerCalled).Should(BeClosed()) 859 }) 860 861 It("cancels the request context when the stream is closed", func() { 862 handlerCalled := make(chan struct{}) 863 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 864 defer GinkgoRecover() 865 Expect(r.Context().Done()).To(BeClosed()) 866 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 867 close(handlerCalled) 868 }) 869 setRequest(encodeRequest(examplePostRequest)) 870 871 reqContext, cancel := context.WithCancel(context.Background()) 872 cancel() 873 str.EXPECT().Context().Return(reqContext) 874 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 875 return len(p), nil 876 }).AnyTimes() 877 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 878 879 serr := s.handleRequest(conn, str, qpackDecoder, nil) 880 Expect(serr.err).ToNot(HaveOccurred()) 881 Eventually(handlerCalled).Should(BeClosed()) 882 }) 883 }) 884 885 Context("setting http headers", func() { 886 BeforeEach(func() { 887 s.QuicConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}} 888 }) 889 890 var ln1 QUICEarlyListener 891 var ln2 QUICEarlyListener 892 expected := http.Header{ 893 "Alt-Svc": {`h3=":443"; ma=2592000`}, 894 } 895 896 addListener := func(addr string, ln *QUICEarlyListener) { 897 mln := newMockAddrListener(addr) 898 mln.EXPECT().Addr() 899 *ln = mln 900 s.addListener(ln) 901 } 902 903 removeListener := func(ln *QUICEarlyListener) { 904 s.removeListener(ln) 905 } 906 907 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 908 hdr := http.Header{} 909 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 910 Expect(hdr).To(expected) 911 } 912 913 checkSetHeaderError := func() { 914 hdr := http.Header{} 915 Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 916 } 917 918 It("sets proper headers with numeric port", func() { 919 addListener(":443", &ln1) 920 checkSetHeaders(Equal(expected)) 921 removeListener(&ln1) 922 checkSetHeaderError() 923 }) 924 925 It("sets proper headers with full addr", func() { 926 addListener("127.0.0.1:443", &ln1) 927 checkSetHeaders(Equal(expected)) 928 removeListener(&ln1) 929 checkSetHeaderError() 930 }) 931 932 It("sets proper headers with string port", func() { 933 addListener(":https", &ln1) 934 checkSetHeaders(Equal(expected)) 935 removeListener(&ln1) 936 checkSetHeaderError() 937 }) 938 939 It("works multiple times", func() { 940 addListener(":https", &ln1) 941 checkSetHeaders(Equal(expected)) 942 checkSetHeaders(Equal(expected)) 943 removeListener(&ln1) 944 checkSetHeaderError() 945 }) 946 947 It("works if the quic.Config sets QUIC versions", func() { 948 s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version2} 949 addListener(":443", &ln1) 950 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 951 removeListener(&ln1) 952 checkSetHeaderError() 953 }) 954 955 It("uses s.Port if set to a non-zero value", func() { 956 s.Port = 8443 957 addListener(":443", &ln1) 958 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 959 removeListener(&ln1) 960 checkSetHeaderError() 961 }) 962 963 It("uses s.Addr if listeners don't have ports available", func() { 964 s.Addr = ":443" 965 mln := &noPortListener{newMockAddrListener("")} 966 mln.EXPECT().Addr() 967 ln1 = mln 968 s.addListener(&ln1) 969 checkSetHeaders(Equal(expected)) 970 s.removeListener(&ln1) 971 checkSetHeaderError() 972 }) 973 974 It("properly announces multiple listeners", func() { 975 addListener(":443", &ln1) 976 addListener(":8443", &ln2) 977 checkSetHeaders(Or( 978 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 979 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 980 )) 981 removeListener(&ln1) 982 removeListener(&ln2) 983 checkSetHeaderError() 984 }) 985 986 It("doesn't duplicate Alt-Svc values", func() { 987 s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version1} 988 addListener(":443", &ln1) 989 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 990 removeListener(&ln1) 991 checkSetHeaderError() 992 }) 993 }) 994 995 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 996 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 997 }) 998 999 It("should nop-Close() when s.server is nil", func() { 1000 Expect((&Server{}).Close()).To(Succeed()) 1001 }) 1002 1003 It("errors when ListenAndServeTLS is called after Close", func() { 1004 serv := &Server{} 1005 Expect(serv.Close()).To(Succeed()) 1006 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 1007 }) 1008 1009 It("handles concurrent Serve and Close", func() { 1010 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 1011 Expect(err).ToNot(HaveOccurred()) 1012 c, err := net.ListenUDP("udp", addr) 1013 Expect(err).ToNot(HaveOccurred()) 1014 done := make(chan struct{}) 1015 go func() { 1016 defer GinkgoRecover() 1017 defer close(done) 1018 s.Serve(c) 1019 }() 1020 runtime.Gosched() 1021 s.Close() 1022 Eventually(done).Should(BeClosed()) 1023 }) 1024 1025 Context("ConfigureTLSConfig", func() { 1026 It("advertises v1 by default", func() { 1027 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 1028 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}}) 1029 Expect(err).ToNot(HaveOccurred()) 1030 defer ln.Close() 1031 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1032 Expect(err).ToNot(HaveOccurred()) 1033 defer c.CloseWithError(0, "") 1034 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1035 }) 1036 1037 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 1038 var receivedConf *tls.Config 1039 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 1040 receivedConf = tlsConf 1041 return nil, errors.New("listen err") 1042 } 1043 Expect(s.ListenAndServe()).To(HaveOccurred()) 1044 Expect(receivedConf).ToNot(BeNil()) 1045 }) 1046 1047 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 1048 tlsConf := &tls.Config{ 1049 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 1050 c := testdata.GetTLSConfig() 1051 c.NextProtos = []string{"foo", "bar"} 1052 return c, nil 1053 }, 1054 } 1055 1056 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1057 Expect(err).ToNot(HaveOccurred()) 1058 defer ln.Close() 1059 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1060 Expect(err).ToNot(HaveOccurred()) 1061 defer c.CloseWithError(0, "") 1062 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1063 }) 1064 1065 It("works if GetConfigForClient returns a nil tls.Config", func() { 1066 tlsConf := testdata.GetTLSConfig() 1067 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 1068 1069 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1070 Expect(err).ToNot(HaveOccurred()) 1071 defer ln.Close() 1072 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1073 Expect(err).ToNot(HaveOccurred()) 1074 defer c.CloseWithError(0, "") 1075 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1076 }) 1077 1078 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 1079 tlsClientConf := testdata.GetTLSConfig() 1080 tlsClientConf.NextProtos = []string{"foo", "bar"} 1081 tlsConf := &tls.Config{ 1082 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 1083 return tlsClientConf, nil 1084 }, 1085 } 1086 1087 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1088 Expect(err).ToNot(HaveOccurred()) 1089 defer ln.Close() 1090 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1091 Expect(err).ToNot(HaveOccurred()) 1092 defer c.CloseWithError(0, "") 1093 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1094 // check that the original config was not modified 1095 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 1096 }) 1097 }) 1098 1099 Context("Serve", func() { 1100 origQuicListen := quicListen 1101 1102 AfterEach(func() { 1103 quicListen = origQuicListen 1104 }) 1105 1106 It("serves a packet conn", func() { 1107 ln := newMockAddrListener(":443") 1108 conn := &net.UDPConn{} 1109 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1110 Expect(c).To(Equal(conn)) 1111 return ln, nil 1112 } 1113 1114 s := &Server{ 1115 TLSConfig: &tls.Config{}, 1116 } 1117 1118 stopAccept := make(chan struct{}) 1119 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1120 <-stopAccept 1121 return nil, errors.New("closed") 1122 }) 1123 ln.EXPECT().Addr() // generate alt-svc headers 1124 done := make(chan struct{}) 1125 go func() { 1126 defer GinkgoRecover() 1127 defer close(done) 1128 s.Serve(conn) 1129 }() 1130 1131 Consistently(done).ShouldNot(BeClosed()) 1132 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1133 Expect(s.Close()).To(Succeed()) 1134 Eventually(done).Should(BeClosed()) 1135 }) 1136 1137 It("serves two packet conns", func() { 1138 ln1 := newMockAddrListener(":443") 1139 ln2 := newMockAddrListener(":8443") 1140 lns := make(chan QUICEarlyListener, 2) 1141 lns <- ln1 1142 lns <- ln2 1143 conn1 := &net.UDPConn{} 1144 conn2 := &net.UDPConn{} 1145 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1146 return <-lns, nil 1147 } 1148 1149 s := &Server{ 1150 TLSConfig: &tls.Config{}, 1151 } 1152 1153 stopAccept1 := make(chan struct{}) 1154 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1155 <-stopAccept1 1156 return nil, errors.New("closed") 1157 }) 1158 ln1.EXPECT().Addr() // generate alt-svc headers 1159 stopAccept2 := make(chan struct{}) 1160 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1161 <-stopAccept2 1162 return nil, errors.New("closed") 1163 }) 1164 ln2.EXPECT().Addr() 1165 1166 done1 := make(chan struct{}) 1167 go func() { 1168 defer GinkgoRecover() 1169 defer close(done1) 1170 s.Serve(conn1) 1171 }() 1172 done2 := make(chan struct{}) 1173 go func() { 1174 defer GinkgoRecover() 1175 defer close(done2) 1176 s.Serve(conn2) 1177 }() 1178 1179 Consistently(done1).ShouldNot(BeClosed()) 1180 Expect(done2).ToNot(BeClosed()) 1181 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1182 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1183 Expect(s.Close()).To(Succeed()) 1184 Eventually(done1).Should(BeClosed()) 1185 Eventually(done2).Should(BeClosed()) 1186 }) 1187 }) 1188 1189 Context("ServeListener", func() { 1190 origQuicListen := quicListen 1191 1192 AfterEach(func() { 1193 quicListen = origQuicListen 1194 }) 1195 1196 It("serves a listener", func() { 1197 var called int32 1198 ln := newMockAddrListener(":443") 1199 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1200 atomic.StoreInt32(&called, 1) 1201 return ln, nil 1202 } 1203 1204 s := &Server{} 1205 1206 stopAccept := make(chan struct{}) 1207 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1208 <-stopAccept 1209 return nil, errors.New("closed") 1210 }) 1211 ln.EXPECT().Addr() // generate alt-svc headers 1212 done := make(chan struct{}) 1213 go func() { 1214 defer GinkgoRecover() 1215 defer close(done) 1216 s.ServeListener(ln) 1217 }() 1218 1219 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1220 Consistently(done).ShouldNot(BeClosed()) 1221 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1222 Expect(s.Close()).To(Succeed()) 1223 Eventually(done).Should(BeClosed()) 1224 }) 1225 1226 It("serves two listeners", func() { 1227 var called int32 1228 ln1 := newMockAddrListener(":443") 1229 ln2 := newMockAddrListener(":8443") 1230 lns := make(chan QUICEarlyListener, 2) 1231 lns <- ln1 1232 lns <- ln2 1233 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1234 atomic.StoreInt32(&called, 1) 1235 return <-lns, nil 1236 } 1237 1238 s := &Server{} 1239 1240 stopAccept1 := make(chan struct{}) 1241 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1242 <-stopAccept1 1243 return nil, errors.New("closed") 1244 }) 1245 ln1.EXPECT().Addr() // generate alt-svc headers 1246 stopAccept2 := make(chan struct{}) 1247 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1248 <-stopAccept2 1249 return nil, errors.New("closed") 1250 }) 1251 ln2.EXPECT().Addr() 1252 1253 done1 := make(chan struct{}) 1254 go func() { 1255 defer GinkgoRecover() 1256 defer close(done1) 1257 s.ServeListener(ln1) 1258 }() 1259 done2 := make(chan struct{}) 1260 go func() { 1261 defer GinkgoRecover() 1262 defer close(done2) 1263 s.ServeListener(ln2) 1264 }() 1265 1266 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1267 Consistently(done1).ShouldNot(BeClosed()) 1268 Expect(done2).ToNot(BeClosed()) 1269 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1270 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1271 Expect(s.Close()).To(Succeed()) 1272 Eventually(done1).Should(BeClosed()) 1273 Eventually(done2).Should(BeClosed()) 1274 }) 1275 }) 1276 1277 Context("ServeQUICConn", func() { 1278 It("serves a QUIC connection", func() { 1279 mux := http.NewServeMux() 1280 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1281 w.Write([]byte("foobar")) 1282 }) 1283 s.Handler = mux 1284 tlsConf := testdata.GetTLSConfig() 1285 tlsConf.NextProtos = []string{NextProtoH3} 1286 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1287 controlStr := mockquic.NewMockStream(mockCtrl) 1288 controlStr.EXPECT().Write(gomock.Any()) 1289 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1290 testDone := make(chan struct{}) 1291 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1292 <-testDone 1293 return nil, errors.New("test done") 1294 }).MaxTimes(1) 1295 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1296 s.ServeQUICConn(conn) 1297 close(testDone) 1298 }) 1299 }) 1300 1301 Context("ListenAndServe", func() { 1302 BeforeEach(func() { 1303 s.Addr = "localhost:0" 1304 }) 1305 1306 AfterEach(func() { 1307 Expect(s.Close()).To(Succeed()) 1308 }) 1309 1310 It("uses the quic.Config to start the QUIC server", func() { 1311 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1312 var receivedConf *quic.Config 1313 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1314 receivedConf = config 1315 return nil, errors.New("listen err") 1316 } 1317 s.QuicConfig = conf 1318 Expect(s.ListenAndServe()).To(HaveOccurred()) 1319 Expect(receivedConf).To(Equal(conf)) 1320 }) 1321 }) 1322 1323 It("closes gracefully", func() { 1324 Expect(s.CloseGracefully(0)).To(Succeed()) 1325 }) 1326 1327 It("errors when listening fails", func() { 1328 testErr := errors.New("listen error") 1329 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1330 return nil, testErr 1331 } 1332 fullpem, privkey := testdata.GetCertificatePaths() 1333 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1334 }) 1335 1336 It("supports H3_DATAGRAM", func() { 1337 s.EnableDatagrams = true 1338 var receivedConf *quic.Config 1339 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1340 receivedConf = config 1341 return nil, errors.New("listen err") 1342 } 1343 Expect(s.ListenAndServe()).To(HaveOccurred()) 1344 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1345 }) 1346 })