github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/handshake_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/danielpfeifer02/quic-go-prio-packs"
    14  	quicproxy "github.com/danielpfeifer02/quic-go-prio-packs/integrationtests/tools/proxy"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    16  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    17  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qtls"
    18  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    19  
    20  	. "github.com/onsi/ginkgo/v2"
    21  	. "github.com/onsi/gomega"
    22  )
    23  
    24  type tokenStore struct {
    25  	store quic.TokenStore
    26  	gets  chan<- string
    27  	puts  chan<- string
    28  }
    29  
    30  var _ quic.TokenStore = &tokenStore{}
    31  
    32  func newTokenStore(gets, puts chan<- string) quic.TokenStore {
    33  	return &tokenStore{
    34  		store: quic.NewLRUTokenStore(10, 4),
    35  		gets:  gets,
    36  		puts:  puts,
    37  	}
    38  }
    39  
    40  func (c *tokenStore) Put(key string, token *quic.ClientToken) {
    41  	c.puts <- key
    42  	c.store.Put(key, token)
    43  }
    44  
    45  func (c *tokenStore) Pop(key string) *quic.ClientToken {
    46  	c.gets <- key
    47  	return c.store.Pop(key)
    48  }
    49  
    50  var _ = Describe("Handshake tests", func() {
    51  	var (
    52  		server        *quic.Listener
    53  		serverConfig  *quic.Config
    54  		acceptStopped chan struct{}
    55  	)
    56  
    57  	BeforeEach(func() {
    58  		server = nil
    59  		acceptStopped = make(chan struct{})
    60  		serverConfig = getQuicConfig(nil)
    61  	})
    62  
    63  	AfterEach(func() {
    64  		if server != nil {
    65  			server.Close()
    66  			<-acceptStopped
    67  		}
    68  	})
    69  
    70  	runServer := func(tlsConf *tls.Config) {
    71  		var err error
    72  		// start the server
    73  		server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
    74  		Expect(err).ToNot(HaveOccurred())
    75  
    76  		go func() {
    77  			defer GinkgoRecover()
    78  			defer close(acceptStopped)
    79  			for {
    80  				if _, err := server.Accept(context.Background()); err != nil {
    81  					return
    82  				}
    83  			}
    84  		}()
    85  	}
    86  
    87  	It("returns the context cancellation error on timeouts", func() {
    88  		ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
    89  		defer cancel()
    90  		errChan := make(chan error, 1)
    91  		go func() {
    92  			_, err := quic.DialAddr(
    93  				ctx,
    94  				"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
    95  				getTLSClientConfig(),
    96  				getQuicConfig(nil),
    97  			)
    98  			errChan <- err
    99  		}()
   100  
   101  		var err error
   102  		Eventually(errChan).Should(Receive(&err))
   103  		Expect(err).To(HaveOccurred())
   104  		Expect(err).To(MatchError(context.DeadlineExceeded))
   105  	})
   106  
   107  	It("returns the cancellation reason when a dial is canceled", func() {
   108  		ctx, cancel := context.WithCancelCause(context.Background())
   109  		errChan := make(chan error, 1)
   110  		go func() {
   111  			_, err := quic.DialAddr(
   112  				ctx,
   113  				"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
   114  				getTLSClientConfig(),
   115  				getQuicConfig(nil),
   116  			)
   117  			errChan <- err
   118  		}()
   119  
   120  		cancel(errors.New("application cancelled"))
   121  		var err error
   122  		Eventually(errChan).Should(Receive(&err))
   123  		Expect(err).To(HaveOccurred())
   124  		Expect(err).To(MatchError("application cancelled"))
   125  	})
   126  
   127  	Context("using different cipher suites", func() {
   128  		for n, id := range map[string]uint16{
   129  			"TLS_AES_128_GCM_SHA256":       tls.TLS_AES_128_GCM_SHA256,
   130  			"TLS_AES_256_GCM_SHA384":       tls.TLS_AES_256_GCM_SHA384,
   131  			"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
   132  		} {
   133  			name := n
   134  			suiteID := id
   135  
   136  			It(fmt.Sprintf("using %s", name), func() {
   137  				reset := qtls.SetCipherSuite(suiteID)
   138  				defer reset()
   139  
   140  				tlsConf := getTLSConfig()
   141  				ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
   142  				Expect(err).ToNot(HaveOccurred())
   143  				defer ln.Close()
   144  
   145  				go func() {
   146  					defer GinkgoRecover()
   147  					conn, err := ln.Accept(context.Background())
   148  					Expect(err).ToNot(HaveOccurred())
   149  					str, err := conn.OpenStream()
   150  					Expect(err).ToNot(HaveOccurred())
   151  					defer str.Close()
   152  					_, err = str.Write(PRData)
   153  					Expect(err).ToNot(HaveOccurred())
   154  				}()
   155  
   156  				conn, err := quic.DialAddr(
   157  					context.Background(),
   158  					fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   159  					getTLSClientConfig(),
   160  					getQuicConfig(nil),
   161  				)
   162  				Expect(err).ToNot(HaveOccurred())
   163  				str, err := conn.AcceptStream(context.Background())
   164  				Expect(err).ToNot(HaveOccurred())
   165  				data, err := io.ReadAll(str)
   166  				Expect(err).ToNot(HaveOccurred())
   167  				Expect(data).To(Equal(PRData))
   168  				Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID))
   169  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   170  			})
   171  		}
   172  	})
   173  
   174  	Context("Certificate validation", func() {
   175  		It("accepts the certificate", func() {
   176  			runServer(getTLSConfig())
   177  			conn, err := quic.DialAddr(
   178  				context.Background(),
   179  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   180  				getTLSClientConfig(),
   181  				getQuicConfig(nil),
   182  			)
   183  			Expect(err).ToNot(HaveOccurred())
   184  			conn.CloseWithError(0, "")
   185  		})
   186  
   187  		It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
   188  			var local, remote net.Addr
   189  			var local2, remote2 net.Addr
   190  			done := make(chan struct{})
   191  			tlsConf := &tls.Config{
   192  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   193  					local = info.Conn.LocalAddr()
   194  					remote = info.Conn.RemoteAddr()
   195  					conf := getTLSConfig()
   196  					conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   197  						defer close(done)
   198  						local2 = info.Conn.LocalAddr()
   199  						remote2 = info.Conn.RemoteAddr()
   200  						return &(conf.Certificates[0]), nil
   201  					}
   202  					return conf, nil
   203  				},
   204  			}
   205  			runServer(tlsConf)
   206  			conn, err := quic.DialAddr(
   207  				context.Background(),
   208  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   209  				getTLSClientConfig(),
   210  				getQuicConfig(nil),
   211  			)
   212  			Expect(err).ToNot(HaveOccurred())
   213  			defer conn.CloseWithError(0, "")
   214  			Eventually(done).Should(BeClosed())
   215  			Expect(server.Addr()).To(Equal(local))
   216  			Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
   217  			Expect(local).To(Equal(local2))
   218  			Expect(remote).To(Equal(remote2))
   219  		})
   220  
   221  		It("works with a long certificate chain", func() {
   222  			runServer(getTLSConfigWithLongCertChain())
   223  			conn, err := quic.DialAddr(
   224  				context.Background(),
   225  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   226  				getTLSClientConfig(),
   227  				getQuicConfig(nil),
   228  			)
   229  			Expect(err).ToNot(HaveOccurred())
   230  			conn.CloseWithError(0, "")
   231  		})
   232  
   233  		It("errors if the server name doesn't match", func() {
   234  			runServer(getTLSConfig())
   235  			conn, err := net.ListenUDP("udp", nil)
   236  			Expect(err).ToNot(HaveOccurred())
   237  			conf := getTLSClientConfig()
   238  			conf.ServerName = "foo.bar"
   239  			_, err = quic.Dial(
   240  				context.Background(),
   241  				conn,
   242  				server.Addr(),
   243  				conf,
   244  				getQuicConfig(nil),
   245  			)
   246  			Expect(err).To(HaveOccurred())
   247  			var transportErr *quic.TransportError
   248  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   249  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   250  			Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
   251  			var certErr *tls.CertificateVerificationError
   252  			Expect(errors.As(transportErr, &certErr)).To(BeTrue())
   253  		})
   254  
   255  		It("fails the handshake if the client fails to provide the requested client cert", func() {
   256  			tlsConf := getTLSConfig()
   257  			tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
   258  			runServer(tlsConf)
   259  
   260  			conn, err := quic.DialAddr(
   261  				context.Background(),
   262  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   263  				getTLSClientConfig(),
   264  				getQuicConfig(nil),
   265  			)
   266  			// Usually, the error will occur after the client already finished the handshake.
   267  			// However, there's a race condition here. The server's CONNECTION_CLOSE might be
   268  			// received before the connection is returned, so we might already get the error while dialing.
   269  			if err == nil {
   270  				errChan := make(chan error)
   271  				go func() {
   272  					defer GinkgoRecover()
   273  					_, err := conn.AcceptStream(context.Background())
   274  					errChan <- err
   275  				}()
   276  				Eventually(errChan).Should(Receive(&err))
   277  			}
   278  			Expect(err).To(HaveOccurred())
   279  			var transportErr *quic.TransportError
   280  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   281  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   282  			Expect(transportErr.Error()).To(Or(
   283  				ContainSubstring("tls: certificate required"),
   284  				ContainSubstring("tls: bad certificate"),
   285  			))
   286  		})
   287  
   288  		It("uses the ServerName in the tls.Config", func() {
   289  			runServer(getTLSConfig())
   290  			tlsConf := getTLSClientConfig()
   291  			tlsConf.ServerName = "foo.bar"
   292  			_, err := quic.DialAddr(
   293  				context.Background(),
   294  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   295  				tlsConf,
   296  				getQuicConfig(nil),
   297  			)
   298  			Expect(err).To(HaveOccurred())
   299  			var transportErr *quic.TransportError
   300  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   301  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   302  			Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
   303  		})
   304  	})
   305  
   306  	Context("queuening and accepting connections", func() {
   307  		var (
   308  			server *quic.Listener
   309  			pconn  net.PacketConn
   310  			dialer *quic.Transport
   311  		)
   312  
   313  		dial := func() (quic.Connection, error) {
   314  			remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
   315  			raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
   316  			Expect(err).ToNot(HaveOccurred())
   317  			return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
   318  		}
   319  
   320  		BeforeEach(func() {
   321  			var err error
   322  			// start the server, but don't call Accept
   323  			server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   324  			Expect(err).ToNot(HaveOccurred())
   325  
   326  			// prepare a (single) packet conn for dialing to the server
   327  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   328  			Expect(err).ToNot(HaveOccurred())
   329  			pconn, err = net.ListenUDP("udp", laddr)
   330  			Expect(err).ToNot(HaveOccurred())
   331  			dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
   332  		})
   333  
   334  		AfterEach(func() {
   335  			Expect(server.Close()).To(Succeed())
   336  			Expect(pconn.Close()).To(Succeed())
   337  			Expect(dialer.Close()).To(Succeed())
   338  		})
   339  
   340  		It("rejects new connection attempts if connections don't get accepted", func() {
   341  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   342  				conn, err := dial()
   343  				Expect(err).ToNot(HaveOccurred())
   344  				defer conn.CloseWithError(0, "")
   345  			}
   346  			time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
   347  
   348  			conn, err := dial()
   349  			Expect(err).ToNot(HaveOccurred())
   350  			ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   351  			defer cancel()
   352  			_, err = conn.AcceptStream(ctx)
   353  			var transportErr *quic.TransportError
   354  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   355  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   356  
   357  			// now accept one connection, freeing one spot in the queue
   358  			_, err = server.Accept(context.Background())
   359  			Expect(err).ToNot(HaveOccurred())
   360  			// dial again, and expect that this dial succeeds
   361  			conn2, err := dial()
   362  			Expect(err).ToNot(HaveOccurred())
   363  			defer conn2.CloseWithError(0, "")
   364  			time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
   365  
   366  			conn3, err := dial()
   367  			Expect(err).ToNot(HaveOccurred())
   368  			ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
   369  			defer cancel()
   370  			_, err = conn3.AcceptStream(ctx)
   371  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   372  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   373  		})
   374  
   375  		It("also returns closed connections from the accept queue", func() {
   376  			firstConn, err := dial()
   377  			Expect(err).ToNot(HaveOccurred())
   378  
   379  			for i := 1; i < protocol.MaxAcceptQueueSize; i++ {
   380  				conn, err := dial()
   381  				Expect(err).ToNot(HaveOccurred())
   382  				defer conn.CloseWithError(0, "")
   383  			}
   384  			time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
   385  
   386  			conn, err := dial()
   387  			Expect(err).ToNot(HaveOccurred())
   388  			ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   389  			defer cancel()
   390  			_, err = conn.AcceptStream(ctx)
   391  			var transportErr *quic.TransportError
   392  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   393  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   394  
   395  			// Now close the one of the connection that are waiting to be accepted.
   396  			const appErrCode quic.ApplicationErrorCode = 12345
   397  			Expect(firstConn.CloseWithError(appErrCode, ""))
   398  			Eventually(firstConn.Context().Done()).Should(BeClosed())
   399  			time.Sleep(scaleDuration(200 * time.Millisecond))
   400  
   401  			// dial again, and expect that this fails again
   402  			conn2, err := dial()
   403  			Expect(err).ToNot(HaveOccurred())
   404  			ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
   405  			defer cancel()
   406  			_, err = conn2.AcceptStream(ctx)
   407  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   408  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   409  
   410  			// now accept all connections
   411  			var closedConn quic.Connection
   412  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   413  				conn, err := server.Accept(context.Background())
   414  				Expect(err).ToNot(HaveOccurred())
   415  				if conn.Context().Err() != nil {
   416  					if closedConn != nil {
   417  						Fail("only expected a single closed connection")
   418  					}
   419  					closedConn = conn
   420  				}
   421  			}
   422  			Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection
   423  			_, err = closedConn.AcceptStream(context.Background())
   424  			var appErr *quic.ApplicationError
   425  			Expect(errors.As(err, &appErr)).To(BeTrue())
   426  			Expect(appErr.ErrorCode).To(Equal(appErrCode))
   427  		})
   428  
   429  		It("closes handshaking connections when the server is closed", func() {
   430  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   431  			Expect(err).ToNot(HaveOccurred())
   432  			udpConn, err := net.ListenUDP("udp", laddr)
   433  			Expect(err).ToNot(HaveOccurred())
   434  			tr := quic.Transport{
   435  				Conn: udpConn,
   436  			}
   437  			defer tr.Close()
   438  			tlsConf := &tls.Config{}
   439  			done := make(chan struct{})
   440  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   441  				<-done
   442  				return nil, errors.New("closed")
   443  			}
   444  			ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
   445  			Expect(err).ToNot(HaveOccurred())
   446  
   447  			errChan := make(chan error, 1)
   448  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   449  			defer cancel()
   450  			go func() {
   451  				defer GinkgoRecover()
   452  				_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
   453  				errChan <- err
   454  			}()
   455  			time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
   456  			Expect(ln.Close()).To(Succeed())
   457  			close(done)
   458  			err = <-errChan
   459  			var transportErr *quic.TransportError
   460  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   461  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   462  		})
   463  	})
   464  
   465  	Context("limiting handshakes", func() {
   466  		var conn *net.UDPConn
   467  
   468  		BeforeEach(func() {
   469  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   470  			Expect(err).ToNot(HaveOccurred())
   471  			conn, err = net.ListenUDP("udp", addr)
   472  			Expect(err).ToNot(HaveOccurred())
   473  		})
   474  
   475  		AfterEach(func() { conn.Close() })
   476  
   477  		It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
   478  			const limit = 3
   479  			tr := quic.Transport{
   480  				Conn:                     conn,
   481  				MaxUnvalidatedHandshakes: limit,
   482  			}
   483  			defer tr.Close()
   484  
   485  			// Block all handshakes.
   486  			handshakes := make(chan struct{})
   487  			var tlsConf tls.Config
   488  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   489  				handshakes <- struct{}{}
   490  				return getTLSConfig(), nil
   491  			}
   492  			ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
   493  			Expect(err).ToNot(HaveOccurred())
   494  			defer ln.Close()
   495  
   496  			const additional = 2
   497  			results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
   498  			// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
   499  			// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
   500  			// exactly 2 to experience a Retry.
   501  			for i := 0; i < limit+additional; i++ {
   502  				go func(index int) {
   503  					defer GinkgoRecover()
   504  					quicConf := getQuicConfig(&quic.Config{
   505  						Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   506  							return &logging.ConnectionTracer{
   507  								ReceivedRetry:    func(*logging.Header) { results[index].retry.Store(true) },
   508  								ClosedConnection: func(error) { results[index].closed.Store(true) },
   509  							}
   510  						},
   511  					})
   512  					conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
   513  					Expect(err).ToNot(HaveOccurred())
   514  					conn.CloseWithError(0, "")
   515  				}(i)
   516  			}
   517  			numRetries := func() (n int) {
   518  				for i := 0; i < limit+additional; i++ {
   519  					if results[i].retry.Load() {
   520  						n++
   521  					}
   522  				}
   523  				return
   524  			}
   525  			numClosed := func() (n int) {
   526  				for i := 0; i < limit+2; i++ {
   527  					if results[i].closed.Load() {
   528  						n++
   529  					}
   530  				}
   531  				return
   532  			}
   533  			Eventually(numRetries).Should(Equal(additional))
   534  			// allow the handshakes to complete
   535  			for i := 0; i < limit+additional; i++ {
   536  				Eventually(handshakes).Should(Receive())
   537  			}
   538  			Eventually(numClosed).Should(Equal(limit + additional))
   539  			Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
   540  		})
   541  
   542  		It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
   543  			const limit = 3
   544  			tr := quic.Transport{
   545  				Conn:          conn,
   546  				MaxHandshakes: limit,
   547  			}
   548  			defer tr.Close()
   549  
   550  			// Block all handshakes.
   551  			handshakes := make(chan struct{})
   552  			var tlsConf tls.Config
   553  			tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
   554  				handshakes <- struct{}{}
   555  				return getTLSConfig(), nil
   556  			}
   557  			ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
   558  			Expect(err).ToNot(HaveOccurred())
   559  			defer ln.Close()
   560  
   561  			const additional = 2
   562  			// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
   563  			// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
   564  			// exactly 2 to experience a Retry.
   565  			var numSuccessful, numFailed atomic.Int32
   566  			for i := 0; i < limit+additional; i++ {
   567  				go func() {
   568  					defer GinkgoRecover()
   569  					quicConf := getQuicConfig(&quic.Config{
   570  						Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   571  							return &logging.ConnectionTracer{
   572  								ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
   573  							}
   574  						},
   575  					})
   576  					conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
   577  					if err != nil {
   578  						var transportErr *quic.TransportError
   579  						if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
   580  							Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
   581  						}
   582  						numFailed.Add(1)
   583  						return
   584  					}
   585  					numSuccessful.Add(1)
   586  					conn.CloseWithError(0, "")
   587  				}()
   588  			}
   589  			Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
   590  			// allow the handshakes to complete
   591  			for i := 0; i < limit; i++ {
   592  				Eventually(handshakes).Should(Receive())
   593  			}
   594  			Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))
   595  
   596  			// make sure that the server is reachable again after these handshakes have completed
   597  			go func() { <-handshakes }() // allow this handshake to complete immediately
   598  			conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
   599  			Expect(err).ToNot(HaveOccurred())
   600  			conn.CloseWithError(0, "")
   601  		})
   602  	})
   603  
   604  	Context("ALPN", func() {
   605  		It("negotiates an application protocol", func() {
   606  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   607  			Expect(err).ToNot(HaveOccurred())
   608  
   609  			done := make(chan struct{})
   610  			go func() {
   611  				defer GinkgoRecover()
   612  				conn, err := ln.Accept(context.Background())
   613  				Expect(err).ToNot(HaveOccurred())
   614  				cs := conn.ConnectionState()
   615  				Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
   616  				close(done)
   617  			}()
   618  
   619  			conn, err := quic.DialAddr(
   620  				context.Background(),
   621  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   622  				getTLSClientConfig(),
   623  				nil,
   624  			)
   625  			Expect(err).ToNot(HaveOccurred())
   626  			defer conn.CloseWithError(0, "")
   627  			cs := conn.ConnectionState()
   628  			Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
   629  			Eventually(done).Should(BeClosed())
   630  			Expect(ln.Close()).To(Succeed())
   631  		})
   632  
   633  		It("errors if application protocol negotiation fails", func() {
   634  			runServer(getTLSConfig())
   635  
   636  			tlsConf := getTLSClientConfig()
   637  			tlsConf.NextProtos = []string{"foobar"}
   638  			_, err := quic.DialAddr(
   639  				context.Background(),
   640  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   641  				tlsConf,
   642  				nil,
   643  			)
   644  			Expect(err).To(HaveOccurred())
   645  			var transportErr *quic.TransportError
   646  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   647  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   648  			Expect(transportErr.Error()).To(ContainSubstring("no application protocol"))
   649  		})
   650  	})
   651  
   652  	Context("using tokens", func() {
   653  		It("uses tokens provided in NEW_TOKEN frames", func() {
   654  			server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   655  			Expect(err).ToNot(HaveOccurred())
   656  			defer server.Close()
   657  
   658  			// dial the first connection and receive the token
   659  			go func() {
   660  				defer GinkgoRecover()
   661  				_, err := server.Accept(context.Background())
   662  				Expect(err).ToNot(HaveOccurred())
   663  			}()
   664  
   665  			gets := make(chan string, 100)
   666  			puts := make(chan string, 100)
   667  			tokenStore := newTokenStore(gets, puts)
   668  			quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore})
   669  			conn, err := quic.DialAddr(
   670  				context.Background(),
   671  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   672  				getTLSClientConfig(),
   673  				quicConf,
   674  			)
   675  			Expect(err).ToNot(HaveOccurred())
   676  			Expect(gets).To(Receive())
   677  			Eventually(puts).Should(Receive())
   678  			// received a token. Close this connection.
   679  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   680  
   681  			// dial the second connection and verify that the token was used
   682  			done := make(chan struct{})
   683  			go func() {
   684  				defer GinkgoRecover()
   685  				defer close(done)
   686  				_, err := server.Accept(context.Background())
   687  				Expect(err).ToNot(HaveOccurred())
   688  			}()
   689  			conn, err = quic.DialAddr(
   690  				context.Background(),
   691  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   692  				getTLSClientConfig(),
   693  				quicConf,
   694  			)
   695  			Expect(err).ToNot(HaveOccurred())
   696  			defer conn.CloseWithError(0, "")
   697  			Expect(gets).To(Receive())
   698  
   699  			Eventually(done).Should(BeClosed())
   700  		})
   701  
   702  		It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
   703  			const rtt = 10 * time.Millisecond
   704  
   705  			// The validity period of the retry token is the handshake timeout,
   706  			// which is twice the handshake idle timeout.
   707  			// By setting the handshake timeout shorter than the RTT, the token will have expired by the time
   708  			// it reaches the server.
   709  			serverConfig.HandshakeIdleTimeout = rtt / 5
   710  
   711  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   712  			Expect(err).ToNot(HaveOccurred())
   713  			udpConn, err := net.ListenUDP("udp", laddr)
   714  			Expect(err).ToNot(HaveOccurred())
   715  			defer udpConn.Close()
   716  			tr := &quic.Transport{
   717  				Conn:                     udpConn,
   718  				MaxUnvalidatedHandshakes: -1,
   719  			}
   720  			defer tr.Close()
   721  			server, err := tr.Listen(getTLSConfig(), serverConfig)
   722  			Expect(err).ToNot(HaveOccurred())
   723  			defer server.Close()
   724  
   725  			serverPort := server.Addr().(*net.UDPAddr).Port
   726  			proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   727  				RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
   728  				DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
   729  					return rtt / 2
   730  				},
   731  			})
   732  			Expect(err).ToNot(HaveOccurred())
   733  			defer proxy.Close()
   734  
   735  			_, err = quic.DialAddr(
   736  				context.Background(),
   737  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   738  				getTLSClientConfig(),
   739  				nil,
   740  			)
   741  			Expect(err).To(HaveOccurred())
   742  			var transportErr *quic.TransportError
   743  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   744  			Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
   745  		})
   746  	})
   747  
   748  	Context("GetConfigForClient", func() {
   749  		It("uses the quic.Config returned by GetConfigForClient", func() {
   750  			serverConfig.EnableDatagrams = false
   751  			var calledFrom net.Addr
   752  			serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
   753  				conf := serverConfig.Clone()
   754  				conf.EnableDatagrams = true
   755  				calledFrom = info.RemoteAddr
   756  				return getQuicConfig(conf), nil
   757  			}
   758  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   759  			Expect(err).ToNot(HaveOccurred())
   760  
   761  			done := make(chan struct{})
   762  			go func() {
   763  				defer GinkgoRecover()
   764  				_, err := ln.Accept(context.Background())
   765  				Expect(err).ToNot(HaveOccurred())
   766  				close(done)
   767  			}()
   768  
   769  			conn, err := quic.DialAddr(
   770  				context.Background(),
   771  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   772  				getTLSClientConfig(),
   773  				getQuicConfig(&quic.Config{EnableDatagrams: true}),
   774  			)
   775  			Expect(err).ToNot(HaveOccurred())
   776  			defer conn.CloseWithError(0, "")
   777  			cs := conn.ConnectionState()
   778  			Expect(cs.SupportsDatagrams).To(BeTrue())
   779  			Eventually(done).Should(BeClosed())
   780  			Expect(ln.Close()).To(Succeed())
   781  			Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
   782  		})
   783  
   784  		It("rejects the connection attempt if GetConfigForClient errors", func() {
   785  			serverConfig.EnableDatagrams = false
   786  			serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
   787  				return nil, errors.New("rejected")
   788  			}
   789  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   790  			Expect(err).ToNot(HaveOccurred())
   791  			defer ln.Close()
   792  
   793  			done := make(chan struct{})
   794  			go func() {
   795  				defer GinkgoRecover()
   796  				_, err := ln.Accept(context.Background())
   797  				Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
   798  				close(done)
   799  			}()
   800  
   801  			_, err = quic.DialAddr(
   802  				context.Background(),
   803  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   804  				getTLSClientConfig(),
   805  				getQuicConfig(&quic.Config{EnableDatagrams: true}),
   806  			)
   807  			Expect(err).To(HaveOccurred())
   808  			var transportErr *quic.TransportError
   809  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   810  			Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
   811  		})
   812  	})
   813  
   814  	It("doesn't send any packets when generating the ClientHello fails", func() {
   815  		ln, err := net.ListenUDP("udp", nil)
   816  		Expect(err).ToNot(HaveOccurred())
   817  		done := make(chan struct{})
   818  		packetChan := make(chan struct{})
   819  		go func() {
   820  			defer GinkgoRecover()
   821  			defer close(done)
   822  			for {
   823  				_, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize))
   824  				if err != nil {
   825  					return
   826  				}
   827  				packetChan <- struct{}{}
   828  			}
   829  		}()
   830  
   831  		tlsConf := getTLSClientConfig()
   832  		tlsConf.NextProtos = []string{""}
   833  		_, err = quic.DialAddr(
   834  			context.Background(),
   835  			fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port),
   836  			tlsConf,
   837  			nil,
   838  		)
   839  		Expect(err).To(MatchError(&qerr.TransportError{
   840  			ErrorCode:    qerr.InternalError,
   841  			ErrorMessage: "tls: invalid NextProtos value",
   842  		}))
   843  		Consistently(packetChan).ShouldNot(Receive())
   844  		ln.Close()
   845  		Eventually(done).Should(BeClosed())
   846  	})
   847  })