github.com/MerlinKodo/quic-go@v0.39.2/internal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"net"
    11  	"runtime"
    12  	"strings"
    13  	"time"
    14  
    15  	mocktls "github.com/MerlinKodo/quic-go/internal/mocks/tls"
    16  	"github.com/MerlinKodo/quic-go/internal/protocol"
    17  	"github.com/MerlinKodo/quic-go/internal/qerr"
    18  	"github.com/MerlinKodo/quic-go/internal/testdata"
    19  	"github.com/MerlinKodo/quic-go/internal/utils"
    20  	"github.com/MerlinKodo/quic-go/internal/wire"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  	"go.uber.org/mock/gomock"
    25  )
    26  
    27  const (
    28  	typeClientHello      = 1
    29  	typeNewSessionTicket = 4
    30  )
    31  
    32  var _ = Describe("Crypto Setup TLS", func() {
    33  	generateCert := func() tls.Certificate {
    34  		priv, err := rsa.GenerateKey(rand.Reader, 2048)
    35  		Expect(err).ToNot(HaveOccurred())
    36  		tmpl := &x509.Certificate{
    37  			SerialNumber:          big.NewInt(1),
    38  			Subject:               pkix.Name{},
    39  			SignatureAlgorithm:    x509.SHA256WithRSA,
    40  			NotBefore:             time.Now(),
    41  			NotAfter:              time.Now().Add(time.Hour), // valid for an hour
    42  			BasicConstraintsValid: true,
    43  		}
    44  		certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
    45  		Expect(err).ToNot(HaveOccurred())
    46  		return tls.Certificate{
    47  			PrivateKey:  priv,
    48  			Certificate: [][]byte{certDER},
    49  		}
    50  	}
    51  
    52  	var clientConf, serverConf *tls.Config
    53  
    54  	BeforeEach(func() {
    55  		serverConf = testdata.GetTLSConfig()
    56  		serverConf.NextProtos = []string{"crypto-setup"}
    57  		clientConf = &tls.Config{
    58  			ServerName: "localhost",
    59  			RootCAs:    testdata.GetRootCA(),
    60  			NextProtos: []string{"crypto-setup"},
    61  		}
    62  	})
    63  
    64  	It("handles qtls errors occurring before during ClientHello generation", func() {
    65  		tlsConf := testdata.GetTLSConfig()
    66  		tlsConf.InsecureSkipVerify = true
    67  		tlsConf.NextProtos = []string{""}
    68  		cl := NewCryptoSetupClient(
    69  			protocol.ConnectionID{},
    70  			&wire.TransportParameters{},
    71  			tlsConf,
    72  			false,
    73  			&utils.RTTStats{},
    74  			nil,
    75  			utils.DefaultLogger.WithPrefix("client"),
    76  			protocol.Version1,
    77  		)
    78  
    79  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
    80  			ErrorCode:    qerr.InternalError,
    81  			ErrorMessage: "tls: invalid NextProtos value",
    82  		}))
    83  	})
    84  
    85  	It("errors when a message is received at the wrong encryption level", func() {
    86  		var token protocol.StatelessResetToken
    87  		server := NewCryptoSetupServer(
    88  			protocol.ConnectionID{},
    89  			&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
    90  			&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
    91  			&wire.TransportParameters{StatelessResetToken: &token},
    92  			testdata.GetTLSConfig(),
    93  			false,
    94  			&utils.RTTStats{},
    95  			nil,
    96  			utils.DefaultLogger.WithPrefix("server"),
    97  			protocol.Version1,
    98  		)
    99  
   100  		Expect(server.StartHandshake()).To(Succeed())
   101  
   102  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   103  		// wrong encryption level
   104  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   105  		Expect(err).To(HaveOccurred())
   106  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   107  	})
   108  
   109  	Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
   110  		var (
   111  			local  = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
   112  			remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
   113  		)
   114  
   115  		It("wraps GetCertificate", func() {
   116  			var localAddr, remoteAddr net.Addr
   117  			tlsConf := &tls.Config{
   118  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   119  					localAddr = info.Conn.LocalAddr()
   120  					remoteAddr = info.Conn.RemoteAddr()
   121  					cert := generateCert()
   122  					return &cert, nil
   123  				},
   124  			}
   125  			addConnToClientHelloInfo(tlsConf, local, remote)
   126  			_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
   127  			Expect(err).ToNot(HaveOccurred())
   128  			Expect(localAddr).To(Equal(local))
   129  			Expect(remoteAddr).To(Equal(remote))
   130  		})
   131  
   132  		It("wraps GetConfigForClient", func() {
   133  			var localAddr, remoteAddr net.Addr
   134  			tlsConf := &tls.Config{
   135  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   136  					localAddr = info.Conn.LocalAddr()
   137  					remoteAddr = info.Conn.RemoteAddr()
   138  					return &tls.Config{}, nil
   139  				},
   140  			}
   141  			addConnToClientHelloInfo(tlsConf, local, remote)
   142  			_, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   143  			Expect(err).ToNot(HaveOccurred())
   144  			Expect(localAddr).To(Equal(local))
   145  			Expect(remoteAddr).To(Equal(remote))
   146  		})
   147  
   148  		It("wraps GetConfigForClient, recursively", func() {
   149  			var localAddr, remoteAddr net.Addr
   150  			tlsConf := &tls.Config{}
   151  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   152  				conf := tlsConf.Clone()
   153  				conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   154  					localAddr = info.Conn.LocalAddr()
   155  					remoteAddr = info.Conn.RemoteAddr()
   156  					cert := generateCert()
   157  					return &cert, nil
   158  				}
   159  				return conf, nil
   160  			}
   161  			addConnToClientHelloInfo(tlsConf, local, remote)
   162  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   163  			Expect(err).ToNot(HaveOccurred())
   164  			_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
   165  			Expect(err).ToNot(HaveOccurred())
   166  			Expect(localAddr).To(Equal(local))
   167  			Expect(remoteAddr).To(Equal(remote))
   168  		})
   169  	})
   170  
   171  	Context("doing the handshake", func() {
   172  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   173  			rttStats := &utils.RTTStats{}
   174  			rttStats.UpdateRTT(rtt, 0, time.Now())
   175  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   176  			return rttStats
   177  		}
   178  
   179  		// The clientEvents and serverEvents contain all events that were not processed by the function,
   180  		// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
   181  		handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
   182  			Expect(client.StartHandshake()).To(Succeed())
   183  			Expect(server.StartHandshake()).To(Succeed())
   184  
   185  			var clientHandshakeComplete, serverHandshakeComplete bool
   186  
   187  			for {
   188  			clientLoop:
   189  				for {
   190  					ev := client.NextEvent()
   191  					//nolint:exhaustive // only need to process a few events
   192  					switch ev.Kind {
   193  					case EventNoEvent:
   194  						break clientLoop
   195  					case EventWriteInitialData:
   196  						if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   197  							serverErr = err
   198  							return
   199  						}
   200  					case EventWriteHandshakeData:
   201  						if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   202  							serverErr = err
   203  							return
   204  						}
   205  					case EventHandshakeComplete:
   206  						clientHandshakeComplete = true
   207  					default:
   208  						clientEvents = append(clientEvents, ev)
   209  					}
   210  				}
   211  
   212  			serverLoop:
   213  				for {
   214  					ev := server.NextEvent()
   215  					//nolint:exhaustive // only need to process a few events
   216  					switch ev.Kind {
   217  					case EventNoEvent:
   218  						break serverLoop
   219  					case EventWriteInitialData:
   220  						if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   221  							clientErr = err
   222  							return
   223  						}
   224  					case EventWriteHandshakeData:
   225  						if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   226  							clientErr = err
   227  							return
   228  						}
   229  					case EventHandshakeComplete:
   230  						serverHandshakeComplete = true
   231  						ticket, err := server.GetSessionTicket()
   232  						Expect(err).ToNot(HaveOccurred())
   233  						if ticket != nil {
   234  							Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   235  						}
   236  					default:
   237  						serverEvents = append(serverEvents, ev)
   238  					}
   239  				}
   240  
   241  				if clientHandshakeComplete && serverHandshakeComplete {
   242  					break
   243  				}
   244  			}
   245  			return
   246  		}
   247  
   248  		handshakeWithTLSConf := func(
   249  			clientConf, serverConf *tls.Config,
   250  			clientRTTStats, serverRTTStats *utils.RTTStats,
   251  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   252  			enable0RTT bool,
   253  		) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
   254  			CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
   255  		) {
   256  			client := NewCryptoSetupClient(
   257  				protocol.ConnectionID{},
   258  				clientTransportParameters,
   259  				clientConf,
   260  				enable0RTT,
   261  				clientRTTStats,
   262  				nil,
   263  				utils.DefaultLogger.WithPrefix("client"),
   264  				protocol.Version1,
   265  			)
   266  
   267  			if serverTransportParameters.StatelessResetToken == nil {
   268  				var token protocol.StatelessResetToken
   269  				serverTransportParameters.StatelessResetToken = &token
   270  			}
   271  			server := NewCryptoSetupServer(
   272  				protocol.ConnectionID{},
   273  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   274  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   275  				serverTransportParameters,
   276  				serverConf,
   277  				enable0RTT,
   278  				serverRTTStats,
   279  				nil,
   280  				utils.DefaultLogger.WithPrefix("server"),
   281  				protocol.Version1,
   282  			)
   283  			cEvents, cErr, sEvents, sErr := handshake(client, server)
   284  			return client, cEvents, cErr, server, sEvents, sErr
   285  		}
   286  
   287  		It("handshakes", func() {
   288  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   289  				clientConf, serverConf,
   290  				&utils.RTTStats{}, &utils.RTTStats{},
   291  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   292  				false,
   293  			)
   294  			Expect(clientErr).ToNot(HaveOccurred())
   295  			Expect(serverErr).ToNot(HaveOccurred())
   296  		})
   297  
   298  		It("performs a HelloRetryRequst", func() {
   299  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   300  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   301  				clientConf, serverConf,
   302  				&utils.RTTStats{}, &utils.RTTStats{},
   303  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   304  				false,
   305  			)
   306  			Expect(clientErr).ToNot(HaveOccurred())
   307  			Expect(serverErr).ToNot(HaveOccurred())
   308  		})
   309  
   310  		It("handshakes with client auth", func() {
   311  			clientConf.Certificates = []tls.Certificate{generateCert()}
   312  			serverConf.ClientAuth = tls.RequireAnyClientCert
   313  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   314  				clientConf, serverConf,
   315  				&utils.RTTStats{}, &utils.RTTStats{},
   316  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   317  				false,
   318  			)
   319  			Expect(clientErr).ToNot(HaveOccurred())
   320  			Expect(serverErr).ToNot(HaveOccurred())
   321  		})
   322  
   323  		It("receives transport parameters", func() {
   324  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
   325  			client := NewCryptoSetupClient(
   326  				protocol.ConnectionID{},
   327  				cTransportParameters,
   328  				clientConf,
   329  				false,
   330  				&utils.RTTStats{},
   331  				nil,
   332  				utils.DefaultLogger.WithPrefix("client"),
   333  				protocol.Version1,
   334  			)
   335  
   336  			var token protocol.StatelessResetToken
   337  			sTransportParameters := &wire.TransportParameters{
   338  				MaxIdleTimeout:          1337 * time.Second,
   339  				StatelessResetToken:     &token,
   340  				ActiveConnectionIDLimit: 2,
   341  			}
   342  			server := NewCryptoSetupServer(
   343  				protocol.ConnectionID{},
   344  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   345  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   346  				sTransportParameters,
   347  				serverConf,
   348  				false,
   349  				&utils.RTTStats{},
   350  				nil,
   351  				utils.DefaultLogger.WithPrefix("server"),
   352  				protocol.Version1,
   353  			)
   354  
   355  			clientEvents, cErr, serverEvents, sErr := handshake(client, server)
   356  			Expect(cErr).ToNot(HaveOccurred())
   357  			Expect(sErr).ToNot(HaveOccurred())
   358  			var clientReceivedTransportParameters *wire.TransportParameters
   359  			for _, ev := range clientEvents {
   360  				if ev.Kind == EventReceivedTransportParameters {
   361  					clientReceivedTransportParameters = ev.TransportParameters
   362  				}
   363  			}
   364  			Expect(clientReceivedTransportParameters).ToNot(BeNil())
   365  			Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
   366  
   367  			var serverReceivedTransportParameters *wire.TransportParameters
   368  			for _, ev := range serverEvents {
   369  				if ev.Kind == EventReceivedTransportParameters {
   370  					serverReceivedTransportParameters = ev.TransportParameters
   371  				}
   372  			}
   373  			Expect(serverReceivedTransportParameters).ToNot(BeNil())
   374  			Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
   375  		})
   376  
   377  		Context("with session tickets", func() {
   378  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   379  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   380  					clientConf, serverConf,
   381  					&utils.RTTStats{}, &utils.RTTStats{},
   382  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   383  					false,
   384  				)
   385  				Expect(clientErr).ToNot(HaveOccurred())
   386  				Expect(serverErr).ToNot(HaveOccurred())
   387  
   388  				// inject an invalid session ticket
   389  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   390  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   391  				Expect(err).To(HaveOccurred())
   392  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   393  			})
   394  
   395  			It("errors when handling the NewSessionTicket fails", func() {
   396  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   397  					clientConf, serverConf,
   398  					&utils.RTTStats{}, &utils.RTTStats{},
   399  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   400  					false,
   401  				)
   402  				Expect(clientErr).ToNot(HaveOccurred())
   403  				Expect(serverErr).ToNot(HaveOccurred())
   404  
   405  				// inject an invalid session ticket
   406  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   407  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   408  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   409  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   410  			})
   411  
   412  			It("uses session resumption", func() {
   413  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   414  				var state *tls.ClientSessionState
   415  				receivedSessionTicket := make(chan struct{})
   416  				csc.EXPECT().Get(gomock.Any())
   417  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   418  					state = css
   419  					close(receivedSessionTicket)
   420  				})
   421  				clientConf.ClientSessionCache = csc
   422  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   423  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   424  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   425  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   426  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   427  					clientConf, serverConf,
   428  					clientOrigRTTStats, serverOrigRTTStats,
   429  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   430  					false,
   431  				)
   432  				Expect(clientErr).ToNot(HaveOccurred())
   433  				Expect(serverErr).ToNot(HaveOccurred())
   434  				Eventually(receivedSessionTicket).Should(BeClosed())
   435  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   436  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   437  
   438  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   439  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   440  				clientRTTStats := &utils.RTTStats{}
   441  				serverRTTStats := &utils.RTTStats{}
   442  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   443  					clientConf, serverConf,
   444  					clientRTTStats, serverRTTStats,
   445  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   446  					false,
   447  				)
   448  				Expect(clientErr).ToNot(HaveOccurred())
   449  				Expect(serverErr).ToNot(HaveOccurred())
   450  				Eventually(receivedSessionTicket).Should(BeClosed())
   451  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   452  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   453  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   454  				if !strings.Contains(runtime.Version(), "go1.20") {
   455  					Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   456  				}
   457  			})
   458  
   459  			It("doesn't use session resumption if the server disabled it", func() {
   460  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   461  				var state *tls.ClientSessionState
   462  				receivedSessionTicket := make(chan struct{})
   463  				csc.EXPECT().Get(gomock.Any())
   464  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   465  					state = css
   466  					close(receivedSessionTicket)
   467  				})
   468  				clientConf.ClientSessionCache = csc
   469  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   470  					clientConf, serverConf,
   471  					&utils.RTTStats{}, &utils.RTTStats{},
   472  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   473  					false,
   474  				)
   475  				Expect(clientErr).ToNot(HaveOccurred())
   476  				Expect(serverErr).ToNot(HaveOccurred())
   477  				Eventually(receivedSessionTicket).Should(BeClosed())
   478  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   479  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   480  
   481  				serverConf.SessionTicketsDisabled = true
   482  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   483  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   484  					clientConf, serverConf,
   485  					&utils.RTTStats{}, &utils.RTTStats{},
   486  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   487  					false,
   488  				)
   489  				Expect(clientErr).ToNot(HaveOccurred())
   490  				Expect(serverErr).ToNot(HaveOccurred())
   491  				Eventually(receivedSessionTicket).Should(BeClosed())
   492  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   493  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   494  			})
   495  
   496  			It("uses 0-RTT", func() {
   497  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   498  				var state *tls.ClientSessionState
   499  				receivedSessionTicket := make(chan struct{})
   500  				csc.EXPECT().Get(gomock.Any())
   501  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   502  					state = css
   503  					close(receivedSessionTicket)
   504  				})
   505  				clientConf.ClientSessionCache = csc
   506  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   507  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   508  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   509  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   510  				const initialMaxData protocol.ByteCount = 1337
   511  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   512  					clientConf, serverConf,
   513  					clientOrigRTTStats, serverOrigRTTStats,
   514  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   515  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   516  					true,
   517  				)
   518  				Expect(clientErr).ToNot(HaveOccurred())
   519  				Expect(serverErr).ToNot(HaveOccurred())
   520  				Eventually(receivedSessionTicket).Should(BeClosed())
   521  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   522  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   523  
   524  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   525  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   526  
   527  				clientRTTStats := &utils.RTTStats{}
   528  				serverRTTStats := &utils.RTTStats{}
   529  				client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
   530  					clientConf, serverConf,
   531  					clientRTTStats, serverRTTStats,
   532  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   533  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   534  					true,
   535  				)
   536  				Expect(clientErr).ToNot(HaveOccurred())
   537  				Expect(serverErr).ToNot(HaveOccurred())
   538  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   539  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   540  
   541  				var tp *wire.TransportParameters
   542  				var clientReceived0RTTKeys bool
   543  				for _, ev := range clientEvents {
   544  					//nolint:exhaustive // only need to process a few events
   545  					switch ev.Kind {
   546  					case EventRestoredTransportParameters:
   547  						tp = ev.TransportParameters
   548  					case EventReceivedReadKeys:
   549  						clientReceived0RTTKeys = true
   550  					}
   551  				}
   552  				Expect(clientReceived0RTTKeys).To(BeTrue())
   553  				Expect(tp).ToNot(BeNil())
   554  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   555  
   556  				var serverReceived0RTTKeys bool
   557  				for _, ev := range serverEvents {
   558  					//nolint:exhaustive // only need to process a few events
   559  					switch ev.Kind {
   560  					case EventReceivedReadKeys:
   561  						serverReceived0RTTKeys = true
   562  					}
   563  				}
   564  				Expect(serverReceived0RTTKeys).To(BeTrue())
   565  
   566  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   567  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   568  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   569  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   570  			})
   571  
   572  			It("rejects 0-RTT, when the transport parameters changed", func() {
   573  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   574  				var state *tls.ClientSessionState
   575  				receivedSessionTicket := make(chan struct{})
   576  				csc.EXPECT().Get(gomock.Any())
   577  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   578  					state = css
   579  					close(receivedSessionTicket)
   580  				})
   581  				clientConf.ClientSessionCache = csc
   582  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   583  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   584  				const initialMaxData protocol.ByteCount = 1337
   585  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   586  					clientConf, serverConf,
   587  					clientOrigRTTStats, &utils.RTTStats{},
   588  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   589  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   590  					true,
   591  				)
   592  				Expect(clientErr).ToNot(HaveOccurred())
   593  				Expect(serverErr).ToNot(HaveOccurred())
   594  				Eventually(receivedSessionTicket).Should(BeClosed())
   595  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   596  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   597  
   598  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   599  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   600  
   601  				clientRTTStats := &utils.RTTStats{}
   602  				client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
   603  					clientConf, serverConf,
   604  					clientRTTStats, &utils.RTTStats{},
   605  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   606  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   607  					true,
   608  				)
   609  				Expect(clientErr).ToNot(HaveOccurred())
   610  				Expect(serverErr).ToNot(HaveOccurred())
   611  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   612  
   613  				var tp *wire.TransportParameters
   614  				var clientReceived0RTTKeys bool
   615  				for _, ev := range clientEvents {
   616  					//nolint:exhaustive // only need to process a few events
   617  					switch ev.Kind {
   618  					case EventRestoredTransportParameters:
   619  						tp = ev.TransportParameters
   620  					case EventReceivedReadKeys:
   621  						clientReceived0RTTKeys = true
   622  					}
   623  				}
   624  				Expect(clientReceived0RTTKeys).To(BeTrue())
   625  				Expect(tp).ToNot(BeNil())
   626  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   627  
   628  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   629  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   630  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   631  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   632  			})
   633  		})
   634  	})
   635  })