github.com/tumi8/quic-go@v0.37.4-tum/noninternal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"net"
    11  	"time"
    12  
    13  	mocktls "github.com/tumi8/quic-go/noninternal/mocks/tls"
    14  	"github.com/tumi8/quic-go/noninternal/protocol"
    15  	"github.com/tumi8/quic-go/noninternal/qerr"
    16  	"github.com/tumi8/quic-go/noninternal/testdata"
    17  	"github.com/tumi8/quic-go/noninternal/utils"
    18  	"github.com/tumi8/quic-go/noninternal/wire"
    19  
    20  	"github.com/golang/mock/gomock"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  )
    25  
    26  const (
    27  	typeClientHello      = 1
    28  	typeNewSessionTicket = 4
    29  )
    30  
    31  var _ = Describe("Crypto Setup TLS", func() {
    32  	generateCert := func() tls.Certificate {
    33  		priv, err := rsa.GenerateKey(rand.Reader, 2048)
    34  		Expect(err).ToNot(HaveOccurred())
    35  		tmpl := &x509.Certificate{
    36  			SerialNumber:          big.NewInt(1),
    37  			Subject:               pkix.Name{},
    38  			SignatureAlgorithm:    x509.SHA256WithRSA,
    39  			NotBefore:             time.Now(),
    40  			NotAfter:              time.Now().Add(time.Hour), // valid for an hour
    41  			BasicConstraintsValid: true,
    42  		}
    43  		certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
    44  		Expect(err).ToNot(HaveOccurred())
    45  		return tls.Certificate{
    46  			PrivateKey:  priv,
    47  			Certificate: [][]byte{certDER},
    48  		}
    49  	}
    50  
    51  	var clientConf, serverConf *tls.Config
    52  
    53  	BeforeEach(func() {
    54  		serverConf = testdata.GetTLSConfig()
    55  		serverConf.NextProtos = []string{"crypto-setup"}
    56  		clientConf = &tls.Config{
    57  			ServerName: "localhost",
    58  			RootCAs:    testdata.GetRootCA(),
    59  			NextProtos: []string{"crypto-setup"},
    60  		}
    61  	})
    62  
    63  	It("handles qtls errors occurring before during ClientHello generation", func() {
    64  		tlsConf := testdata.GetTLSConfig()
    65  		tlsConf.InsecureSkipVerify = true
    66  		tlsConf.NextProtos = []string{""}
    67  		cl := NewCryptoSetupClient(
    68  			protocol.ConnectionID{},
    69  			&wire.TransportParameters{},
    70  			tlsConf,
    71  			false,
    72  			&utils.RTTStats{},
    73  			nil,
    74  			utils.DefaultLogger.WithPrefix("client"),
    75  			protocol.Version1,
    76  		)
    77  
    78  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
    79  			ErrorCode:    qerr.InternalError,
    80  			ErrorMessage: "tls: invalid NextProtos value",
    81  		}))
    82  	})
    83  
    84  	It("errors when a message is received at the wrong encryption level", func() {
    85  		var token protocol.StatelessResetToken
    86  		server := NewCryptoSetupServer(
    87  			protocol.ConnectionID{},
    88  			&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
    89  			&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
    90  			&wire.TransportParameters{StatelessResetToken: &token},
    91  			testdata.GetTLSConfig(),
    92  			false,
    93  			&utils.RTTStats{},
    94  			nil,
    95  			utils.DefaultLogger.WithPrefix("server"),
    96  			protocol.Version1,
    97  		)
    98  
    99  		Expect(server.StartHandshake()).To(Succeed())
   100  
   101  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   102  		// wrong encryption level
   103  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   106  	})
   107  
   108  	Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
   109  		var (
   110  			local  = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
   111  			remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
   112  		)
   113  
   114  		It("wraps GetCertificate", func() {
   115  			var localAddr, remoteAddr net.Addr
   116  			tlsConf := &tls.Config{
   117  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   118  					localAddr = info.Conn.LocalAddr()
   119  					remoteAddr = info.Conn.RemoteAddr()
   120  					cert := generateCert()
   121  					return &cert, nil
   122  				},
   123  			}
   124  			addConnToClientHelloInfo(tlsConf, local, remote)
   125  			_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
   126  			Expect(err).ToNot(HaveOccurred())
   127  			Expect(localAddr).To(Equal(local))
   128  			Expect(remoteAddr).To(Equal(remote))
   129  		})
   130  
   131  		It("wraps GetConfigForClient", func() {
   132  			var localAddr, remoteAddr net.Addr
   133  			tlsConf := &tls.Config{
   134  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   135  					localAddr = info.Conn.LocalAddr()
   136  					remoteAddr = info.Conn.RemoteAddr()
   137  					return &tls.Config{}, nil
   138  				},
   139  			}
   140  			addConnToClientHelloInfo(tlsConf, local, remote)
   141  			_, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   142  			Expect(err).ToNot(HaveOccurred())
   143  			Expect(localAddr).To(Equal(local))
   144  			Expect(remoteAddr).To(Equal(remote))
   145  		})
   146  
   147  		It("wraps GetConfigForClient, recursively", func() {
   148  			var localAddr, remoteAddr net.Addr
   149  			tlsConf := &tls.Config{}
   150  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   151  				conf := tlsConf.Clone()
   152  				conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   153  					localAddr = info.Conn.LocalAddr()
   154  					remoteAddr = info.Conn.RemoteAddr()
   155  					cert := generateCert()
   156  					return &cert, nil
   157  				}
   158  				return conf, nil
   159  			}
   160  			addConnToClientHelloInfo(tlsConf, local, remote)
   161  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   162  			Expect(err).ToNot(HaveOccurred())
   163  			_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
   164  			Expect(err).ToNot(HaveOccurred())
   165  			Expect(localAddr).To(Equal(local))
   166  			Expect(remoteAddr).To(Equal(remote))
   167  		})
   168  	})
   169  
   170  	Context("doing the handshake", func() {
   171  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   172  			rttStats := &utils.RTTStats{}
   173  			rttStats.UpdateRTT(rtt, 0, time.Now())
   174  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   175  			return rttStats
   176  		}
   177  
   178  		// The clientEvents and serverEvents contain all events that were not processed by the function,
   179  		// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
   180  		handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
   181  			Expect(client.StartHandshake()).To(Succeed())
   182  			Expect(server.StartHandshake()).To(Succeed())
   183  
   184  			var clientHandshakeComplete, serverHandshakeComplete bool
   185  
   186  			for {
   187  			clientLoop:
   188  				for {
   189  					ev := client.NextEvent()
   190  					//nolint:exhaustive // only need to process a few events
   191  					switch ev.Kind {
   192  					case EventNoEvent:
   193  						break clientLoop
   194  					case EventWriteInitialData:
   195  						if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   196  							serverErr = err
   197  							return
   198  						}
   199  					case EventWriteHandshakeData:
   200  						if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   201  							serverErr = err
   202  							return
   203  						}
   204  					case EventHandshakeComplete:
   205  						clientHandshakeComplete = true
   206  					default:
   207  						clientEvents = append(clientEvents, ev)
   208  					}
   209  				}
   210  
   211  			serverLoop:
   212  				for {
   213  					ev := server.NextEvent()
   214  					//nolint:exhaustive // only need to process a few events
   215  					switch ev.Kind {
   216  					case EventNoEvent:
   217  						break serverLoop
   218  					case EventWriteInitialData:
   219  						if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   220  							clientErr = err
   221  							return
   222  						}
   223  					case EventWriteHandshakeData:
   224  						if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   225  							clientErr = err
   226  							return
   227  						}
   228  					case EventHandshakeComplete:
   229  						serverHandshakeComplete = true
   230  						ticket, err := server.GetSessionTicket()
   231  						Expect(err).ToNot(HaveOccurred())
   232  						if ticket != nil {
   233  							Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   234  						}
   235  					default:
   236  						serverEvents = append(serverEvents, ev)
   237  					}
   238  				}
   239  
   240  				if clientHandshakeComplete && serverHandshakeComplete {
   241  					break
   242  				}
   243  			}
   244  			return
   245  		}
   246  
   247  		handshakeWithTLSConf := func(
   248  			clientConf, serverConf *tls.Config,
   249  			clientRTTStats, serverRTTStats *utils.RTTStats,
   250  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   251  			enable0RTT bool,
   252  		) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
   253  			CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
   254  		) {
   255  			client := NewCryptoSetupClient(
   256  				protocol.ConnectionID{},
   257  				clientTransportParameters,
   258  				clientConf,
   259  				enable0RTT,
   260  				clientRTTStats,
   261  				nil,
   262  				utils.DefaultLogger.WithPrefix("client"),
   263  				protocol.Version1,
   264  			)
   265  
   266  			if serverTransportParameters.StatelessResetToken == nil {
   267  				var token protocol.StatelessResetToken
   268  				serverTransportParameters.StatelessResetToken = &token
   269  			}
   270  			server := NewCryptoSetupServer(
   271  				protocol.ConnectionID{},
   272  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   273  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   274  				serverTransportParameters,
   275  				serverConf,
   276  				enable0RTT,
   277  				serverRTTStats,
   278  				nil,
   279  				utils.DefaultLogger.WithPrefix("server"),
   280  				protocol.Version1,
   281  			)
   282  			cEvents, cErr, sEvents, sErr := handshake(client, server)
   283  			return client, cEvents, cErr, server, sEvents, sErr
   284  		}
   285  
   286  		It("handshakes", func() {
   287  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   288  				clientConf, serverConf,
   289  				&utils.RTTStats{}, &utils.RTTStats{},
   290  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   291  				false,
   292  			)
   293  			Expect(clientErr).ToNot(HaveOccurred())
   294  			Expect(serverErr).ToNot(HaveOccurred())
   295  		})
   296  
   297  		It("performs a HelloRetryRequst", func() {
   298  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   299  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   300  				clientConf, serverConf,
   301  				&utils.RTTStats{}, &utils.RTTStats{},
   302  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   303  				false,
   304  			)
   305  			Expect(clientErr).ToNot(HaveOccurred())
   306  			Expect(serverErr).ToNot(HaveOccurred())
   307  		})
   308  
   309  		It("handshakes with client auth", func() {
   310  			clientConf.Certificates = []tls.Certificate{generateCert()}
   311  			serverConf.ClientAuth = tls.RequireAnyClientCert
   312  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   313  				clientConf, serverConf,
   314  				&utils.RTTStats{}, &utils.RTTStats{},
   315  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   316  				false,
   317  			)
   318  			Expect(clientErr).ToNot(HaveOccurred())
   319  			Expect(serverErr).ToNot(HaveOccurred())
   320  		})
   321  
   322  		It("receives transport parameters", func() {
   323  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
   324  			client := NewCryptoSetupClient(
   325  				protocol.ConnectionID{},
   326  				cTransportParameters,
   327  				clientConf,
   328  				false,
   329  				&utils.RTTStats{},
   330  				nil,
   331  				utils.DefaultLogger.WithPrefix("client"),
   332  				protocol.Version1,
   333  			)
   334  
   335  			var token protocol.StatelessResetToken
   336  			sTransportParameters := &wire.TransportParameters{
   337  				MaxIdleTimeout:          1337 * time.Second,
   338  				StatelessResetToken:     &token,
   339  				ActiveConnectionIDLimit: 2,
   340  			}
   341  			server := NewCryptoSetupServer(
   342  				protocol.ConnectionID{},
   343  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   344  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   345  				sTransportParameters,
   346  				serverConf,
   347  				false,
   348  				&utils.RTTStats{},
   349  				nil,
   350  				utils.DefaultLogger.WithPrefix("server"),
   351  				protocol.Version1,
   352  			)
   353  
   354  			clientEvents, cErr, serverEvents, sErr := handshake(client, server)
   355  			Expect(cErr).ToNot(HaveOccurred())
   356  			Expect(sErr).ToNot(HaveOccurred())
   357  			var clientReceivedTransportParameters *wire.TransportParameters
   358  			for _, ev := range clientEvents {
   359  				if ev.Kind == EventReceivedTransportParameters {
   360  					clientReceivedTransportParameters = ev.TransportParameters
   361  				}
   362  			}
   363  			Expect(clientReceivedTransportParameters).ToNot(BeNil())
   364  			Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
   365  
   366  			var serverReceivedTransportParameters *wire.TransportParameters
   367  			for _, ev := range serverEvents {
   368  				if ev.Kind == EventReceivedTransportParameters {
   369  					serverReceivedTransportParameters = ev.TransportParameters
   370  				}
   371  			}
   372  			Expect(serverReceivedTransportParameters).ToNot(BeNil())
   373  			Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
   374  		})
   375  
   376  		Context("with session tickets", func() {
   377  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   378  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   379  					clientConf, serverConf,
   380  					&utils.RTTStats{}, &utils.RTTStats{},
   381  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   382  					false,
   383  				)
   384  				Expect(clientErr).ToNot(HaveOccurred())
   385  				Expect(serverErr).ToNot(HaveOccurred())
   386  
   387  				// inject an invalid session ticket
   388  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   389  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   390  				Expect(err).To(HaveOccurred())
   391  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   392  			})
   393  
   394  			It("errors when handling the NewSessionTicket fails", func() {
   395  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   396  					clientConf, serverConf,
   397  					&utils.RTTStats{}, &utils.RTTStats{},
   398  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   399  					false,
   400  				)
   401  				Expect(clientErr).ToNot(HaveOccurred())
   402  				Expect(serverErr).ToNot(HaveOccurred())
   403  
   404  				// inject an invalid session ticket
   405  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   406  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   407  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   408  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   409  			})
   410  
   411  			It("uses session resumption", func() {
   412  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   413  				var state *tls.ClientSessionState
   414  				receivedSessionTicket := make(chan struct{})
   415  				csc.EXPECT().Get(gomock.Any())
   416  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   417  					state = css
   418  					close(receivedSessionTicket)
   419  				})
   420  				clientConf.ClientSessionCache = csc
   421  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   422  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   423  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   424  					clientConf, serverConf,
   425  					clientOrigRTTStats, &utils.RTTStats{},
   426  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   427  					false,
   428  				)
   429  				Expect(clientErr).ToNot(HaveOccurred())
   430  				Expect(serverErr).ToNot(HaveOccurred())
   431  				Eventually(receivedSessionTicket).Should(BeClosed())
   432  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   433  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   434  
   435  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   436  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   437  				clientRTTStats := &utils.RTTStats{}
   438  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   439  					clientConf, serverConf,
   440  					clientRTTStats, &utils.RTTStats{},
   441  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   442  					false,
   443  				)
   444  				Expect(clientErr).ToNot(HaveOccurred())
   445  				Expect(serverErr).ToNot(HaveOccurred())
   446  				Eventually(receivedSessionTicket).Should(BeClosed())
   447  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   448  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   449  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   450  			})
   451  
   452  			It("doesn't use session resumption if the server disabled it", func() {
   453  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   454  				var state *tls.ClientSessionState
   455  				receivedSessionTicket := make(chan struct{})
   456  				csc.EXPECT().Get(gomock.Any())
   457  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   458  					state = css
   459  					close(receivedSessionTicket)
   460  				})
   461  				clientConf.ClientSessionCache = csc
   462  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   463  					clientConf, serverConf,
   464  					&utils.RTTStats{}, &utils.RTTStats{},
   465  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   466  					false,
   467  				)
   468  				Expect(clientErr).ToNot(HaveOccurred())
   469  				Expect(serverErr).ToNot(HaveOccurred())
   470  				Eventually(receivedSessionTicket).Should(BeClosed())
   471  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   472  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   473  
   474  				serverConf.SessionTicketsDisabled = true
   475  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   476  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   477  					clientConf, serverConf,
   478  					&utils.RTTStats{}, &utils.RTTStats{},
   479  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   480  					false,
   481  				)
   482  				Expect(clientErr).ToNot(HaveOccurred())
   483  				Expect(serverErr).ToNot(HaveOccurred())
   484  				Eventually(receivedSessionTicket).Should(BeClosed())
   485  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   486  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   487  			})
   488  
   489  			It("uses 0-RTT", func() {
   490  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   491  				var state *tls.ClientSessionState
   492  				receivedSessionTicket := make(chan struct{})
   493  				csc.EXPECT().Get(gomock.Any())
   494  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   495  					state = css
   496  					close(receivedSessionTicket)
   497  				})
   498  				clientConf.ClientSessionCache = csc
   499  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   500  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   501  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   502  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   503  				const initialMaxData protocol.ByteCount = 1337
   504  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   505  					clientConf, serverConf,
   506  					clientOrigRTTStats, serverOrigRTTStats,
   507  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   508  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   509  					true,
   510  				)
   511  				Expect(clientErr).ToNot(HaveOccurred())
   512  				Expect(serverErr).ToNot(HaveOccurred())
   513  				Eventually(receivedSessionTicket).Should(BeClosed())
   514  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   515  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   516  
   517  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   518  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   519  
   520  				clientRTTStats := &utils.RTTStats{}
   521  				serverRTTStats := &utils.RTTStats{}
   522  				client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
   523  					clientConf, serverConf,
   524  					clientRTTStats, serverRTTStats,
   525  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   526  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   527  					true,
   528  				)
   529  				Expect(clientErr).ToNot(HaveOccurred())
   530  				Expect(serverErr).ToNot(HaveOccurred())
   531  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   532  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   533  
   534  				var tp *wire.TransportParameters
   535  				var clientReceived0RTTKeys bool
   536  				for _, ev := range clientEvents {
   537  					//nolint:exhaustive // only need to process a few events
   538  					switch ev.Kind {
   539  					case EventRestoredTransportParameters:
   540  						tp = ev.TransportParameters
   541  					case EventReceivedReadKeys:
   542  						clientReceived0RTTKeys = true
   543  					}
   544  				}
   545  				Expect(clientReceived0RTTKeys).To(BeTrue())
   546  				Expect(tp).ToNot(BeNil())
   547  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   548  
   549  				var serverReceived0RTTKeys bool
   550  				for _, ev := range serverEvents {
   551  					//nolint:exhaustive // only need to process a few events
   552  					switch ev.Kind {
   553  					case EventReceivedReadKeys:
   554  						serverReceived0RTTKeys = true
   555  					}
   556  				}
   557  				Expect(serverReceived0RTTKeys).To(BeTrue())
   558  
   559  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   560  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   561  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   562  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   563  			})
   564  
   565  			It("rejects 0-RTT, when the transport parameters changed", func() {
   566  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   567  				var state *tls.ClientSessionState
   568  				receivedSessionTicket := make(chan struct{})
   569  				csc.EXPECT().Get(gomock.Any())
   570  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   571  					state = css
   572  					close(receivedSessionTicket)
   573  				})
   574  				clientConf.ClientSessionCache = csc
   575  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   576  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   577  				const initialMaxData protocol.ByteCount = 1337
   578  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   579  					clientConf, serverConf,
   580  					clientOrigRTTStats, &utils.RTTStats{},
   581  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   582  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   583  					true,
   584  				)
   585  				Expect(clientErr).ToNot(HaveOccurred())
   586  				Expect(serverErr).ToNot(HaveOccurred())
   587  				Eventually(receivedSessionTicket).Should(BeClosed())
   588  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   589  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   590  
   591  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   592  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   593  
   594  				clientRTTStats := &utils.RTTStats{}
   595  				client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
   596  					clientConf, serverConf,
   597  					clientRTTStats, &utils.RTTStats{},
   598  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   599  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   600  					true,
   601  				)
   602  				Expect(clientErr).ToNot(HaveOccurred())
   603  				Expect(serverErr).ToNot(HaveOccurred())
   604  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   605  
   606  				var tp *wire.TransportParameters
   607  				var clientReceived0RTTKeys bool
   608  				for _, ev := range clientEvents {
   609  					//nolint:exhaustive // only need to process a few events
   610  					switch ev.Kind {
   611  					case EventRestoredTransportParameters:
   612  						tp = ev.TransportParameters
   613  					case EventReceivedReadKeys:
   614  						clientReceived0RTTKeys = true
   615  					}
   616  				}
   617  				Expect(clientReceived0RTTKeys).To(BeTrue())
   618  				Expect(tp).ToNot(BeNil())
   619  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   620  
   621  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   622  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   623  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   624  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   625  			})
   626  		})
   627  	})
   628  })