github.com/quic-go/quic-go@v0.44.0/internal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"net"
    11  	"reflect"
    12  	"time"
    13  
    14  	mocktls "github.com/quic-go/quic-go/internal/mocks/tls"
    15  	"github.com/quic-go/quic-go/internal/protocol"
    16  	"github.com/quic-go/quic-go/internal/qerr"
    17  	"github.com/quic-go/quic-go/internal/testdata"
    18  	"github.com/quic-go/quic-go/internal/utils"
    19  	"github.com/quic-go/quic-go/internal/wire"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	"go.uber.org/mock/gomock"
    24  )
    25  
    26  const (
    27  	typeClientHello      = 1
    28  	typeNewSessionTicket = 4
    29  )
    30  
    31  var _ = Describe("Crypto Setup TLS", func() {
    32  	generateCert := func() tls.Certificate {
    33  		priv, err := rsa.GenerateKey(rand.Reader, 2048)
    34  		Expect(err).ToNot(HaveOccurred())
    35  		tmpl := &x509.Certificate{
    36  			SerialNumber:          big.NewInt(1),
    37  			Subject:               pkix.Name{},
    38  			SignatureAlgorithm:    x509.SHA256WithRSA,
    39  			NotBefore:             time.Now(),
    40  			NotAfter:              time.Now().Add(time.Hour), // valid for an hour
    41  			BasicConstraintsValid: true,
    42  		}
    43  		certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
    44  		Expect(err).ToNot(HaveOccurred())
    45  		return tls.Certificate{
    46  			PrivateKey:  priv,
    47  			Certificate: [][]byte{certDER},
    48  		}
    49  	}
    50  
    51  	var clientConf, serverConf *tls.Config
    52  
    53  	BeforeEach(func() {
    54  		serverConf = testdata.GetTLSConfig()
    55  		serverConf.NextProtos = []string{"crypto-setup"}
    56  		clientConf = &tls.Config{
    57  			ServerName: "localhost",
    58  			RootCAs:    testdata.GetRootCA(),
    59  			NextProtos: []string{"crypto-setup"},
    60  		}
    61  	})
    62  
    63  	It("handles qtls errors occurring before during ClientHello generation", func() {
    64  		tlsConf := testdata.GetTLSConfig()
    65  		tlsConf.InsecureSkipVerify = true
    66  		tlsConf.NextProtos = []string{""}
    67  		cl := NewCryptoSetupClient(
    68  			protocol.ConnectionID{},
    69  			&wire.TransportParameters{},
    70  			tlsConf,
    71  			false,
    72  			&utils.RTTStats{},
    73  			nil,
    74  			utils.DefaultLogger.WithPrefix("client"),
    75  			protocol.Version1,
    76  		)
    77  
    78  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
    79  			ErrorCode:    qerr.InternalError,
    80  			ErrorMessage: "tls: invalid NextProtos value",
    81  		}))
    82  	})
    83  
    84  	It("errors when a message is received at the wrong encryption level", func() {
    85  		var token protocol.StatelessResetToken
    86  		server := NewCryptoSetupServer(
    87  			protocol.ConnectionID{},
    88  			&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
    89  			&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
    90  			&wire.TransportParameters{StatelessResetToken: &token},
    91  			testdata.GetTLSConfig(),
    92  			false,
    93  			&utils.RTTStats{},
    94  			nil,
    95  			utils.DefaultLogger.WithPrefix("server"),
    96  			protocol.Version1,
    97  		)
    98  
    99  		Expect(server.StartHandshake()).To(Succeed())
   100  
   101  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   102  		// wrong encryption level
   103  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   104  		Expect(err).To(HaveOccurred())
   105  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   106  	})
   107  
   108  	Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
   109  		var (
   110  			local  = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
   111  			remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
   112  		)
   113  
   114  		It("wraps GetCertificate", func() {
   115  			var localAddr, remoteAddr net.Addr
   116  			tlsConf := &tls.Config{
   117  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   118  					localAddr = info.Conn.LocalAddr()
   119  					remoteAddr = info.Conn.RemoteAddr()
   120  					cert := generateCert()
   121  					return &cert, nil
   122  				},
   123  			}
   124  			addConnToClientHelloInfo(tlsConf, local, remote)
   125  			_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
   126  			Expect(err).ToNot(HaveOccurred())
   127  			Expect(localAddr).To(Equal(local))
   128  			Expect(remoteAddr).To(Equal(remote))
   129  		})
   130  
   131  		It("wraps GetConfigForClient", func() {
   132  			var localAddr, remoteAddr net.Addr
   133  			tlsConf := &tls.Config{
   134  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   135  					localAddr = info.Conn.LocalAddr()
   136  					remoteAddr = info.Conn.RemoteAddr()
   137  					return &tls.Config{}, nil
   138  				},
   139  			}
   140  			addConnToClientHelloInfo(tlsConf, local, remote)
   141  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   142  			Expect(err).ToNot(HaveOccurred())
   143  			Expect(localAddr).To(Equal(local))
   144  			Expect(remoteAddr).To(Equal(remote))
   145  			Expect(conf).ToNot(BeNil())
   146  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   147  		})
   148  
   149  		It("wraps GetConfigForClient, recursively", func() {
   150  			var localAddr, remoteAddr net.Addr
   151  			tlsConf := &tls.Config{}
   152  			var innerConf *tls.Config
   153  			getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
   154  				localAddr = info.Conn.LocalAddr()
   155  				remoteAddr = info.Conn.RemoteAddr()
   156  				cert := generateCert()
   157  				return &cert, nil
   158  			}
   159  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   160  				innerConf = tlsConf.Clone()
   161  				// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
   162  				innerConf.MaxVersion = tls.VersionTLS12
   163  				innerConf.GetCertificate = getCert
   164  				return innerConf, nil
   165  			}
   166  			addConnToClientHelloInfo(tlsConf, local, remote)
   167  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   168  			Expect(err).ToNot(HaveOccurred())
   169  			Expect(conf).ToNot(BeNil())
   170  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   171  			_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
   172  			Expect(err).ToNot(HaveOccurred())
   173  			Expect(localAddr).To(Equal(local))
   174  			Expect(remoteAddr).To(Equal(remote))
   175  			// make sure that the tls.Config returned by GetConfigForClient isn't modified
   176  			Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
   177  			Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
   178  		})
   179  	})
   180  
   181  	Context("doing the handshake", func() {
   182  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   183  			rttStats := &utils.RTTStats{}
   184  			rttStats.UpdateRTT(rtt, 0, time.Now())
   185  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   186  			return rttStats
   187  		}
   188  
   189  		// The clientEvents and serverEvents contain all events that were not processed by the function,
   190  		// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
   191  		handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
   192  			Expect(client.StartHandshake()).To(Succeed())
   193  			Expect(server.StartHandshake()).To(Succeed())
   194  
   195  			var clientHandshakeComplete, serverHandshakeComplete bool
   196  
   197  			for {
   198  			clientLoop:
   199  				for {
   200  					ev := client.NextEvent()
   201  					switch ev.Kind {
   202  					case EventNoEvent:
   203  						break clientLoop
   204  					case EventWriteInitialData:
   205  						if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   206  							serverErr = err
   207  							return
   208  						}
   209  					case EventWriteHandshakeData:
   210  						if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   211  							serverErr = err
   212  							return
   213  						}
   214  					case EventHandshakeComplete:
   215  						clientHandshakeComplete = true
   216  					default:
   217  						clientEvents = append(clientEvents, ev)
   218  					}
   219  				}
   220  
   221  			serverLoop:
   222  				for {
   223  					ev := server.NextEvent()
   224  					switch ev.Kind {
   225  					case EventNoEvent:
   226  						break serverLoop
   227  					case EventWriteInitialData:
   228  						if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   229  							clientErr = err
   230  							return
   231  						}
   232  					case EventWriteHandshakeData:
   233  						if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   234  							clientErr = err
   235  							return
   236  						}
   237  					case EventHandshakeComplete:
   238  						serverHandshakeComplete = true
   239  						ticket, err := server.GetSessionTicket()
   240  						Expect(err).ToNot(HaveOccurred())
   241  						if ticket != nil {
   242  							Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   243  						}
   244  					default:
   245  						serverEvents = append(serverEvents, ev)
   246  					}
   247  				}
   248  
   249  				if clientHandshakeComplete && serverHandshakeComplete {
   250  					break
   251  				}
   252  			}
   253  			return
   254  		}
   255  
   256  		handshakeWithTLSConf := func(
   257  			clientConf, serverConf *tls.Config,
   258  			clientRTTStats, serverRTTStats *utils.RTTStats,
   259  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   260  			enable0RTT bool,
   261  		) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
   262  			CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
   263  		) {
   264  			client := NewCryptoSetupClient(
   265  				protocol.ConnectionID{},
   266  				clientTransportParameters,
   267  				clientConf,
   268  				enable0RTT,
   269  				clientRTTStats,
   270  				nil,
   271  				utils.DefaultLogger.WithPrefix("client"),
   272  				protocol.Version1,
   273  			)
   274  
   275  			if serverTransportParameters.StatelessResetToken == nil {
   276  				var token protocol.StatelessResetToken
   277  				serverTransportParameters.StatelessResetToken = &token
   278  			}
   279  			server := NewCryptoSetupServer(
   280  				protocol.ConnectionID{},
   281  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   282  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   283  				serverTransportParameters,
   284  				serverConf,
   285  				enable0RTT,
   286  				serverRTTStats,
   287  				nil,
   288  				utils.DefaultLogger.WithPrefix("server"),
   289  				protocol.Version1,
   290  			)
   291  			cEvents, cErr, sEvents, sErr := handshake(client, server)
   292  			return client, cEvents, cErr, server, sEvents, sErr
   293  		}
   294  
   295  		It("handshakes", func() {
   296  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   297  				clientConf, serverConf,
   298  				&utils.RTTStats{}, &utils.RTTStats{},
   299  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   300  				false,
   301  			)
   302  			Expect(clientErr).ToNot(HaveOccurred())
   303  			Expect(serverErr).ToNot(HaveOccurred())
   304  		})
   305  
   306  		It("performs a HelloRetryRequst", func() {
   307  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   308  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   309  				clientConf, serverConf,
   310  				&utils.RTTStats{}, &utils.RTTStats{},
   311  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   312  				false,
   313  			)
   314  			Expect(clientErr).ToNot(HaveOccurred())
   315  			Expect(serverErr).ToNot(HaveOccurred())
   316  		})
   317  
   318  		It("handshakes with client auth", func() {
   319  			clientConf.Certificates = []tls.Certificate{generateCert()}
   320  			serverConf.ClientAuth = tls.RequireAnyClientCert
   321  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   322  				clientConf, serverConf,
   323  				&utils.RTTStats{}, &utils.RTTStats{},
   324  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   325  				false,
   326  			)
   327  			Expect(clientErr).ToNot(HaveOccurred())
   328  			Expect(serverErr).ToNot(HaveOccurred())
   329  		})
   330  
   331  		It("receives transport parameters", func() {
   332  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
   333  			client := NewCryptoSetupClient(
   334  				protocol.ConnectionID{},
   335  				cTransportParameters,
   336  				clientConf,
   337  				false,
   338  				&utils.RTTStats{},
   339  				nil,
   340  				utils.DefaultLogger.WithPrefix("client"),
   341  				protocol.Version1,
   342  			)
   343  
   344  			var token protocol.StatelessResetToken
   345  			sTransportParameters := &wire.TransportParameters{
   346  				MaxIdleTimeout:          1337 * time.Second,
   347  				StatelessResetToken:     &token,
   348  				ActiveConnectionIDLimit: 2,
   349  			}
   350  			server := NewCryptoSetupServer(
   351  				protocol.ConnectionID{},
   352  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   353  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   354  				sTransportParameters,
   355  				serverConf,
   356  				false,
   357  				&utils.RTTStats{},
   358  				nil,
   359  				utils.DefaultLogger.WithPrefix("server"),
   360  				protocol.Version1,
   361  			)
   362  
   363  			clientEvents, cErr, serverEvents, sErr := handshake(client, server)
   364  			Expect(cErr).ToNot(HaveOccurred())
   365  			Expect(sErr).ToNot(HaveOccurred())
   366  			var clientReceivedTransportParameters *wire.TransportParameters
   367  			for _, ev := range clientEvents {
   368  				if ev.Kind == EventReceivedTransportParameters {
   369  					clientReceivedTransportParameters = ev.TransportParameters
   370  				}
   371  			}
   372  			Expect(clientReceivedTransportParameters).ToNot(BeNil())
   373  			Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
   374  
   375  			var serverReceivedTransportParameters *wire.TransportParameters
   376  			for _, ev := range serverEvents {
   377  				if ev.Kind == EventReceivedTransportParameters {
   378  					serverReceivedTransportParameters = ev.TransportParameters
   379  				}
   380  			}
   381  			Expect(serverReceivedTransportParameters).ToNot(BeNil())
   382  			Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
   383  		})
   384  
   385  		Context("with session tickets", func() {
   386  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   387  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   388  					clientConf, serverConf,
   389  					&utils.RTTStats{}, &utils.RTTStats{},
   390  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   391  					false,
   392  				)
   393  				Expect(clientErr).ToNot(HaveOccurred())
   394  				Expect(serverErr).ToNot(HaveOccurred())
   395  
   396  				// inject an invalid session ticket
   397  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   398  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   399  				Expect(err).To(HaveOccurred())
   400  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   401  			})
   402  
   403  			It("errors when handling the NewSessionTicket fails", func() {
   404  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   405  					clientConf, serverConf,
   406  					&utils.RTTStats{}, &utils.RTTStats{},
   407  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   408  					false,
   409  				)
   410  				Expect(clientErr).ToNot(HaveOccurred())
   411  				Expect(serverErr).ToNot(HaveOccurred())
   412  
   413  				// inject an invalid session ticket
   414  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   415  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   416  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   417  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   418  			})
   419  
   420  			It("uses session resumption", func() {
   421  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   422  				var state *tls.ClientSessionState
   423  				receivedSessionTicket := make(chan struct{})
   424  				csc.EXPECT().Get(gomock.Any())
   425  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   426  					state = css
   427  					close(receivedSessionTicket)
   428  				})
   429  				clientConf.ClientSessionCache = csc
   430  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   431  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   432  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   433  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   434  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   435  					clientConf, serverConf,
   436  					clientOrigRTTStats, serverOrigRTTStats,
   437  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   438  					false,
   439  				)
   440  				Expect(clientErr).ToNot(HaveOccurred())
   441  				Expect(serverErr).ToNot(HaveOccurred())
   442  				Eventually(receivedSessionTicket).Should(BeClosed())
   443  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   444  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   445  
   446  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   447  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   448  				clientRTTStats := &utils.RTTStats{}
   449  				serverRTTStats := &utils.RTTStats{}
   450  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   451  					clientConf, serverConf,
   452  					clientRTTStats, serverRTTStats,
   453  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   454  					false,
   455  				)
   456  				Expect(clientErr).ToNot(HaveOccurred())
   457  				Expect(serverErr).ToNot(HaveOccurred())
   458  				Eventually(receivedSessionTicket).Should(BeClosed())
   459  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   460  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   461  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   462  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   463  			})
   464  
   465  			It("doesn't use session resumption if the server disabled it", func() {
   466  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   467  				var state *tls.ClientSessionState
   468  				receivedSessionTicket := make(chan struct{})
   469  				csc.EXPECT().Get(gomock.Any())
   470  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   471  					state = css
   472  					close(receivedSessionTicket)
   473  				})
   474  				clientConf.ClientSessionCache = csc
   475  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   476  					clientConf, serverConf,
   477  					&utils.RTTStats{}, &utils.RTTStats{},
   478  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   479  					false,
   480  				)
   481  				Expect(clientErr).ToNot(HaveOccurred())
   482  				Expect(serverErr).ToNot(HaveOccurred())
   483  				Eventually(receivedSessionTicket).Should(BeClosed())
   484  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   485  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   486  
   487  				serverConf.SessionTicketsDisabled = true
   488  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   489  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   490  					clientConf, serverConf,
   491  					&utils.RTTStats{}, &utils.RTTStats{},
   492  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   493  					false,
   494  				)
   495  				Expect(clientErr).ToNot(HaveOccurred())
   496  				Expect(serverErr).ToNot(HaveOccurred())
   497  				Eventually(receivedSessionTicket).Should(BeClosed())
   498  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   499  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   500  			})
   501  
   502  			It("uses 0-RTT", func() {
   503  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   504  				var state *tls.ClientSessionState
   505  				receivedSessionTicket := make(chan struct{})
   506  				csc.EXPECT().Get(gomock.Any())
   507  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   508  					state = css
   509  					close(receivedSessionTicket)
   510  				})
   511  				clientConf.ClientSessionCache = csc
   512  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   513  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   514  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   515  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   516  				const initialMaxData protocol.ByteCount = 1337
   517  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   518  					clientConf, serverConf,
   519  					clientOrigRTTStats, serverOrigRTTStats,
   520  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   521  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   522  					true,
   523  				)
   524  				Expect(clientErr).ToNot(HaveOccurred())
   525  				Expect(serverErr).ToNot(HaveOccurred())
   526  				Eventually(receivedSessionTicket).Should(BeClosed())
   527  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   528  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   529  
   530  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   531  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   532  
   533  				clientRTTStats := &utils.RTTStats{}
   534  				serverRTTStats := &utils.RTTStats{}
   535  				client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
   536  					clientConf, serverConf,
   537  					clientRTTStats, serverRTTStats,
   538  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   539  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   540  					true,
   541  				)
   542  				Expect(clientErr).ToNot(HaveOccurred())
   543  				Expect(serverErr).ToNot(HaveOccurred())
   544  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   545  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   546  
   547  				var tp *wire.TransportParameters
   548  				var clientReceived0RTTKeys bool
   549  				for _, ev := range clientEvents {
   550  					switch ev.Kind {
   551  					case EventRestoredTransportParameters:
   552  						tp = ev.TransportParameters
   553  					case EventReceivedReadKeys:
   554  						clientReceived0RTTKeys = true
   555  					}
   556  				}
   557  				Expect(clientReceived0RTTKeys).To(BeTrue())
   558  				Expect(tp).ToNot(BeNil())
   559  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   560  
   561  				var serverReceived0RTTKeys bool
   562  				for _, ev := range serverEvents {
   563  					switch ev.Kind {
   564  					case EventReceivedReadKeys:
   565  						serverReceived0RTTKeys = true
   566  					}
   567  				}
   568  				Expect(serverReceived0RTTKeys).To(BeTrue())
   569  
   570  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   571  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   572  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   573  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   574  			})
   575  
   576  			It("rejects 0-RTT, when the transport parameters changed", func() {
   577  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   578  				var state *tls.ClientSessionState
   579  				receivedSessionTicket := make(chan struct{})
   580  				csc.EXPECT().Get(gomock.Any())
   581  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   582  					state = css
   583  					close(receivedSessionTicket)
   584  				})
   585  				clientConf.ClientSessionCache = csc
   586  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   587  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   588  				const initialMaxData protocol.ByteCount = 1337
   589  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   590  					clientConf, serverConf,
   591  					clientOrigRTTStats, &utils.RTTStats{},
   592  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   593  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   594  					true,
   595  				)
   596  				Expect(clientErr).ToNot(HaveOccurred())
   597  				Expect(serverErr).ToNot(HaveOccurred())
   598  				Eventually(receivedSessionTicket).Should(BeClosed())
   599  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   600  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   601  
   602  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   603  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   604  
   605  				clientRTTStats := &utils.RTTStats{}
   606  				client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
   607  					clientConf, serverConf,
   608  					clientRTTStats, &utils.RTTStats{},
   609  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   610  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   611  					true,
   612  				)
   613  				Expect(clientErr).ToNot(HaveOccurred())
   614  				Expect(serverErr).ToNot(HaveOccurred())
   615  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   616  
   617  				var tp *wire.TransportParameters
   618  				var clientReceived0RTTKeys bool
   619  				for _, ev := range clientEvents {
   620  					switch ev.Kind {
   621  					case EventRestoredTransportParameters:
   622  						tp = ev.TransportParameters
   623  					case EventReceivedReadKeys:
   624  						clientReceived0RTTKeys = true
   625  					}
   626  				}
   627  				Expect(clientReceived0RTTKeys).To(BeTrue())
   628  				Expect(tp).ToNot(BeNil())
   629  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   630  
   631  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   632  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   633  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   634  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   635  			})
   636  		})
   637  	})
   638  })