github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/http3/server_test.go (about) 1 package http3 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "errors" 8 "fmt" 9 "io" 10 "net" 11 "net/http" 12 "runtime" 13 "sync/atomic" 14 "time" 15 16 "github.com/TugasAkhir-QUIC/quic-go" 17 mockquic "github.com/TugasAkhir-QUIC/quic-go/internal/mocks/quic" 18 "github.com/TugasAkhir-QUIC/quic-go/internal/protocol" 19 "github.com/TugasAkhir-QUIC/quic-go/internal/testdata" 20 "github.com/TugasAkhir-QUIC/quic-go/internal/utils" 21 "github.com/TugasAkhir-QUIC/quic-go/quicvarint" 22 23 "github.com/quic-go/qpack" 24 "go.uber.org/mock/gomock" 25 26 . "github.com/onsi/ginkgo/v2" 27 . "github.com/onsi/gomega" 28 gmtypes "github.com/onsi/gomega/types" 29 ) 30 31 type mockAddr struct{ addr string } 32 33 func (ma *mockAddr) Network() string { return "udp" } 34 func (ma *mockAddr) String() string { return ma.addr } 35 36 type mockAddrListener struct { 37 *MockQUICEarlyListener 38 addr *mockAddr 39 } 40 41 func (m *mockAddrListener) Addr() net.Addr { 42 _ = m.MockQUICEarlyListener.Addr() 43 return m.addr 44 } 45 46 func newMockAddrListener(addr string) *mockAddrListener { 47 return &mockAddrListener{ 48 MockQUICEarlyListener: NewMockQUICEarlyListener(mockCtrl), 49 addr: &mockAddr{addr: addr}, 50 } 51 } 52 53 type noPortListener struct { 54 *mockAddrListener 55 } 56 57 func (m *noPortListener) Addr() net.Addr { 58 _ = m.mockAddrListener.Addr() 59 return &net.UnixAddr{ 60 Net: "unix", 61 Name: "/tmp/quic.sock", 62 } 63 } 64 65 var _ = Describe("Server", func() { 66 var ( 67 s *Server 68 origQuicListenAddr = quicListenAddr 69 ) 70 type testConnContextKey string 71 72 BeforeEach(func() { 73 s = &Server{ 74 TLSConfig: testdata.GetTLSConfig(), 75 logger: utils.DefaultLogger, 76 ConnContext: func(ctx context.Context, c quic.Connection) context.Context { 77 return context.WithValue(ctx, testConnContextKey("test"), c) 78 }, 79 } 80 origQuicListenAddr = quicListenAddr 81 }) 82 83 AfterEach(func() { 84 quicListenAddr = origQuicListenAddr 85 }) 86 87 Context("handling requests", func() { 88 var ( 89 qpackDecoder *qpack.Decoder 90 str *mockquic.MockStream 91 conn *mockquic.MockEarlyConnection 92 exampleGetRequest *http.Request 93 examplePostRequest *http.Request 94 ) 95 reqContext := context.Background() 96 97 decodeHeader := func(str io.Reader) map[string][]string { 98 fields := make(map[string][]string) 99 decoder := qpack.NewDecoder(nil) 100 101 frame, err := parseNextFrame(str, nil) 102 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 103 ExpectWithOffset(1, frame).To(BeAssignableToTypeOf(&headersFrame{})) 104 headersFrame := frame.(*headersFrame) 105 data := make([]byte, headersFrame.Length) 106 _, err = io.ReadFull(str, data) 107 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 108 hfs, err := decoder.DecodeFull(data) 109 ExpectWithOffset(1, err).ToNot(HaveOccurred()) 110 for _, p := range hfs { 111 fields[p.Name] = append(fields[p.Name], p.Value) 112 } 113 return fields 114 } 115 116 encodeRequest := func(req *http.Request) []byte { 117 buf := &bytes.Buffer{} 118 str := mockquic.NewMockStream(mockCtrl) 119 str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write).AnyTimes() 120 rw := newRequestWriter(utils.DefaultLogger) 121 Expect(rw.WriteRequestHeader(str, req, false)).To(Succeed()) 122 return buf.Bytes() 123 } 124 125 setRequest := func(data []byte) { 126 buf := bytes.NewBuffer(data) 127 str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 128 if buf.Len() == 0 { 129 return 0, io.EOF 130 } 131 return buf.Read(p) 132 }).AnyTimes() 133 } 134 135 BeforeEach(func() { 136 var err error 137 exampleGetRequest, err = http.NewRequest("GET", "https://www.example.com", nil) 138 Expect(err).ToNot(HaveOccurred()) 139 examplePostRequest, err = http.NewRequest("POST", "https://www.example.com", bytes.NewReader([]byte("foobar"))) 140 Expect(err).ToNot(HaveOccurred()) 141 142 qpackDecoder = qpack.NewDecoder(nil) 143 str = mockquic.NewMockStream(mockCtrl) 144 conn = mockquic.NewMockEarlyConnection(mockCtrl) 145 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 146 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 147 conn.EXPECT().LocalAddr().AnyTimes() 148 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 149 }) 150 151 It("calls the HTTP handler function", func() { 152 requestChan := make(chan *http.Request, 1) 153 s.Handler = http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { 154 requestChan <- r 155 }) 156 157 setRequest(encodeRequest(exampleGetRequest)) 158 str.EXPECT().Context().Return(reqContext) 159 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 160 return len(p), nil 161 }).AnyTimes() 162 str.EXPECT().CancelRead(gomock.Any()) 163 164 Expect(s.handleRequest(conn, str, qpackDecoder, nil)).To(Equal(requestError{})) 165 var req *http.Request 166 Eventually(requestChan).Should(Receive(&req)) 167 Expect(req.Host).To(Equal("www.example.com")) 168 Expect(req.RemoteAddr).To(Equal("127.0.0.1:1337")) 169 Expect(req.Context().Value(ServerContextKey)).To(Equal(s)) 170 Expect(req.Context().Value(testConnContextKey("test"))).ToNot(Equal(nil)) 171 }) 172 173 It("returns 200 with an empty handler", func() { 174 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 175 176 responseBuf := &bytes.Buffer{} 177 setRequest(encodeRequest(exampleGetRequest)) 178 str.EXPECT().Context().Return(reqContext) 179 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 180 str.EXPECT().CancelRead(gomock.Any()) 181 182 serr := s.handleRequest(conn, str, qpackDecoder, nil) 183 Expect(serr.err).ToNot(HaveOccurred()) 184 hfs := decodeHeader(responseBuf) 185 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 186 }) 187 188 It("sets Content-Length when the handler doesn't flush to the client", func() { 189 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 190 w.Write([]byte("foobar")) 191 }) 192 193 responseBuf := &bytes.Buffer{} 194 setRequest(encodeRequest(exampleGetRequest)) 195 str.EXPECT().Context().Return(reqContext) 196 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 197 str.EXPECT().CancelRead(gomock.Any()) 198 199 serr := s.handleRequest(conn, str, qpackDecoder, nil) 200 Expect(serr.err).ToNot(HaveOccurred()) 201 hfs := decodeHeader(responseBuf) 202 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 203 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"6"})) 204 // status, content-length, date, content-type 205 Expect(hfs).To(HaveLen(4)) 206 }) 207 208 It("not sets Content-Length when the handler flushes to the client", func() { 209 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 210 w.Write([]byte("foobar")) 211 // force flush 212 w.(http.Flusher).Flush() 213 }) 214 215 responseBuf := &bytes.Buffer{} 216 setRequest(encodeRequest(exampleGetRequest)) 217 str.EXPECT().Context().Return(reqContext) 218 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 219 str.EXPECT().CancelRead(gomock.Any()) 220 221 serr := s.handleRequest(conn, str, qpackDecoder, nil) 222 Expect(serr.err).ToNot(HaveOccurred()) 223 hfs := decodeHeader(responseBuf) 224 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 225 // status, date, content-type 226 Expect(hfs).To(HaveLen(3)) 227 }) 228 229 It("response to HEAD request should not have body", func() { 230 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 231 w.Write([]byte("foobar")) 232 }) 233 234 headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) 235 Expect(err).ToNot(HaveOccurred()) 236 responseBuf := &bytes.Buffer{} 237 setRequest(encodeRequest(headRequest)) 238 str.EXPECT().Context().Return(reqContext) 239 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 240 str.EXPECT().CancelRead(gomock.Any()) 241 serr := s.handleRequest(conn, str, qpackDecoder, nil) 242 Expect(serr.err).ToNot(HaveOccurred()) 243 hfs := decodeHeader(responseBuf) 244 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 245 Expect(responseBuf.Bytes()).To(HaveLen(0)) 246 }) 247 248 It("response to HEAD request should also do content sniffing", func() { 249 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 250 w.Write([]byte("<html></html>")) 251 }) 252 253 headRequest, err := http.NewRequest("HEAD", "https://www.example.com", nil) 254 Expect(err).ToNot(HaveOccurred()) 255 responseBuf := &bytes.Buffer{} 256 setRequest(encodeRequest(headRequest)) 257 str.EXPECT().Context().Return(reqContext) 258 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 259 str.EXPECT().CancelRead(gomock.Any()) 260 serr := s.handleRequest(conn, str, qpackDecoder, nil) 261 Expect(serr.err).ToNot(HaveOccurred()) 262 hfs := decodeHeader(responseBuf) 263 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 264 Expect(hfs).To(HaveKeyWithValue("content-length", []string{"13"})) 265 Expect(hfs).To(HaveKeyWithValue("content-type", []string{"text/html; charset=utf-8"})) 266 }) 267 268 It("handles a aborting handler", func() { 269 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 270 panic(http.ErrAbortHandler) 271 }) 272 273 responseBuf := &bytes.Buffer{} 274 setRequest(encodeRequest(exampleGetRequest)) 275 str.EXPECT().Context().Return(reqContext) 276 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 277 str.EXPECT().CancelRead(gomock.Any()) 278 279 serr := s.handleRequest(conn, str, qpackDecoder, nil) 280 Expect(serr.err).To(MatchError(errPanicked)) 281 Expect(responseBuf.Bytes()).To(HaveLen(0)) 282 }) 283 284 It("handles a panicking handler", func() { 285 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 286 panic("foobar") 287 }) 288 289 responseBuf := &bytes.Buffer{} 290 setRequest(encodeRequest(exampleGetRequest)) 291 str.EXPECT().Context().Return(reqContext) 292 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 293 str.EXPECT().CancelRead(gomock.Any()) 294 295 serr := s.handleRequest(conn, str, qpackDecoder, nil) 296 Expect(serr.err).To(MatchError(errPanicked)) 297 Expect(responseBuf.Bytes()).To(HaveLen(0)) 298 }) 299 300 Context("hijacking bidirectional streams", func() { 301 var conn *mockquic.MockEarlyConnection 302 testDone := make(chan struct{}) 303 304 BeforeEach(func() { 305 testDone = make(chan struct{}) 306 conn = mockquic.NewMockEarlyConnection(mockCtrl) 307 controlStr := mockquic.NewMockStream(mockCtrl) 308 controlStr.EXPECT().Write(gomock.Any()) 309 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 310 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 311 conn.EXPECT().LocalAddr().AnyTimes() 312 }) 313 314 AfterEach(func() { testDone <- struct{}{} }) 315 316 It("hijacks a bidirectional stream of unknown frame type", func() { 317 frameTypeChan := make(chan FrameType, 1) 318 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 319 Expect(e).ToNot(HaveOccurred()) 320 frameTypeChan <- ft 321 return true, nil 322 } 323 324 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 325 unknownStr := mockquic.NewMockStream(mockCtrl) 326 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 327 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 328 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 329 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 330 <-testDone 331 return nil, errors.New("test done") 332 }) 333 s.handleConn(conn) 334 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 335 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 336 }) 337 338 It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { 339 frameTypeChan := make(chan FrameType, 1) 340 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 341 Expect(e).ToNot(HaveOccurred()) 342 frameTypeChan <- ft 343 return false, nil 344 } 345 346 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 347 unknownStr := mockquic.NewMockStream(mockCtrl) 348 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 349 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 350 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 351 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 352 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 353 <-testDone 354 return nil, errors.New("test done") 355 }) 356 s.handleConn(conn) 357 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 358 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 359 }) 360 361 It("cancels writing when hijacker returned error", func() { 362 frameTypeChan := make(chan FrameType, 1) 363 s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) { 364 Expect(e).ToNot(HaveOccurred()) 365 frameTypeChan <- ft 366 return false, errors.New("error in hijacker") 367 } 368 369 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x41)) 370 unknownStr := mockquic.NewMockStream(mockCtrl) 371 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 372 unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)) 373 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 374 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 375 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 376 <-testDone 377 return nil, errors.New("test done") 378 }) 379 s.handleConn(conn) 380 Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) 381 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 382 }) 383 384 It("handles errors that occur when reading the stream type", func() { 385 testErr := errors.New("test error") 386 done := make(chan struct{}) 387 unknownStr := mockquic.NewMockStream(mockCtrl) 388 s.StreamHijacker = func(ft FrameType, _ quic.Connection, str quic.Stream, err error) (bool, error) { 389 defer close(done) 390 Expect(ft).To(BeZero()) 391 Expect(str).To(Equal(unknownStr)) 392 Expect(err).To(MatchError(testErr)) 393 return true, nil 394 } 395 396 unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes() 397 conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) 398 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 399 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 400 <-testDone 401 return nil, errors.New("test done") 402 }) 403 s.handleConn(conn) 404 Eventually(done).Should(BeClosed()) 405 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 406 }) 407 }) 408 409 Context("hijacking unidirectional streams", func() { 410 var conn *mockquic.MockEarlyConnection 411 testDone := make(chan struct{}) 412 413 BeforeEach(func() { 414 testDone = make(chan struct{}) 415 conn = mockquic.NewMockEarlyConnection(mockCtrl) 416 controlStr := mockquic.NewMockStream(mockCtrl) 417 controlStr.EXPECT().Write(gomock.Any()) 418 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 419 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 420 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 421 conn.EXPECT().LocalAddr().AnyTimes() 422 }) 423 424 AfterEach(func() { testDone <- struct{}{} }) 425 426 It("hijacks an unidirectional stream of unknown stream type", func() { 427 streamTypeChan := make(chan StreamType, 1) 428 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 429 Expect(err).ToNot(HaveOccurred()) 430 streamTypeChan <- st 431 return true 432 } 433 434 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 435 unknownStr := mockquic.NewMockStream(mockCtrl) 436 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 437 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 438 return unknownStr, nil 439 }) 440 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 441 <-testDone 442 return nil, errors.New("test done") 443 }) 444 s.handleConn(conn) 445 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 446 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 447 }) 448 449 It("handles errors that occur when reading the stream type", func() { 450 testErr := errors.New("test error") 451 done := make(chan struct{}) 452 unknownStr := mockquic.NewMockStream(mockCtrl) 453 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool { 454 defer close(done) 455 Expect(st).To(BeZero()) 456 Expect(str).To(Equal(unknownStr)) 457 Expect(err).To(MatchError(testErr)) 458 return true 459 } 460 461 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr }) 462 conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil) 463 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 464 <-testDone 465 return nil, errors.New("test done") 466 }) 467 s.handleConn(conn) 468 Eventually(done).Should(BeClosed()) 469 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 470 }) 471 472 It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { 473 streamTypeChan := make(chan StreamType, 1) 474 s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool { 475 Expect(err).ToNot(HaveOccurred()) 476 streamTypeChan <- st 477 return false 478 } 479 480 buf := bytes.NewBuffer(quicvarint.Append(nil, 0x54)) 481 unknownStr := mockquic.NewMockStream(mockCtrl) 482 unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 483 unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)) 484 485 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 486 return unknownStr, nil 487 }) 488 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 489 <-testDone 490 return nil, errors.New("test done") 491 }) 492 s.handleConn(conn) 493 Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) 494 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 495 }) 496 }) 497 498 Context("control stream handling", func() { 499 var conn *mockquic.MockEarlyConnection 500 testDone := make(chan struct{}) 501 502 BeforeEach(func() { 503 conn = mockquic.NewMockEarlyConnection(mockCtrl) 504 controlStr := mockquic.NewMockStream(mockCtrl) 505 controlStr.EXPECT().Write(gomock.Any()) 506 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 507 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 508 conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() 509 conn.EXPECT().LocalAddr().AnyTimes() 510 }) 511 512 AfterEach(func() { testDone <- struct{}{} }) 513 514 It("parses the SETTINGS frame", func() { 515 b := quicvarint.Append(nil, streamTypeControlStream) 516 b = (&settingsFrame{}).Append(b) 517 controlStr := mockquic.NewMockStream(mockCtrl) 518 r := bytes.NewReader(b) 519 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 520 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 521 return controlStr, nil 522 }) 523 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 524 <-testDone 525 return nil, errors.New("test done") 526 }) 527 s.handleConn(conn) 528 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError 529 }) 530 531 for _, t := range []uint64{streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream} { 532 streamType := t 533 name := "encoder" 534 if streamType == streamTypeQPACKDecoderStream { 535 name = "decoder" 536 } 537 538 It(fmt.Sprintf("ignores the QPACK %s streams", name), func() { 539 buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) 540 str := mockquic.NewMockStream(mockCtrl) 541 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 542 543 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 544 return str, nil 545 }) 546 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 547 <-testDone 548 return nil, errors.New("test done") 549 }) 550 s.handleConn(conn) 551 time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to str.CancelRead 552 }) 553 } 554 555 It("reset streams other than the control stream and the QPACK streams", func() { 556 buf := bytes.NewBuffer(quicvarint.Append(nil, 0o1337)) 557 str := mockquic.NewMockStream(mockCtrl) 558 str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() 559 done := make(chan struct{}) 560 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError)).Do(func(quic.StreamErrorCode) { close(done) }) 561 562 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 563 return str, nil 564 }) 565 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 566 <-testDone 567 return nil, errors.New("test done") 568 }) 569 s.handleConn(conn) 570 Eventually(done).Should(BeClosed()) 571 }) 572 573 It("errors when the first frame on the control stream is not a SETTINGS frame", func() { 574 b := quicvarint.Append(nil, streamTypeControlStream) 575 b = (&dataFrame{}).Append(b) 576 controlStr := mockquic.NewMockStream(mockCtrl) 577 r := bytes.NewReader(b) 578 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 579 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 580 return controlStr, nil 581 }) 582 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 583 <-testDone 584 return nil, errors.New("test done") 585 }) 586 done := make(chan struct{}) 587 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 588 close(done) 589 return nil 590 }) 591 s.handleConn(conn) 592 Eventually(done).Should(BeClosed()) 593 }) 594 595 It("errors when parsing the frame on the control stream fails", func() { 596 b := quicvarint.Append(nil, streamTypeControlStream) 597 b = (&settingsFrame{}).Append(b) 598 r := bytes.NewReader(b[:len(b)-1]) 599 controlStr := mockquic.NewMockStream(mockCtrl) 600 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 601 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 602 return controlStr, nil 603 }) 604 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 605 <-testDone 606 return nil, errors.New("test done") 607 }) 608 done := make(chan struct{}) 609 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 610 close(done) 611 return nil 612 }) 613 s.handleConn(conn) 614 Eventually(done).Should(BeClosed()) 615 }) 616 617 It("errors when the client opens a push stream", func() { 618 b := quicvarint.Append(nil, streamTypePushStream) 619 b = (&dataFrame{}).Append(b) 620 r := bytes.NewReader(b) 621 controlStr := mockquic.NewMockStream(mockCtrl) 622 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 623 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 624 return controlStr, nil 625 }) 626 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 627 <-testDone 628 return nil, errors.New("test done") 629 }) 630 done := make(chan struct{}) 631 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 632 close(done) 633 return nil 634 }) 635 s.handleConn(conn) 636 Eventually(done).Should(BeClosed()) 637 }) 638 639 It("errors when the client advertises datagram support (and we enabled support for it)", func() { 640 s.EnableDatagrams = true 641 b := quicvarint.Append(nil, streamTypeControlStream) 642 b = (&settingsFrame{Datagram: true}).Append(b) 643 r := bytes.NewReader(b) 644 controlStr := mockquic.NewMockStream(mockCtrl) 645 controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(r.Read).AnyTimes() 646 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 647 return controlStr, nil 648 }) 649 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 650 <-testDone 651 return nil, errors.New("test done") 652 }) 653 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{SupportsDatagrams: false}) 654 done := make(chan struct{}) 655 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support").Do(func(quic.ApplicationErrorCode, string) error { 656 close(done) 657 return nil 658 }) 659 s.handleConn(conn) 660 Eventually(done).Should(BeClosed()) 661 }) 662 }) 663 664 Context("stream- and connection-level errors", func() { 665 var conn *mockquic.MockEarlyConnection 666 testDone := make(chan struct{}) 667 668 BeforeEach(func() { 669 testDone = make(chan struct{}) 670 addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} 671 conn = mockquic.NewMockEarlyConnection(mockCtrl) 672 controlStr := mockquic.NewMockStream(mockCtrl) 673 controlStr.EXPECT().Write(gomock.Any()) 674 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 675 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 676 <-testDone 677 return nil, errors.New("test done") 678 }) 679 conn.EXPECT().AcceptStream(gomock.Any()).Return(str, nil) 680 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) 681 conn.EXPECT().RemoteAddr().Return(addr).AnyTimes() 682 conn.EXPECT().LocalAddr().AnyTimes() 683 conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}).AnyTimes() 684 }) 685 686 AfterEach(func() { testDone <- struct{}{} }) 687 688 It("cancels reading when client sends a body in GET request", func() { 689 var handlerCalled bool 690 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 691 handlerCalled = true 692 }) 693 694 requestData := encodeRequest(exampleGetRequest) 695 b := (&dataFrame{Length: 6}).Append(nil) // add a body 696 b = append(b, []byte("foobar")...) 697 responseBuf := &bytes.Buffer{} 698 setRequest(append(requestData, b...)) 699 done := make(chan struct{}) 700 str.EXPECT().Context().Return(reqContext) 701 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 702 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 703 str.EXPECT().Close().Do(func() error { close(done); return nil }) 704 705 s.handleConn(conn) 706 Eventually(done).Should(BeClosed()) 707 hfs := decodeHeader(responseBuf) 708 Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) 709 Expect(handlerCalled).To(BeTrue()) 710 }) 711 712 It("doesn't close the stream if the stream was hijacked (via HTTPStream)", func() { 713 handlerCalled := make(chan struct{}) 714 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 715 defer close(handlerCalled) 716 r.Body.(HTTPStreamer).HTTPStream() 717 str.Write([]byte("foobar")) 718 }) 719 720 requestData := encodeRequest(exampleGetRequest) 721 b := (&dataFrame{Length: 6}).Append(nil) // add a body 722 b = append(b, []byte("foobar")...) 723 setRequest(append(requestData, b...)) 724 str.EXPECT().Context().Return(reqContext) 725 str.EXPECT().Write([]byte("foobar")).Return(6, nil) 726 727 s.handleConn(conn) 728 Eventually(handlerCalled).Should(BeClosed()) 729 }) 730 731 It("errors when the client sends a too large header frame", func() { 732 s.MaxHeaderBytes = 20 733 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 734 Fail("Handler should not be called.") 735 }) 736 737 requestData := encodeRequest(exampleGetRequest) 738 b := (&dataFrame{Length: 6}).Append(nil) // add a body 739 b = append(b, []byte("foobar")...) 740 responseBuf := &bytes.Buffer{} 741 setRequest(append(requestData, b...)) 742 done := make(chan struct{}) 743 str.EXPECT().Write(gomock.Any()).DoAndReturn(responseBuf.Write).AnyTimes() 744 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 745 746 s.handleConn(conn) 747 Eventually(done).Should(BeClosed()) 748 }) 749 750 It("handles a request for which the client immediately resets the stream", func() { 751 handlerCalled := make(chan struct{}) 752 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 753 close(handlerCalled) 754 }) 755 756 testErr := errors.New("stream reset") 757 done := make(chan struct{}) 758 str.EXPECT().Read(gomock.Any()).Return(0, testErr) 759 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeRequestIncomplete)).Do(func(quic.StreamErrorCode) { close(done) }) 760 761 s.handleConn(conn) 762 Consistently(handlerCalled).ShouldNot(BeClosed()) 763 }) 764 765 It("closes the connection when the first frame is not a HEADERS frame", func() { 766 handlerCalled := make(chan struct{}) 767 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 768 close(handlerCalled) 769 }) 770 771 b := (&dataFrame{}).Append(nil) 772 setRequest(b) 773 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 774 return len(p), nil 775 }).AnyTimes() 776 777 done := make(chan struct{}) 778 conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameUnexpected), gomock.Any()).Do(func(quic.ApplicationErrorCode, string) error { 779 close(done) 780 return nil 781 }) 782 s.handleConn(conn) 783 Eventually(done).Should(BeClosed()) 784 }) 785 786 It("closes the connection when the first frame is not a HEADERS frame", func() { 787 handlerCalled := make(chan struct{}) 788 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 789 close(handlerCalled) 790 }) 791 792 // use 2*DefaultMaxHeaderBytes here. qpack will compress the requiest, 793 // but the request will still end up larger than DefaultMaxHeaderBytes. 794 url := bytes.Repeat([]byte{'a'}, http.DefaultMaxHeaderBytes*2) 795 req, err := http.NewRequest(http.MethodGet, "https://"+string(url), nil) 796 Expect(err).ToNot(HaveOccurred()) 797 setRequest(encodeRequest(req)) 798 // str.EXPECT().Context().Return(reqContext) 799 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 800 return len(p), nil 801 }).AnyTimes() 802 done := make(chan struct{}) 803 str.EXPECT().CancelWrite(quic.StreamErrorCode(ErrCodeFrameError)).Do(func(quic.StreamErrorCode) { close(done) }) 804 805 s.handleConn(conn) 806 Eventually(done).Should(BeClosed()) 807 }) 808 }) 809 810 It("resets the stream when the body of POST request is not read, and the request handler replaces the request.Body", func() { 811 handlerCalled := make(chan struct{}) 812 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 813 r.Body = struct { 814 io.Reader 815 io.Closer 816 }{} 817 close(handlerCalled) 818 }) 819 820 setRequest(encodeRequest(examplePostRequest)) 821 str.EXPECT().Context().Return(reqContext) 822 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 823 return len(p), nil 824 }).AnyTimes() 825 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 826 827 serr := s.handleRequest(conn, str, qpackDecoder, nil) 828 Expect(serr.err).ToNot(HaveOccurred()) 829 Eventually(handlerCalled).Should(BeClosed()) 830 }) 831 832 It("cancels the request context when the stream is closed", func() { 833 handlerCalled := make(chan struct{}) 834 s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 835 defer GinkgoRecover() 836 Expect(r.Context().Done()).To(BeClosed()) 837 Expect(r.Context().Err()).To(MatchError(context.Canceled)) 838 close(handlerCalled) 839 }) 840 setRequest(encodeRequest(examplePostRequest)) 841 842 reqContext, cancel := context.WithCancel(context.Background()) 843 cancel() 844 str.EXPECT().Context().Return(reqContext) 845 str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { 846 return len(p), nil 847 }).AnyTimes() 848 str.EXPECT().CancelRead(quic.StreamErrorCode(ErrCodeNoError)) 849 850 serr := s.handleRequest(conn, str, qpackDecoder, nil) 851 Expect(serr.err).ToNot(HaveOccurred()) 852 Eventually(handlerCalled).Should(BeClosed()) 853 }) 854 }) 855 856 Context("setting http headers", func() { 857 BeforeEach(func() { 858 s.QuicConfig = &quic.Config{Versions: []protocol.Version{protocol.Version1}} 859 }) 860 861 var ln1 QUICEarlyListener 862 var ln2 QUICEarlyListener 863 expected := http.Header{ 864 "Alt-Svc": {`h3=":443"; ma=2592000`}, 865 } 866 867 addListener := func(addr string, ln *QUICEarlyListener) { 868 mln := newMockAddrListener(addr) 869 mln.EXPECT().Addr() 870 *ln = mln 871 s.addListener(ln) 872 } 873 874 removeListener := func(ln *QUICEarlyListener) { 875 s.removeListener(ln) 876 } 877 878 checkSetHeaders := func(expected gmtypes.GomegaMatcher) { 879 hdr := http.Header{} 880 Expect(s.SetQuicHeaders(hdr)).To(Succeed()) 881 Expect(hdr).To(expected) 882 } 883 884 checkSetHeaderError := func() { 885 hdr := http.Header{} 886 Expect(s.SetQuicHeaders(hdr)).To(Equal(ErrNoAltSvcPort)) 887 } 888 889 It("sets proper headers with numeric port", func() { 890 addListener(":443", &ln1) 891 checkSetHeaders(Equal(expected)) 892 removeListener(&ln1) 893 checkSetHeaderError() 894 }) 895 896 It("sets proper headers with full addr", func() { 897 addListener("127.0.0.1:443", &ln1) 898 checkSetHeaders(Equal(expected)) 899 removeListener(&ln1) 900 checkSetHeaderError() 901 }) 902 903 It("sets proper headers with string port", func() { 904 addListener(":https", &ln1) 905 checkSetHeaders(Equal(expected)) 906 removeListener(&ln1) 907 checkSetHeaderError() 908 }) 909 910 It("works multiple times", func() { 911 addListener(":https", &ln1) 912 checkSetHeaders(Equal(expected)) 913 checkSetHeaders(Equal(expected)) 914 removeListener(&ln1) 915 checkSetHeaderError() 916 }) 917 918 It("works if the quic.Config sets QUIC versions", func() { 919 s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version2} 920 addListener(":443", &ln1) 921 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 922 removeListener(&ln1) 923 checkSetHeaderError() 924 }) 925 926 It("uses s.Port if set to a non-zero value", func() { 927 s.Port = 8443 928 addListener(":443", &ln1) 929 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000`}})) 930 removeListener(&ln1) 931 checkSetHeaderError() 932 }) 933 934 It("uses s.Addr if listeners don't have ports available", func() { 935 s.Addr = ":443" 936 mln := &noPortListener{newMockAddrListener("")} 937 mln.EXPECT().Addr() 938 ln1 = mln 939 s.addListener(&ln1) 940 checkSetHeaders(Equal(expected)) 941 s.removeListener(&ln1) 942 checkSetHeaderError() 943 }) 944 945 It("properly announces multiple listeners", func() { 946 addListener(":443", &ln1) 947 addListener(":8443", &ln2) 948 checkSetHeaders(Or( 949 Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000,h3=":8443"; ma=2592000`}}), 950 Equal(http.Header{"Alt-Svc": {`h3=":8443"; ma=2592000,h3=":443"; ma=2592000`}}), 951 )) 952 removeListener(&ln1) 953 removeListener(&ln2) 954 checkSetHeaderError() 955 }) 956 957 It("doesn't duplicate Alt-Svc values", func() { 958 s.QuicConfig.Versions = []quic.Version{quic.Version1, quic.Version1} 959 addListener(":443", &ln1) 960 checkSetHeaders(Equal(http.Header{"Alt-Svc": {`h3=":443"; ma=2592000`}})) 961 removeListener(&ln1) 962 checkSetHeaderError() 963 }) 964 }) 965 966 It("errors when ListenAndServe is called with s.TLSConfig nil", func() { 967 Expect((&Server{}).ListenAndServe()).To(MatchError(errServerWithoutTLSConfig)) 968 }) 969 970 It("should nop-Close() when s.server is nil", func() { 971 Expect((&Server{}).Close()).To(Succeed()) 972 }) 973 974 It("errors when ListenAndServeTLS is called after Close", func() { 975 serv := &Server{} 976 Expect(serv.Close()).To(Succeed()) 977 Expect(serv.ListenAndServeTLS(testdata.GetCertificatePaths())).To(MatchError(http.ErrServerClosed)) 978 }) 979 980 It("handles concurrent Serve and Close", func() { 981 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 982 Expect(err).ToNot(HaveOccurred()) 983 c, err := net.ListenUDP("udp", addr) 984 Expect(err).ToNot(HaveOccurred()) 985 done := make(chan struct{}) 986 go func() { 987 defer GinkgoRecover() 988 defer close(done) 989 s.Serve(c) 990 }() 991 runtime.Gosched() 992 s.Close() 993 Eventually(done).Should(BeClosed()) 994 }) 995 996 Context("ConfigureTLSConfig", func() { 997 It("advertises v1 by default", func() { 998 conf := ConfigureTLSConfig(testdata.GetTLSConfig()) 999 ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.Version{quic.Version1}}) 1000 Expect(err).ToNot(HaveOccurred()) 1001 defer ln.Close() 1002 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1003 Expect(err).ToNot(HaveOccurred()) 1004 defer c.CloseWithError(0, "") 1005 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1006 }) 1007 1008 It("sets the GetConfigForClient callback if no tls.Config is given", func() { 1009 var receivedConf *tls.Config 1010 quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (QUICEarlyListener, error) { 1011 receivedConf = tlsConf 1012 return nil, errors.New("listen err") 1013 } 1014 Expect(s.ListenAndServe()).To(HaveOccurred()) 1015 Expect(receivedConf).ToNot(BeNil()) 1016 }) 1017 1018 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { 1019 tlsConf := &tls.Config{ 1020 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 1021 c := testdata.GetTLSConfig() 1022 c.NextProtos = []string{"foo", "bar"} 1023 return c, nil 1024 }, 1025 } 1026 1027 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1028 Expect(err).ToNot(HaveOccurred()) 1029 defer ln.Close() 1030 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1031 Expect(err).ToNot(HaveOccurred()) 1032 defer c.CloseWithError(0, "") 1033 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1034 }) 1035 1036 It("works if GetConfigForClient returns a nil tls.Config", func() { 1037 tlsConf := testdata.GetTLSConfig() 1038 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } 1039 1040 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1041 Expect(err).ToNot(HaveOccurred()) 1042 defer ln.Close() 1043 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1044 Expect(err).ToNot(HaveOccurred()) 1045 defer c.CloseWithError(0, "") 1046 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1047 }) 1048 1049 It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { 1050 tlsClientConf := testdata.GetTLSConfig() 1051 tlsClientConf.NextProtos = []string{"foo", "bar"} 1052 tlsConf := &tls.Config{ 1053 GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { 1054 return tlsClientConf, nil 1055 }, 1056 } 1057 1058 ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.Version{quic.Version1}}) 1059 Expect(err).ToNot(HaveOccurred()) 1060 defer ln.Close() 1061 c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) 1062 Expect(err).ToNot(HaveOccurred()) 1063 defer c.CloseWithError(0, "") 1064 Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) 1065 // check that the original config was not modified 1066 Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) 1067 }) 1068 }) 1069 1070 Context("Serve", func() { 1071 origQuicListen := quicListen 1072 1073 AfterEach(func() { 1074 quicListen = origQuicListen 1075 }) 1076 1077 It("serves a packet conn", func() { 1078 ln := newMockAddrListener(":443") 1079 conn := &net.UDPConn{} 1080 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1081 Expect(c).To(Equal(conn)) 1082 return ln, nil 1083 } 1084 1085 s := &Server{ 1086 TLSConfig: &tls.Config{}, 1087 } 1088 1089 stopAccept := make(chan struct{}) 1090 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1091 <-stopAccept 1092 return nil, errors.New("closed") 1093 }) 1094 ln.EXPECT().Addr() // generate alt-svc headers 1095 done := make(chan struct{}) 1096 go func() { 1097 defer GinkgoRecover() 1098 defer close(done) 1099 s.Serve(conn) 1100 }() 1101 1102 Consistently(done).ShouldNot(BeClosed()) 1103 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1104 Expect(s.Close()).To(Succeed()) 1105 Eventually(done).Should(BeClosed()) 1106 }) 1107 1108 It("serves two packet conns", func() { 1109 ln1 := newMockAddrListener(":443") 1110 ln2 := newMockAddrListener(":8443") 1111 lns := make(chan QUICEarlyListener, 2) 1112 lns <- ln1 1113 lns <- ln2 1114 conn1 := &net.UDPConn{} 1115 conn2 := &net.UDPConn{} 1116 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1117 return <-lns, nil 1118 } 1119 1120 s := &Server{ 1121 TLSConfig: &tls.Config{}, 1122 } 1123 1124 stopAccept1 := make(chan struct{}) 1125 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1126 <-stopAccept1 1127 return nil, errors.New("closed") 1128 }) 1129 ln1.EXPECT().Addr() // generate alt-svc headers 1130 stopAccept2 := make(chan struct{}) 1131 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1132 <-stopAccept2 1133 return nil, errors.New("closed") 1134 }) 1135 ln2.EXPECT().Addr() 1136 1137 done1 := make(chan struct{}) 1138 go func() { 1139 defer GinkgoRecover() 1140 defer close(done1) 1141 s.Serve(conn1) 1142 }() 1143 done2 := make(chan struct{}) 1144 go func() { 1145 defer GinkgoRecover() 1146 defer close(done2) 1147 s.Serve(conn2) 1148 }() 1149 1150 Consistently(done1).ShouldNot(BeClosed()) 1151 Expect(done2).ToNot(BeClosed()) 1152 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1153 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1154 Expect(s.Close()).To(Succeed()) 1155 Eventually(done1).Should(BeClosed()) 1156 Eventually(done2).Should(BeClosed()) 1157 }) 1158 }) 1159 1160 Context("ServeListener", func() { 1161 origQuicListen := quicListen 1162 1163 AfterEach(func() { 1164 quicListen = origQuicListen 1165 }) 1166 1167 It("serves a listener", func() { 1168 var called int32 1169 ln := newMockAddrListener(":443") 1170 quicListen = func(conn net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1171 atomic.StoreInt32(&called, 1) 1172 return ln, nil 1173 } 1174 1175 s := &Server{} 1176 1177 stopAccept := make(chan struct{}) 1178 ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1179 <-stopAccept 1180 return nil, errors.New("closed") 1181 }) 1182 ln.EXPECT().Addr() // generate alt-svc headers 1183 done := make(chan struct{}) 1184 go func() { 1185 defer GinkgoRecover() 1186 defer close(done) 1187 s.ServeListener(ln) 1188 }() 1189 1190 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1191 Consistently(done).ShouldNot(BeClosed()) 1192 ln.EXPECT().Close().Do(func() error { close(stopAccept); return nil }) 1193 Expect(s.Close()).To(Succeed()) 1194 Eventually(done).Should(BeClosed()) 1195 }) 1196 1197 It("serves two listeners", func() { 1198 var called int32 1199 ln1 := newMockAddrListener(":443") 1200 ln2 := newMockAddrListener(":8443") 1201 lns := make(chan QUICEarlyListener, 2) 1202 lns <- ln1 1203 lns <- ln2 1204 quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1205 atomic.StoreInt32(&called, 1) 1206 return <-lns, nil 1207 } 1208 1209 s := &Server{} 1210 1211 stopAccept1 := make(chan struct{}) 1212 ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1213 <-stopAccept1 1214 return nil, errors.New("closed") 1215 }) 1216 ln1.EXPECT().Addr() // generate alt-svc headers 1217 stopAccept2 := make(chan struct{}) 1218 ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.EarlyConnection, error) { 1219 <-stopAccept2 1220 return nil, errors.New("closed") 1221 }) 1222 ln2.EXPECT().Addr() 1223 1224 done1 := make(chan struct{}) 1225 go func() { 1226 defer GinkgoRecover() 1227 defer close(done1) 1228 s.ServeListener(ln1) 1229 }() 1230 done2 := make(chan struct{}) 1231 go func() { 1232 defer GinkgoRecover() 1233 defer close(done2) 1234 s.ServeListener(ln2) 1235 }() 1236 1237 Consistently(func() int32 { return atomic.LoadInt32(&called) }).Should(Equal(int32(0))) 1238 Consistently(done1).ShouldNot(BeClosed()) 1239 Expect(done2).ToNot(BeClosed()) 1240 ln1.EXPECT().Close().Do(func() error { close(stopAccept1); return nil }) 1241 ln2.EXPECT().Close().Do(func() error { close(stopAccept2); return nil }) 1242 Expect(s.Close()).To(Succeed()) 1243 Eventually(done1).Should(BeClosed()) 1244 Eventually(done2).Should(BeClosed()) 1245 }) 1246 }) 1247 1248 Context("ServeQUICConn", func() { 1249 It("serves a QUIC connection", func() { 1250 mux := http.NewServeMux() 1251 mux.HandleFunc("/hello", func(w http.ResponseWriter, _ *http.Request) { 1252 w.Write([]byte("foobar")) 1253 }) 1254 s.Handler = mux 1255 tlsConf := testdata.GetTLSConfig() 1256 tlsConf.NextProtos = []string{NextProtoH3} 1257 conn := mockquic.NewMockEarlyConnection(mockCtrl) 1258 controlStr := mockquic.NewMockStream(mockCtrl) 1259 controlStr.EXPECT().Write(gomock.Any()) 1260 conn.EXPECT().OpenUniStream().Return(controlStr, nil) 1261 testDone := make(chan struct{}) 1262 conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { 1263 <-testDone 1264 return nil, errors.New("test done") 1265 }).MaxTimes(1) 1266 conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, &quic.ApplicationError{ErrorCode: quic.ApplicationErrorCode(ErrCodeNoError)}) 1267 s.ServeQUICConn(conn) 1268 close(testDone) 1269 }) 1270 }) 1271 1272 Context("ListenAndServe", func() { 1273 BeforeEach(func() { 1274 s.Addr = "localhost:0" 1275 }) 1276 1277 AfterEach(func() { 1278 Expect(s.Close()).To(Succeed()) 1279 }) 1280 1281 It("uses the quic.Config to start the QUIC server", func() { 1282 conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} 1283 var receivedConf *quic.Config 1284 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1285 receivedConf = config 1286 return nil, errors.New("listen err") 1287 } 1288 s.QuicConfig = conf 1289 Expect(s.ListenAndServe()).To(HaveOccurred()) 1290 Expect(receivedConf).To(Equal(conf)) 1291 }) 1292 }) 1293 1294 It("closes gracefully", func() { 1295 Expect(s.CloseGracefully(0)).To(Succeed()) 1296 }) 1297 1298 It("errors when listening fails", func() { 1299 testErr := errors.New("listen error") 1300 quicListenAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1301 return nil, testErr 1302 } 1303 fullpem, privkey := testdata.GetCertificatePaths() 1304 Expect(ListenAndServeQUIC("", fullpem, privkey, nil)).To(MatchError(testErr)) 1305 }) 1306 1307 It("supports H3_DATAGRAM", func() { 1308 s.EnableDatagrams = true 1309 var receivedConf *quic.Config 1310 quicListenAddr = func(addr string, _ *tls.Config, config *quic.Config) (QUICEarlyListener, error) { 1311 receivedConf = config 1312 return nil, errors.New("listen err") 1313 } 1314 Expect(s.ListenAndServe()).To(HaveOccurred()) 1315 Expect(receivedConf.EnableDatagrams).To(BeTrue()) 1316 }) 1317 })