github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/internal/handshake/crypto_setup_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "crypto/x509/pkix" 9 "math/big" 10 "time" 11 12 mocktls "github.com/mikelsr/quic-go/internal/mocks/tls" 13 "github.com/mikelsr/quic-go/internal/protocol" 14 "github.com/mikelsr/quic-go/internal/qerr" 15 "github.com/mikelsr/quic-go/internal/testdata" 16 "github.com/mikelsr/quic-go/internal/utils" 17 "github.com/mikelsr/quic-go/internal/wire" 18 19 "github.com/golang/mock/gomock" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 ) 24 25 const ( 26 typeClientHello = 1 27 typeNewSessionTicket = 4 28 ) 29 30 type chunk struct { 31 data []byte 32 encLevel protocol.EncryptionLevel 33 } 34 35 type stream struct { 36 encLevel protocol.EncryptionLevel 37 chunkChan chan<- chunk 38 } 39 40 func newStream(chunkChan chan<- chunk, encLevel protocol.EncryptionLevel) *stream { 41 return &stream{ 42 chunkChan: chunkChan, 43 encLevel: encLevel, 44 } 45 } 46 47 func (s *stream) Write(b []byte) (int, error) { 48 data := make([]byte, len(b)) 49 copy(data, b) 50 select { 51 case s.chunkChan <- chunk{data: data, encLevel: s.encLevel}: 52 default: 53 panic("chunkChan too small") 54 } 55 return len(b), nil 56 } 57 58 var _ = Describe("Crypto Setup TLS", func() { 59 var clientConf, serverConf *tls.Config 60 61 // unparam incorrectly complains that the first argument is never used. 62 //nolint:unparam 63 initStreams := func() (chan chunk, *stream /* initial */, *stream /* handshake */) { 64 chunkChan := make(chan chunk, 100) 65 initialStream := newStream(chunkChan, protocol.EncryptionInitial) 66 handshakeStream := newStream(chunkChan, protocol.EncryptionHandshake) 67 return chunkChan, initialStream, handshakeStream 68 } 69 70 BeforeEach(func() { 71 serverConf = testdata.GetTLSConfig() 72 serverConf.NextProtos = []string{"crypto-setup"} 73 clientConf = &tls.Config{ 74 ServerName: "localhost", 75 RootCAs: testdata.GetRootCA(), 76 NextProtos: []string{"crypto-setup"}, 77 } 78 }) 79 80 It("handles qtls errors occurring before during ClientHello generation", func() { 81 _, sInitialStream, sHandshakeStream := initStreams() 82 tlsConf := testdata.GetTLSConfig() 83 tlsConf.InsecureSkipVerify = true 84 tlsConf.NextProtos = []string{""} 85 cl, _ := NewCryptoSetupClient( 86 sInitialStream, 87 sHandshakeStream, 88 nil, 89 protocol.ConnectionID{}, 90 &wire.TransportParameters{}, 91 NewMockHandshakeRunner(mockCtrl), 92 tlsConf, 93 false, 94 &utils.RTTStats{}, 95 nil, 96 utils.DefaultLogger.WithPrefix("client"), 97 protocol.Version1, 98 ) 99 100 Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ 101 ErrorCode: qerr.InternalError, 102 ErrorMessage: "tls: invalid NextProtos value", 103 })) 104 }) 105 106 It("errors when a message is received at the wrong encryption level", func() { 107 _, sInitialStream, sHandshakeStream := initStreams() 108 runner := NewMockHandshakeRunner(mockCtrl) 109 var token protocol.StatelessResetToken 110 server := NewCryptoSetupServer( 111 sInitialStream, 112 sHandshakeStream, 113 nil, 114 protocol.ConnectionID{}, 115 &wire.TransportParameters{StatelessResetToken: &token}, 116 runner, 117 testdata.GetTLSConfig(), 118 false, 119 &utils.RTTStats{}, 120 nil, 121 utils.DefaultLogger.WithPrefix("server"), 122 protocol.Version1, 123 ) 124 125 Expect(server.StartHandshake()).To(Succeed()) 126 127 fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) 128 // wrong encryption level 129 err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) 130 Expect(err).To(HaveOccurred()) 131 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 132 }) 133 134 Context("doing the handshake", func() { 135 generateCert := func() tls.Certificate { 136 priv, err := rsa.GenerateKey(rand.Reader, 2048) 137 Expect(err).ToNot(HaveOccurred()) 138 tmpl := &x509.Certificate{ 139 SerialNumber: big.NewInt(1), 140 Subject: pkix.Name{}, 141 SignatureAlgorithm: x509.SHA256WithRSA, 142 NotBefore: time.Now(), 143 NotAfter: time.Now().Add(time.Hour), // valid for an hour 144 BasicConstraintsValid: true, 145 } 146 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) 147 Expect(err).ToNot(HaveOccurred()) 148 return tls.Certificate{ 149 PrivateKey: priv, 150 Certificate: [][]byte{certDER}, 151 } 152 } 153 154 newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { 155 rttStats := &utils.RTTStats{} 156 rttStats.UpdateRTT(rtt, 0, time.Now()) 157 ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) 158 return rttStats 159 } 160 161 handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) { 162 Expect(client.StartHandshake()).To(Succeed()) 163 Expect(server.StartHandshake()).To(Succeed()) 164 165 for { 166 select { 167 case c := <-cChunkChan: 168 Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed()) 169 continue 170 default: 171 } 172 select { 173 case c := <-sChunkChan: 174 Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed()) 175 continue 176 default: 177 } 178 // no more messages to send from client and server. Handshake complete? 179 break 180 } 181 182 ticket, err := server.GetSessionTicket() 183 Expect(err).ToNot(HaveOccurred()) 184 if ticket != nil { 185 Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) 186 } 187 } 188 189 handshakeWithTLSConf := func( 190 clientConf, serverConf *tls.Config, 191 clientRTTStats, serverRTTStats *utils.RTTStats, 192 clientTransportParameters, serverTransportParameters *wire.TransportParameters, 193 enable0RTT bool, 194 ) (<-chan *wire.TransportParameters /* clientHelloWrittenChan */, CryptoSetup /* client */, error /* client error */, CryptoSetup /* server */, error /* server error */) { 195 var cHandshakeComplete bool 196 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 197 cErrChan := make(chan error, 1) 198 cRunner := NewMockHandshakeRunner(mockCtrl) 199 cRunner.EXPECT().OnReceivedParams(gomock.Any()) 200 cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise 201 cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) 202 cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) 203 client, clientHelloWrittenChan := NewCryptoSetupClient( 204 cInitialStream, 205 cHandshakeStream, 206 nil, 207 protocol.ConnectionID{}, 208 clientTransportParameters, 209 cRunner, 210 clientConf, 211 enable0RTT, 212 clientRTTStats, 213 nil, 214 utils.DefaultLogger.WithPrefix("client"), 215 protocol.Version1, 216 ) 217 218 var sHandshakeComplete bool 219 sChunkChan, sInitialStream, sHandshakeStream := initStreams() 220 sErrChan := make(chan error, 1) 221 sRunner := NewMockHandshakeRunner(mockCtrl) 222 sRunner.EXPECT().OnReceivedParams(gomock.Any()) 223 sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise 224 sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) 225 if serverTransportParameters.StatelessResetToken == nil { 226 var token protocol.StatelessResetToken 227 serverTransportParameters.StatelessResetToken = &token 228 } 229 server := NewCryptoSetupServer( 230 sInitialStream, 231 sHandshakeStream, 232 nil, 233 protocol.ConnectionID{}, 234 serverTransportParameters, 235 sRunner, 236 serverConf, 237 enable0RTT, 238 serverRTTStats, 239 nil, 240 utils.DefaultLogger.WithPrefix("server"), 241 protocol.Version1, 242 ) 243 244 handshake(client, cChunkChan, server, sChunkChan) 245 var cErr, sErr error 246 select { 247 case sErr = <-sErrChan: 248 default: 249 Expect(sHandshakeComplete).To(BeTrue()) 250 } 251 select { 252 case cErr = <-cErrChan: 253 default: 254 Expect(cHandshakeComplete).To(BeTrue()) 255 } 256 return clientHelloWrittenChan, client, cErr, server, sErr 257 } 258 259 It("handshakes", func() { 260 _, _, clientErr, _, serverErr := handshakeWithTLSConf( 261 clientConf, serverConf, 262 &utils.RTTStats{}, &utils.RTTStats{}, 263 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 264 false, 265 ) 266 Expect(clientErr).ToNot(HaveOccurred()) 267 Expect(serverErr).ToNot(HaveOccurred()) 268 }) 269 270 It("performs a HelloRetryRequst", func() { 271 serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} 272 _, _, clientErr, _, serverErr := handshakeWithTLSConf( 273 clientConf, serverConf, 274 &utils.RTTStats{}, &utils.RTTStats{}, 275 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 276 false, 277 ) 278 Expect(clientErr).ToNot(HaveOccurred()) 279 Expect(serverErr).ToNot(HaveOccurred()) 280 }) 281 282 It("handshakes with client auth", func() { 283 clientConf.Certificates = []tls.Certificate{generateCert()} 284 serverConf.ClientAuth = tls.RequireAnyClientCert 285 _, _, clientErr, _, serverErr := handshakeWithTLSConf( 286 clientConf, serverConf, 287 &utils.RTTStats{}, &utils.RTTStats{}, 288 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 289 false, 290 ) 291 Expect(clientErr).ToNot(HaveOccurred()) 292 Expect(serverErr).ToNot(HaveOccurred()) 293 }) 294 295 It("signals when it has written the ClientHello", func() { 296 runner := NewMockHandshakeRunner(mockCtrl) 297 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 298 client, chChan := NewCryptoSetupClient( 299 cInitialStream, 300 cHandshakeStream, 301 nil, 302 protocol.ConnectionID{}, 303 &wire.TransportParameters{}, 304 runner, 305 &tls.Config{InsecureSkipVerify: true}, 306 false, 307 &utils.RTTStats{}, 308 nil, 309 utils.DefaultLogger.WithPrefix("client"), 310 protocol.Version1, 311 ) 312 313 Expect(client.StartHandshake()).To(Succeed()) 314 var ch chunk 315 Eventually(cChunkChan).Should(Receive(&ch)) 316 Eventually(chChan).Should(Receive(BeNil())) 317 // make sure the whole ClientHello was written 318 Expect(len(ch.data)).To(BeNumerically(">=", 4)) 319 Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello)) 320 length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) 321 Expect(len(ch.data) - 4).To(Equal(length)) 322 }) 323 324 It("receives transport parameters", func() { 325 var cTransportParametersRcvd, sTransportParametersRcvd *wire.TransportParameters 326 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 327 cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second} 328 cRunner := NewMockHandshakeRunner(mockCtrl) 329 cRunner.EXPECT().OnReceivedReadKeys().Times(2) 330 cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) 331 cRunner.EXPECT().OnHandshakeComplete() 332 client, _ := NewCryptoSetupClient( 333 cInitialStream, 334 cHandshakeStream, 335 nil, 336 protocol.ConnectionID{}, 337 cTransportParameters, 338 cRunner, 339 clientConf, 340 false, 341 &utils.RTTStats{}, 342 nil, 343 utils.DefaultLogger.WithPrefix("client"), 344 protocol.Version1, 345 ) 346 347 sChunkChan, sInitialStream, sHandshakeStream := initStreams() 348 var token protocol.StatelessResetToken 349 sRunner := NewMockHandshakeRunner(mockCtrl) 350 sRunner.EXPECT().OnReceivedReadKeys().Times(2) 351 sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) 352 sRunner.EXPECT().OnHandshakeComplete() 353 sTransportParameters := &wire.TransportParameters{ 354 MaxIdleTimeout: 0x1337 * time.Second, 355 StatelessResetToken: &token, 356 ActiveConnectionIDLimit: 2, 357 } 358 server := NewCryptoSetupServer( 359 sInitialStream, 360 sHandshakeStream, 361 nil, 362 protocol.ConnectionID{}, 363 sTransportParameters, 364 sRunner, 365 serverConf, 366 false, 367 &utils.RTTStats{}, 368 nil, 369 utils.DefaultLogger.WithPrefix("server"), 370 protocol.Version1, 371 ) 372 373 done := make(chan struct{}) 374 go func() { 375 defer GinkgoRecover() 376 handshake(client, cChunkChan, server, sChunkChan) 377 close(done) 378 }() 379 Eventually(done).Should(BeClosed()) 380 Expect(cTransportParametersRcvd.MaxIdleTimeout).To(Equal(cTransportParameters.MaxIdleTimeout)) 381 Expect(sTransportParametersRcvd).ToNot(BeNil()) 382 Expect(sTransportParametersRcvd.MaxIdleTimeout).To(Equal(sTransportParameters.MaxIdleTimeout)) 383 }) 384 385 Context("with session tickets", func() { 386 It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { 387 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 388 cRunner := NewMockHandshakeRunner(mockCtrl) 389 cRunner.EXPECT().OnReceivedParams(gomock.Any()) 390 cRunner.EXPECT().OnReceivedReadKeys().Times(2) 391 cRunner.EXPECT().OnHandshakeComplete() 392 client, _ := NewCryptoSetupClient( 393 cInitialStream, 394 cHandshakeStream, 395 nil, 396 protocol.ConnectionID{}, 397 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 398 cRunner, 399 clientConf, 400 false, 401 &utils.RTTStats{}, 402 nil, 403 utils.DefaultLogger.WithPrefix("client"), 404 protocol.Version1, 405 ) 406 407 sChunkChan, sInitialStream, sHandshakeStream := initStreams() 408 sRunner := NewMockHandshakeRunner(mockCtrl) 409 sRunner.EXPECT().OnReceivedParams(gomock.Any()) 410 sRunner.EXPECT().OnReceivedReadKeys().Times(2) 411 sRunner.EXPECT().OnHandshakeComplete() 412 var token protocol.StatelessResetToken 413 server := NewCryptoSetupServer( 414 sInitialStream, 415 sHandshakeStream, 416 nil, 417 protocol.ConnectionID{}, 418 &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, 419 sRunner, 420 serverConf, 421 false, 422 &utils.RTTStats{}, 423 nil, 424 utils.DefaultLogger.WithPrefix("server"), 425 protocol.Version1, 426 ) 427 428 done := make(chan struct{}) 429 go func() { 430 defer GinkgoRecover() 431 handshake(client, cChunkChan, server, sChunkChan) 432 close(done) 433 }() 434 Eventually(done).Should(BeClosed()) 435 436 // inject an invalid session ticket 437 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 438 err := client.HandleMessage(b, protocol.EncryptionHandshake) 439 Expect(err).To(HaveOccurred()) 440 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 441 }) 442 443 It("errors when handling the NewSessionTicket fails", func() { 444 cChunkChan, cInitialStream, cHandshakeStream := initStreams() 445 cRunner := NewMockHandshakeRunner(mockCtrl) 446 cRunner.EXPECT().OnReceivedParams(gomock.Any()) 447 cRunner.EXPECT().OnReceivedReadKeys().Times(2) 448 cRunner.EXPECT().OnHandshakeComplete() 449 client, _ := NewCryptoSetupClient( 450 cInitialStream, 451 cHandshakeStream, 452 nil, 453 protocol.ConnectionID{}, 454 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 455 cRunner, 456 clientConf, 457 false, 458 &utils.RTTStats{}, 459 nil, 460 utils.DefaultLogger.WithPrefix("client"), 461 protocol.Version1, 462 ) 463 464 sChunkChan, sInitialStream, sHandshakeStream := initStreams() 465 sRunner := NewMockHandshakeRunner(mockCtrl) 466 sRunner.EXPECT().OnReceivedParams(gomock.Any()) 467 sRunner.EXPECT().OnReceivedReadKeys().Times(2) 468 sRunner.EXPECT().OnHandshakeComplete() 469 var token protocol.StatelessResetToken 470 server := NewCryptoSetupServer( 471 sInitialStream, 472 sHandshakeStream, 473 nil, 474 protocol.ConnectionID{}, 475 &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, 476 sRunner, 477 serverConf, 478 false, 479 &utils.RTTStats{}, 480 nil, 481 utils.DefaultLogger.WithPrefix("server"), 482 protocol.Version1, 483 ) 484 485 done := make(chan struct{}) 486 go func() { 487 defer GinkgoRecover() 488 handshake(client, cChunkChan, server, sChunkChan) 489 close(done) 490 }() 491 Eventually(done).Should(BeClosed()) 492 493 // inject an invalid session ticket 494 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 495 err := client.HandleMessage(b, protocol.Encryption1RTT) 496 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 497 Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) 498 }) 499 500 It("uses session resumption", func() { 501 csc := mocktls.NewMockClientSessionCache(mockCtrl) 502 var state *tls.ClientSessionState 503 receivedSessionTicket := make(chan struct{}) 504 csc.EXPECT().Get(gomock.Any()) 505 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 506 state = css 507 close(receivedSessionTicket) 508 }) 509 clientConf.ClientSessionCache = csc 510 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 511 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 512 clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( 513 clientConf, serverConf, 514 clientOrigRTTStats, &utils.RTTStats{}, 515 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 516 false, 517 ) 518 Expect(clientErr).ToNot(HaveOccurred()) 519 Expect(serverErr).ToNot(HaveOccurred()) 520 Eventually(receivedSessionTicket).Should(BeClosed()) 521 Expect(server.ConnectionState().DidResume).To(BeFalse()) 522 Expect(client.ConnectionState().DidResume).To(BeFalse()) 523 Expect(clientHelloWrittenChan).To(Receive(BeNil())) 524 525 csc.EXPECT().Get(gomock.Any()).Return(state, true) 526 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 527 clientRTTStats := &utils.RTTStats{} 528 clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( 529 clientConf, serverConf, 530 clientRTTStats, &utils.RTTStats{}, 531 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 532 false, 533 ) 534 Expect(clientErr).ToNot(HaveOccurred()) 535 Expect(serverErr).ToNot(HaveOccurred()) 536 Eventually(receivedSessionTicket).Should(BeClosed()) 537 Expect(server.ConnectionState().DidResume).To(BeTrue()) 538 Expect(client.ConnectionState().DidResume).To(BeTrue()) 539 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 540 Expect(clientHelloWrittenChan).To(Receive(BeNil())) 541 }) 542 543 It("doesn't use session resumption if the server disabled it", func() { 544 csc := mocktls.NewMockClientSessionCache(mockCtrl) 545 var state *tls.ClientSessionState 546 receivedSessionTicket := make(chan struct{}) 547 csc.EXPECT().Get(gomock.Any()) 548 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 549 state = css 550 close(receivedSessionTicket) 551 }) 552 clientConf.ClientSessionCache = csc 553 _, client, clientErr, server, serverErr := handshakeWithTLSConf( 554 clientConf, serverConf, 555 &utils.RTTStats{}, &utils.RTTStats{}, 556 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 557 false, 558 ) 559 Expect(clientErr).ToNot(HaveOccurred()) 560 Expect(serverErr).ToNot(HaveOccurred()) 561 Eventually(receivedSessionTicket).Should(BeClosed()) 562 Expect(server.ConnectionState().DidResume).To(BeFalse()) 563 Expect(client.ConnectionState().DidResume).To(BeFalse()) 564 565 serverConf.SessionTicketsDisabled = true 566 csc.EXPECT().Get(gomock.Any()).Return(state, true) 567 _, client, clientErr, server, serverErr = handshakeWithTLSConf( 568 clientConf, serverConf, 569 &utils.RTTStats{}, &utils.RTTStats{}, 570 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 571 false, 572 ) 573 Expect(clientErr).ToNot(HaveOccurred()) 574 Expect(serverErr).ToNot(HaveOccurred()) 575 Eventually(receivedSessionTicket).Should(BeClosed()) 576 Expect(server.ConnectionState().DidResume).To(BeFalse()) 577 Expect(client.ConnectionState().DidResume).To(BeFalse()) 578 }) 579 580 It("uses 0-RTT", func() { 581 csc := mocktls.NewMockClientSessionCache(mockCtrl) 582 var state *tls.ClientSessionState 583 receivedSessionTicket := make(chan struct{}) 584 csc.EXPECT().Get(gomock.Any()) 585 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 586 state = css 587 close(receivedSessionTicket) 588 }) 589 clientConf.ClientSessionCache = csc 590 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 591 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 592 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 593 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 594 const initialMaxData protocol.ByteCount = 1337 595 clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( 596 clientConf, serverConf, 597 clientOrigRTTStats, serverOrigRTTStats, 598 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 599 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 600 true, 601 ) 602 Expect(clientErr).ToNot(HaveOccurred()) 603 Expect(serverErr).ToNot(HaveOccurred()) 604 Eventually(receivedSessionTicket).Should(BeClosed()) 605 Expect(server.ConnectionState().DidResume).To(BeFalse()) 606 Expect(client.ConnectionState().DidResume).To(BeFalse()) 607 Expect(clientHelloWrittenChan).To(Receive(BeNil())) 608 609 csc.EXPECT().Get(gomock.Any()).Return(state, true) 610 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 611 612 clientRTTStats := &utils.RTTStats{} 613 serverRTTStats := &utils.RTTStats{} 614 clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( 615 clientConf, serverConf, 616 clientRTTStats, serverRTTStats, 617 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 618 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 619 true, 620 ) 621 Expect(clientErr).ToNot(HaveOccurred()) 622 Expect(serverErr).ToNot(HaveOccurred()) 623 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 624 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 625 626 var tp *wire.TransportParameters 627 Expect(clientHelloWrittenChan).To(Receive(&tp)) 628 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 629 630 Expect(server.ConnectionState().DidResume).To(BeTrue()) 631 Expect(client.ConnectionState().DidResume).To(BeTrue()) 632 Expect(server.ConnectionState().Used0RTT).To(BeTrue()) 633 Expect(client.ConnectionState().Used0RTT).To(BeTrue()) 634 }) 635 636 It("rejects 0-RTT, when the transport parameters changed", func() { 637 csc := mocktls.NewMockClientSessionCache(mockCtrl) 638 var state *tls.ClientSessionState 639 receivedSessionTicket := make(chan struct{}) 640 csc.EXPECT().Get(gomock.Any()) 641 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 642 state = css 643 close(receivedSessionTicket) 644 }) 645 clientConf.ClientSessionCache = csc 646 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 647 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 648 const initialMaxData protocol.ByteCount = 1337 649 clientHelloWrittenChan, client, clientErr, server, serverErr := handshakeWithTLSConf( 650 clientConf, serverConf, 651 clientOrigRTTStats, &utils.RTTStats{}, 652 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 653 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 654 true, 655 ) 656 Expect(clientErr).ToNot(HaveOccurred()) 657 Expect(serverErr).ToNot(HaveOccurred()) 658 Eventually(receivedSessionTicket).Should(BeClosed()) 659 Expect(server.ConnectionState().DidResume).To(BeFalse()) 660 Expect(client.ConnectionState().DidResume).To(BeFalse()) 661 Expect(clientHelloWrittenChan).To(Receive(BeNil())) 662 663 csc.EXPECT().Get(gomock.Any()).Return(state, true) 664 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 665 666 clientRTTStats := &utils.RTTStats{} 667 clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( 668 clientConf, serverConf, 669 clientRTTStats, &utils.RTTStats{}, 670 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 671 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, 672 true, 673 ) 674 Expect(clientErr).ToNot(HaveOccurred()) 675 Expect(serverErr).ToNot(HaveOccurred()) 676 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 677 678 var tp *wire.TransportParameters 679 Expect(clientHelloWrittenChan).To(Receive(&tp)) 680 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 681 682 Expect(server.ConnectionState().DidResume).To(BeTrue()) 683 Expect(client.ConnectionState().DidResume).To(BeTrue()) 684 Expect(server.ConnectionState().Used0RTT).To(BeFalse()) 685 Expect(client.ConnectionState().Used0RTT).To(BeFalse()) 686 }) 687 }) 688 }) 689 })