github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/internal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"net"
    11  	"reflect"
    12  	"time"
    13  
    14  	mocktls "github.com/apernet/quic-go/internal/mocks/tls"
    15  	"github.com/apernet/quic-go/internal/protocol"
    16  	"github.com/apernet/quic-go/internal/qerr"
    17  	"github.com/apernet/quic-go/internal/testdata"
    18  	"github.com/apernet/quic-go/internal/utils"
    19  	"github.com/apernet/quic-go/internal/wire"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	"go.uber.org/mock/gomock"
    24  )
    25  
    26  const (
    27  	typeClientHello      = 1
    28  	typeNewSessionTicket = 4
    29  )
    30  
    31  var _ = Describe("Crypto Setup TLS", func() {
    32  	generateCert := func() tls.Certificate {
    33  		priv, err := rsa.GenerateKey(rand.Reader, 2048)
    34  		Expect(err).ToNot(HaveOccurred())
    35  		tmpl := &x509.Certificate{
    36  			SerialNumber:          big.NewInt(1),
    37  			Subject:               pkix.Name{},
    38  			SignatureAlgorithm:    x509.SHA256WithRSA,
    39  			NotBefore:             time.Now(),
    40  			NotAfter:              time.Now().Add(time.Hour), // valid for an hour
    41  			BasicConstraintsValid: true,
    42  		}
    43  		certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
    44  		Expect(err).ToNot(HaveOccurred())
    45  		return tls.Certificate{
    46  			PrivateKey:  priv,
    47  			Certificate: [][]byte{certDER},
    48  		}
    49  	}
    50  
    51  	var clientConf, serverConf *tls.Config
    52  
    53  	BeforeEach(func() {
    54  		serverConf = testdata.GetTLSConfig()
    55  		serverConf.NextProtos = []string{"crypto-setup"}
    56  		clientConf = &tls.Config{
    57  			ServerName: "localhost",
    58  			RootCAs:    testdata.GetRootCA(),
    59  			NextProtos: []string{"crypto-setup"},
    60  		}
    61  	})
    62  
    63  	It("handles qtls errors occurring before during ClientHello generation", func() {
    64  		tlsConf := testdata.GetTLSConfig()
    65  		tlsConf.InsecureSkipVerify = true
    66  		tlsConf.NextProtos = []string{""}
    67  		cl := NewCryptoSetupClient(
    68  			protocol.ConnectionID{},
    69  			&wire.TransportParameters{},
    70  			tlsConf,
    71  			false,
    72  			&utils.RTTStats{},
    73  			nil,
    74  			utils.DefaultLogger.WithPrefix("client"),
    75  			protocol.Version1,
    76  		)
    77  
    78  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
    79  			ErrorCode:    qerr.InternalError,
    80  			ErrorMessage: "tls: invalid NextProtos value",
    81  		}))
    82  	})
    83  
    84  	It("errors when a message is received at the wrong encryption level", func() {
    85  		var token protocol.StatelessResetToken
    86  		server := NewCryptoSetupServer(
    87  			protocol.ConnectionID{},
    88  			&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
    89  			&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
    90  			&wire.TransportParameters{StatelessResetToken: &token},
    91  			testdata.GetTLSConfig(),
    92  			false,
    93  			&utils.RTTStats{},
    94  			nil,
    95  			utils.DefaultLogger.WithPrefix("server"),
    96  			protocol.Version1,
    97  		)
    98  
    99  		Expect(server.StartHandshake()).To(Succeed())
   100  
   101  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   102  		// wrong encryption level
   103  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   106  	})
   107  
   108  	Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
   109  		var (
   110  			local  = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
   111  			remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
   112  		)
   113  
   114  		It("wraps GetCertificate", func() {
   115  			var localAddr, remoteAddr net.Addr
   116  			tlsConf := &tls.Config{
   117  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   118  					localAddr = info.Conn.LocalAddr()
   119  					remoteAddr = info.Conn.RemoteAddr()
   120  					cert := generateCert()
   121  					return &cert, nil
   122  				},
   123  			}
   124  			addConnToClientHelloInfo(tlsConf, local, remote)
   125  			_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
   126  			Expect(err).ToNot(HaveOccurred())
   127  			Expect(localAddr).To(Equal(local))
   128  			Expect(remoteAddr).To(Equal(remote))
   129  		})
   130  
   131  		It("wraps GetConfigForClient", func() {
   132  			var localAddr, remoteAddr net.Addr
   133  			tlsConf := &tls.Config{
   134  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   135  					localAddr = info.Conn.LocalAddr()
   136  					remoteAddr = info.Conn.RemoteAddr()
   137  					return &tls.Config{}, nil
   138  				},
   139  			}
   140  			addConnToClientHelloInfo(tlsConf, local, remote)
   141  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   142  			Expect(err).ToNot(HaveOccurred())
   143  			Expect(localAddr).To(Equal(local))
   144  			Expect(remoteAddr).To(Equal(remote))
   145  			Expect(conf).ToNot(BeNil())
   146  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   147  		})
   148  
   149  		It("wraps GetConfigForClient, recursively", func() {
   150  			var localAddr, remoteAddr net.Addr
   151  			tlsConf := &tls.Config{}
   152  			var innerConf *tls.Config
   153  			getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
   154  				localAddr = info.Conn.LocalAddr()
   155  				remoteAddr = info.Conn.RemoteAddr()
   156  				cert := generateCert()
   157  				return &cert, nil
   158  			}
   159  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   160  				innerConf = tlsConf.Clone()
   161  				// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
   162  				innerConf.MaxVersion = tls.VersionTLS12
   163  				innerConf.GetCertificate = getCert
   164  				return innerConf, nil
   165  			}
   166  			addConnToClientHelloInfo(tlsConf, local, remote)
   167  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   168  			Expect(err).ToNot(HaveOccurred())
   169  			Expect(conf).ToNot(BeNil())
   170  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   171  			_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
   172  			Expect(err).ToNot(HaveOccurred())
   173  			Expect(localAddr).To(Equal(local))
   174  			Expect(remoteAddr).To(Equal(remote))
   175  			// make sure that the tls.Config returned by GetConfigForClient isn't modified
   176  			Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
   177  			Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
   178  		})
   179  	})
   180  
   181  	Context("doing the handshake", func() {
   182  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   183  			rttStats := &utils.RTTStats{}
   184  			rttStats.UpdateRTT(rtt, 0, time.Now())
   185  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   186  			return rttStats
   187  		}
   188  
   189  		// The clientEvents and serverEvents contain all events that were not processed by the function,
   190  		// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
   191  		handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
   192  			Expect(client.StartHandshake()).To(Succeed())
   193  			Expect(server.StartHandshake()).To(Succeed())
   194  
   195  			var clientHandshakeComplete, serverHandshakeComplete bool
   196  
   197  			for {
   198  			clientLoop:
   199  				for {
   200  					ev := client.NextEvent()
   201  					//nolint:exhaustive // only need to process a few events
   202  					switch ev.Kind {
   203  					case EventNoEvent:
   204  						break clientLoop
   205  					case EventWriteInitialData:
   206  						if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   207  							serverErr = err
   208  							return
   209  						}
   210  					case EventWriteHandshakeData:
   211  						if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   212  							serverErr = err
   213  							return
   214  						}
   215  					case EventHandshakeComplete:
   216  						clientHandshakeComplete = true
   217  					default:
   218  						clientEvents = append(clientEvents, ev)
   219  					}
   220  				}
   221  
   222  			serverLoop:
   223  				for {
   224  					ev := server.NextEvent()
   225  					//nolint:exhaustive // only need to process a few events
   226  					switch ev.Kind {
   227  					case EventNoEvent:
   228  						break serverLoop
   229  					case EventWriteInitialData:
   230  						if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   231  							clientErr = err
   232  							return
   233  						}
   234  					case EventWriteHandshakeData:
   235  						if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   236  							clientErr = err
   237  							return
   238  						}
   239  					case EventHandshakeComplete:
   240  						serverHandshakeComplete = true
   241  						ticket, err := server.GetSessionTicket()
   242  						Expect(err).ToNot(HaveOccurred())
   243  						if ticket != nil {
   244  							Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   245  						}
   246  					default:
   247  						serverEvents = append(serverEvents, ev)
   248  					}
   249  				}
   250  
   251  				if clientHandshakeComplete && serverHandshakeComplete {
   252  					break
   253  				}
   254  			}
   255  			return
   256  		}
   257  
   258  		handshakeWithTLSConf := func(
   259  			clientConf, serverConf *tls.Config,
   260  			clientRTTStats, serverRTTStats *utils.RTTStats,
   261  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   262  			enable0RTT bool,
   263  		) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
   264  			CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
   265  		) {
   266  			client := NewCryptoSetupClient(
   267  				protocol.ConnectionID{},
   268  				clientTransportParameters,
   269  				clientConf,
   270  				enable0RTT,
   271  				clientRTTStats,
   272  				nil,
   273  				utils.DefaultLogger.WithPrefix("client"),
   274  				protocol.Version1,
   275  			)
   276  
   277  			if serverTransportParameters.StatelessResetToken == nil {
   278  				var token protocol.StatelessResetToken
   279  				serverTransportParameters.StatelessResetToken = &token
   280  			}
   281  			server := NewCryptoSetupServer(
   282  				protocol.ConnectionID{},
   283  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   284  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   285  				serverTransportParameters,
   286  				serverConf,
   287  				enable0RTT,
   288  				serverRTTStats,
   289  				nil,
   290  				utils.DefaultLogger.WithPrefix("server"),
   291  				protocol.Version1,
   292  			)
   293  			cEvents, cErr, sEvents, sErr := handshake(client, server)
   294  			return client, cEvents, cErr, server, sEvents, sErr
   295  		}
   296  
   297  		It("handshakes", func() {
   298  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   299  				clientConf, serverConf,
   300  				&utils.RTTStats{}, &utils.RTTStats{},
   301  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   302  				false,
   303  			)
   304  			Expect(clientErr).ToNot(HaveOccurred())
   305  			Expect(serverErr).ToNot(HaveOccurred())
   306  		})
   307  
   308  		It("performs a HelloRetryRequst", func() {
   309  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   310  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   311  				clientConf, serverConf,
   312  				&utils.RTTStats{}, &utils.RTTStats{},
   313  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   314  				false,
   315  			)
   316  			Expect(clientErr).ToNot(HaveOccurred())
   317  			Expect(serverErr).ToNot(HaveOccurred())
   318  		})
   319  
   320  		It("handshakes with client auth", func() {
   321  			clientConf.Certificates = []tls.Certificate{generateCert()}
   322  			serverConf.ClientAuth = tls.RequireAnyClientCert
   323  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   324  				clientConf, serverConf,
   325  				&utils.RTTStats{}, &utils.RTTStats{},
   326  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   327  				false,
   328  			)
   329  			Expect(clientErr).ToNot(HaveOccurred())
   330  			Expect(serverErr).ToNot(HaveOccurred())
   331  		})
   332  
   333  		It("receives transport parameters", func() {
   334  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
   335  			client := NewCryptoSetupClient(
   336  				protocol.ConnectionID{},
   337  				cTransportParameters,
   338  				clientConf,
   339  				false,
   340  				&utils.RTTStats{},
   341  				nil,
   342  				utils.DefaultLogger.WithPrefix("client"),
   343  				protocol.Version1,
   344  			)
   345  
   346  			var token protocol.StatelessResetToken
   347  			sTransportParameters := &wire.TransportParameters{
   348  				MaxIdleTimeout:          1337 * time.Second,
   349  				StatelessResetToken:     &token,
   350  				ActiveConnectionIDLimit: 2,
   351  			}
   352  			server := NewCryptoSetupServer(
   353  				protocol.ConnectionID{},
   354  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   355  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   356  				sTransportParameters,
   357  				serverConf,
   358  				false,
   359  				&utils.RTTStats{},
   360  				nil,
   361  				utils.DefaultLogger.WithPrefix("server"),
   362  				protocol.Version1,
   363  			)
   364  
   365  			clientEvents, cErr, serverEvents, sErr := handshake(client, server)
   366  			Expect(cErr).ToNot(HaveOccurred())
   367  			Expect(sErr).ToNot(HaveOccurred())
   368  			var clientReceivedTransportParameters *wire.TransportParameters
   369  			for _, ev := range clientEvents {
   370  				if ev.Kind == EventReceivedTransportParameters {
   371  					clientReceivedTransportParameters = ev.TransportParameters
   372  				}
   373  			}
   374  			Expect(clientReceivedTransportParameters).ToNot(BeNil())
   375  			Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
   376  
   377  			var serverReceivedTransportParameters *wire.TransportParameters
   378  			for _, ev := range serverEvents {
   379  				if ev.Kind == EventReceivedTransportParameters {
   380  					serverReceivedTransportParameters = ev.TransportParameters
   381  				}
   382  			}
   383  			Expect(serverReceivedTransportParameters).ToNot(BeNil())
   384  			Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
   385  		})
   386  
   387  		Context("with session tickets", func() {
   388  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   389  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   390  					clientConf, serverConf,
   391  					&utils.RTTStats{}, &utils.RTTStats{},
   392  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   393  					false,
   394  				)
   395  				Expect(clientErr).ToNot(HaveOccurred())
   396  				Expect(serverErr).ToNot(HaveOccurred())
   397  
   398  				// inject an invalid session ticket
   399  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   400  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   401  				Expect(err).To(HaveOccurred())
   402  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   403  			})
   404  
   405  			It("errors when handling the NewSessionTicket fails", func() {
   406  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   407  					clientConf, serverConf,
   408  					&utils.RTTStats{}, &utils.RTTStats{},
   409  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   410  					false,
   411  				)
   412  				Expect(clientErr).ToNot(HaveOccurred())
   413  				Expect(serverErr).ToNot(HaveOccurred())
   414  
   415  				// inject an invalid session ticket
   416  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   417  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   418  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   419  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   420  			})
   421  
   422  			It("uses session resumption", func() {
   423  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   424  				var state *tls.ClientSessionState
   425  				receivedSessionTicket := make(chan struct{})
   426  				csc.EXPECT().Get(gomock.Any())
   427  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   428  					state = css
   429  					close(receivedSessionTicket)
   430  				})
   431  				clientConf.ClientSessionCache = csc
   432  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   433  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   434  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   435  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   436  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   437  					clientConf, serverConf,
   438  					clientOrigRTTStats, serverOrigRTTStats,
   439  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   440  					false,
   441  				)
   442  				Expect(clientErr).ToNot(HaveOccurred())
   443  				Expect(serverErr).ToNot(HaveOccurred())
   444  				Eventually(receivedSessionTicket).Should(BeClosed())
   445  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   446  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   447  
   448  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   449  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   450  				clientRTTStats := &utils.RTTStats{}
   451  				serverRTTStats := &utils.RTTStats{}
   452  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   453  					clientConf, serverConf,
   454  					clientRTTStats, serverRTTStats,
   455  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   456  					false,
   457  				)
   458  				Expect(clientErr).ToNot(HaveOccurred())
   459  				Expect(serverErr).ToNot(HaveOccurred())
   460  				Eventually(receivedSessionTicket).Should(BeClosed())
   461  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   462  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   463  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   464  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   465  			})
   466  
   467  			It("doesn't use session resumption if the server disabled it", func() {
   468  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   469  				var state *tls.ClientSessionState
   470  				receivedSessionTicket := make(chan struct{})
   471  				csc.EXPECT().Get(gomock.Any())
   472  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   473  					state = css
   474  					close(receivedSessionTicket)
   475  				})
   476  				clientConf.ClientSessionCache = csc
   477  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   478  					clientConf, serverConf,
   479  					&utils.RTTStats{}, &utils.RTTStats{},
   480  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   481  					false,
   482  				)
   483  				Expect(clientErr).ToNot(HaveOccurred())
   484  				Expect(serverErr).ToNot(HaveOccurred())
   485  				Eventually(receivedSessionTicket).Should(BeClosed())
   486  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   487  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   488  
   489  				serverConf.SessionTicketsDisabled = true
   490  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   491  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   492  					clientConf, serverConf,
   493  					&utils.RTTStats{}, &utils.RTTStats{},
   494  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   495  					false,
   496  				)
   497  				Expect(clientErr).ToNot(HaveOccurred())
   498  				Expect(serverErr).ToNot(HaveOccurred())
   499  				Eventually(receivedSessionTicket).Should(BeClosed())
   500  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   501  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   502  			})
   503  
   504  			It("uses 0-RTT", func() {
   505  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   506  				var state *tls.ClientSessionState
   507  				receivedSessionTicket := make(chan struct{})
   508  				csc.EXPECT().Get(gomock.Any())
   509  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   510  					state = css
   511  					close(receivedSessionTicket)
   512  				})
   513  				clientConf.ClientSessionCache = csc
   514  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   515  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   516  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   517  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   518  				const initialMaxData protocol.ByteCount = 1337
   519  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   520  					clientConf, serverConf,
   521  					clientOrigRTTStats, serverOrigRTTStats,
   522  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   523  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   524  					true,
   525  				)
   526  				Expect(clientErr).ToNot(HaveOccurred())
   527  				Expect(serverErr).ToNot(HaveOccurred())
   528  				Eventually(receivedSessionTicket).Should(BeClosed())
   529  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   530  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   531  
   532  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   533  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   534  
   535  				clientRTTStats := &utils.RTTStats{}
   536  				serverRTTStats := &utils.RTTStats{}
   537  				client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
   538  					clientConf, serverConf,
   539  					clientRTTStats, serverRTTStats,
   540  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   541  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   542  					true,
   543  				)
   544  				Expect(clientErr).ToNot(HaveOccurred())
   545  				Expect(serverErr).ToNot(HaveOccurred())
   546  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   547  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   548  
   549  				var tp *wire.TransportParameters
   550  				var clientReceived0RTTKeys bool
   551  				for _, ev := range clientEvents {
   552  					//nolint:exhaustive // only need to process a few events
   553  					switch ev.Kind {
   554  					case EventRestoredTransportParameters:
   555  						tp = ev.TransportParameters
   556  					case EventReceivedReadKeys:
   557  						clientReceived0RTTKeys = true
   558  					}
   559  				}
   560  				Expect(clientReceived0RTTKeys).To(BeTrue())
   561  				Expect(tp).ToNot(BeNil())
   562  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   563  
   564  				var serverReceived0RTTKeys bool
   565  				for _, ev := range serverEvents {
   566  					//nolint:exhaustive // only need to process a few events
   567  					switch ev.Kind {
   568  					case EventReceivedReadKeys:
   569  						serverReceived0RTTKeys = true
   570  					}
   571  				}
   572  				Expect(serverReceived0RTTKeys).To(BeTrue())
   573  
   574  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   575  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   576  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   577  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   578  			})
   579  
   580  			It("rejects 0-RTT, when the transport parameters changed", func() {
   581  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   582  				var state *tls.ClientSessionState
   583  				receivedSessionTicket := make(chan struct{})
   584  				csc.EXPECT().Get(gomock.Any())
   585  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   586  					state = css
   587  					close(receivedSessionTicket)
   588  				})
   589  				clientConf.ClientSessionCache = csc
   590  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   591  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   592  				const initialMaxData protocol.ByteCount = 1337
   593  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   594  					clientConf, serverConf,
   595  					clientOrigRTTStats, &utils.RTTStats{},
   596  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   597  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   598  					true,
   599  				)
   600  				Expect(clientErr).ToNot(HaveOccurred())
   601  				Expect(serverErr).ToNot(HaveOccurred())
   602  				Eventually(receivedSessionTicket).Should(BeClosed())
   603  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   604  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   605  
   606  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   607  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   608  
   609  				clientRTTStats := &utils.RTTStats{}
   610  				client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
   611  					clientConf, serverConf,
   612  					clientRTTStats, &utils.RTTStats{},
   613  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   614  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   615  					true,
   616  				)
   617  				Expect(clientErr).ToNot(HaveOccurred())
   618  				Expect(serverErr).ToNot(HaveOccurred())
   619  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   620  
   621  				var tp *wire.TransportParameters
   622  				var clientReceived0RTTKeys bool
   623  				for _, ev := range clientEvents {
   624  					//nolint:exhaustive // only need to process a few events
   625  					switch ev.Kind {
   626  					case EventRestoredTransportParameters:
   627  						tp = ev.TransportParameters
   628  					case EventReceivedReadKeys:
   629  						clientReceived0RTTKeys = true
   630  					}
   631  				}
   632  				Expect(clientReceived0RTTKeys).To(BeTrue())
   633  				Expect(tp).ToNot(BeNil())
   634  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   635  
   636  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   637  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   638  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   639  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   640  			})
   641  		})
   642  	})
   643  })