github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/internal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"net"
    11  	"reflect"
    12  	"runtime"
    13  	"strings"
    14  	"time"
    15  
    16  	mocktls "github.com/metacubex/quic-go/internal/mocks/tls"
    17  	"github.com/metacubex/quic-go/internal/protocol"
    18  	"github.com/metacubex/quic-go/internal/qerr"
    19  	"github.com/metacubex/quic-go/internal/testdata"
    20  	"github.com/metacubex/quic-go/internal/utils"
    21  	"github.com/metacubex/quic-go/internal/wire"
    22  
    23  	. "github.com/onsi/ginkgo/v2"
    24  	. "github.com/onsi/gomega"
    25  	"go.uber.org/mock/gomock"
    26  )
    27  
    28  const (
    29  	typeClientHello      = 1
    30  	typeNewSessionTicket = 4
    31  )
    32  
    33  var _ = Describe("Crypto Setup TLS", func() {
    34  	generateCert := func() tls.Certificate {
    35  		priv, err := rsa.GenerateKey(rand.Reader, 2048)
    36  		Expect(err).ToNot(HaveOccurred())
    37  		tmpl := &x509.Certificate{
    38  			SerialNumber:          big.NewInt(1),
    39  			Subject:               pkix.Name{},
    40  			SignatureAlgorithm:    x509.SHA256WithRSA,
    41  			NotBefore:             time.Now(),
    42  			NotAfter:              time.Now().Add(time.Hour), // valid for an hour
    43  			BasicConstraintsValid: true,
    44  		}
    45  		certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
    46  		Expect(err).ToNot(HaveOccurred())
    47  		return tls.Certificate{
    48  			PrivateKey:  priv,
    49  			Certificate: [][]byte{certDER},
    50  		}
    51  	}
    52  
    53  	var clientConf, serverConf *tls.Config
    54  
    55  	BeforeEach(func() {
    56  		serverConf = testdata.GetTLSConfig()
    57  		serverConf.NextProtos = []string{"crypto-setup"}
    58  		clientConf = &tls.Config{
    59  			ServerName: "localhost",
    60  			RootCAs:    testdata.GetRootCA(),
    61  			NextProtos: []string{"crypto-setup"},
    62  		}
    63  	})
    64  
    65  	It("handles qtls errors occurring before during ClientHello generation", func() {
    66  		tlsConf := testdata.GetTLSConfig()
    67  		tlsConf.InsecureSkipVerify = true
    68  		tlsConf.NextProtos = []string{""}
    69  		cl := NewCryptoSetupClient(
    70  			protocol.ConnectionID{},
    71  			&wire.TransportParameters{},
    72  			tlsConf,
    73  			false,
    74  			&utils.RTTStats{},
    75  			nil,
    76  			utils.DefaultLogger.WithPrefix("client"),
    77  			protocol.Version1,
    78  		)
    79  
    80  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
    81  			ErrorCode:    qerr.InternalError,
    82  			ErrorMessage: "tls: invalid NextProtos value",
    83  		}))
    84  	})
    85  
    86  	It("errors when a message is received at the wrong encryption level", func() {
    87  		var token protocol.StatelessResetToken
    88  		server := NewCryptoSetupServer(
    89  			protocol.ConnectionID{},
    90  			&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
    91  			&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
    92  			&wire.TransportParameters{StatelessResetToken: &token},
    93  			testdata.GetTLSConfig(),
    94  			false,
    95  			&utils.RTTStats{},
    96  			nil,
    97  			utils.DefaultLogger.WithPrefix("server"),
    98  			protocol.Version1,
    99  		)
   100  
   101  		Expect(server.StartHandshake()).To(Succeed())
   102  
   103  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   104  		// wrong encryption level
   105  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   106  		Expect(err).To(HaveOccurred())
   107  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   108  	})
   109  
   110  	Context("filling in a net.Conn in tls.ClientHelloInfo", func() {
   111  		var (
   112  			local  = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
   113  			remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
   114  		)
   115  
   116  		It("wraps GetCertificate", func() {
   117  			var localAddr, remoteAddr net.Addr
   118  			tlsConf := &tls.Config{
   119  				GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
   120  					localAddr = info.Conn.LocalAddr()
   121  					remoteAddr = info.Conn.RemoteAddr()
   122  					cert := generateCert()
   123  					return &cert, nil
   124  				},
   125  			}
   126  			addConnToClientHelloInfo(tlsConf, local, remote)
   127  			_, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{})
   128  			Expect(err).ToNot(HaveOccurred())
   129  			Expect(localAddr).To(Equal(local))
   130  			Expect(remoteAddr).To(Equal(remote))
   131  		})
   132  
   133  		It("wraps GetConfigForClient", func() {
   134  			var localAddr, remoteAddr net.Addr
   135  			tlsConf := &tls.Config{
   136  				GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   137  					localAddr = info.Conn.LocalAddr()
   138  					remoteAddr = info.Conn.RemoteAddr()
   139  					return &tls.Config{}, nil
   140  				},
   141  			}
   142  			addConnToClientHelloInfo(tlsConf, local, remote)
   143  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   144  			Expect(err).ToNot(HaveOccurred())
   145  			Expect(localAddr).To(Equal(local))
   146  			Expect(remoteAddr).To(Equal(remote))
   147  			Expect(conf).ToNot(BeNil())
   148  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   149  		})
   150  
   151  		It("wraps GetConfigForClient, recursively", func() {
   152  			var localAddr, remoteAddr net.Addr
   153  			tlsConf := &tls.Config{}
   154  			var innerConf *tls.Config
   155  			getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
   156  				localAddr = info.Conn.LocalAddr()
   157  				remoteAddr = info.Conn.RemoteAddr()
   158  				cert := generateCert()
   159  				return &cert, nil
   160  			}
   161  			tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
   162  				innerConf = tlsConf.Clone()
   163  				// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
   164  				innerConf.MaxVersion = tls.VersionTLS12
   165  				innerConf.GetCertificate = getCert
   166  				return innerConf, nil
   167  			}
   168  			addConnToClientHelloInfo(tlsConf, local, remote)
   169  			conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
   170  			Expect(err).ToNot(HaveOccurred())
   171  			Expect(conf).ToNot(BeNil())
   172  			Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
   173  			_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
   174  			Expect(err).ToNot(HaveOccurred())
   175  			Expect(localAddr).To(Equal(local))
   176  			Expect(remoteAddr).To(Equal(remote))
   177  			// make sure that the tls.Config returned by GetConfigForClient isn't modified
   178  			Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
   179  			Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
   180  		})
   181  	})
   182  
   183  	Context("doing the handshake", func() {
   184  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   185  			rttStats := &utils.RTTStats{}
   186  			rttStats.UpdateRTT(rtt, 0, time.Now())
   187  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   188  			return rttStats
   189  		}
   190  
   191  		// The clientEvents and serverEvents contain all events that were not processed by the function,
   192  		// i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete.
   193  		handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) {
   194  			Expect(client.StartHandshake()).To(Succeed())
   195  			Expect(server.StartHandshake()).To(Succeed())
   196  
   197  			var clientHandshakeComplete, serverHandshakeComplete bool
   198  
   199  			for {
   200  			clientLoop:
   201  				for {
   202  					ev := client.NextEvent()
   203  					switch ev.Kind {
   204  					case EventNoEvent:
   205  						break clientLoop
   206  					case EventWriteInitialData:
   207  						if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   208  							serverErr = err
   209  							return
   210  						}
   211  					case EventWriteHandshakeData:
   212  						if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   213  							serverErr = err
   214  							return
   215  						}
   216  					case EventHandshakeComplete:
   217  						clientHandshakeComplete = true
   218  					default:
   219  						clientEvents = append(clientEvents, ev)
   220  					}
   221  				}
   222  
   223  			serverLoop:
   224  				for {
   225  					ev := server.NextEvent()
   226  					switch ev.Kind {
   227  					case EventNoEvent:
   228  						break serverLoop
   229  					case EventWriteInitialData:
   230  						if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil {
   231  							clientErr = err
   232  							return
   233  						}
   234  					case EventWriteHandshakeData:
   235  						if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil {
   236  							clientErr = err
   237  							return
   238  						}
   239  					case EventHandshakeComplete:
   240  						serverHandshakeComplete = true
   241  						ticket, err := server.GetSessionTicket()
   242  						Expect(err).ToNot(HaveOccurred())
   243  						if ticket != nil {
   244  							Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   245  						}
   246  					default:
   247  						serverEvents = append(serverEvents, ev)
   248  					}
   249  				}
   250  
   251  				if clientHandshakeComplete && serverHandshakeComplete {
   252  					break
   253  				}
   254  			}
   255  			return
   256  		}
   257  
   258  		handshakeWithTLSConf := func(
   259  			clientConf, serverConf *tls.Config,
   260  			clientRTTStats, serverRTTStats *utils.RTTStats,
   261  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   262  			enable0RTT bool,
   263  		) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */
   264  			CryptoSetup /* server */, []Event /* more server events */, error, /* server error */
   265  		) {
   266  			client := NewCryptoSetupClient(
   267  				protocol.ConnectionID{},
   268  				clientTransportParameters,
   269  				clientConf,
   270  				enable0RTT,
   271  				clientRTTStats,
   272  				nil,
   273  				utils.DefaultLogger.WithPrefix("client"),
   274  				protocol.Version1,
   275  			)
   276  
   277  			if serverTransportParameters.StatelessResetToken == nil {
   278  				var token protocol.StatelessResetToken
   279  				serverTransportParameters.StatelessResetToken = &token
   280  			}
   281  			server := NewCryptoSetupServer(
   282  				protocol.ConnectionID{},
   283  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   284  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   285  				serverTransportParameters,
   286  				serverConf,
   287  				enable0RTT,
   288  				serverRTTStats,
   289  				nil,
   290  				utils.DefaultLogger.WithPrefix("server"),
   291  				protocol.Version1,
   292  			)
   293  			cEvents, cErr, sEvents, sErr := handshake(client, server)
   294  			return client, cEvents, cErr, server, sEvents, sErr
   295  		}
   296  
   297  		It("handshakes", func() {
   298  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   299  				clientConf, serverConf,
   300  				&utils.RTTStats{}, &utils.RTTStats{},
   301  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   302  				false,
   303  			)
   304  			Expect(clientErr).ToNot(HaveOccurred())
   305  			Expect(serverErr).ToNot(HaveOccurred())
   306  		})
   307  
   308  		It("performs a HelloRetryRequst", func() {
   309  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   310  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   311  				clientConf, serverConf,
   312  				&utils.RTTStats{}, &utils.RTTStats{},
   313  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   314  				false,
   315  			)
   316  			Expect(clientErr).ToNot(HaveOccurred())
   317  			Expect(serverErr).ToNot(HaveOccurred())
   318  		})
   319  
   320  		It("handshakes with client auth", func() {
   321  			clientConf.Certificates = []tls.Certificate{generateCert()}
   322  			serverConf.ClientAuth = tls.RequireAnyClientCert
   323  			_, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   324  				clientConf, serverConf,
   325  				&utils.RTTStats{}, &utils.RTTStats{},
   326  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   327  				false,
   328  			)
   329  			Expect(clientErr).ToNot(HaveOccurred())
   330  			Expect(serverErr).ToNot(HaveOccurred())
   331  		})
   332  
   333  		It("receives transport parameters", func() {
   334  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second}
   335  			client := NewCryptoSetupClient(
   336  				protocol.ConnectionID{},
   337  				cTransportParameters,
   338  				clientConf,
   339  				false,
   340  				&utils.RTTStats{},
   341  				nil,
   342  				utils.DefaultLogger.WithPrefix("client"),
   343  				protocol.Version1,
   344  			)
   345  
   346  			var token protocol.StatelessResetToken
   347  			sTransportParameters := &wire.TransportParameters{
   348  				MaxIdleTimeout:          1337 * time.Second,
   349  				StatelessResetToken:     &token,
   350  				ActiveConnectionIDLimit: 2,
   351  			}
   352  			server := NewCryptoSetupServer(
   353  				protocol.ConnectionID{},
   354  				&net.UDPAddr{IP: net.IPv6loopback, Port: 1234},
   355  				&net.UDPAddr{IP: net.IPv6loopback, Port: 4321},
   356  				sTransportParameters,
   357  				serverConf,
   358  				false,
   359  				&utils.RTTStats{},
   360  				nil,
   361  				utils.DefaultLogger.WithPrefix("server"),
   362  				protocol.Version1,
   363  			)
   364  
   365  			clientEvents, cErr, serverEvents, sErr := handshake(client, server)
   366  			Expect(cErr).ToNot(HaveOccurred())
   367  			Expect(sErr).ToNot(HaveOccurred())
   368  			var clientReceivedTransportParameters *wire.TransportParameters
   369  			for _, ev := range clientEvents {
   370  				if ev.Kind == EventReceivedTransportParameters {
   371  					clientReceivedTransportParameters = ev.TransportParameters
   372  				}
   373  			}
   374  			Expect(clientReceivedTransportParameters).ToNot(BeNil())
   375  			Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second))
   376  
   377  			var serverReceivedTransportParameters *wire.TransportParameters
   378  			for _, ev := range serverEvents {
   379  				if ev.Kind == EventReceivedTransportParameters {
   380  					serverReceivedTransportParameters = ev.TransportParameters
   381  				}
   382  			}
   383  			Expect(serverReceivedTransportParameters).ToNot(BeNil())
   384  			Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second))
   385  		})
   386  
   387  		Context("with session tickets", func() {
   388  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   389  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   390  					clientConf, serverConf,
   391  					&utils.RTTStats{}, &utils.RTTStats{},
   392  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   393  					false,
   394  				)
   395  				Expect(clientErr).ToNot(HaveOccurred())
   396  				Expect(serverErr).ToNot(HaveOccurred())
   397  
   398  				// inject an invalid session ticket
   399  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   400  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   401  				Expect(err).To(HaveOccurred())
   402  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   403  			})
   404  
   405  			It("errors when handling the NewSessionTicket fails", func() {
   406  				client, _, clientErr, _, _, serverErr := handshakeWithTLSConf(
   407  					clientConf, serverConf,
   408  					&utils.RTTStats{}, &utils.RTTStats{},
   409  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   410  					false,
   411  				)
   412  				Expect(clientErr).ToNot(HaveOccurred())
   413  				Expect(serverErr).ToNot(HaveOccurred())
   414  
   415  				// inject an invalid session ticket
   416  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   417  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   418  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   419  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   420  			})
   421  
   422  			It("uses session resumption", func() {
   423  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   424  				var state *tls.ClientSessionState
   425  				receivedSessionTicket := make(chan struct{})
   426  				csc.EXPECT().Get(gomock.Any())
   427  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   428  					state = css
   429  					close(receivedSessionTicket)
   430  				})
   431  				clientConf.ClientSessionCache = csc
   432  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   433  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   434  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   435  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   436  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   437  					clientConf, serverConf,
   438  					clientOrigRTTStats, serverOrigRTTStats,
   439  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   440  					false,
   441  				)
   442  				Expect(clientErr).ToNot(HaveOccurred())
   443  				Expect(serverErr).ToNot(HaveOccurred())
   444  				Eventually(receivedSessionTicket).Should(BeClosed())
   445  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   446  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   447  
   448  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   449  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   450  				clientRTTStats := &utils.RTTStats{}
   451  				serverRTTStats := &utils.RTTStats{}
   452  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   453  					clientConf, serverConf,
   454  					clientRTTStats, serverRTTStats,
   455  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   456  					false,
   457  				)
   458  				Expect(clientErr).ToNot(HaveOccurred())
   459  				Expect(serverErr).ToNot(HaveOccurred())
   460  				Eventually(receivedSessionTicket).Should(BeClosed())
   461  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   462  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   463  				if !strings.Contains(runtime.Version(), "go1.20") {
   464  					Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   465  					Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   466  				}
   467  			})
   468  
   469  			It("doesn't use session resumption if the server disabled it", func() {
   470  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   471  				var state *tls.ClientSessionState
   472  				receivedSessionTicket := make(chan struct{})
   473  				csc.EXPECT().Get(gomock.Any())
   474  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   475  					state = css
   476  					close(receivedSessionTicket)
   477  				})
   478  				clientConf.ClientSessionCache = csc
   479  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   480  					clientConf, serverConf,
   481  					&utils.RTTStats{}, &utils.RTTStats{},
   482  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   483  					false,
   484  				)
   485  				Expect(clientErr).ToNot(HaveOccurred())
   486  				Expect(serverErr).ToNot(HaveOccurred())
   487  				Eventually(receivedSessionTicket).Should(BeClosed())
   488  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   489  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   490  
   491  				serverConf.SessionTicketsDisabled = true
   492  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   493  				client, _, clientErr, server, _, serverErr = handshakeWithTLSConf(
   494  					clientConf, serverConf,
   495  					&utils.RTTStats{}, &utils.RTTStats{},
   496  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   497  					false,
   498  				)
   499  				Expect(clientErr).ToNot(HaveOccurred())
   500  				Expect(serverErr).ToNot(HaveOccurred())
   501  				Eventually(receivedSessionTicket).Should(BeClosed())
   502  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   503  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   504  			})
   505  
   506  			It("uses 0-RTT", func() {
   507  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   508  				var state *tls.ClientSessionState
   509  				receivedSessionTicket := make(chan struct{})
   510  				csc.EXPECT().Get(gomock.Any())
   511  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   512  					state = css
   513  					close(receivedSessionTicket)
   514  				})
   515  				clientConf.ClientSessionCache = csc
   516  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   517  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   518  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   519  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   520  				const initialMaxData protocol.ByteCount = 1337
   521  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   522  					clientConf, serverConf,
   523  					clientOrigRTTStats, serverOrigRTTStats,
   524  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   525  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   526  					true,
   527  				)
   528  				Expect(clientErr).ToNot(HaveOccurred())
   529  				Expect(serverErr).ToNot(HaveOccurred())
   530  				Eventually(receivedSessionTicket).Should(BeClosed())
   531  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   532  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   533  
   534  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   535  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   536  
   537  				clientRTTStats := &utils.RTTStats{}
   538  				serverRTTStats := &utils.RTTStats{}
   539  				client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf(
   540  					clientConf, serverConf,
   541  					clientRTTStats, serverRTTStats,
   542  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   543  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   544  					true,
   545  				)
   546  				Expect(clientErr).ToNot(HaveOccurred())
   547  				Expect(serverErr).ToNot(HaveOccurred())
   548  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   549  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   550  
   551  				var tp *wire.TransportParameters
   552  				var clientReceived0RTTKeys bool
   553  				for _, ev := range clientEvents {
   554  					switch ev.Kind {
   555  					case EventRestoredTransportParameters:
   556  						tp = ev.TransportParameters
   557  					case EventReceivedReadKeys:
   558  						clientReceived0RTTKeys = true
   559  					}
   560  				}
   561  				Expect(clientReceived0RTTKeys).To(BeTrue())
   562  				Expect(tp).ToNot(BeNil())
   563  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   564  
   565  				var serverReceived0RTTKeys bool
   566  				for _, ev := range serverEvents {
   567  					switch ev.Kind {
   568  					case EventReceivedReadKeys:
   569  						serverReceived0RTTKeys = true
   570  					}
   571  				}
   572  				Expect(serverReceived0RTTKeys).To(BeTrue())
   573  
   574  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   575  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   576  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   577  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   578  			})
   579  
   580  			It("rejects 0-RTT, when the transport parameters changed", func() {
   581  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   582  				var state *tls.ClientSessionState
   583  				receivedSessionTicket := make(chan struct{})
   584  				csc.EXPECT().Get(gomock.Any())
   585  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   586  					state = css
   587  					close(receivedSessionTicket)
   588  				})
   589  				clientConf.ClientSessionCache = csc
   590  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   591  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   592  				const initialMaxData protocol.ByteCount = 1337
   593  				client, _, clientErr, server, _, serverErr := handshakeWithTLSConf(
   594  					clientConf, serverConf,
   595  					clientOrigRTTStats, &utils.RTTStats{},
   596  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   597  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   598  					true,
   599  				)
   600  				Expect(clientErr).ToNot(HaveOccurred())
   601  				Expect(serverErr).ToNot(HaveOccurred())
   602  				Eventually(receivedSessionTicket).Should(BeClosed())
   603  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   604  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   605  
   606  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   607  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   608  
   609  				clientRTTStats := &utils.RTTStats{}
   610  				client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf(
   611  					clientConf, serverConf,
   612  					clientRTTStats, &utils.RTTStats{},
   613  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   614  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   615  					true,
   616  				)
   617  				Expect(clientErr).ToNot(HaveOccurred())
   618  				Expect(serverErr).ToNot(HaveOccurred())
   619  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   620  
   621  				var tp *wire.TransportParameters
   622  				var clientReceived0RTTKeys bool
   623  				for _, ev := range clientEvents {
   624  					switch ev.Kind {
   625  					case EventRestoredTransportParameters:
   626  						tp = ev.TransportParameters
   627  					case EventReceivedReadKeys:
   628  						clientReceived0RTTKeys = true
   629  					}
   630  				}
   631  				Expect(clientReceived0RTTKeys).To(BeTrue())
   632  				Expect(tp).ToNot(BeNil())
   633  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   634  
   635  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   636  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   637  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   638  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   639  			})
   640  		})
   641  	})
   642  })