github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/self/handshake_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"time"
    11  
    12  	"github.com/daeuniverse/quic-go"
    13  	quicproxy "github.com/daeuniverse/quic-go/integrationtests/tools/proxy"
    14  	"github.com/daeuniverse/quic-go/internal/protocol"
    15  	"github.com/daeuniverse/quic-go/internal/qerr"
    16  	"github.com/daeuniverse/quic-go/internal/qtls"
    17  
    18  	. "github.com/onsi/ginkgo/v2"
    19  	. "github.com/onsi/gomega"
    20  )
    21  
    22  type tokenStore struct {
    23  	store quic.TokenStore
    24  	gets  chan<- string
    25  	puts  chan<- string
    26  }
    27  
    28  var _ quic.TokenStore = &tokenStore{}
    29  
    30  func newTokenStore(gets, puts chan<- string) quic.TokenStore {
    31  	return &tokenStore{
    32  		store: quic.NewLRUTokenStore(10, 4),
    33  		gets:  gets,
    34  		puts:  puts,
    35  	}
    36  }
    37  
    38  func (c *tokenStore) Put(key string, token *quic.ClientToken) {
    39  	c.puts <- key
    40  	c.store.Put(key, token)
    41  }
    42  
    43  func (c *tokenStore) Pop(key string) *quic.ClientToken {
    44  	c.gets <- key
    45  	return c.store.Pop(key)
    46  }
    47  
    48  var _ = Describe("Handshake tests", func() {
    49  	var (
    50  		server        *quic.Listener
    51  		serverConfig  *quic.Config
    52  		acceptStopped chan struct{}
    53  	)
    54  
    55  	BeforeEach(func() {
    56  		server = nil
    57  		acceptStopped = make(chan struct{})
    58  		serverConfig = getQuicConfig(nil)
    59  	})
    60  
    61  	AfterEach(func() {
    62  		if server != nil {
    63  			server.Close()
    64  			<-acceptStopped
    65  		}
    66  	})
    67  
    68  	runServer := func(tlsConf *tls.Config) {
    69  		var err error
    70  		// start the server
    71  		server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig)
    72  		Expect(err).ToNot(HaveOccurred())
    73  
    74  		go func() {
    75  			defer GinkgoRecover()
    76  			defer close(acceptStopped)
    77  			for {
    78  				if _, err := server.Accept(context.Background()); err != nil {
    79  					return
    80  				}
    81  			}
    82  		}()
    83  	}
    84  
    85  	It("returns the context cancellation error on timeouts", func() {
    86  		ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(20*time.Millisecond))
    87  		defer cancel()
    88  		errChan := make(chan error, 1)
    89  		go func() {
    90  			_, err := quic.DialAddr(
    91  				ctx,
    92  				"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
    93  				getTLSClientConfig(),
    94  				getQuicConfig(nil),
    95  			)
    96  			errChan <- err
    97  		}()
    98  
    99  		var err error
   100  		Eventually(errChan).Should(Receive(&err))
   101  		Expect(err).To(HaveOccurred())
   102  		Expect(err).To(MatchError(context.DeadlineExceeded))
   103  	})
   104  
   105  	It("returns the cancellation reason when a dial is canceled", func() {
   106  		ctx, cancel := context.WithCancelCause(context.Background())
   107  		errChan := make(chan error, 1)
   108  		go func() {
   109  			_, err := quic.DialAddr(
   110  				ctx,
   111  				"localhost:1234", // nobody is listening on this port, but we're going to cancel this dial anyway
   112  				getTLSClientConfig(),
   113  				getQuicConfig(nil),
   114  			)
   115  			errChan <- err
   116  		}()
   117  
   118  		cancel(errors.New("application cancelled"))
   119  		var err error
   120  		Eventually(errChan).Should(Receive(&err))
   121  		Expect(err).To(HaveOccurred())
   122  		Expect(err).To(MatchError("application cancelled"))
   123  	})
   124  
   125  	Context("using different cipher suites", func() {
   126  		for n, id := range map[string]uint16{
   127  			"TLS_AES_128_GCM_SHA256":       tls.TLS_AES_128_GCM_SHA256,
   128  			"TLS_AES_256_GCM_SHA384":       tls.TLS_AES_256_GCM_SHA384,
   129  			"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
   130  		} {
   131  			name := n
   132  			suiteID := id
   133  
   134  			It(fmt.Sprintf("using %s", name), func() {
   135  				reset := qtls.SetCipherSuite(suiteID)
   136  				defer reset()
   137  
   138  				tlsConf := getTLSConfig()
   139  				ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig)
   140  				Expect(err).ToNot(HaveOccurred())
   141  				defer ln.Close()
   142  
   143  				go func() {
   144  					defer GinkgoRecover()
   145  					conn, err := ln.Accept(context.Background())
   146  					Expect(err).ToNot(HaveOccurred())
   147  					str, err := conn.OpenStream()
   148  					Expect(err).ToNot(HaveOccurred())
   149  					defer str.Close()
   150  					_, err = str.Write(PRData)
   151  					Expect(err).ToNot(HaveOccurred())
   152  				}()
   153  
   154  				conn, err := quic.DialAddr(
   155  					context.Background(),
   156  					fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   157  					getTLSClientConfig(),
   158  					getQuicConfig(nil),
   159  				)
   160  				Expect(err).ToNot(HaveOccurred())
   161  				str, err := conn.AcceptStream(context.Background())
   162  				Expect(err).ToNot(HaveOccurred())
   163  				data, err := io.ReadAll(str)
   164  				Expect(err).ToNot(HaveOccurred())
   165  				Expect(data).To(Equal(PRData))
   166  				Expect(conn.ConnectionState().TLS.CipherSuite).To(Equal(suiteID))
   167  				Expect(conn.CloseWithError(0, "")).To(Succeed())
   168  			})
   169  		}
   170  	})
   171  
   172  	Context("Certificate validation", func() {
   173  		It("accepts the certificate", func() {
   174  			runServer(getTLSConfig())
   175  			conn, err := quic.DialAddr(
   176  				context.Background(),
   177  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   178  				getTLSClientConfig(),
   179  				getQuicConfig(nil),
   180  			)
   181  			Expect(err).ToNot(HaveOccurred())
   182  			conn.CloseWithError(0, "")
   183  		})
   184  
   185  		It("has the right local and remote address on the tls.Config.GetConfigForClient ClientHelloInfo.Conn", func() {
   186  			var local, remote net.Addr
   187  			var local2, remote2 net.Addr
   188  			done := make(chan struct{})
   189  			tlsConf := &tls.Config{
   190  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   191  					local = info.Conn.LocalAddr()
   192  					remote = info.Conn.RemoteAddr()
   193  					conf := getTLSConfig()
   194  					conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   195  						defer close(done)
   196  						local2 = info.Conn.LocalAddr()
   197  						remote2 = info.Conn.RemoteAddr()
   198  						return &(conf.Certificates[0]), nil
   199  					}
   200  					return conf, nil
   201  				},
   202  			}
   203  			runServer(tlsConf)
   204  			conn, err := quic.DialAddr(
   205  				context.Background(),
   206  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   207  				getTLSClientConfig(),
   208  				getQuicConfig(nil),
   209  			)
   210  			Expect(err).ToNot(HaveOccurred())
   211  			defer conn.CloseWithError(0, "")
   212  			Eventually(done).Should(BeClosed())
   213  			Expect(server.Addr()).To(Equal(local))
   214  			Expect(conn.LocalAddr().(*net.UDPAddr).Port).To(Equal(remote.(*net.UDPAddr).Port))
   215  			Expect(local).To(Equal(local2))
   216  			Expect(remote).To(Equal(remote2))
   217  		})
   218  
   219  		It("works with a long certificate chain", func() {
   220  			runServer(getTLSConfigWithLongCertChain())
   221  			conn, err := quic.DialAddr(
   222  				context.Background(),
   223  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   224  				getTLSClientConfig(),
   225  				getQuicConfig(nil),
   226  			)
   227  			Expect(err).ToNot(HaveOccurred())
   228  			conn.CloseWithError(0, "")
   229  		})
   230  
   231  		It("errors if the server name doesn't match", func() {
   232  			runServer(getTLSConfig())
   233  			conn, err := net.ListenUDP("udp", nil)
   234  			Expect(err).ToNot(HaveOccurred())
   235  			conf := getTLSClientConfig()
   236  			conf.ServerName = "foo.bar"
   237  			_, err = quic.Dial(
   238  				context.Background(),
   239  				conn,
   240  				server.Addr(),
   241  				conf,
   242  				getQuicConfig(nil),
   243  			)
   244  			Expect(err).To(HaveOccurred())
   245  			var transportErr *quic.TransportError
   246  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   247  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   248  			Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
   249  			var certErr *tls.CertificateVerificationError
   250  			Expect(errors.As(transportErr, &certErr)).To(BeTrue())
   251  		})
   252  
   253  		It("fails the handshake if the client fails to provide the requested client cert", func() {
   254  			tlsConf := getTLSConfig()
   255  			tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
   256  			runServer(tlsConf)
   257  
   258  			conn, err := quic.DialAddr(
   259  				context.Background(),
   260  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   261  				getTLSClientConfig(),
   262  				getQuicConfig(nil),
   263  			)
   264  			// Usually, the error will occur after the client already finished the handshake.
   265  			// However, there's a race condition here. The server's CONNECTION_CLOSE might be
   266  			// received before the connection is returned, so we might already get the error while dialing.
   267  			if err == nil {
   268  				errChan := make(chan error)
   269  				go func() {
   270  					defer GinkgoRecover()
   271  					_, err := conn.AcceptStream(context.Background())
   272  					errChan <- err
   273  				}()
   274  				Eventually(errChan).Should(Receive(&err))
   275  			}
   276  			Expect(err).To(HaveOccurred())
   277  			var transportErr *quic.TransportError
   278  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   279  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   280  			Expect(transportErr.Error()).To(Or(
   281  				ContainSubstring("tls: certificate required"),
   282  				ContainSubstring("tls: bad certificate"),
   283  			))
   284  		})
   285  
   286  		It("uses the ServerName in the tls.Config", func() {
   287  			runServer(getTLSConfig())
   288  			tlsConf := getTLSClientConfig()
   289  			tlsConf.ServerName = "foo.bar"
   290  			_, err := quic.DialAddr(
   291  				context.Background(),
   292  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   293  				tlsConf,
   294  				getQuicConfig(nil),
   295  			)
   296  			Expect(err).To(HaveOccurred())
   297  			var transportErr *quic.TransportError
   298  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   299  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   300  			Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar"))
   301  		})
   302  	})
   303  
   304  	Context("queuening and accepting connections", func() {
   305  		var (
   306  			server *quic.Listener
   307  			pconn  net.PacketConn
   308  			dialer *quic.Transport
   309  		)
   310  
   311  		dial := func() (quic.Connection, error) {
   312  			remoteAddr := fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port)
   313  			raddr, err := net.ResolveUDPAddr("udp", remoteAddr)
   314  			Expect(err).ToNot(HaveOccurred())
   315  			return dialer.Dial(context.Background(), raddr, getTLSClientConfig(), getQuicConfig(nil))
   316  		}
   317  
   318  		BeforeEach(func() {
   319  			var err error
   320  			// start the server, but don't call Accept
   321  			server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   322  			Expect(err).ToNot(HaveOccurred())
   323  
   324  			// prepare a (single) packet conn for dialing to the server
   325  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   326  			Expect(err).ToNot(HaveOccurred())
   327  			pconn, err = net.ListenUDP("udp", laddr)
   328  			Expect(err).ToNot(HaveOccurred())
   329  			dialer = &quic.Transport{
   330  				Conn:               pconn,
   331  				ConnectionIDLength: 4,
   332  			}
   333  		})
   334  
   335  		AfterEach(func() {
   336  			Expect(server.Close()).To(Succeed())
   337  			Expect(pconn.Close()).To(Succeed())
   338  			Expect(dialer.Close()).To(Succeed())
   339  		})
   340  
   341  		It("rejects new connection attempts if connections don't get accepted", func() {
   342  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   343  				conn, err := dial()
   344  				Expect(err).ToNot(HaveOccurred())
   345  				defer conn.CloseWithError(0, "")
   346  			}
   347  			time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
   348  
   349  			conn, err := dial()
   350  			Expect(err).ToNot(HaveOccurred())
   351  			ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   352  			defer cancel()
   353  			_, err = conn.AcceptStream(ctx)
   354  			var transportErr *quic.TransportError
   355  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   356  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   357  
   358  			// now accept one connection, freeing one spot in the queue
   359  			_, err = server.Accept(context.Background())
   360  			Expect(err).ToNot(HaveOccurred())
   361  			// dial again, and expect that this dial succeeds
   362  			conn2, err := dial()
   363  			Expect(err).ToNot(HaveOccurred())
   364  			defer conn2.CloseWithError(0, "")
   365  			time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued
   366  
   367  			conn3, err := dial()
   368  			Expect(err).ToNot(HaveOccurred())
   369  			ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
   370  			defer cancel()
   371  			_, err = conn3.AcceptStream(ctx)
   372  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   373  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   374  		})
   375  
   376  		It("also returns closed connections from the accept queue", func() {
   377  			firstConn, err := dial()
   378  			Expect(err).ToNot(HaveOccurred())
   379  
   380  			for i := 1; i < protocol.MaxAcceptQueueSize; i++ {
   381  				conn, err := dial()
   382  				Expect(err).ToNot(HaveOccurred())
   383  				defer conn.CloseWithError(0, "")
   384  			}
   385  			time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
   386  
   387  			conn, err := dial()
   388  			Expect(err).ToNot(HaveOccurred())
   389  			ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
   390  			defer cancel()
   391  			_, err = conn.AcceptStream(ctx)
   392  			var transportErr *quic.TransportError
   393  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   394  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   395  
   396  			// Now close the one of the connection that are waiting to be accepted.
   397  			const appErrCode quic.ApplicationErrorCode = 12345
   398  			Expect(firstConn.CloseWithError(appErrCode, ""))
   399  			Eventually(firstConn.Context().Done()).Should(BeClosed())
   400  			time.Sleep(scaleDuration(200 * time.Millisecond))
   401  
   402  			// dial again, and expect that this fails again
   403  			conn2, err := dial()
   404  			Expect(err).ToNot(HaveOccurred())
   405  			ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
   406  			defer cancel()
   407  			_, err = conn2.AcceptStream(ctx)
   408  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   409  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   410  
   411  			// now accept all connections
   412  			var closedConn quic.Connection
   413  			for i := 0; i < protocol.MaxAcceptQueueSize; i++ {
   414  				conn, err := server.Accept(context.Background())
   415  				Expect(err).ToNot(HaveOccurred())
   416  				if conn.Context().Err() != nil {
   417  					if closedConn != nil {
   418  						Fail("only expected a single closed connection")
   419  					}
   420  					closedConn = conn
   421  				}
   422  			}
   423  			Expect(closedConn).ToNot(BeNil()) // there should be exactly one closed connection
   424  			_, err = closedConn.AcceptStream(context.Background())
   425  			var appErr *quic.ApplicationError
   426  			Expect(errors.As(err, &appErr)).To(BeTrue())
   427  			Expect(appErr.ErrorCode).To(Equal(appErrCode))
   428  		})
   429  
   430  		It("closes handshaking connections when the server is closed", func() {
   431  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   432  			Expect(err).ToNot(HaveOccurred())
   433  			udpConn, err := net.ListenUDP("udp", laddr)
   434  			Expect(err).ToNot(HaveOccurred())
   435  			tr := &quic.Transport{Conn: udpConn}
   436  			addTracer(tr)
   437  			defer tr.Close()
   438  			tlsConf := &tls.Config{}
   439  			done := make(chan struct{})
   440  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   441  				<-done
   442  				return nil, errors.New("closed")
   443  			}
   444  			ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
   445  			Expect(err).ToNot(HaveOccurred())
   446  
   447  			errChan := make(chan error, 1)
   448  			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   449  			defer cancel()
   450  			go func() {
   451  				defer GinkgoRecover()
   452  				_, err := quic.DialAddr(ctx, ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
   453  				errChan <- err
   454  			}()
   455  			time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued
   456  			Expect(ln.Close()).To(Succeed())
   457  			close(done)
   458  			err = <-errChan
   459  			var transportErr *quic.TransportError
   460  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   461  			Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
   462  		})
   463  	})
   464  
   465  	Context("ALPN", func() {
   466  		It("negotiates an application protocol", func() {
   467  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   468  			Expect(err).ToNot(HaveOccurred())
   469  
   470  			done := make(chan struct{})
   471  			go func() {
   472  				defer GinkgoRecover()
   473  				conn, err := ln.Accept(context.Background())
   474  				Expect(err).ToNot(HaveOccurred())
   475  				cs := conn.ConnectionState()
   476  				Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
   477  				close(done)
   478  			}()
   479  
   480  			conn, err := quic.DialAddr(
   481  				context.Background(),
   482  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   483  				getTLSClientConfig(),
   484  				nil,
   485  			)
   486  			Expect(err).ToNot(HaveOccurred())
   487  			defer conn.CloseWithError(0, "")
   488  			cs := conn.ConnectionState()
   489  			Expect(cs.TLS.NegotiatedProtocol).To(Equal(alpn))
   490  			Eventually(done).Should(BeClosed())
   491  			Expect(ln.Close()).To(Succeed())
   492  		})
   493  
   494  		It("errors if application protocol negotiation fails", func() {
   495  			runServer(getTLSConfig())
   496  
   497  			tlsConf := getTLSClientConfig()
   498  			tlsConf.NextProtos = []string{"foobar"}
   499  			_, err := quic.DialAddr(
   500  				context.Background(),
   501  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   502  				tlsConf,
   503  				nil,
   504  			)
   505  			Expect(err).To(HaveOccurred())
   506  			var transportErr *quic.TransportError
   507  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   508  			Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue())
   509  			Expect(transportErr.Error()).To(ContainSubstring("no application protocol"))
   510  		})
   511  	})
   512  
   513  	Context("using tokens", func() {
   514  		It("uses tokens provided in NEW_TOKEN frames", func() {
   515  			server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   516  			Expect(err).ToNot(HaveOccurred())
   517  			defer server.Close()
   518  
   519  			// dial the first connection and receive the token
   520  			go func() {
   521  				defer GinkgoRecover()
   522  				_, err := server.Accept(context.Background())
   523  				Expect(err).ToNot(HaveOccurred())
   524  			}()
   525  
   526  			gets := make(chan string, 100)
   527  			puts := make(chan string, 100)
   528  			tokenStore := newTokenStore(gets, puts)
   529  			quicConf := getQuicConfig(&quic.Config{TokenStore: tokenStore})
   530  			conn, err := quic.DialAddr(
   531  				context.Background(),
   532  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   533  				getTLSClientConfig(),
   534  				quicConf,
   535  			)
   536  			Expect(err).ToNot(HaveOccurred())
   537  			Expect(gets).To(Receive())
   538  			Eventually(puts).Should(Receive())
   539  			// received a token. Close this connection.
   540  			Expect(conn.CloseWithError(0, "")).To(Succeed())
   541  
   542  			// dial the second connection and verify that the token was used
   543  			done := make(chan struct{})
   544  			go func() {
   545  				defer GinkgoRecover()
   546  				defer close(done)
   547  				_, err := server.Accept(context.Background())
   548  				Expect(err).ToNot(HaveOccurred())
   549  			}()
   550  			conn, err = quic.DialAddr(
   551  				context.Background(),
   552  				fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   553  				getTLSClientConfig(),
   554  				quicConf,
   555  			)
   556  			Expect(err).ToNot(HaveOccurred())
   557  			defer conn.CloseWithError(0, "")
   558  			Expect(gets).To(Receive())
   559  
   560  			Eventually(done).Should(BeClosed())
   561  		})
   562  
   563  		It("rejects invalid Retry token with the INVALID_TOKEN error", func() {
   564  			const rtt = 10 * time.Millisecond
   565  
   566  			// The validity period of the retry token is the handshake timeout,
   567  			// which is twice the handshake idle timeout.
   568  			// By setting the handshake timeout shorter than the RTT, the token will have expired by the time
   569  			// it reaches the server.
   570  			serverConfig.HandshakeIdleTimeout = rtt / 5
   571  
   572  			laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
   573  			Expect(err).ToNot(HaveOccurred())
   574  			udpConn, err := net.ListenUDP("udp", laddr)
   575  			Expect(err).ToNot(HaveOccurred())
   576  			defer udpConn.Close()
   577  			tr := &quic.Transport{
   578  				Conn:                udpConn,
   579  				VerifySourceAddress: func(net.Addr) bool { return true },
   580  			}
   581  			addTracer(tr)
   582  			defer tr.Close()
   583  			server, err := tr.Listen(getTLSConfig(), serverConfig)
   584  			Expect(err).ToNot(HaveOccurred())
   585  			defer server.Close()
   586  
   587  			serverPort := server.Addr().(*net.UDPAddr).Port
   588  			proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
   589  				RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
   590  				DelayPacket: func(quicproxy.Direction, []byte) time.Duration {
   591  					return rtt / 2
   592  				},
   593  			})
   594  			Expect(err).ToNot(HaveOccurred())
   595  			defer proxy.Close()
   596  
   597  			_, err = quic.DialAddr(
   598  				context.Background(),
   599  				fmt.Sprintf("localhost:%d", proxy.LocalPort()),
   600  				getTLSClientConfig(),
   601  				nil,
   602  			)
   603  			Expect(err).To(HaveOccurred())
   604  			var transportErr *quic.TransportError
   605  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   606  			Expect(transportErr.ErrorCode).To(Equal(quic.InvalidToken))
   607  		})
   608  	})
   609  
   610  	Context("GetConfigForClient", func() {
   611  		It("uses the quic.Config returned by GetConfigForClient", func() {
   612  			serverConfig.EnableDatagrams = false
   613  			var calledFrom net.Addr
   614  			serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
   615  				conf := serverConfig.Clone()
   616  				conf.EnableDatagrams = true
   617  				calledFrom = info.RemoteAddr
   618  				return getQuicConfig(conf), nil
   619  			}
   620  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   621  			Expect(err).ToNot(HaveOccurred())
   622  
   623  			done := make(chan struct{})
   624  			go func() {
   625  				defer GinkgoRecover()
   626  				_, err := ln.Accept(context.Background())
   627  				Expect(err).ToNot(HaveOccurred())
   628  				close(done)
   629  			}()
   630  
   631  			conn, err := quic.DialAddr(
   632  				context.Background(),
   633  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   634  				getTLSClientConfig(),
   635  				getQuicConfig(&quic.Config{EnableDatagrams: true}),
   636  			)
   637  			Expect(err).ToNot(HaveOccurred())
   638  			defer conn.CloseWithError(0, "")
   639  			cs := conn.ConnectionState()
   640  			Expect(cs.SupportsDatagrams).To(BeTrue())
   641  			Eventually(done).Should(BeClosed())
   642  			Expect(ln.Close()).To(Succeed())
   643  			Expect(calledFrom.(*net.UDPAddr).Port).To(Equal(conn.LocalAddr().(*net.UDPAddr).Port))
   644  		})
   645  
   646  		It("rejects the connection attempt if GetConfigForClient errors", func() {
   647  			serverConfig.EnableDatagrams = false
   648  			serverConfig.GetConfigForClient = func(info *quic.ClientHelloInfo) (*quic.Config, error) {
   649  				return nil, errors.New("rejected")
   650  			}
   651  			ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
   652  			Expect(err).ToNot(HaveOccurred())
   653  			defer ln.Close()
   654  
   655  			done := make(chan struct{})
   656  			go func() {
   657  				defer GinkgoRecover()
   658  				_, err := ln.Accept(context.Background())
   659  				Expect(err).To(HaveOccurred()) // we don't expect to accept any connection
   660  				close(done)
   661  			}()
   662  
   663  			_, err = quic.DialAddr(
   664  				context.Background(),
   665  				fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
   666  				getTLSClientConfig(),
   667  				getQuicConfig(&quic.Config{EnableDatagrams: true}),
   668  			)
   669  			Expect(err).To(HaveOccurred())
   670  			var transportErr *quic.TransportError
   671  			Expect(errors.As(err, &transportErr)).To(BeTrue())
   672  			Expect(transportErr.ErrorCode).To(Equal(qerr.ConnectionRefused))
   673  		})
   674  	})
   675  
   676  	It("doesn't send any packets when generating the ClientHello fails", func() {
   677  		ln, err := net.ListenUDP("udp", nil)
   678  		Expect(err).ToNot(HaveOccurred())
   679  		done := make(chan struct{})
   680  		packetChan := make(chan struct{})
   681  		go func() {
   682  			defer GinkgoRecover()
   683  			defer close(done)
   684  			for {
   685  				_, _, err := ln.ReadFromUDP(make([]byte, protocol.MaxPacketBufferSize))
   686  				if err != nil {
   687  					return
   688  				}
   689  				packetChan <- struct{}{}
   690  			}
   691  		}()
   692  
   693  		tlsConf := getTLSClientConfig()
   694  		tlsConf.NextProtos = []string{""}
   695  		_, err = quic.DialAddr(
   696  			context.Background(),
   697  			fmt.Sprintf("localhost:%d", ln.LocalAddr().(*net.UDPAddr).Port),
   698  			tlsConf,
   699  			nil,
   700  		)
   701  		Expect(err).To(MatchError(&qerr.TransportError{
   702  			ErrorCode:    qerr.InternalError,
   703  			ErrorMessage: "tls: invalid NextProtos value",
   704  		}))
   705  		Consistently(packetChan).ShouldNot(Receive())
   706  		ln.Close()
   707  		Eventually(done).Should(BeClosed())
   708  	})
   709  })