github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/self/handshake_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "time" 11 12 "github.com/apernet/quic-go" 13 quicproxy "github.com/apernet/quic-go/integrationtests/tools/proxy" 14 "github.com/apernet/quic-go/internal/protocol" 15 "github.com/apernet/quic-go/internal/qerr" 16 "github.com/apernet/quic-go/internal/qtls" 17 18 . "github.com/onsi/ginkgo/v2" 19 . "github.com/onsi/gomega" 20 ) 21 22 type tokenStore struct { 23 store quic.TokenStore 24 gets chan<- string 25 puts chan<- string 26 } 27 28 var _ quic.TokenStore = &tokenStore{} 29 30 func newTokenStore(gets, puts chan<- string) quic.TokenStore { 31 return &tokenStore{ 32 store: quic.NewLRUTokenStore(10, 4), 33 gets: gets, 34 puts: puts, 35 } 36 } 37 38 func (c *tokenStore) Put(key string, token *quic.ClientToken) { 39 c.puts <- key 40 c.store.Put(key, token) 41 } 42 43 func (c *tokenStore) Pop(key string) *quic.ClientToken { 44 c.gets <- key 45 return c.store.Pop(key) 46 } 47 48 var _ = Describe("Handshake tests", func() { 49 var ( 50 server *quic.Listener 51 serverConfig *quic.Config 52 acceptStopped chan struct{} 53 ) 54 55 BeforeEach(func() { 56 server = nil 57 acceptStopped = make(chan struct{}) 58 serverConfig = getQuicConfig(nil) 59 }) 60 61 AfterEach(func() { 62 if server != nil { 63 server.Close() 64 <-acceptStopped 65 } 66 }) 67 68 runServer := func(tlsConf *tls.Config) { 69 var err error 70 // start the server 71 server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig) 72 Expect(err).ToNot(HaveOccurred()) 73 74 go func() { 75 defer GinkgoRecover() 76 defer close(acceptStopped) 77 for { 78 if _, err := server.Accept(context.Background()); err != nil { 79 return 80 } 81 } 82 }() 83 } 84 85 It("returns the context cancellation error on timeouts", func() { 86 ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) 87 defer cancel() 88 errChan := make(chan error, 1) 89 go func() { 90 _, err := quic.DialAddr( 91 ctx, 92 "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway 93 getTLSClientConfig(), 94 getQuicConfig(nil), 95 ) 96 errChan <- err 97 }() 98 99 var err error 100 Eventually(errChan).Should(Receive(&err)) 101 Expect(err).To(HaveOccurred()) 102 Expect(err).To(MatchError(context.DeadlineExceeded)) 103 }) 104 105 It("returns the cancellation reason when a dial is canceled", func() { 106 ctx, cancel := context.WithCancelCause(context.Background()) 107 errChan := make(chan error, 1) 108 go func() { 109 _, err := quic.DialAddr( 110 ctx, 111 "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway 112 getTLSClientConfig(), 113 getQuicConfig(nil), 114 ) 115 errChan <- err 116 }() 117 118 cancel(errors.New("application cancelled")) 119 var err error 120 Eventually(errChan).Should(Receive(&err)) 121 Expect(err).To(HaveOccurred()) 122 Expect(err).To(MatchError("application cancelled")) 123 }) 124 125 Context("using different cipher suites", func() { 126 for n, id := range map[string]uint16{ 127 "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, 128 "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, 129 "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, 130 } { 131 name := n 132 suiteID := id 133 134 It(fmt.Sprintf("using %s", name), func() { 135 reset := qtls.SetCipherSuite(suiteID) 136 defer reset() 137 138 tlsConf := getTLSConfig() 139 ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig) 140 Expect(err).ToNot(HaveOccurred()) 141 defer ln.Close() 142 143 go func() { 144 defer GinkgoRecover() 145 conn, err := ln.Accept(context.Background()) 146 Expect(err).ToNot(HaveOccurred()) 147 str, err := conn.OpenStream() 148 Expect(err).ToNot(HaveOccurred()) 149 defer str.Close() 150 _, err = str.Write(PRData) 151 Expect(err).ToNot(HaveOccurred()) 152 }() 153 154 conn, err := quic.DialAddr( 155 context.Background(), 156 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 157 getTLSClientConfig(), 158 getQuicConfig(nil), 159 ) 160 Expect(err).ToNot(HaveOccurred()) 161 str, err := conn.AcceptStream(context.Background()) 162 Expect(err).ToNot(HaveOccurred()) 163 data, err := io.ReadAll(str) 164 Expect(err).ToNot(HaveOccurred()) 165 Expect(data).To(Equal(PRData)) 166 Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) 167 Expect(conn.CloseWithError(0, "")).To(Succeed()) 168 }) 169 } 170 }) 171 172 Context("Certificate validation", func() { 173 It("accepts the certificate", func() { 174 runServer(getTLSConfig()) 175 conn, err := quic.DialAddr( 176 context.Background(), 177 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 178 getTLSClientConfig(), 179 getQuicConfig(nil), 180 ) 181 Expect(err).ToNot(HaveOccurred()) 182 conn.CloseWithError(0, "") 183 }) 184 185 It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { 186 var local, remote net.Addr 187 var local2, remote2 net.Addr 188 done := make(chan struct{}) 189 tlsConf := &tls.Config{ 190 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 191 local = info.Conn.LocalAddr() 192 remote = info.Conn.RemoteAddr() 193 conf := getTLSConfig() 194 conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 195 defer close(done) 196 local2 = info.Conn.LocalAddr() 197 remote2 = info.Conn.RemoteAddr() 198 return &(conf.Certificates[0]), nil 199 } 200 return conf, nil 201 }, 202 } 203 runServer(tlsConf) 204 conn, err := quic.DialAddr( 205 context.Background(), 206 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 207 getTLSClientConfig(), 208 getQuicConfig(nil), 209 ) 210 Expect(err).ToNot(HaveOccurred()) 211 defer conn.CloseWithError(0, "") 212 Eventually(done).Should(BeClosed()) 213 Expect(server.Addr()).To(Equal(local)) 214 Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) 215 Expect(local).To(Equal(local2)) 216 Expect(remote).To(Equal(remote2)) 217 }) 218 219 It("works with a long certificate chain", func() { 220 runServer(getTLSConfigWithLongCertChain()) 221 conn, err := quic.DialAddr( 222 context.Background(), 223 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 224 getTLSClientConfig(), 225 getQuicConfig(nil), 226 ) 227 Expect(err).ToNot(HaveOccurred()) 228 conn.CloseWithError(0, "") 229 }) 230 231 It("errors if the server name doesn't match", func() { 232 runServer(getTLSConfig()) 233 conn, err := net.ListenUDP("udp", nil) 234 Expect(err).ToNot(HaveOccurred()) 235 conf := getTLSClientConfig() 236 conf.ServerName = "foo.bar" 237 _, err = quic.Dial( 238 context.Background(), 239 conn, 240 server.Addr(), 241 conf, 242 getQuicConfig(nil), 243 ) 244 Expect(err).To(HaveOccurred()) 245 var transportErr *quic.TransportError 246 Expect(errors.As(err, &transportErr)).To(BeTrue()) 247 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 248 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 249 var certErr *tls.CertificateVerificationError 250 Expect(errors.As(transportErr, &certErr)).To(BeTrue()) 251 }) 252 253 It("fails the handshake if the client fails to provide the requested client cert", func() { 254 tlsConf := getTLSConfig() 255 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 256 runServer(tlsConf) 257 258 conn, err := quic.DialAddr( 259 context.Background(), 260 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 261 getTLSClientConfig(), 262 getQuicConfig(nil), 263 ) 264 // Usually, the error will occur after the client already finished the handshake. 265 // However, there's a race condition here. The server's CONNECTION_CLOSE might be 266 // received before the connection is returned, so we might already get the error while dialing. 267 if err == nil { 268 errChan := make(chan error) 269 go func() { 270 defer GinkgoRecover() 271 _, err := conn.AcceptStream(context.Background()) 272 errChan <- err 273 }() 274 Eventually(errChan).Should(Receive(&err)) 275 } 276 Expect(err).To(HaveOccurred()) 277 var transportErr *quic.TransportError 278 Expect(errors.As(err, &transportErr)).To(BeTrue()) 279 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 280 Expect(transportErr.Error()).To(Or( 281 ContainSubstring("tls: certificate required"), 282 ContainSubstring("tls: bad certificate"), 283 )) 284 }) 285 286 It("uses the ServerName in the tls.Config", func() { 287 runServer(getTLSConfig()) 288 tlsConf := getTLSClientConfig() 289 tlsConf.ServerName = "foo.bar" 290 _, err := quic.DialAddr( 291 context.Background(), 292 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 293 tlsConf, 294 getQuicConfig(nil), 295 ) 296 Expect(err).To(HaveOccurred()) 297 var transportErr *quic.TransportError 298 Expect(errors.As(err, &transportErr)).To(BeTrue()) 299 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 300 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 301 }) 302 }) 303 304 Context("queuening and accepting connections", func() { 305 var ( 306 server *quic.Listener 307 pconn net.PacketConn 308 dialer *quic.Transport 309 ) 310 311 dial := func() (quic.Connection, error) { 312 remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) 313 raddr, err := net.ResolveUDPAddr("udp", remoteAddr) 314 Expect(err).ToNot(HaveOccurred()) 315 return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil)) 316 } 317 318 BeforeEach(func() { 319 var err error 320 // start the server, but don't call Accept 321 server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 322 Expect(err).ToNot(HaveOccurred()) 323 324 // prepare a (single) packet conn for dialing to the server 325 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 326 Expect(err).ToNot(HaveOccurred()) 327 pconn, err = net.ListenUDP("udp", laddr) 328 Expect(err).ToNot(HaveOccurred()) 329 dialer = &quic.Transport{ 330 Conn: pconn, 331 ConnectionIDLength: 4, 332 } 333 }) 334 335 AfterEach(func() { 336 Expect(server.Close()).To(Succeed()) 337 Expect(pconn.Close()).To(Succeed()) 338 Expect(dialer.Close()).To(Succeed()) 339 }) 340 341 It("rejects new connection attempts if connections don't get accepted", func() { 342 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 343 conn, err := dial() 344 Expect(err).ToNot(HaveOccurred()) 345 defer conn.CloseWithError(0, "") 346 } 347 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 348 349 conn, err := dial() 350 Expect(err).ToNot(HaveOccurred()) 351 ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 352 defer cancel() 353 _, err = conn.AcceptStream(ctx) 354 var transportErr *quic.TransportError 355 Expect(errors.As(err, &transportErr)).To(BeTrue()) 356 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 357 358 // now accept one connection, freeing one spot in the queue 359 _, err = server.Accept(context.Background()) 360 Expect(err).ToNot(HaveOccurred()) 361 // dial again, and expect that this dial succeeds 362 conn2, err := dial() 363 Expect(err).ToNot(HaveOccurred()) 364 defer conn2.CloseWithError(0, "") 365 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 366 367 conn3, err := dial() 368 Expect(err).ToNot(HaveOccurred()) 369 ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) 370 defer cancel() 371 _, err = conn3.AcceptStream(ctx) 372 Expect(errors.As(err, &transportErr)).To(BeTrue()) 373 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 374 }) 375 376 It("also returns closed connections from the accept queue", func() { 377 firstConn, err := dial() 378 Expect(err).ToNot(HaveOccurred()) 379 380 for i := 1; i < protocol.MaxAcceptQueueSize; i++ { 381 conn, err := dial() 382 Expect(err).ToNot(HaveOccurred()) 383 defer conn.CloseWithError(0, "") 384 } 385 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 386 387 conn, err := dial() 388 Expect(err).ToNot(HaveOccurred()) 389 ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 390 defer cancel() 391 _, err = conn.AcceptStream(ctx) 392 var transportErr *quic.TransportError 393 Expect(errors.As(err, &transportErr)).To(BeTrue()) 394 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 395 396 // Now close the one of the connection that are waiting to be accepted. 397 const appErrCode quic.ApplicationErrorCode = 12345 398 Expect(firstConn.CloseWithError(appErrCode, "")) 399 Eventually(firstConn.Context().Done()).Should(BeClosed()) 400 time.Sleep(scaleDuration(200 * time.Millisecond)) 401 402 // dial again, and expect that this fails again 403 conn2, err := dial() 404 Expect(err).ToNot(HaveOccurred()) 405 ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) 406 defer cancel() 407 _, err = conn2.AcceptStream(ctx) 408 Expect(errors.As(err, &transportErr)).To(BeTrue()) 409 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 410 411 // now accept all connections 412 var closedConn quic.Connection 413 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 414 conn, err := server.Accept(context.Background()) 415 Expect(err).ToNot(HaveOccurred()) 416 if conn.Context().Err() != nil { 417 if closedConn != nil { 418 Fail("only expected a single closed connection") 419 } 420 closedConn = conn 421 } 422 } 423 Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection 424 _, err = closedConn.AcceptStream(context.Background()) 425 var appErr *quic.ApplicationError 426 Expect(errors.As(err, &appErr)).To(BeTrue()) 427 Expect(appErr.ErrorCode).To(Equal(appErrCode)) 428 }) 429 430 It("closes handshaking connections when the server is closed", func() { 431 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 432 Expect(err).ToNot(HaveOccurred()) 433 udpConn, err := net.ListenUDP("udp", laddr) 434 Expect(err).ToNot(HaveOccurred()) 435 tr := &quic.Transport{Conn: udpConn} 436 addTracer(tr) 437 defer tr.Close() 438 tlsConf := &tls.Config{} 439 done := make(chan struct{}) 440 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 441 <-done 442 return nil, errors.New("closed") 443 } 444 ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) 445 Expect(err).ToNot(HaveOccurred()) 446 447 errChan := make(chan error, 1) 448 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 449 defer cancel() 450 go func() { 451 defer GinkgoRecover() 452 _, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) 453 errChan <- err 454 }() 455 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 456 Expect(ln.Close()).To(Succeed()) 457 close(done) 458 err = <-errChan 459 var transportErr *quic.TransportError 460 Expect(errors.As(err, &transportErr)).To(BeTrue()) 461 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 462 }) 463 }) 464 465 Context("ALPN", func() { 466 It("negotiates an application protocol", func() { 467 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 468 Expect(err).ToNot(HaveOccurred()) 469 470 done := make(chan struct{}) 471 go func() { 472 defer GinkgoRecover() 473 conn, err := ln.Accept(context.Background()) 474 Expect(err).ToNot(HaveOccurred()) 475 cs := conn.ConnectionState() 476 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 477 close(done) 478 }() 479 480 conn, err := quic.DialAddr( 481 context.Background(), 482 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 483 getTLSClientConfig(), 484 nil, 485 ) 486 Expect(err).ToNot(HaveOccurred()) 487 defer conn.CloseWithError(0, "") 488 cs := conn.ConnectionState() 489 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 490 Eventually(done).Should(BeClosed()) 491 Expect(ln.Close()).To(Succeed()) 492 }) 493 494 It("errors if application protocol negotiation fails", func() { 495 runServer(getTLSConfig()) 496 497 tlsConf := getTLSClientConfig() 498 tlsConf.NextProtos = []string{"foobar"} 499 _, err := quic.DialAddr( 500 context.Background(), 501 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 502 tlsConf, 503 nil, 504 ) 505 Expect(err).To(HaveOccurred()) 506 var transportErr *quic.TransportError 507 Expect(errors.As(err, &transportErr)).To(BeTrue()) 508 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 509 Expect(transportErr.Error()).To(ContainSubstring("no application protocol")) 510 }) 511 }) 512 513 Context("using tokens", func() { 514 It("uses tokens provided in NEW_TOKEN frames", func() { 515 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 516 Expect(err).ToNot(HaveOccurred()) 517 defer server.Close() 518 519 // dial the first connection and receive the token 520 go func() { 521 defer GinkgoRecover() 522 _, err := server.Accept(context.Background()) 523 Expect(err).ToNot(HaveOccurred()) 524 }() 525 526 gets := make(chan string, 100) 527 puts := make(chan string, 100) 528 tokenStore := newTokenStore(gets, puts) 529 quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore}) 530 conn, err := quic.DialAddr( 531 context.Background(), 532 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 533 getTLSClientConfig(), 534 quicConf, 535 ) 536 Expect(err).ToNot(HaveOccurred()) 537 Expect(gets).To(Receive()) 538 Eventually(puts).Should(Receive()) 539 // received a token. Close this connection. 540 Expect(conn.CloseWithError(0, "")).To(Succeed()) 541 542 // dial the second connection and verify that the token was used 543 done := make(chan struct{}) 544 go func() { 545 defer GinkgoRecover() 546 defer close(done) 547 _, err := server.Accept(context.Background()) 548 Expect(err).ToNot(HaveOccurred()) 549 }() 550 conn, err = quic.DialAddr( 551 context.Background(), 552 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 553 getTLSClientConfig(), 554 quicConf, 555 ) 556 Expect(err).ToNot(HaveOccurred()) 557 defer conn.CloseWithError(0, "") 558 Expect(gets).To(Receive()) 559 560 Eventually(done).Should(BeClosed()) 561 }) 562 563 It("rejects invalid Retry token with the INVALID_TOKEN error", func() { 564 const rtt = 10 * time.Millisecond 565 566 // The validity period of the retry token is the handshake timeout, 567 // which is twice the handshake idle timeout. 568 // By setting the handshake timeout shorter than the RTT, the token will have expired by the time 569 // it reaches the server. 570 serverConfig.HandshakeIdleTimeout = rtt / 5 571 572 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 573 Expect(err).ToNot(HaveOccurred()) 574 udpConn, err := net.ListenUDP("udp", laddr) 575 Expect(err).ToNot(HaveOccurred()) 576 defer udpConn.Close() 577 tr := &quic.Transport{ 578 Conn: udpConn, 579 VerifySourceAddress: func(net.Addr) bool { return true }, 580 } 581 addTracer(tr) 582 defer tr.Close() 583 server, err := tr.Listen(getTLSConfig(), serverConfig) 584 Expect(err).ToNot(HaveOccurred()) 585 defer server.Close() 586 587 serverPort := server.Addr().(*net.UDPAddr).Port 588 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 589 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 590 DelayPacket: func(quicproxy.Direction, []byte) time.Duration { 591 return rtt / 2 592 }, 593 }) 594 Expect(err).ToNot(HaveOccurred()) 595 defer proxy.Close() 596 597 _, err = quic.DialAddr( 598 context.Background(), 599 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 600 getTLSClientConfig(), 601 nil, 602 ) 603 Expect(err).To(HaveOccurred()) 604 var transportErr *quic.TransportError 605 Expect(errors.As(err, &transportErr)).To(BeTrue()) 606 Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken)) 607 }) 608 }) 609 610 Context("GetConfigForClient", func() { 611 It("uses the quic.Config returned by GetConfigForClient", func() { 612 serverConfig.EnableDatagrams = false 613 var calledFrom net.Addr 614 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 615 conf := serverConfig.Clone() 616 conf.EnableDatagrams = true 617 calledFrom = info.RemoteAddr 618 return getQuicConfig(conf), nil 619 } 620 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 621 Expect(err).ToNot(HaveOccurred()) 622 623 done := make(chan struct{}) 624 go func() { 625 defer GinkgoRecover() 626 _, err := ln.Accept(context.Background()) 627 Expect(err).ToNot(HaveOccurred()) 628 close(done) 629 }() 630 631 conn, err := quic.DialAddr( 632 context.Background(), 633 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 634 getTLSClientConfig(), 635 getQuicConfig(&quic.Config{EnableDatagrams: true}), 636 ) 637 Expect(err).ToNot(HaveOccurred()) 638 defer conn.CloseWithError(0, "") 639 cs := conn.ConnectionState() 640 Expect(cs.SupportsDatagrams).To(BeTrue()) 641 Eventually(done).Should(BeClosed()) 642 Expect(ln.Close()).To(Succeed()) 643 Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port)) 644 }) 645 646 It("rejects the connection attempt if GetConfigForClient errors", func() { 647 serverConfig.EnableDatagrams = false 648 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 649 return nil, errors.New("rejected") 650 } 651 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 652 Expect(err).ToNot(HaveOccurred()) 653 defer ln.Close() 654 655 done := make(chan struct{}) 656 go func() { 657 defer GinkgoRecover() 658 _, err := ln.Accept(context.Background()) 659 Expect(err).To(HaveOccurred()) // we don't expect to accept any connection 660 close(done) 661 }() 662 663 _, err = quic.DialAddr( 664 context.Background(), 665 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 666 getTLSClientConfig(), 667 getQuicConfig(&quic.Config{EnableDatagrams: true}), 668 ) 669 Expect(err).To(HaveOccurred()) 670 var transportErr *quic.TransportError 671 Expect(errors.As(err, &transportErr)).To(BeTrue()) 672 Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) 673 }) 674 }) 675 676 It("doesn't send any packets when generating the ClientHello fails", func() { 677 ln, err := net.ListenUDP("udp", nil) 678 Expect(err).ToNot(HaveOccurred()) 679 done := make(chan struct{}) 680 packetChan := make(chan struct{}) 681 go func() { 682 defer GinkgoRecover() 683 defer close(done) 684 for { 685 _, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize)) 686 if err != nil { 687 return 688 } 689 packetChan <- struct{}{} 690 } 691 }() 692 693 tlsConf := getTLSClientConfig() 694 tlsConf.NextProtos = []string{""} 695 _, err = quic.DialAddr( 696 context.Background(), 697 fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port), 698 tlsConf, 699 nil, 700 ) 701 Expect(err).To(MatchError(&qerr.TransportError{ 702 ErrorCode: qerr.InternalError, 703 ErrorMessage: "tls: invalid NextProtos value", 704 })) 705 Consistently(packetChan).ShouldNot(Receive()) 706 ln.Close() 707 Eventually(done).Should(BeClosed()) 708 }) 709 })