github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/handshake_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "time" 11 12 "github.com/MerlinKodo/quic-go" 13 quicproxy "github.com/MerlinKodo/quic-go/integrationtests/tools/proxy" 14 "github.com/MerlinKodo/quic-go/internal/protocol" 15 "github.com/MerlinKodo/quic-go/internal/qerr" 16 "github.com/MerlinKodo/quic-go/internal/qtls" 17 18 . "github.com/onsi/ginkgo/v2" 19 . "github.com/onsi/gomega" 20 ) 21 22 type tokenStore struct { 23 store quic.TokenStore 24 gets chan<- string 25 puts chan<- string 26 } 27 28 var _ quic.TokenStore = &tokenStore{} 29 30 func newTokenStore(gets, puts chan<- string) quic.TokenStore { 31 return &tokenStore{ 32 store: quic.NewLRUTokenStore(10, 4), 33 gets: gets, 34 puts: puts, 35 } 36 } 37 38 func (c *tokenStore) Put(key string, token *quic.ClientToken) { 39 c.puts <- key 40 c.store.Put(key, token) 41 } 42 43 func (c *tokenStore) Pop(key string) *quic.ClientToken { 44 c.gets <- key 45 return c.store.Pop(key) 46 } 47 48 var _ = Describe("Handshake tests", func() { 49 var ( 50 server *quic.Listener 51 serverConfig *quic.Config 52 acceptStopped chan struct{} 53 ) 54 55 BeforeEach(func() { 56 server = nil 57 acceptStopped = make(chan struct{}) 58 serverConfig = getQuicConfig(nil) 59 }) 60 61 AfterEach(func() { 62 if server != nil { 63 server.Close() 64 <-acceptStopped 65 } 66 }) 67 68 runServer := func(tlsConf *tls.Config) { 69 var err error 70 // start the server 71 server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig) 72 Expect(err).ToNot(HaveOccurred()) 73 74 go func() { 75 defer GinkgoRecover() 76 defer close(acceptStopped) 77 for { 78 if _, err := server.Accept(context.Background()); err != nil { 79 return 80 } 81 } 82 }() 83 } 84 85 It("returns the cancellation reason when a dial is canceled", func() { 86 ctx, cancel := context.WithCancelCause(context.Background()) 87 errChan := make(chan error, 1) 88 go func() { 89 _, err := quic.DialAddr( 90 ctx, 91 "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway 92 getTLSClientConfig(), 93 getQuicConfig(nil), 94 ) 95 errChan <- err 96 }() 97 98 cancel(errors.New("application cancelled")) 99 var err error 100 Eventually(errChan).Should(Receive(&err)) 101 Expect(err).To(HaveOccurred()) 102 Expect(err).To(MatchError("application cancelled")) 103 }) 104 105 Context("using different cipher suites", func() { 106 for n, id := range map[string]uint16{ 107 "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, 108 "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, 109 "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, 110 } { 111 name := n 112 suiteID := id 113 114 It(fmt.Sprintf("using %s", name), func() { 115 reset := qtls.SetCipherSuite(suiteID) 116 defer reset() 117 118 tlsConf := getTLSConfig() 119 ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig) 120 Expect(err).ToNot(HaveOccurred()) 121 defer ln.Close() 122 123 go func() { 124 defer GinkgoRecover() 125 conn, err := ln.Accept(context.Background()) 126 Expect(err).ToNot(HaveOccurred()) 127 str, err := conn.OpenStream() 128 Expect(err).ToNot(HaveOccurred()) 129 defer str.Close() 130 _, err = str.Write(PRData) 131 Expect(err).ToNot(HaveOccurred()) 132 }() 133 134 conn, err := quic.DialAddr( 135 context.Background(), 136 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 137 getTLSClientConfig(), 138 getQuicConfig(nil), 139 ) 140 Expect(err).ToNot(HaveOccurred()) 141 str, err := conn.AcceptStream(context.Background()) 142 Expect(err).ToNot(HaveOccurred()) 143 data, err := io.ReadAll(str) 144 Expect(err).ToNot(HaveOccurred()) 145 Expect(data).To(Equal(PRData)) 146 Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) 147 Expect(conn.CloseWithError(0, "")).To(Succeed()) 148 }) 149 } 150 }) 151 152 Context("Certificate validation", func() { 153 It("accepts the certificate", func() { 154 runServer(getTLSConfig()) 155 _, err := quic.DialAddr( 156 context.Background(), 157 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 158 getTLSClientConfig(), 159 getQuicConfig(nil), 160 ) 161 Expect(err).ToNot(HaveOccurred()) 162 }) 163 164 It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { 165 var local, remote net.Addr 166 var local2, remote2 net.Addr 167 done := make(chan struct{}) 168 tlsConf := &tls.Config{ 169 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 170 local = info.Conn.LocalAddr() 171 remote = info.Conn.RemoteAddr() 172 conf := getTLSConfig() 173 conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 174 defer close(done) 175 local2 = info.Conn.LocalAddr() 176 remote2 = info.Conn.RemoteAddr() 177 return &(conf.Certificates[0]), nil 178 } 179 return conf, nil 180 }, 181 } 182 runServer(tlsConf) 183 conn, err := quic.DialAddr( 184 context.Background(), 185 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 186 getTLSClientConfig(), 187 getQuicConfig(nil), 188 ) 189 Expect(err).ToNot(HaveOccurred()) 190 Eventually(done).Should(BeClosed()) 191 Expect(server.Addr()).To(Equal(local)) 192 Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) 193 Expect(local).To(Equal(local2)) 194 Expect(remote).To(Equal(remote2)) 195 }) 196 197 It("works with a long certificate chain", func() { 198 runServer(getTLSConfigWithLongCertChain()) 199 _, err := quic.DialAddr( 200 context.Background(), 201 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 202 getTLSClientConfig(), 203 getQuicConfig(nil), 204 ) 205 Expect(err).ToNot(HaveOccurred()) 206 }) 207 208 It("errors if the server name doesn't match", func() { 209 runServer(getTLSConfig()) 210 conn, err := net.ListenUDP("udp", nil) 211 Expect(err).ToNot(HaveOccurred()) 212 conf := getTLSClientConfig() 213 conf.ServerName = "foo.bar" 214 _, err = quic.Dial( 215 context.Background(), 216 conn, 217 server.Addr(), 218 conf, 219 getQuicConfig(nil), 220 ) 221 Expect(err).To(HaveOccurred()) 222 var transportErr *quic.TransportError 223 Expect(errors.As(err, &transportErr)).To(BeTrue()) 224 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 225 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 226 var certErr *tls.CertificateVerificationError 227 Expect(errors.As(transportErr, &certErr)).To(BeTrue()) 228 }) 229 230 It("fails the handshake if the client fails to provide the requested client cert", func() { 231 tlsConf := getTLSConfig() 232 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 233 runServer(tlsConf) 234 235 conn, err := quic.DialAddr( 236 context.Background(), 237 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 238 getTLSClientConfig(), 239 getQuicConfig(nil), 240 ) 241 // Usually, the error will occur after the client already finished the handshake. 242 // However, there's a race condition here. The server's CONNECTION_CLOSE might be 243 // received before the connection is returned, so we might already get the error while dialing. 244 if err == nil { 245 errChan := make(chan error) 246 go func() { 247 defer GinkgoRecover() 248 _, err := conn.AcceptStream(context.Background()) 249 errChan <- err 250 }() 251 Eventually(errChan).Should(Receive(&err)) 252 } 253 Expect(err).To(HaveOccurred()) 254 var transportErr *quic.TransportError 255 Expect(errors.As(err, &transportErr)).To(BeTrue()) 256 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 257 Expect(transportErr.Error()).To(Or( 258 ContainSubstring("tls: certificate required"), 259 ContainSubstring("tls: bad certificate"), 260 )) 261 }) 262 263 It("uses the ServerName in the tls.Config", func() { 264 runServer(getTLSConfig()) 265 tlsConf := getTLSClientConfig() 266 tlsConf.ServerName = "foo.bar" 267 _, err := quic.DialAddr( 268 context.Background(), 269 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 270 tlsConf, 271 getQuicConfig(nil), 272 ) 273 Expect(err).To(HaveOccurred()) 274 var transportErr *quic.TransportError 275 Expect(errors.As(err, &transportErr)).To(BeTrue()) 276 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 277 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 278 }) 279 }) 280 281 Context("rate limiting", func() { 282 var ( 283 server *quic.Listener 284 pconn net.PacketConn 285 dialer *quic.Transport 286 ) 287 288 dial := func() (quic.Connection, error) { 289 remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) 290 raddr, err := net.ResolveUDPAddr("udp", remoteAddr) 291 Expect(err).ToNot(HaveOccurred()) 292 return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil)) 293 } 294 295 BeforeEach(func() { 296 var err error 297 // start the server, but don't call Accept 298 server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 299 Expect(err).ToNot(HaveOccurred()) 300 301 // prepare a (single) packet conn for dialing to the server 302 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 303 Expect(err).ToNot(HaveOccurred()) 304 pconn, err = net.ListenUDP("udp", laddr) 305 Expect(err).ToNot(HaveOccurred()) 306 dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} 307 }) 308 309 AfterEach(func() { 310 Expect(server.Close()).To(Succeed()) 311 Expect(pconn.Close()).To(Succeed()) 312 Expect(dialer.Close()).To(Succeed()) 313 }) 314 315 It("rejects new connection attempts if connections don't get accepted", func() { 316 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 317 conn, err := dial() 318 Expect(err).ToNot(HaveOccurred()) 319 defer conn.CloseWithError(0, "") 320 } 321 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 322 323 _, err := dial() 324 Expect(err).To(HaveOccurred()) 325 var transportErr *quic.TransportError 326 Expect(errors.As(err, &transportErr)).To(BeTrue()) 327 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 328 329 // now accept one connection, freeing one spot in the queue 330 _, err = server.Accept(context.Background()) 331 Expect(err).ToNot(HaveOccurred()) 332 // dial again, and expect that this dial succeeds 333 conn, err := dial() 334 Expect(err).ToNot(HaveOccurred()) 335 defer conn.CloseWithError(0, "") 336 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 337 338 _, err = dial() 339 Expect(err).To(HaveOccurred()) 340 Expect(errors.As(err, &transportErr)).To(BeTrue()) 341 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 342 }) 343 344 It("removes closed connections from the accept queue", func() { 345 firstConn, err := dial() 346 Expect(err).ToNot(HaveOccurred()) 347 348 for i := 1; i < protocol.MaxAcceptQueueSize; i++ { 349 conn, err := dial() 350 Expect(err).ToNot(HaveOccurred()) 351 defer conn.CloseWithError(0, "") 352 } 353 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 354 355 _, err = dial() 356 Expect(err).To(HaveOccurred()) 357 var transportErr *quic.TransportError 358 Expect(errors.As(err, &transportErr)).To(BeTrue()) 359 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 360 361 // Now close the one of the connection that are waiting to be accepted. 362 // This should free one spot in the queue. 363 Expect(firstConn.CloseWithError(0, "")) 364 Eventually(firstConn.Context().Done()).Should(BeClosed()) 365 time.Sleep(scaleDuration(200 * time.Millisecond)) 366 367 // dial again, and expect that this dial succeeds 368 _, err = dial() 369 Expect(err).ToNot(HaveOccurred()) 370 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 371 372 _, err = dial() 373 Expect(err).To(HaveOccurred()) 374 Expect(errors.As(err, &transportErr)).To(BeTrue()) 375 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 376 }) 377 }) 378 379 Context("ALPN", func() { 380 It("negotiates an application protocol", func() { 381 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 382 Expect(err).ToNot(HaveOccurred()) 383 384 done := make(chan struct{}) 385 go func() { 386 defer GinkgoRecover() 387 conn, err := ln.Accept(context.Background()) 388 Expect(err).ToNot(HaveOccurred()) 389 cs := conn.ConnectionState() 390 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 391 close(done) 392 }() 393 394 conn, err := quic.DialAddr( 395 context.Background(), 396 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 397 getTLSClientConfig(), 398 nil, 399 ) 400 Expect(err).ToNot(HaveOccurred()) 401 defer conn.CloseWithError(0, "") 402 cs := conn.ConnectionState() 403 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 404 Eventually(done).Should(BeClosed()) 405 Expect(ln.Close()).To(Succeed()) 406 }) 407 408 It("errors if application protocol negotiation fails", func() { 409 runServer(getTLSConfig()) 410 411 tlsConf := getTLSClientConfig() 412 tlsConf.NextProtos = []string{"foobar"} 413 _, err := quic.DialAddr( 414 context.Background(), 415 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 416 tlsConf, 417 nil, 418 ) 419 Expect(err).To(HaveOccurred()) 420 var transportErr *quic.TransportError 421 Expect(errors.As(err, &transportErr)).To(BeTrue()) 422 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 423 Expect(transportErr.Error()).To(ContainSubstring("no application protocol")) 424 }) 425 }) 426 427 Context("using tokens", func() { 428 It("uses tokens provided in NEW_TOKEN frames", func() { 429 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 430 Expect(err).ToNot(HaveOccurred()) 431 defer server.Close() 432 433 // dial the first connection and receive the token 434 go func() { 435 defer GinkgoRecover() 436 _, err := server.Accept(context.Background()) 437 Expect(err).ToNot(HaveOccurred()) 438 }() 439 440 gets := make(chan string, 100) 441 puts := make(chan string, 100) 442 tokenStore := newTokenStore(gets, puts) 443 quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore}) 444 conn, err := quic.DialAddr( 445 context.Background(), 446 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 447 getTLSClientConfig(), 448 quicConf, 449 ) 450 Expect(err).ToNot(HaveOccurred()) 451 Expect(gets).To(Receive()) 452 Eventually(puts).Should(Receive()) 453 // received a token. Close this connection. 454 Expect(conn.CloseWithError(0, "")).To(Succeed()) 455 456 // dial the second connection and verify that the token was used 457 done := make(chan struct{}) 458 go func() { 459 defer GinkgoRecover() 460 defer close(done) 461 _, err := server.Accept(context.Background()) 462 Expect(err).ToNot(HaveOccurred()) 463 }() 464 conn, err = quic.DialAddr( 465 context.Background(), 466 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 467 getTLSClientConfig(), 468 quicConf, 469 ) 470 Expect(err).ToNot(HaveOccurred()) 471 defer conn.CloseWithError(0, "") 472 Expect(gets).To(Receive()) 473 474 Eventually(done).Should(BeClosed()) 475 }) 476 477 It("rejects invalid Retry token with the INVALID_TOKEN error", func() { 478 const rtt = 10 * time.Millisecond 479 serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } 480 // The validity period of the retry token is the handshake timeout, 481 // which is twice the handshake idle timeout. 482 // By setting the handshake timeout shorter than the RTT, the token will have expired by the time 483 // it reaches the server. 484 serverConfig.HandshakeIdleTimeout = rtt / 5 485 486 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 487 Expect(err).ToNot(HaveOccurred()) 488 defer server.Close() 489 490 serverPort := server.Addr().(*net.UDPAddr).Port 491 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 492 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 493 DelayPacket: func(quicproxy.Direction, []byte) time.Duration { 494 return rtt / 2 495 }, 496 }) 497 Expect(err).ToNot(HaveOccurred()) 498 defer proxy.Close() 499 500 _, err = quic.DialAddr( 501 context.Background(), 502 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 503 getTLSClientConfig(), 504 nil, 505 ) 506 Expect(err).To(HaveOccurred()) 507 var transportErr *quic.TransportError 508 Expect(errors.As(err, &transportErr)).To(BeTrue()) 509 Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken)) 510 }) 511 }) 512 513 Context("GetConfigForClient", func() { 514 It("uses the quic.Config returned by GetConfigForClient", func() { 515 serverConfig.EnableDatagrams = false 516 var calledFrom net.Addr 517 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 518 conf := serverConfig.Clone() 519 conf.EnableDatagrams = true 520 calledFrom = info.RemoteAddr 521 return getQuicConfig(conf), nil 522 } 523 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 524 Expect(err).ToNot(HaveOccurred()) 525 526 done := make(chan struct{}) 527 go func() { 528 defer GinkgoRecover() 529 _, err := ln.Accept(context.Background()) 530 Expect(err).ToNot(HaveOccurred()) 531 close(done) 532 }() 533 534 conn, err := quic.DialAddr( 535 context.Background(), 536 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 537 getTLSClientConfig(), 538 getQuicConfig(&quic.Config{EnableDatagrams: true}), 539 ) 540 Expect(err).ToNot(HaveOccurred()) 541 defer conn.CloseWithError(0, "") 542 cs := conn.ConnectionState() 543 Expect(cs.SupportsDatagrams).To(BeTrue()) 544 Eventually(done).Should(BeClosed()) 545 Expect(ln.Close()).To(Succeed()) 546 Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port)) 547 }) 548 549 It("rejects the connection attempt if GetConfigForClient errors", func() { 550 serverConfig.EnableDatagrams = false 551 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 552 return nil, errors.New("rejected") 553 } 554 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 555 Expect(err).ToNot(HaveOccurred()) 556 defer ln.Close() 557 558 done := make(chan struct{}) 559 go func() { 560 defer GinkgoRecover() 561 _, err := ln.Accept(context.Background()) 562 Expect(err).To(HaveOccurred()) // we don't expect to accept any connection 563 close(done) 564 }() 565 566 _, err = quic.DialAddr( 567 context.Background(), 568 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 569 getTLSClientConfig(), 570 getQuicConfig(&quic.Config{EnableDatagrams: true}), 571 ) 572 Expect(err).To(HaveOccurred()) 573 var transportErr *quic.TransportError 574 Expect(errors.As(err, &transportErr)).To(BeTrue()) 575 Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) 576 }) 577 }) 578 579 It("doesn't send any packets when generating the ClientHello fails", func() { 580 ln, err := net.ListenUDP("udp", nil) 581 Expect(err).ToNot(HaveOccurred()) 582 done := make(chan struct{}) 583 packetChan := make(chan struct{}) 584 go func() { 585 defer GinkgoRecover() 586 defer close(done) 587 for { 588 _, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize)) 589 if err != nil { 590 return 591 } 592 packetChan <- struct{}{} 593 } 594 }() 595 596 tlsConf := getTLSClientConfig() 597 tlsConf.NextProtos = []string{""} 598 _, err = quic.DialAddr( 599 context.Background(), 600 fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port), 601 tlsConf, 602 nil, 603 ) 604 Expect(err).To(MatchError(&qerr.TransportError{ 605 ErrorCode: qerr.InternalError, 606 ErrorMessage: "tls: invalid NextProtos value", 607 })) 608 Consistently(packetChan).ShouldNot(Receive()) 609 ln.Close() 610 Eventually(done).Should(BeClosed()) 611 }) 612 })