github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/handshake_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "sync/atomic" 11 "time" 12 13 "github.com/danielpfeifer02/quic-go-prio-packs" 14 quicproxy "github.com/danielpfeifer02/quic-go-prio-packs/integrationtests/tools/proxy" 15 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 16 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr" 17 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qtls" 18 "github.com/danielpfeifer02/quic-go-prio-packs/logging" 19 20 . "github.com/onsi/ginkgo/v2" 21 . "github.com/onsi/gomega" 22 ) 23 24 type tokenStore struct { 25 store quic.TokenStore 26 gets chan<- string 27 puts chan<- string 28 } 29 30 var _ quic.TokenStore = &tokenStore{} 31 32 func newTokenStore(gets, puts chan<- string) quic.TokenStore { 33 return &tokenStore{ 34 store: quic.NewLRUTokenStore(10, 4), 35 gets: gets, 36 puts: puts, 37 } 38 } 39 40 func (c *tokenStore) Put(key string, token *quic.ClientToken) { 41 c.puts <- key 42 c.store.Put(key, token) 43 } 44 45 func (c *tokenStore) Pop(key string) *quic.ClientToken { 46 c.gets <- key 47 return c.store.Pop(key) 48 } 49 50 var _ = Describe("Handshake tests", func() { 51 var ( 52 server *quic.Listener 53 serverConfig *quic.Config 54 acceptStopped chan struct{} 55 ) 56 57 BeforeEach(func() { 58 server = nil 59 acceptStopped = make(chan struct{}) 60 serverConfig = getQuicConfig(nil) 61 }) 62 63 AfterEach(func() { 64 if server != nil { 65 server.Close() 66 <-acceptStopped 67 } 68 }) 69 70 runServer := func(tlsConf *tls.Config) { 71 var err error 72 // start the server 73 server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig) 74 Expect(err).ToNot(HaveOccurred()) 75 76 go func() { 77 defer GinkgoRecover() 78 defer close(acceptStopped) 79 for { 80 if _, err := server.Accept(context.Background()); err != nil { 81 return 82 } 83 } 84 }() 85 } 86 87 It("returns the context cancellation error on timeouts", func() { 88 ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond)) 89 defer cancel() 90 errChan := make(chan error, 1) 91 go func() { 92 _, err := quic.DialAddr( 93 ctx, 94 "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway 95 getTLSClientConfig(), 96 getQuicConfig(nil), 97 ) 98 errChan <- err 99 }() 100 101 var err error 102 Eventually(errChan).Should(Receive(&err)) 103 Expect(err).To(HaveOccurred()) 104 Expect(err).To(MatchError(context.DeadlineExceeded)) 105 }) 106 107 It("returns the cancellation reason when a dial is canceled", func() { 108 ctx, cancel := context.WithCancelCause(context.Background()) 109 errChan := make(chan error, 1) 110 go func() { 111 _, err := quic.DialAddr( 112 ctx, 113 "localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway 114 getTLSClientConfig(), 115 getQuicConfig(nil), 116 ) 117 errChan <- err 118 }() 119 120 cancel(errors.New("application cancelled")) 121 var err error 122 Eventually(errChan).Should(Receive(&err)) 123 Expect(err).To(HaveOccurred()) 124 Expect(err).To(MatchError("application cancelled")) 125 }) 126 127 Context("using different cipher suites", func() { 128 for n, id := range map[string]uint16{ 129 "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, 130 "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, 131 "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, 132 } { 133 name := n 134 suiteID := id 135 136 It(fmt.Sprintf("using %s", name), func() { 137 reset := qtls.SetCipherSuite(suiteID) 138 defer reset() 139 140 tlsConf := getTLSConfig() 141 ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig) 142 Expect(err).ToNot(HaveOccurred()) 143 defer ln.Close() 144 145 go func() { 146 defer GinkgoRecover() 147 conn, err := ln.Accept(context.Background()) 148 Expect(err).ToNot(HaveOccurred()) 149 str, err := conn.OpenStream() 150 Expect(err).ToNot(HaveOccurred()) 151 defer str.Close() 152 _, err = str.Write(PRData) 153 Expect(err).ToNot(HaveOccurred()) 154 }() 155 156 conn, err := quic.DialAddr( 157 context.Background(), 158 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 159 getTLSClientConfig(), 160 getQuicConfig(nil), 161 ) 162 Expect(err).ToNot(HaveOccurred()) 163 str, err := conn.AcceptStream(context.Background()) 164 Expect(err).ToNot(HaveOccurred()) 165 data, err := io.ReadAll(str) 166 Expect(err).ToNot(HaveOccurred()) 167 Expect(data).To(Equal(PRData)) 168 Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID)) 169 Expect(conn.CloseWithError(0, "")).To(Succeed()) 170 }) 171 } 172 }) 173 174 Context("Certificate validation", func() { 175 It("accepts the certificate", func() { 176 runServer(getTLSConfig()) 177 conn, err := quic.DialAddr( 178 context.Background(), 179 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 180 getTLSClientConfig(), 181 getQuicConfig(nil), 182 ) 183 Expect(err).ToNot(HaveOccurred()) 184 conn.CloseWithError(0, "") 185 }) 186 187 It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() { 188 var local, remote net.Addr 189 var local2, remote2 net.Addr 190 done := make(chan struct{}) 191 tlsConf := &tls.Config{ 192 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 193 local = info.Conn.LocalAddr() 194 remote = info.Conn.RemoteAddr() 195 conf := getTLSConfig() 196 conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 197 defer close(done) 198 local2 = info.Conn.LocalAddr() 199 remote2 = info.Conn.RemoteAddr() 200 return &(conf.Certificates[0]), nil 201 } 202 return conf, nil 203 }, 204 } 205 runServer(tlsConf) 206 conn, err := quic.DialAddr( 207 context.Background(), 208 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 209 getTLSClientConfig(), 210 getQuicConfig(nil), 211 ) 212 Expect(err).ToNot(HaveOccurred()) 213 defer conn.CloseWithError(0, "") 214 Eventually(done).Should(BeClosed()) 215 Expect(server.Addr()).To(Equal(local)) 216 Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port)) 217 Expect(local).To(Equal(local2)) 218 Expect(remote).To(Equal(remote2)) 219 }) 220 221 It("works with a long certificate chain", func() { 222 runServer(getTLSConfigWithLongCertChain()) 223 conn, err := quic.DialAddr( 224 context.Background(), 225 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 226 getTLSClientConfig(), 227 getQuicConfig(nil), 228 ) 229 Expect(err).ToNot(HaveOccurred()) 230 conn.CloseWithError(0, "") 231 }) 232 233 It("errors if the server name doesn't match", func() { 234 runServer(getTLSConfig()) 235 conn, err := net.ListenUDP("udp", nil) 236 Expect(err).ToNot(HaveOccurred()) 237 conf := getTLSClientConfig() 238 conf.ServerName = "foo.bar" 239 _, err = quic.Dial( 240 context.Background(), 241 conn, 242 server.Addr(), 243 conf, 244 getQuicConfig(nil), 245 ) 246 Expect(err).To(HaveOccurred()) 247 var transportErr *quic.TransportError 248 Expect(errors.As(err, &transportErr)).To(BeTrue()) 249 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 250 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 251 var certErr *tls.CertificateVerificationError 252 Expect(errors.As(transportErr, &certErr)).To(BeTrue()) 253 }) 254 255 It("fails the handshake if the client fails to provide the requested client cert", func() { 256 tlsConf := getTLSConfig() 257 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 258 runServer(tlsConf) 259 260 conn, err := quic.DialAddr( 261 context.Background(), 262 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 263 getTLSClientConfig(), 264 getQuicConfig(nil), 265 ) 266 // Usually, the error will occur after the client already finished the handshake. 267 // However, there's a race condition here. The server's CONNECTION_CLOSE might be 268 // received before the connection is returned, so we might already get the error while dialing. 269 if err == nil { 270 errChan := make(chan error) 271 go func() { 272 defer GinkgoRecover() 273 _, err := conn.AcceptStream(context.Background()) 274 errChan <- err 275 }() 276 Eventually(errChan).Should(Receive(&err)) 277 } 278 Expect(err).To(HaveOccurred()) 279 var transportErr *quic.TransportError 280 Expect(errors.As(err, &transportErr)).To(BeTrue()) 281 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 282 Expect(transportErr.Error()).To(Or( 283 ContainSubstring("tls: certificate required"), 284 ContainSubstring("tls: bad certificate"), 285 )) 286 }) 287 288 It("uses the ServerName in the tls.Config", func() { 289 runServer(getTLSConfig()) 290 tlsConf := getTLSClientConfig() 291 tlsConf.ServerName = "foo.bar" 292 _, err := quic.DialAddr( 293 context.Background(), 294 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 295 tlsConf, 296 getQuicConfig(nil), 297 ) 298 Expect(err).To(HaveOccurred()) 299 var transportErr *quic.TransportError 300 Expect(errors.As(err, &transportErr)).To(BeTrue()) 301 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 302 Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) 303 }) 304 }) 305 306 Context("queuening and accepting connections", func() { 307 var ( 308 server *quic.Listener 309 pconn net.PacketConn 310 dialer *quic.Transport 311 ) 312 313 dial := func() (quic.Connection, error) { 314 remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) 315 raddr, err := net.ResolveUDPAddr("udp", remoteAddr) 316 Expect(err).ToNot(HaveOccurred()) 317 return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil)) 318 } 319 320 BeforeEach(func() { 321 var err error 322 // start the server, but don't call Accept 323 server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 324 Expect(err).ToNot(HaveOccurred()) 325 326 // prepare a (single) packet conn for dialing to the server 327 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 328 Expect(err).ToNot(HaveOccurred()) 329 pconn, err = net.ListenUDP("udp", laddr) 330 Expect(err).ToNot(HaveOccurred()) 331 dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4} 332 }) 333 334 AfterEach(func() { 335 Expect(server.Close()).To(Succeed()) 336 Expect(pconn.Close()).To(Succeed()) 337 Expect(dialer.Close()).To(Succeed()) 338 }) 339 340 It("rejects new connection attempts if connections don't get accepted", func() { 341 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 342 conn, err := dial() 343 Expect(err).ToNot(HaveOccurred()) 344 defer conn.CloseWithError(0, "") 345 } 346 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 347 348 conn, err := dial() 349 Expect(err).ToNot(HaveOccurred()) 350 ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 351 defer cancel() 352 _, err = conn.AcceptStream(ctx) 353 var transportErr *quic.TransportError 354 Expect(errors.As(err, &transportErr)).To(BeTrue()) 355 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 356 357 // now accept one connection, freeing one spot in the queue 358 _, err = server.Accept(context.Background()) 359 Expect(err).ToNot(HaveOccurred()) 360 // dial again, and expect that this dial succeeds 361 conn2, err := dial() 362 Expect(err).ToNot(HaveOccurred()) 363 defer conn2.CloseWithError(0, "") 364 time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued 365 366 conn3, err := dial() 367 Expect(err).ToNot(HaveOccurred()) 368 ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) 369 defer cancel() 370 _, err = conn3.AcceptStream(ctx) 371 Expect(errors.As(err, &transportErr)).To(BeTrue()) 372 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 373 }) 374 375 It("also returns closed connections from the accept queue", func() { 376 firstConn, err := dial() 377 Expect(err).ToNot(HaveOccurred()) 378 379 for i := 1; i < protocol.MaxAcceptQueueSize; i++ { 380 conn, err := dial() 381 Expect(err).ToNot(HaveOccurred()) 382 defer conn.CloseWithError(0, "") 383 } 384 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 385 386 conn, err := dial() 387 Expect(err).ToNot(HaveOccurred()) 388 ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 389 defer cancel() 390 _, err = conn.AcceptStream(ctx) 391 var transportErr *quic.TransportError 392 Expect(errors.As(err, &transportErr)).To(BeTrue()) 393 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 394 395 // Now close the one of the connection that are waiting to be accepted. 396 const appErrCode quic.ApplicationErrorCode = 12345 397 Expect(firstConn.CloseWithError(appErrCode, "")) 398 Eventually(firstConn.Context().Done()).Should(BeClosed()) 399 time.Sleep(scaleDuration(200 * time.Millisecond)) 400 401 // dial again, and expect that this fails again 402 conn2, err := dial() 403 Expect(err).ToNot(HaveOccurred()) 404 ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond) 405 defer cancel() 406 _, err = conn2.AcceptStream(ctx) 407 Expect(errors.As(err, &transportErr)).To(BeTrue()) 408 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 409 410 // now accept all connections 411 var closedConn quic.Connection 412 for i := 0; i < protocol.MaxAcceptQueueSize; i++ { 413 conn, err := server.Accept(context.Background()) 414 Expect(err).ToNot(HaveOccurred()) 415 if conn.Context().Err() != nil { 416 if closedConn != nil { 417 Fail("only expected a single closed connection") 418 } 419 closedConn = conn 420 } 421 } 422 Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection 423 _, err = closedConn.AcceptStream(context.Background()) 424 var appErr *quic.ApplicationError 425 Expect(errors.As(err, &appErr)).To(BeTrue()) 426 Expect(appErr.ErrorCode).To(Equal(appErrCode)) 427 }) 428 429 It("closes handshaking connections when the server is closed", func() { 430 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 431 Expect(err).ToNot(HaveOccurred()) 432 udpConn, err := net.ListenUDP("udp", laddr) 433 Expect(err).ToNot(HaveOccurred()) 434 tr := quic.Transport{ 435 Conn: udpConn, 436 } 437 defer tr.Close() 438 tlsConf := &tls.Config{} 439 done := make(chan struct{}) 440 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 441 <-done 442 return nil, errors.New("closed") 443 } 444 ln, err := tr.Listen(tlsConf, getQuicConfig(nil)) 445 Expect(err).ToNot(HaveOccurred()) 446 447 errChan := make(chan error, 1) 448 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 449 defer cancel() 450 go func() { 451 defer GinkgoRecover() 452 _, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) 453 errChan <- err 454 }() 455 time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued 456 Expect(ln.Close()).To(Succeed()) 457 close(done) 458 err = <-errChan 459 var transportErr *quic.TransportError 460 Expect(errors.As(err, &transportErr)).To(BeTrue()) 461 Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused)) 462 }) 463 }) 464 465 Context("limiting handshakes", func() { 466 var conn *net.UDPConn 467 468 BeforeEach(func() { 469 addr, err := net.ResolveUDPAddr("udp", "localhost:0") 470 Expect(err).ToNot(HaveOccurred()) 471 conn, err = net.ListenUDP("udp", addr) 472 Expect(err).ToNot(HaveOccurred()) 473 }) 474 475 AfterEach(func() { conn.Close() }) 476 477 It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() { 478 const limit = 3 479 tr := quic.Transport{ 480 Conn: conn, 481 MaxUnvalidatedHandshakes: limit, 482 } 483 defer tr.Close() 484 485 // Block all handshakes. 486 handshakes := make(chan struct{}) 487 var tlsConf tls.Config 488 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 489 handshakes <- struct{}{} 490 return getTLSConfig(), nil 491 } 492 ln, err := tr.Listen(&tlsConf, getQuicConfig(nil)) 493 Expect(err).ToNot(HaveOccurred()) 494 defer ln.Close() 495 496 const additional = 2 497 results := make([]struct{ retry, closed atomic.Bool }, limit+additional) 498 // Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel. 499 // Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and 500 // exactly 2 to experience a Retry. 501 for i := 0; i < limit+additional; i++ { 502 go func(index int) { 503 defer GinkgoRecover() 504 quicConf := getQuicConfig(&quic.Config{ 505 Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 506 return &logging.ConnectionTracer{ 507 ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) }, 508 ClosedConnection: func(error) { results[index].closed.Store(true) }, 509 } 510 }, 511 }) 512 conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf) 513 Expect(err).ToNot(HaveOccurred()) 514 conn.CloseWithError(0, "") 515 }(i) 516 } 517 numRetries := func() (n int) { 518 for i := 0; i < limit+additional; i++ { 519 if results[i].retry.Load() { 520 n++ 521 } 522 } 523 return 524 } 525 numClosed := func() (n int) { 526 for i := 0; i < limit+2; i++ { 527 if results[i].closed.Load() { 528 n++ 529 } 530 } 531 return 532 } 533 Eventually(numRetries).Should(Equal(additional)) 534 // allow the handshakes to complete 535 for i := 0; i < limit+additional; i++ { 536 Eventually(handshakes).Should(Receive()) 537 } 538 Eventually(numClosed).Should(Equal(limit + additional)) 539 Expect(numRetries()).To(Equal(additional)) // just to be on the safe side 540 }) 541 542 It("rejects connections when the number of handshakes reaches MaxHandshakes", func() { 543 const limit = 3 544 tr := quic.Transport{ 545 Conn: conn, 546 MaxHandshakes: limit, 547 } 548 defer tr.Close() 549 550 // Block all handshakes. 551 handshakes := make(chan struct{}) 552 var tlsConf tls.Config 553 tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { 554 handshakes <- struct{}{} 555 return getTLSConfig(), nil 556 } 557 ln, err := tr.Listen(&tlsConf, getQuicConfig(nil)) 558 Expect(err).ToNot(HaveOccurred()) 559 defer ln.Close() 560 561 const additional = 2 562 // Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel. 563 // Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and 564 // exactly 2 to experience a Retry. 565 var numSuccessful, numFailed atomic.Int32 566 for i := 0; i < limit+additional; i++ { 567 go func() { 568 defer GinkgoRecover() 569 quicConf := getQuicConfig(&quic.Config{ 570 Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 571 return &logging.ConnectionTracer{ 572 ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") }, 573 } 574 }, 575 }) 576 conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf) 577 if err != nil { 578 var transportErr *quic.TransportError 579 if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused { 580 Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err)) 581 } 582 numFailed.Add(1) 583 return 584 } 585 numSuccessful.Add(1) 586 conn.CloseWithError(0, "") 587 }() 588 } 589 Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional)) 590 // allow the handshakes to complete 591 for i := 0; i < limit; i++ { 592 Eventually(handshakes).Should(Receive()) 593 } 594 Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit)) 595 596 // make sure that the server is reachable again after these handshakes have completed 597 go func() { <-handshakes }() // allow this handshake to complete immediately 598 conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil)) 599 Expect(err).ToNot(HaveOccurred()) 600 conn.CloseWithError(0, "") 601 }) 602 }) 603 604 Context("ALPN", func() { 605 It("negotiates an application protocol", func() { 606 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 607 Expect(err).ToNot(HaveOccurred()) 608 609 done := make(chan struct{}) 610 go func() { 611 defer GinkgoRecover() 612 conn, err := ln.Accept(context.Background()) 613 Expect(err).ToNot(HaveOccurred()) 614 cs := conn.ConnectionState() 615 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 616 close(done) 617 }() 618 619 conn, err := quic.DialAddr( 620 context.Background(), 621 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 622 getTLSClientConfig(), 623 nil, 624 ) 625 Expect(err).ToNot(HaveOccurred()) 626 defer conn.CloseWithError(0, "") 627 cs := conn.ConnectionState() 628 Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn)) 629 Eventually(done).Should(BeClosed()) 630 Expect(ln.Close()).To(Succeed()) 631 }) 632 633 It("errors if application protocol negotiation fails", func() { 634 runServer(getTLSConfig()) 635 636 tlsConf := getTLSClientConfig() 637 tlsConf.NextProtos = []string{"foobar"} 638 _, err := quic.DialAddr( 639 context.Background(), 640 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 641 tlsConf, 642 nil, 643 ) 644 Expect(err).To(HaveOccurred()) 645 var transportErr *quic.TransportError 646 Expect(errors.As(err, &transportErr)).To(BeTrue()) 647 Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) 648 Expect(transportErr.Error()).To(ContainSubstring("no application protocol")) 649 }) 650 }) 651 652 Context("using tokens", func() { 653 It("uses tokens provided in NEW_TOKEN frames", func() { 654 server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 655 Expect(err).ToNot(HaveOccurred()) 656 defer server.Close() 657 658 // dial the first connection and receive the token 659 go func() { 660 defer GinkgoRecover() 661 _, err := server.Accept(context.Background()) 662 Expect(err).ToNot(HaveOccurred()) 663 }() 664 665 gets := make(chan string, 100) 666 puts := make(chan string, 100) 667 tokenStore := newTokenStore(gets, puts) 668 quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore}) 669 conn, err := quic.DialAddr( 670 context.Background(), 671 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 672 getTLSClientConfig(), 673 quicConf, 674 ) 675 Expect(err).ToNot(HaveOccurred()) 676 Expect(gets).To(Receive()) 677 Eventually(puts).Should(Receive()) 678 // received a token. Close this connection. 679 Expect(conn.CloseWithError(0, "")).To(Succeed()) 680 681 // dial the second connection and verify that the token was used 682 done := make(chan struct{}) 683 go func() { 684 defer GinkgoRecover() 685 defer close(done) 686 _, err := server.Accept(context.Background()) 687 Expect(err).ToNot(HaveOccurred()) 688 }() 689 conn, err = quic.DialAddr( 690 context.Background(), 691 fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), 692 getTLSClientConfig(), 693 quicConf, 694 ) 695 Expect(err).ToNot(HaveOccurred()) 696 defer conn.CloseWithError(0, "") 697 Expect(gets).To(Receive()) 698 699 Eventually(done).Should(BeClosed()) 700 }) 701 702 It("rejects invalid Retry token with the INVALID_TOKEN error", func() { 703 const rtt = 10 * time.Millisecond 704 705 // The validity period of the retry token is the handshake timeout, 706 // which is twice the handshake idle timeout. 707 // By setting the handshake timeout shorter than the RTT, the token will have expired by the time 708 // it reaches the server. 709 serverConfig.HandshakeIdleTimeout = rtt / 5 710 711 laddr, err := net.ResolveUDPAddr("udp", "localhost:0") 712 Expect(err).ToNot(HaveOccurred()) 713 udpConn, err := net.ListenUDP("udp", laddr) 714 Expect(err).ToNot(HaveOccurred()) 715 defer udpConn.Close() 716 tr := &quic.Transport{ 717 Conn: udpConn, 718 MaxUnvalidatedHandshakes: -1, 719 } 720 defer tr.Close() 721 server, err := tr.Listen(getTLSConfig(), serverConfig) 722 Expect(err).ToNot(HaveOccurred()) 723 defer server.Close() 724 725 serverPort := server.Addr().(*net.UDPAddr).Port 726 proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ 727 RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), 728 DelayPacket: func(quicproxy.Direction, []byte) time.Duration { 729 return rtt / 2 730 }, 731 }) 732 Expect(err).ToNot(HaveOccurred()) 733 defer proxy.Close() 734 735 _, err = quic.DialAddr( 736 context.Background(), 737 fmt.Sprintf("localhost:%d", proxy.LocalPort()), 738 getTLSClientConfig(), 739 nil, 740 ) 741 Expect(err).To(HaveOccurred()) 742 var transportErr *quic.TransportError 743 Expect(errors.As(err, &transportErr)).To(BeTrue()) 744 Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken)) 745 }) 746 }) 747 748 Context("GetConfigForClient", func() { 749 It("uses the quic.Config returned by GetConfigForClient", func() { 750 serverConfig.EnableDatagrams = false 751 var calledFrom net.Addr 752 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 753 conf := serverConfig.Clone() 754 conf.EnableDatagrams = true 755 calledFrom = info.RemoteAddr 756 return getQuicConfig(conf), nil 757 } 758 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 759 Expect(err).ToNot(HaveOccurred()) 760 761 done := make(chan struct{}) 762 go func() { 763 defer GinkgoRecover() 764 _, err := ln.Accept(context.Background()) 765 Expect(err).ToNot(HaveOccurred()) 766 close(done) 767 }() 768 769 conn, err := quic.DialAddr( 770 context.Background(), 771 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 772 getTLSClientConfig(), 773 getQuicConfig(&quic.Config{EnableDatagrams: true}), 774 ) 775 Expect(err).ToNot(HaveOccurred()) 776 defer conn.CloseWithError(0, "") 777 cs := conn.ConnectionState() 778 Expect(cs.SupportsDatagrams).To(BeTrue()) 779 Eventually(done).Should(BeClosed()) 780 Expect(ln.Close()).To(Succeed()) 781 Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port)) 782 }) 783 784 It("rejects the connection attempt if GetConfigForClient errors", func() { 785 serverConfig.EnableDatagrams = false 786 serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) { 787 return nil, errors.New("rejected") 788 } 789 ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) 790 Expect(err).ToNot(HaveOccurred()) 791 defer ln.Close() 792 793 done := make(chan struct{}) 794 go func() { 795 defer GinkgoRecover() 796 _, err := ln.Accept(context.Background()) 797 Expect(err).To(HaveOccurred()) // we don't expect to accept any connection 798 close(done) 799 }() 800 801 _, err = quic.DialAddr( 802 context.Background(), 803 fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), 804 getTLSClientConfig(), 805 getQuicConfig(&quic.Config{EnableDatagrams: true}), 806 ) 807 Expect(err).To(HaveOccurred()) 808 var transportErr *quic.TransportError 809 Expect(errors.As(err, &transportErr)).To(BeTrue()) 810 Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused)) 811 }) 812 }) 813 814 It("doesn't send any packets when generating the ClientHello fails", func() { 815 ln, err := net.ListenUDP("udp", nil) 816 Expect(err).ToNot(HaveOccurred()) 817 done := make(chan struct{}) 818 packetChan := make(chan struct{}) 819 go func() { 820 defer GinkgoRecover() 821 defer close(done) 822 for { 823 _, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize)) 824 if err != nil { 825 return 826 } 827 packetChan <- struct{}{} 828 } 829 }() 830 831 tlsConf := getTLSClientConfig() 832 tlsConf.NextProtos = []string{""} 833 _, err = quic.DialAddr( 834 context.Background(), 835 fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port), 836 tlsConf, 837 nil, 838 ) 839 Expect(err).To(MatchError(&qerr.TransportError{ 840 ErrorCode: qerr.InternalError, 841 ErrorMessage: "tls: invalid NextProtos value", 842 })) 843 Consistently(packetChan).ShouldNot(Receive()) 844 ln.Close() 845 Eventually(done).Should(BeClosed()) 846 }) 847 })