github.com/quic-go/quic-go@v0.44.0/internal/handshake/crypto_setup_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "crypto/x509/pkix" 9 "math/big" 10 "net" 11 "reflect" 12 "time" 13 14 mocktls "github.com/quic-go/quic-go/internal/mocks/tls" 15 "github.com/quic-go/quic-go/internal/protocol" 16 "github.com/quic-go/quic-go/internal/qerr" 17 "github.com/quic-go/quic-go/internal/testdata" 18 "github.com/quic-go/quic-go/internal/utils" 19 "github.com/quic-go/quic-go/internal/wire" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 "go.uber.org/mock/gomock" 24 ) 25 26 const ( 27 typeClientHello = 1 28 typeNewSessionTicket = 4 29 ) 30 31 var _ = Describe("Crypto Setup TLS", func() { 32 generateCert := func() tls.Certificate { 33 priv, err := rsa.GenerateKey(rand.Reader, 2048) 34 Expect(err).ToNot(HaveOccurred()) 35 tmpl := &x509.Certificate{ 36 SerialNumber: big.NewInt(1), 37 Subject: pkix.Name{}, 38 SignatureAlgorithm: x509.SHA256WithRSA, 39 NotBefore: time.Now(), 40 NotAfter: time.Now().Add(time.Hour), // valid for an hour 41 BasicConstraintsValid: true, 42 } 43 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) 44 Expect(err).ToNot(HaveOccurred()) 45 return tls.Certificate{ 46 PrivateKey: priv, 47 Certificate: [][]byte{certDER}, 48 } 49 } 50 51 var clientConf, serverConf *tls.Config 52 53 BeforeEach(func() { 54 serverConf = testdata.GetTLSConfig() 55 serverConf.NextProtos = []string{"crypto-setup"} 56 clientConf = &tls.Config{ 57 ServerName: "localhost", 58 RootCAs: testdata.GetRootCA(), 59 NextProtos: []string{"crypto-setup"}, 60 } 61 }) 62 63 It("handles qtls errors occurring before during ClientHello generation", func() { 64 tlsConf := testdata.GetTLSConfig() 65 tlsConf.InsecureSkipVerify = true 66 tlsConf.NextProtos = []string{""} 67 cl := NewCryptoSetupClient( 68 protocol.ConnectionID{}, 69 &wire.TransportParameters{}, 70 tlsConf, 71 false, 72 &utils.RTTStats{}, 73 nil, 74 utils.DefaultLogger.WithPrefix("client"), 75 protocol.Version1, 76 ) 77 78 Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ 79 ErrorCode: qerr.InternalError, 80 ErrorMessage: "tls: invalid NextProtos value", 81 })) 82 }) 83 84 It("errors when a message is received at the wrong encryption level", func() { 85 var token protocol.StatelessResetToken 86 server := NewCryptoSetupServer( 87 protocol.ConnectionID{}, 88 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 89 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 90 &wire.TransportParameters{StatelessResetToken: &token}, 91 testdata.GetTLSConfig(), 92 false, 93 &utils.RTTStats{}, 94 nil, 95 utils.DefaultLogger.WithPrefix("server"), 96 protocol.Version1, 97 ) 98 99 Expect(server.StartHandshake()).To(Succeed()) 100 101 fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) 102 // wrong encryption level 103 err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) 104 Expect(err).To(HaveOccurred()) 105 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 106 }) 107 108 Context("filling in a net.Conn in tls.ClientHelloInfo", func() { 109 var ( 110 local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} 111 remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 112 ) 113 114 It("wraps GetCertificate", func() { 115 var localAddr, remoteAddr net.Addr 116 tlsConf := &tls.Config{ 117 GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 118 localAddr = info.Conn.LocalAddr() 119 remoteAddr = info.Conn.RemoteAddr() 120 cert := generateCert() 121 return &cert, nil 122 }, 123 } 124 addConnToClientHelloInfo(tlsConf, local, remote) 125 _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) 126 Expect(err).ToNot(HaveOccurred()) 127 Expect(localAddr).To(Equal(local)) 128 Expect(remoteAddr).To(Equal(remote)) 129 }) 130 131 It("wraps GetConfigForClient", func() { 132 var localAddr, remoteAddr net.Addr 133 tlsConf := &tls.Config{ 134 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 135 localAddr = info.Conn.LocalAddr() 136 remoteAddr = info.Conn.RemoteAddr() 137 return &tls.Config{}, nil 138 }, 139 } 140 addConnToClientHelloInfo(tlsConf, local, remote) 141 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 142 Expect(err).ToNot(HaveOccurred()) 143 Expect(localAddr).To(Equal(local)) 144 Expect(remoteAddr).To(Equal(remote)) 145 Expect(conf).ToNot(BeNil()) 146 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 147 }) 148 149 It("wraps GetConfigForClient, recursively", func() { 150 var localAddr, remoteAddr net.Addr 151 tlsConf := &tls.Config{} 152 var innerConf *tls.Config 153 getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam 154 localAddr = info.Conn.LocalAddr() 155 remoteAddr = info.Conn.RemoteAddr() 156 cert := generateCert() 157 return &cert, nil 158 } 159 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 160 innerConf = tlsConf.Clone() 161 // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config 162 innerConf.MaxVersion = tls.VersionTLS12 163 innerConf.GetCertificate = getCert 164 return innerConf, nil 165 } 166 addConnToClientHelloInfo(tlsConf, local, remote) 167 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 168 Expect(err).ToNot(HaveOccurred()) 169 Expect(conf).ToNot(BeNil()) 170 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 171 _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) 172 Expect(err).ToNot(HaveOccurred()) 173 Expect(localAddr).To(Equal(local)) 174 Expect(remoteAddr).To(Equal(remote)) 175 // make sure that the tls.Config returned by GetConfigForClient isn't modified 176 Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) 177 Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) 178 }) 179 }) 180 181 Context("doing the handshake", func() { 182 newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { 183 rttStats := &utils.RTTStats{} 184 rttStats.UpdateRTT(rtt, 0, time.Now()) 185 ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) 186 return rttStats 187 } 188 189 // The clientEvents and serverEvents contain all events that were not processed by the function, 190 // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. 191 handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { 192 Expect(client.StartHandshake()).To(Succeed()) 193 Expect(server.StartHandshake()).To(Succeed()) 194 195 var clientHandshakeComplete, serverHandshakeComplete bool 196 197 for { 198 clientLoop: 199 for { 200 ev := client.NextEvent() 201 switch ev.Kind { 202 case EventNoEvent: 203 break clientLoop 204 case EventWriteInitialData: 205 if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 206 serverErr = err 207 return 208 } 209 case EventWriteHandshakeData: 210 if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 211 serverErr = err 212 return 213 } 214 case EventHandshakeComplete: 215 clientHandshakeComplete = true 216 default: 217 clientEvents = append(clientEvents, ev) 218 } 219 } 220 221 serverLoop: 222 for { 223 ev := server.NextEvent() 224 switch ev.Kind { 225 case EventNoEvent: 226 break serverLoop 227 case EventWriteInitialData: 228 if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 229 clientErr = err 230 return 231 } 232 case EventWriteHandshakeData: 233 if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 234 clientErr = err 235 return 236 } 237 case EventHandshakeComplete: 238 serverHandshakeComplete = true 239 ticket, err := server.GetSessionTicket() 240 Expect(err).ToNot(HaveOccurred()) 241 if ticket != nil { 242 Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) 243 } 244 default: 245 serverEvents = append(serverEvents, ev) 246 } 247 } 248 249 if clientHandshakeComplete && serverHandshakeComplete { 250 break 251 } 252 } 253 return 254 } 255 256 handshakeWithTLSConf := func( 257 clientConf, serverConf *tls.Config, 258 clientRTTStats, serverRTTStats *utils.RTTStats, 259 clientTransportParameters, serverTransportParameters *wire.TransportParameters, 260 enable0RTT bool, 261 ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ 262 CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ 263 ) { 264 client := NewCryptoSetupClient( 265 protocol.ConnectionID{}, 266 clientTransportParameters, 267 clientConf, 268 enable0RTT, 269 clientRTTStats, 270 nil, 271 utils.DefaultLogger.WithPrefix("client"), 272 protocol.Version1, 273 ) 274 275 if serverTransportParameters.StatelessResetToken == nil { 276 var token protocol.StatelessResetToken 277 serverTransportParameters.StatelessResetToken = &token 278 } 279 server := NewCryptoSetupServer( 280 protocol.ConnectionID{}, 281 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 282 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 283 serverTransportParameters, 284 serverConf, 285 enable0RTT, 286 serverRTTStats, 287 nil, 288 utils.DefaultLogger.WithPrefix("server"), 289 protocol.Version1, 290 ) 291 cEvents, cErr, sEvents, sErr := handshake(client, server) 292 return client, cEvents, cErr, server, sEvents, sErr 293 } 294 295 It("handshakes", func() { 296 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 297 clientConf, serverConf, 298 &utils.RTTStats{}, &utils.RTTStats{}, 299 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 300 false, 301 ) 302 Expect(clientErr).ToNot(HaveOccurred()) 303 Expect(serverErr).ToNot(HaveOccurred()) 304 }) 305 306 It("performs a HelloRetryRequst", func() { 307 serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} 308 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 309 clientConf, serverConf, 310 &utils.RTTStats{}, &utils.RTTStats{}, 311 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 312 false, 313 ) 314 Expect(clientErr).ToNot(HaveOccurred()) 315 Expect(serverErr).ToNot(HaveOccurred()) 316 }) 317 318 It("handshakes with client auth", func() { 319 clientConf.Certificates = []tls.Certificate{generateCert()} 320 serverConf.ClientAuth = tls.RequireAnyClientCert 321 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 322 clientConf, serverConf, 323 &utils.RTTStats{}, &utils.RTTStats{}, 324 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 325 false, 326 ) 327 Expect(clientErr).ToNot(HaveOccurred()) 328 Expect(serverErr).ToNot(HaveOccurred()) 329 }) 330 331 It("receives transport parameters", func() { 332 cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} 333 client := NewCryptoSetupClient( 334 protocol.ConnectionID{}, 335 cTransportParameters, 336 clientConf, 337 false, 338 &utils.RTTStats{}, 339 nil, 340 utils.DefaultLogger.WithPrefix("client"), 341 protocol.Version1, 342 ) 343 344 var token protocol.StatelessResetToken 345 sTransportParameters := &wire.TransportParameters{ 346 MaxIdleTimeout: 1337 * time.Second, 347 StatelessResetToken: &token, 348 ActiveConnectionIDLimit: 2, 349 } 350 server := NewCryptoSetupServer( 351 protocol.ConnectionID{}, 352 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 353 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 354 sTransportParameters, 355 serverConf, 356 false, 357 &utils.RTTStats{}, 358 nil, 359 utils.DefaultLogger.WithPrefix("server"), 360 protocol.Version1, 361 ) 362 363 clientEvents, cErr, serverEvents, sErr := handshake(client, server) 364 Expect(cErr).ToNot(HaveOccurred()) 365 Expect(sErr).ToNot(HaveOccurred()) 366 var clientReceivedTransportParameters *wire.TransportParameters 367 for _, ev := range clientEvents { 368 if ev.Kind == EventReceivedTransportParameters { 369 clientReceivedTransportParameters = ev.TransportParameters 370 } 371 } 372 Expect(clientReceivedTransportParameters).ToNot(BeNil()) 373 Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second)) 374 375 var serverReceivedTransportParameters *wire.TransportParameters 376 for _, ev := range serverEvents { 377 if ev.Kind == EventReceivedTransportParameters { 378 serverReceivedTransportParameters = ev.TransportParameters 379 } 380 } 381 Expect(serverReceivedTransportParameters).ToNot(BeNil()) 382 Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second)) 383 }) 384 385 Context("with session tickets", func() { 386 It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { 387 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 388 clientConf, serverConf, 389 &utils.RTTStats{}, &utils.RTTStats{}, 390 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 391 false, 392 ) 393 Expect(clientErr).ToNot(HaveOccurred()) 394 Expect(serverErr).ToNot(HaveOccurred()) 395 396 // inject an invalid session ticket 397 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 398 err := client.HandleMessage(b, protocol.EncryptionHandshake) 399 Expect(err).To(HaveOccurred()) 400 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 401 }) 402 403 It("errors when handling the NewSessionTicket fails", func() { 404 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 405 clientConf, serverConf, 406 &utils.RTTStats{}, &utils.RTTStats{}, 407 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 408 false, 409 ) 410 Expect(clientErr).ToNot(HaveOccurred()) 411 Expect(serverErr).ToNot(HaveOccurred()) 412 413 // inject an invalid session ticket 414 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 415 err := client.HandleMessage(b, protocol.Encryption1RTT) 416 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 417 Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) 418 }) 419 420 It("uses session resumption", func() { 421 csc := mocktls.NewMockClientSessionCache(mockCtrl) 422 var state *tls.ClientSessionState 423 receivedSessionTicket := make(chan struct{}) 424 csc.EXPECT().Get(gomock.Any()) 425 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 426 state = css 427 close(receivedSessionTicket) 428 }) 429 clientConf.ClientSessionCache = csc 430 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 431 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 432 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 433 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 434 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 435 clientConf, serverConf, 436 clientOrigRTTStats, serverOrigRTTStats, 437 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 438 false, 439 ) 440 Expect(clientErr).ToNot(HaveOccurred()) 441 Expect(serverErr).ToNot(HaveOccurred()) 442 Eventually(receivedSessionTicket).Should(BeClosed()) 443 Expect(server.ConnectionState().DidResume).To(BeFalse()) 444 Expect(client.ConnectionState().DidResume).To(BeFalse()) 445 446 csc.EXPECT().Get(gomock.Any()).Return(state, true) 447 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 448 clientRTTStats := &utils.RTTStats{} 449 serverRTTStats := &utils.RTTStats{} 450 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 451 clientConf, serverConf, 452 clientRTTStats, serverRTTStats, 453 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 454 false, 455 ) 456 Expect(clientErr).ToNot(HaveOccurred()) 457 Expect(serverErr).ToNot(HaveOccurred()) 458 Eventually(receivedSessionTicket).Should(BeClosed()) 459 Expect(server.ConnectionState().DidResume).To(BeTrue()) 460 Expect(client.ConnectionState().DidResume).To(BeTrue()) 461 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 462 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 463 }) 464 465 It("doesn't use session resumption if the server disabled it", func() { 466 csc := mocktls.NewMockClientSessionCache(mockCtrl) 467 var state *tls.ClientSessionState 468 receivedSessionTicket := make(chan struct{}) 469 csc.EXPECT().Get(gomock.Any()) 470 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 471 state = css 472 close(receivedSessionTicket) 473 }) 474 clientConf.ClientSessionCache = csc 475 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 476 clientConf, serverConf, 477 &utils.RTTStats{}, &utils.RTTStats{}, 478 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 479 false, 480 ) 481 Expect(clientErr).ToNot(HaveOccurred()) 482 Expect(serverErr).ToNot(HaveOccurred()) 483 Eventually(receivedSessionTicket).Should(BeClosed()) 484 Expect(server.ConnectionState().DidResume).To(BeFalse()) 485 Expect(client.ConnectionState().DidResume).To(BeFalse()) 486 487 serverConf.SessionTicketsDisabled = true 488 csc.EXPECT().Get(gomock.Any()).Return(state, true) 489 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 490 clientConf, serverConf, 491 &utils.RTTStats{}, &utils.RTTStats{}, 492 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 493 false, 494 ) 495 Expect(clientErr).ToNot(HaveOccurred()) 496 Expect(serverErr).ToNot(HaveOccurred()) 497 Eventually(receivedSessionTicket).Should(BeClosed()) 498 Expect(server.ConnectionState().DidResume).To(BeFalse()) 499 Expect(client.ConnectionState().DidResume).To(BeFalse()) 500 }) 501 502 It("uses 0-RTT", func() { 503 csc := mocktls.NewMockClientSessionCache(mockCtrl) 504 var state *tls.ClientSessionState 505 receivedSessionTicket := make(chan struct{}) 506 csc.EXPECT().Get(gomock.Any()) 507 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 508 state = css 509 close(receivedSessionTicket) 510 }) 511 clientConf.ClientSessionCache = csc 512 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 513 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 514 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 515 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 516 const initialMaxData protocol.ByteCount = 1337 517 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 518 clientConf, serverConf, 519 clientOrigRTTStats, serverOrigRTTStats, 520 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 521 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 522 true, 523 ) 524 Expect(clientErr).ToNot(HaveOccurred()) 525 Expect(serverErr).ToNot(HaveOccurred()) 526 Eventually(receivedSessionTicket).Should(BeClosed()) 527 Expect(server.ConnectionState().DidResume).To(BeFalse()) 528 Expect(client.ConnectionState().DidResume).To(BeFalse()) 529 530 csc.EXPECT().Get(gomock.Any()).Return(state, true) 531 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 532 533 clientRTTStats := &utils.RTTStats{} 534 serverRTTStats := &utils.RTTStats{} 535 client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( 536 clientConf, serverConf, 537 clientRTTStats, serverRTTStats, 538 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 539 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 540 true, 541 ) 542 Expect(clientErr).ToNot(HaveOccurred()) 543 Expect(serverErr).ToNot(HaveOccurred()) 544 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 545 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 546 547 var tp *wire.TransportParameters 548 var clientReceived0RTTKeys bool 549 for _, ev := range clientEvents { 550 switch ev.Kind { 551 case EventRestoredTransportParameters: 552 tp = ev.TransportParameters 553 case EventReceivedReadKeys: 554 clientReceived0RTTKeys = true 555 } 556 } 557 Expect(clientReceived0RTTKeys).To(BeTrue()) 558 Expect(tp).ToNot(BeNil()) 559 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 560 561 var serverReceived0RTTKeys bool 562 for _, ev := range serverEvents { 563 switch ev.Kind { 564 case EventReceivedReadKeys: 565 serverReceived0RTTKeys = true 566 } 567 } 568 Expect(serverReceived0RTTKeys).To(BeTrue()) 569 570 Expect(server.ConnectionState().DidResume).To(BeTrue()) 571 Expect(client.ConnectionState().DidResume).To(BeTrue()) 572 Expect(server.ConnectionState().Used0RTT).To(BeTrue()) 573 Expect(client.ConnectionState().Used0RTT).To(BeTrue()) 574 }) 575 576 It("rejects 0-RTT, when the transport parameters changed", func() { 577 csc := mocktls.NewMockClientSessionCache(mockCtrl) 578 var state *tls.ClientSessionState 579 receivedSessionTicket := make(chan struct{}) 580 csc.EXPECT().Get(gomock.Any()) 581 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 582 state = css 583 close(receivedSessionTicket) 584 }) 585 clientConf.ClientSessionCache = csc 586 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 587 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 588 const initialMaxData protocol.ByteCount = 1337 589 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 590 clientConf, serverConf, 591 clientOrigRTTStats, &utils.RTTStats{}, 592 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 593 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 594 true, 595 ) 596 Expect(clientErr).ToNot(HaveOccurred()) 597 Expect(serverErr).ToNot(HaveOccurred()) 598 Eventually(receivedSessionTicket).Should(BeClosed()) 599 Expect(server.ConnectionState().DidResume).To(BeFalse()) 600 Expect(client.ConnectionState().DidResume).To(BeFalse()) 601 602 csc.EXPECT().Get(gomock.Any()).Return(state, true) 603 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 604 605 clientRTTStats := &utils.RTTStats{} 606 client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( 607 clientConf, serverConf, 608 clientRTTStats, &utils.RTTStats{}, 609 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 610 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, 611 true, 612 ) 613 Expect(clientErr).ToNot(HaveOccurred()) 614 Expect(serverErr).ToNot(HaveOccurred()) 615 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 616 617 var tp *wire.TransportParameters 618 var clientReceived0RTTKeys bool 619 for _, ev := range clientEvents { 620 switch ev.Kind { 621 case EventRestoredTransportParameters: 622 tp = ev.TransportParameters 623 case EventReceivedReadKeys: 624 clientReceived0RTTKeys = true 625 } 626 } 627 Expect(clientReceived0RTTKeys).To(BeTrue()) 628 Expect(tp).ToNot(BeNil()) 629 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 630 631 Expect(server.ConnectionState().DidResume).To(BeTrue()) 632 Expect(client.ConnectionState().DidResume).To(BeTrue()) 633 Expect(server.ConnectionState().Used0RTT).To(BeFalse()) 634 Expect(client.ConnectionState().Used0RTT).To(BeFalse()) 635 }) 636 }) 637 }) 638 })