github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/internal/handshake/crypto_setup_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"crypto/x509/pkix"
     9  	"math/big"
    10  	"time"
    11  
    12  	mocktls "github.com/mikelsr/quic-go/internal/mocks/tls"
    13  	"github.com/mikelsr/quic-go/internal/protocol"
    14  	"github.com/mikelsr/quic-go/internal/qerr"
    15  	"github.com/mikelsr/quic-go/internal/testdata"
    16  	"github.com/mikelsr/quic-go/internal/utils"
    17  	"github.com/mikelsr/quic-go/internal/wire"
    18  
    19  	"github.com/golang/mock/gomock"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  )
    24  
    25  const (
    26  	typeClientHello      = 1
    27  	typeNewSessionTicket = 4
    28  )
    29  
    30  type chunk struct {
    31  	data     []byte
    32  	encLevel protocol.EncryptionLevel
    33  }
    34  
    35  type stream struct {
    36  	encLevel  protocol.EncryptionLevel
    37  	chunkChan chan<- chunk
    38  }
    39  
    40  func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream {
    41  	return &stream{
    42  		chunkChan: chunkChan,
    43  		encLevel:  encLevel,
    44  	}
    45  }
    46  
    47  func (s *stream) Write(b []byte) (int, error) {
    48  	data := make([]byte, len(b))
    49  	copy(data, b)
    50  	select {
    51  	case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}:
    52  	default:
    53  		panic("chunkChan too small")
    54  	}
    55  	return len(b), nil
    56  }
    57  
    58  var _ = Describe("Crypto Setup TLS", func() {
    59  	var clientConf, serverConf *tls.Config
    60  
    61  	// unparam incorrectly complains that the first argument is never used.
    62  	//nolint:unparam
    63  	initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) {
    64  		chunkChan := make(chan chunk, 100)
    65  		initialStream := newStream(chunkChan, protocol.EncryptionInitial)
    66  		handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake)
    67  		return chunkChan, initialStream, handshakeStream
    68  	}
    69  
    70  	BeforeEach(func() {
    71  		serverConf = testdata.GetTLSConfig()
    72  		serverConf.NextProtos = []string{"crypto-setup"}
    73  		clientConf = &tls.Config{
    74  			ServerName: "localhost",
    75  			RootCAs:    testdata.GetRootCA(),
    76  			NextProtos: []string{"crypto-setup"},
    77  		}
    78  	})
    79  
    80  	It("handles qtls errors occurring before during ClientHello generation", func() {
    81  		_, sInitialStream, sHandshakeStream := initStreams()
    82  		tlsConf := testdata.GetTLSConfig()
    83  		tlsConf.InsecureSkipVerify = true
    84  		tlsConf.NextProtos = []string{""}
    85  		cl, _ := NewCryptoSetupClient(
    86  			sInitialStream,
    87  			sHandshakeStream,
    88  			nil,
    89  			protocol.ConnectionID{},
    90  			&wire.TransportParameters{},
    91  			NewMockHandshakeRunner(mockCtrl),
    92  			tlsConf,
    93  			false,
    94  			&utils.RTTStats{},
    95  			nil,
    96  			utils.DefaultLogger.WithPrefix("client"),
    97  			protocol.Version1,
    98  		)
    99  
   100  		Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{
   101  			ErrorCode:    qerr.InternalError,
   102  			ErrorMessage: "tls: invalid NextProtos value",
   103  		}))
   104  	})
   105  
   106  	It("errors when a message is received at the wrong encryption level", func() {
   107  		_, sInitialStream, sHandshakeStream := initStreams()
   108  		runner := NewMockHandshakeRunner(mockCtrl)
   109  		var token protocol.StatelessResetToken
   110  		server := NewCryptoSetupServer(
   111  			sInitialStream,
   112  			sHandshakeStream,
   113  			nil,
   114  			protocol.ConnectionID{},
   115  			&wire.TransportParameters{StatelessResetToken: &token},
   116  			runner,
   117  			testdata.GetTLSConfig(),
   118  			false,
   119  			&utils.RTTStats{},
   120  			nil,
   121  			utils.DefaultLogger.WithPrefix("server"),
   122  			protocol.Version1,
   123  		)
   124  
   125  		Expect(server.StartHandshake()).To(Succeed())
   126  
   127  		fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...)
   128  		// wrong encryption level
   129  		err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake)
   130  		Expect(err).To(HaveOccurred())
   131  		Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   132  	})
   133  
   134  	Context("doing the handshake", func() {
   135  		generateCert := func() tls.Certificate {
   136  			priv, err := rsa.GenerateKey(rand.Reader, 2048)
   137  			Expect(err).ToNot(HaveOccurred())
   138  			tmpl := &x509.Certificate{
   139  				SerialNumber:          big.NewInt(1),
   140  				Subject:               pkix.Name{},
   141  				SignatureAlgorithm:    x509.SHA256WithRSA,
   142  				NotBefore:             time.Now(),
   143  				NotAfter:              time.Now().Add(time.Hour), // valid for an hour
   144  				BasicConstraintsValid: true,
   145  			}
   146  			certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
   147  			Expect(err).ToNot(HaveOccurred())
   148  			return tls.Certificate{
   149  				PrivateKey:  priv,
   150  				Certificate: [][]byte{certDER},
   151  			}
   152  		}
   153  
   154  		newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats {
   155  			rttStats := &utils.RTTStats{}
   156  			rttStats.UpdateRTT(rtt, 0, time.Now())
   157  			ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt))
   158  			return rttStats
   159  		}
   160  
   161  		handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) {
   162  			Expect(client.StartHandshake()).To(Succeed())
   163  			Expect(server.StartHandshake()).To(Succeed())
   164  
   165  			for {
   166  				select {
   167  				case c := <-cChunkChan:
   168  					Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed())
   169  					continue
   170  				default:
   171  				}
   172  				select {
   173  				case c := <-sChunkChan:
   174  					Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed())
   175  					continue
   176  				default:
   177  				}
   178  				// no more messages to send from client and server. Handshake complete?
   179  				break
   180  			}
   181  
   182  			ticket, err := server.GetSessionTicket()
   183  			Expect(err).ToNot(HaveOccurred())
   184  			if ticket != nil {
   185  				Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed())
   186  			}
   187  		}
   188  
   189  		handshakeWithTLSConf := func(
   190  			clientConf, serverConf *tls.Config,
   191  			clientRTTStats, serverRTTStats *utils.RTTStats,
   192  			clientTransportParameters, serverTransportParameters *wire.TransportParameters,
   193  			enable0RTT bool,
   194  		) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) {
   195  			var cHandshakeComplete bool
   196  			cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   197  			cErrChan := make(chan error, 1)
   198  			cRunner := NewMockHandshakeRunner(mockCtrl)
   199  			cRunner.EXPECT().OnReceivedParams(gomock.Any())
   200  			cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
   201  			cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1)
   202  			cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1)
   203  			client, clientHelloWrittenChan := NewCryptoSetupClient(
   204  				cInitialStream,
   205  				cHandshakeStream,
   206  				nil,
   207  				protocol.ConnectionID{},
   208  				clientTransportParameters,
   209  				cRunner,
   210  				clientConf,
   211  				enable0RTT,
   212  				clientRTTStats,
   213  				nil,
   214  				utils.DefaultLogger.WithPrefix("client"),
   215  				protocol.Version1,
   216  			)
   217  
   218  			var sHandshakeComplete bool
   219  			sChunkChan, sInitialStream, sHandshakeStream := initStreams()
   220  			sErrChan := make(chan error, 1)
   221  			sRunner := NewMockHandshakeRunner(mockCtrl)
   222  			sRunner.EXPECT().OnReceivedParams(gomock.Any())
   223  			sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise
   224  			sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1)
   225  			if serverTransportParameters.StatelessResetToken == nil {
   226  				var token protocol.StatelessResetToken
   227  				serverTransportParameters.StatelessResetToken = &token
   228  			}
   229  			server := NewCryptoSetupServer(
   230  				sInitialStream,
   231  				sHandshakeStream,
   232  				nil,
   233  				protocol.ConnectionID{},
   234  				serverTransportParameters,
   235  				sRunner,
   236  				serverConf,
   237  				enable0RTT,
   238  				serverRTTStats,
   239  				nil,
   240  				utils.DefaultLogger.WithPrefix("server"),
   241  				protocol.Version1,
   242  			)
   243  
   244  			handshake(client, cChunkChan, server, sChunkChan)
   245  			var cErr, sErr error
   246  			select {
   247  			case sErr = <-sErrChan:
   248  			default:
   249  				Expect(sHandshakeComplete).To(BeTrue())
   250  			}
   251  			select {
   252  			case cErr = <-cErrChan:
   253  			default:
   254  				Expect(cHandshakeComplete).To(BeTrue())
   255  			}
   256  			return clientHelloWrittenChan, client, cErr, server, sErr
   257  		}
   258  
   259  		It("handshakes", func() {
   260  			_, _, clientErr, _, serverErr := handshakeWithTLSConf(
   261  				clientConf, serverConf,
   262  				&utils.RTTStats{}, &utils.RTTStats{},
   263  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   264  				false,
   265  			)
   266  			Expect(clientErr).ToNot(HaveOccurred())
   267  			Expect(serverErr).ToNot(HaveOccurred())
   268  		})
   269  
   270  		It("performs a HelloRetryRequst", func() {
   271  			serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384}
   272  			_, _, clientErr, _, serverErr := handshakeWithTLSConf(
   273  				clientConf, serverConf,
   274  				&utils.RTTStats{}, &utils.RTTStats{},
   275  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   276  				false,
   277  			)
   278  			Expect(clientErr).ToNot(HaveOccurred())
   279  			Expect(serverErr).ToNot(HaveOccurred())
   280  		})
   281  
   282  		It("handshakes with client auth", func() {
   283  			clientConf.Certificates = []tls.Certificate{generateCert()}
   284  			serverConf.ClientAuth = tls.RequireAnyClientCert
   285  			_, _, clientErr, _, serverErr := handshakeWithTLSConf(
   286  				clientConf, serverConf,
   287  				&utils.RTTStats{}, &utils.RTTStats{},
   288  				&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   289  				false,
   290  			)
   291  			Expect(clientErr).ToNot(HaveOccurred())
   292  			Expect(serverErr).ToNot(HaveOccurred())
   293  		})
   294  
   295  		It("signals when it has written the ClientHello", func() {
   296  			runner := NewMockHandshakeRunner(mockCtrl)
   297  			cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   298  			client, chChan := NewCryptoSetupClient(
   299  				cInitialStream,
   300  				cHandshakeStream,
   301  				nil,
   302  				protocol.ConnectionID{},
   303  				&wire.TransportParameters{},
   304  				runner,
   305  				&tls.Config{InsecureSkipVerify: true},
   306  				false,
   307  				&utils.RTTStats{},
   308  				nil,
   309  				utils.DefaultLogger.WithPrefix("client"),
   310  				protocol.Version1,
   311  			)
   312  
   313  			Expect(client.StartHandshake()).To(Succeed())
   314  			var ch chunk
   315  			Eventually(cChunkChan).Should(Receive(&ch))
   316  			Eventually(chChan).Should(Receive(BeNil()))
   317  			// make sure the whole ClientHello was written
   318  			Expect(len(ch.data)).To(BeNumerically(">=", 4))
   319  			Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello))
   320  			length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3])
   321  			Expect(len(ch.data) - 4).To(Equal(length))
   322  		})
   323  
   324  		It("receives transport parameters", func() {
   325  			var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters
   326  			cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   327  			cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second}
   328  			cRunner := NewMockHandshakeRunner(mockCtrl)
   329  			cRunner.EXPECT().OnReceivedReadKeys().Times(2)
   330  			cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp })
   331  			cRunner.EXPECT().OnHandshakeComplete()
   332  			client, _ := NewCryptoSetupClient(
   333  				cInitialStream,
   334  				cHandshakeStream,
   335  				nil,
   336  				protocol.ConnectionID{},
   337  				cTransportParameters,
   338  				cRunner,
   339  				clientConf,
   340  				false,
   341  				&utils.RTTStats{},
   342  				nil,
   343  				utils.DefaultLogger.WithPrefix("client"),
   344  				protocol.Version1,
   345  			)
   346  
   347  			sChunkChan, sInitialStream, sHandshakeStream := initStreams()
   348  			var token protocol.StatelessResetToken
   349  			sRunner := NewMockHandshakeRunner(mockCtrl)
   350  			sRunner.EXPECT().OnReceivedReadKeys().Times(2)
   351  			sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp })
   352  			sRunner.EXPECT().OnHandshakeComplete()
   353  			sTransportParameters := &wire.TransportParameters{
   354  				MaxIdleTimeout:          0x1337 * time.Second,
   355  				StatelessResetToken:     &token,
   356  				ActiveConnectionIDLimit: 2,
   357  			}
   358  			server := NewCryptoSetupServer(
   359  				sInitialStream,
   360  				sHandshakeStream,
   361  				nil,
   362  				protocol.ConnectionID{},
   363  				sTransportParameters,
   364  				sRunner,
   365  				serverConf,
   366  				false,
   367  				&utils.RTTStats{},
   368  				nil,
   369  				utils.DefaultLogger.WithPrefix("server"),
   370  				protocol.Version1,
   371  			)
   372  
   373  			done := make(chan struct{})
   374  			go func() {
   375  				defer GinkgoRecover()
   376  				handshake(client, cChunkChan, server, sChunkChan)
   377  				close(done)
   378  			}()
   379  			Eventually(done).Should(BeClosed())
   380  			Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout))
   381  			Expect(sTransportParametersRcvd).ToNot(BeNil())
   382  			Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout))
   383  		})
   384  
   385  		Context("with session tickets", func() {
   386  			It("errors when the NewSessionTicket is sent at the wrong encryption level", func() {
   387  				cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   388  				cRunner := NewMockHandshakeRunner(mockCtrl)
   389  				cRunner.EXPECT().OnReceivedParams(gomock.Any())
   390  				cRunner.EXPECT().OnReceivedReadKeys().Times(2)
   391  				cRunner.EXPECT().OnHandshakeComplete()
   392  				client, _ := NewCryptoSetupClient(
   393  					cInitialStream,
   394  					cHandshakeStream,
   395  					nil,
   396  					protocol.ConnectionID{},
   397  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   398  					cRunner,
   399  					clientConf,
   400  					false,
   401  					&utils.RTTStats{},
   402  					nil,
   403  					utils.DefaultLogger.WithPrefix("client"),
   404  					protocol.Version1,
   405  				)
   406  
   407  				sChunkChan, sInitialStream, sHandshakeStream := initStreams()
   408  				sRunner := NewMockHandshakeRunner(mockCtrl)
   409  				sRunner.EXPECT().OnReceivedParams(gomock.Any())
   410  				sRunner.EXPECT().OnReceivedReadKeys().Times(2)
   411  				sRunner.EXPECT().OnHandshakeComplete()
   412  				var token protocol.StatelessResetToken
   413  				server := NewCryptoSetupServer(
   414  					sInitialStream,
   415  					sHandshakeStream,
   416  					nil,
   417  					protocol.ConnectionID{},
   418  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
   419  					sRunner,
   420  					serverConf,
   421  					false,
   422  					&utils.RTTStats{},
   423  					nil,
   424  					utils.DefaultLogger.WithPrefix("server"),
   425  					protocol.Version1,
   426  				)
   427  
   428  				done := make(chan struct{})
   429  				go func() {
   430  					defer GinkgoRecover()
   431  					handshake(client, cChunkChan, server, sChunkChan)
   432  					close(done)
   433  				}()
   434  				Eventually(done).Should(BeClosed())
   435  
   436  				// inject an invalid session ticket
   437  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   438  				err := client.HandleMessage(b, protocol.EncryptionHandshake)
   439  				Expect(err).To(HaveOccurred())
   440  				Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level"))
   441  			})
   442  
   443  			It("errors when handling the NewSessionTicket fails", func() {
   444  				cChunkChan, cInitialStream, cHandshakeStream := initStreams()
   445  				cRunner := NewMockHandshakeRunner(mockCtrl)
   446  				cRunner.EXPECT().OnReceivedParams(gomock.Any())
   447  				cRunner.EXPECT().OnReceivedReadKeys().Times(2)
   448  				cRunner.EXPECT().OnHandshakeComplete()
   449  				client, _ := NewCryptoSetupClient(
   450  					cInitialStream,
   451  					cHandshakeStream,
   452  					nil,
   453  					protocol.ConnectionID{},
   454  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   455  					cRunner,
   456  					clientConf,
   457  					false,
   458  					&utils.RTTStats{},
   459  					nil,
   460  					utils.DefaultLogger.WithPrefix("client"),
   461  					protocol.Version1,
   462  				)
   463  
   464  				sChunkChan, sInitialStream, sHandshakeStream := initStreams()
   465  				sRunner := NewMockHandshakeRunner(mockCtrl)
   466  				sRunner.EXPECT().OnReceivedParams(gomock.Any())
   467  				sRunner.EXPECT().OnReceivedReadKeys().Times(2)
   468  				sRunner.EXPECT().OnHandshakeComplete()
   469  				var token protocol.StatelessResetToken
   470  				server := NewCryptoSetupServer(
   471  					sInitialStream,
   472  					sHandshakeStream,
   473  					nil,
   474  					protocol.ConnectionID{},
   475  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
   476  					sRunner,
   477  					serverConf,
   478  					false,
   479  					&utils.RTTStats{},
   480  					nil,
   481  					utils.DefaultLogger.WithPrefix("server"),
   482  					protocol.Version1,
   483  				)
   484  
   485  				done := make(chan struct{})
   486  				go func() {
   487  					defer GinkgoRecover()
   488  					handshake(client, cChunkChan, server, sChunkChan)
   489  					close(done)
   490  				}()
   491  				Eventually(done).Should(BeClosed())
   492  
   493  				// inject an invalid session ticket
   494  				b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...)
   495  				err := client.HandleMessage(b, protocol.Encryption1RTT)
   496  				Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   497  				Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue())
   498  			})
   499  
   500  			It("uses session resumption", func() {
   501  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   502  				var state *tls.ClientSessionState
   503  				receivedSessionTicket := make(chan struct{})
   504  				csc.EXPECT().Get(gomock.Any())
   505  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   506  					state = css
   507  					close(receivedSessionTicket)
   508  				})
   509  				clientConf.ClientSessionCache = csc
   510  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   511  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   512  				clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
   513  					clientConf, serverConf,
   514  					clientOrigRTTStats, &utils.RTTStats{},
   515  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   516  					false,
   517  				)
   518  				Expect(clientErr).ToNot(HaveOccurred())
   519  				Expect(serverErr).ToNot(HaveOccurred())
   520  				Eventually(receivedSessionTicket).Should(BeClosed())
   521  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   522  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   523  				Expect(clientHelloWrittenChan).To(Receive(BeNil()))
   524  
   525  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   526  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   527  				clientRTTStats := &utils.RTTStats{}
   528  				clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
   529  					clientConf, serverConf,
   530  					clientRTTStats, &utils.RTTStats{},
   531  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   532  					false,
   533  				)
   534  				Expect(clientErr).ToNot(HaveOccurred())
   535  				Expect(serverErr).ToNot(HaveOccurred())
   536  				Eventually(receivedSessionTicket).Should(BeClosed())
   537  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   538  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   539  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   540  				Expect(clientHelloWrittenChan).To(Receive(BeNil()))
   541  			})
   542  
   543  			It("doesn't use session resumption if the server disabled it", func() {
   544  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   545  				var state *tls.ClientSessionState
   546  				receivedSessionTicket := make(chan struct{})
   547  				csc.EXPECT().Get(gomock.Any())
   548  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   549  					state = css
   550  					close(receivedSessionTicket)
   551  				})
   552  				clientConf.ClientSessionCache = csc
   553  				_, client, clientErr, server, serverErr := handshakeWithTLSConf(
   554  					clientConf, serverConf,
   555  					&utils.RTTStats{}, &utils.RTTStats{},
   556  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   557  					false,
   558  				)
   559  				Expect(clientErr).ToNot(HaveOccurred())
   560  				Expect(serverErr).ToNot(HaveOccurred())
   561  				Eventually(receivedSessionTicket).Should(BeClosed())
   562  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   563  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   564  
   565  				serverConf.SessionTicketsDisabled = true
   566  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   567  				_, client, clientErr, server, serverErr = handshakeWithTLSConf(
   568  					clientConf, serverConf,
   569  					&utils.RTTStats{}, &utils.RTTStats{},
   570  					&wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2},
   571  					false,
   572  				)
   573  				Expect(clientErr).ToNot(HaveOccurred())
   574  				Expect(serverErr).ToNot(HaveOccurred())
   575  				Eventually(receivedSessionTicket).Should(BeClosed())
   576  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   577  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   578  			})
   579  
   580  			It("uses 0-RTT", func() {
   581  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   582  				var state *tls.ClientSessionState
   583  				receivedSessionTicket := make(chan struct{})
   584  				csc.EXPECT().Get(gomock.Any())
   585  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   586  					state = css
   587  					close(receivedSessionTicket)
   588  				})
   589  				clientConf.ClientSessionCache = csc
   590  				const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored.
   591  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   592  				serverOrigRTTStats := newRTTStatsWithRTT(serverRTT)
   593  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   594  				const initialMaxData protocol.ByteCount = 1337
   595  				clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
   596  					clientConf, serverConf,
   597  					clientOrigRTTStats, serverOrigRTTStats,
   598  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   599  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   600  					true,
   601  				)
   602  				Expect(clientErr).ToNot(HaveOccurred())
   603  				Expect(serverErr).ToNot(HaveOccurred())
   604  				Eventually(receivedSessionTicket).Should(BeClosed())
   605  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   606  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   607  				Expect(clientHelloWrittenChan).To(Receive(BeNil()))
   608  
   609  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   610  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   611  
   612  				clientRTTStats := &utils.RTTStats{}
   613  				serverRTTStats := &utils.RTTStats{}
   614  				clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
   615  					clientConf, serverConf,
   616  					clientRTTStats, serverRTTStats,
   617  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   618  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   619  					true,
   620  				)
   621  				Expect(clientErr).ToNot(HaveOccurred())
   622  				Expect(serverErr).ToNot(HaveOccurred())
   623  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   624  				Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT))
   625  
   626  				var tp *wire.TransportParameters
   627  				Expect(clientHelloWrittenChan).To(Receive(&tp))
   628  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   629  
   630  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   631  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   632  				Expect(server.ConnectionState().Used0RTT).To(BeTrue())
   633  				Expect(client.ConnectionState().Used0RTT).To(BeTrue())
   634  			})
   635  
   636  			It("rejects 0-RTT, when the transport parameters changed", func() {
   637  				csc := mocktls.NewMockClientSessionCache(mockCtrl)
   638  				var state *tls.ClientSessionState
   639  				receivedSessionTicket := make(chan struct{})
   640  				csc.EXPECT().Get(gomock.Any())
   641  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) {
   642  					state = css
   643  					close(receivedSessionTicket)
   644  				})
   645  				clientConf.ClientSessionCache = csc
   646  				const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored.
   647  				clientOrigRTTStats := newRTTStatsWithRTT(clientRTT)
   648  				const initialMaxData protocol.ByteCount = 1337
   649  				clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf(
   650  					clientConf, serverConf,
   651  					clientOrigRTTStats, &utils.RTTStats{},
   652  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   653  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData},
   654  					true,
   655  				)
   656  				Expect(clientErr).ToNot(HaveOccurred())
   657  				Expect(serverErr).ToNot(HaveOccurred())
   658  				Eventually(receivedSessionTicket).Should(BeClosed())
   659  				Expect(server.ConnectionState().DidResume).To(BeFalse())
   660  				Expect(client.ConnectionState().DidResume).To(BeFalse())
   661  				Expect(clientHelloWrittenChan).To(Receive(BeNil()))
   662  
   663  				csc.EXPECT().Get(gomock.Any()).Return(state, true)
   664  				csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1)
   665  
   666  				clientRTTStats := &utils.RTTStats{}
   667  				clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf(
   668  					clientConf, serverConf,
   669  					clientRTTStats, &utils.RTTStats{},
   670  					&wire.TransportParameters{ActiveConnectionIDLimit: 2},
   671  					&wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1},
   672  					true,
   673  				)
   674  				Expect(clientErr).ToNot(HaveOccurred())
   675  				Expect(serverErr).ToNot(HaveOccurred())
   676  				Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT))
   677  
   678  				var tp *wire.TransportParameters
   679  				Expect(clientHelloWrittenChan).To(Receive(&tp))
   680  				Expect(tp.InitialMaxData).To(Equal(initialMaxData))
   681  
   682  				Expect(server.ConnectionState().DidResume).To(BeTrue())
   683  				Expect(client.ConnectionState().DidResume).To(BeTrue())
   684  				Expect(server.ConnectionState().Used0RTT).To(BeFalse())
   685  				Expect(client.ConnectionState().Used0RTT).To(BeFalse())
   686  			})
   687  		})
   688  	})
   689  })