github.com/MerlinKodo/quic-go@v0.39.2/internal/handshake/crypto_setup_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "crypto/x509/pkix" 9 "math/big" 10 "net" 11 "runtime" 12 "strings" 13 "time" 14 15 mocktls "github.com/MerlinKodo/quic-go/internal/mocks/tls" 16 "github.com/MerlinKodo/quic-go/internal/protocol" 17 "github.com/MerlinKodo/quic-go/internal/qerr" 18 "github.com/MerlinKodo/quic-go/internal/testdata" 19 "github.com/MerlinKodo/quic-go/internal/utils" 20 "github.com/MerlinKodo/quic-go/internal/wire" 21 22 . "github.com/onsi/ginkgo/v2" 23 . "github.com/onsi/gomega" 24 "go.uber.org/mock/gomock" 25 ) 26 27 const ( 28 typeClientHello = 1 29 typeNewSessionTicket = 4 30 ) 31 32 var _ = Describe("Crypto Setup TLS", func() { 33 generateCert := func() tls.Certificate { 34 priv, err := rsa.GenerateKey(rand.Reader, 2048) 35 Expect(err).ToNot(HaveOccurred()) 36 tmpl := &x509.Certificate{ 37 SerialNumber: big.NewInt(1), 38 Subject: pkix.Name{}, 39 SignatureAlgorithm: x509.SHA256WithRSA, 40 NotBefore: time.Now(), 41 NotAfter: time.Now().Add(time.Hour), // valid for an hour 42 BasicConstraintsValid: true, 43 } 44 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) 45 Expect(err).ToNot(HaveOccurred()) 46 return tls.Certificate{ 47 PrivateKey: priv, 48 Certificate: [][]byte{certDER}, 49 } 50 } 51 52 var clientConf, serverConf *tls.Config 53 54 BeforeEach(func() { 55 serverConf = testdata.GetTLSConfig() 56 serverConf.NextProtos = []string{"crypto-setup"} 57 clientConf = &tls.Config{ 58 ServerName: "localhost", 59 RootCAs: testdata.GetRootCA(), 60 NextProtos: []string{"crypto-setup"}, 61 } 62 }) 63 64 It("handles qtls errors occurring before during ClientHello generation", func() { 65 tlsConf := testdata.GetTLSConfig() 66 tlsConf.InsecureSkipVerify = true 67 tlsConf.NextProtos = []string{""} 68 cl := NewCryptoSetupClient( 69 protocol.ConnectionID{}, 70 &wire.TransportParameters{}, 71 tlsConf, 72 false, 73 &utils.RTTStats{}, 74 nil, 75 utils.DefaultLogger.WithPrefix("client"), 76 protocol.Version1, 77 ) 78 79 Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ 80 ErrorCode: qerr.InternalError, 81 ErrorMessage: "tls: invalid NextProtos value", 82 })) 83 }) 84 85 It("errors when a message is received at the wrong encryption level", func() { 86 var token protocol.StatelessResetToken 87 server := NewCryptoSetupServer( 88 protocol.ConnectionID{}, 89 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 90 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 91 &wire.TransportParameters{StatelessResetToken: &token}, 92 testdata.GetTLSConfig(), 93 false, 94 &utils.RTTStats{}, 95 nil, 96 utils.DefaultLogger.WithPrefix("server"), 97 protocol.Version1, 98 ) 99 100 Expect(server.StartHandshake()).To(Succeed()) 101 102 fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) 103 // wrong encryption level 104 err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) 105 Expect(err).To(HaveOccurred()) 106 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 107 }) 108 109 Context("filling in a net.Conn in tls.ClientHelloInfo", func() { 110 var ( 111 local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} 112 remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 113 ) 114 115 It("wraps GetCertificate", func() { 116 var localAddr, remoteAddr net.Addr 117 tlsConf := &tls.Config{ 118 GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 119 localAddr = info.Conn.LocalAddr() 120 remoteAddr = info.Conn.RemoteAddr() 121 cert := generateCert() 122 return &cert, nil 123 }, 124 } 125 addConnToClientHelloInfo(tlsConf, local, remote) 126 _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) 127 Expect(err).ToNot(HaveOccurred()) 128 Expect(localAddr).To(Equal(local)) 129 Expect(remoteAddr).To(Equal(remote)) 130 }) 131 132 It("wraps GetConfigForClient", func() { 133 var localAddr, remoteAddr net.Addr 134 tlsConf := &tls.Config{ 135 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 136 localAddr = info.Conn.LocalAddr() 137 remoteAddr = info.Conn.RemoteAddr() 138 return &tls.Config{}, nil 139 }, 140 } 141 addConnToClientHelloInfo(tlsConf, local, remote) 142 _, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 143 Expect(err).ToNot(HaveOccurred()) 144 Expect(localAddr).To(Equal(local)) 145 Expect(remoteAddr).To(Equal(remote)) 146 }) 147 148 It("wraps GetConfigForClient, recursively", func() { 149 var localAddr, remoteAddr net.Addr 150 tlsConf := &tls.Config{} 151 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 152 conf := tlsConf.Clone() 153 conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 154 localAddr = info.Conn.LocalAddr() 155 remoteAddr = info.Conn.RemoteAddr() 156 cert := generateCert() 157 return &cert, nil 158 } 159 return conf, nil 160 } 161 addConnToClientHelloInfo(tlsConf, local, remote) 162 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 163 Expect(err).ToNot(HaveOccurred()) 164 _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) 165 Expect(err).ToNot(HaveOccurred()) 166 Expect(localAddr).To(Equal(local)) 167 Expect(remoteAddr).To(Equal(remote)) 168 }) 169 }) 170 171 Context("doing the handshake", func() { 172 newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { 173 rttStats := &utils.RTTStats{} 174 rttStats.UpdateRTT(rtt, 0, time.Now()) 175 ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) 176 return rttStats 177 } 178 179 // The clientEvents and serverEvents contain all events that were not processed by the function, 180 // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. 181 handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { 182 Expect(client.StartHandshake()).To(Succeed()) 183 Expect(server.StartHandshake()).To(Succeed()) 184 185 var clientHandshakeComplete, serverHandshakeComplete bool 186 187 for { 188 clientLoop: 189 for { 190 ev := client.NextEvent() 191 //nolint:exhaustive // only need to process a few events 192 switch ev.Kind { 193 case EventNoEvent: 194 break clientLoop 195 case EventWriteInitialData: 196 if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 197 serverErr = err 198 return 199 } 200 case EventWriteHandshakeData: 201 if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 202 serverErr = err 203 return 204 } 205 case EventHandshakeComplete: 206 clientHandshakeComplete = true 207 default: 208 clientEvents = append(clientEvents, ev) 209 } 210 } 211 212 serverLoop: 213 for { 214 ev := server.NextEvent() 215 //nolint:exhaustive // only need to process a few events 216 switch ev.Kind { 217 case EventNoEvent: 218 break serverLoop 219 case EventWriteInitialData: 220 if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 221 clientErr = err 222 return 223 } 224 case EventWriteHandshakeData: 225 if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 226 clientErr = err 227 return 228 } 229 case EventHandshakeComplete: 230 serverHandshakeComplete = true 231 ticket, err := server.GetSessionTicket() 232 Expect(err).ToNot(HaveOccurred()) 233 if ticket != nil { 234 Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) 235 } 236 default: 237 serverEvents = append(serverEvents, ev) 238 } 239 } 240 241 if clientHandshakeComplete && serverHandshakeComplete { 242 break 243 } 244 } 245 return 246 } 247 248 handshakeWithTLSConf := func( 249 clientConf, serverConf *tls.Config, 250 clientRTTStats, serverRTTStats *utils.RTTStats, 251 clientTransportParameters, serverTransportParameters *wire.TransportParameters, 252 enable0RTT bool, 253 ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ 254 CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ 255 ) { 256 client := NewCryptoSetupClient( 257 protocol.ConnectionID{}, 258 clientTransportParameters, 259 clientConf, 260 enable0RTT, 261 clientRTTStats, 262 nil, 263 utils.DefaultLogger.WithPrefix("client"), 264 protocol.Version1, 265 ) 266 267 if serverTransportParameters.StatelessResetToken == nil { 268 var token protocol.StatelessResetToken 269 serverTransportParameters.StatelessResetToken = &token 270 } 271 server := NewCryptoSetupServer( 272 protocol.ConnectionID{}, 273 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 274 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 275 serverTransportParameters, 276 serverConf, 277 enable0RTT, 278 serverRTTStats, 279 nil, 280 utils.DefaultLogger.WithPrefix("server"), 281 protocol.Version1, 282 ) 283 cEvents, cErr, sEvents, sErr := handshake(client, server) 284 return client, cEvents, cErr, server, sEvents, sErr 285 } 286 287 It("handshakes", func() { 288 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 289 clientConf, serverConf, 290 &utils.RTTStats{}, &utils.RTTStats{}, 291 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 292 false, 293 ) 294 Expect(clientErr).ToNot(HaveOccurred()) 295 Expect(serverErr).ToNot(HaveOccurred()) 296 }) 297 298 It("performs a HelloRetryRequst", func() { 299 serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} 300 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 301 clientConf, serverConf, 302 &utils.RTTStats{}, &utils.RTTStats{}, 303 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 304 false, 305 ) 306 Expect(clientErr).ToNot(HaveOccurred()) 307 Expect(serverErr).ToNot(HaveOccurred()) 308 }) 309 310 It("handshakes with client auth", func() { 311 clientConf.Certificates = []tls.Certificate{generateCert()} 312 serverConf.ClientAuth = tls.RequireAnyClientCert 313 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 314 clientConf, serverConf, 315 &utils.RTTStats{}, &utils.RTTStats{}, 316 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 317 false, 318 ) 319 Expect(clientErr).ToNot(HaveOccurred()) 320 Expect(serverErr).ToNot(HaveOccurred()) 321 }) 322 323 It("receives transport parameters", func() { 324 cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} 325 client := NewCryptoSetupClient( 326 protocol.ConnectionID{}, 327 cTransportParameters, 328 clientConf, 329 false, 330 &utils.RTTStats{}, 331 nil, 332 utils.DefaultLogger.WithPrefix("client"), 333 protocol.Version1, 334 ) 335 336 var token protocol.StatelessResetToken 337 sTransportParameters := &wire.TransportParameters{ 338 MaxIdleTimeout: 1337 * time.Second, 339 StatelessResetToken: &token, 340 ActiveConnectionIDLimit: 2, 341 } 342 server := NewCryptoSetupServer( 343 protocol.ConnectionID{}, 344 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 345 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 346 sTransportParameters, 347 serverConf, 348 false, 349 &utils.RTTStats{}, 350 nil, 351 utils.DefaultLogger.WithPrefix("server"), 352 protocol.Version1, 353 ) 354 355 clientEvents, cErr, serverEvents, sErr := handshake(client, server) 356 Expect(cErr).ToNot(HaveOccurred()) 357 Expect(sErr).ToNot(HaveOccurred()) 358 var clientReceivedTransportParameters *wire.TransportParameters 359 for _, ev := range clientEvents { 360 if ev.Kind == EventReceivedTransportParameters { 361 clientReceivedTransportParameters = ev.TransportParameters 362 } 363 } 364 Expect(clientReceivedTransportParameters).ToNot(BeNil()) 365 Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second)) 366 367 var serverReceivedTransportParameters *wire.TransportParameters 368 for _, ev := range serverEvents { 369 if ev.Kind == EventReceivedTransportParameters { 370 serverReceivedTransportParameters = ev.TransportParameters 371 } 372 } 373 Expect(serverReceivedTransportParameters).ToNot(BeNil()) 374 Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second)) 375 }) 376 377 Context("with session tickets", func() { 378 It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { 379 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 380 clientConf, serverConf, 381 &utils.RTTStats{}, &utils.RTTStats{}, 382 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 383 false, 384 ) 385 Expect(clientErr).ToNot(HaveOccurred()) 386 Expect(serverErr).ToNot(HaveOccurred()) 387 388 // inject an invalid session ticket 389 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 390 err := client.HandleMessage(b, protocol.EncryptionHandshake) 391 Expect(err).To(HaveOccurred()) 392 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 393 }) 394 395 It("errors when handling the NewSessionTicket fails", func() { 396 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 397 clientConf, serverConf, 398 &utils.RTTStats{}, &utils.RTTStats{}, 399 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 400 false, 401 ) 402 Expect(clientErr).ToNot(HaveOccurred()) 403 Expect(serverErr).ToNot(HaveOccurred()) 404 405 // inject an invalid session ticket 406 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 407 err := client.HandleMessage(b, protocol.Encryption1RTT) 408 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 409 Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) 410 }) 411 412 It("uses session resumption", func() { 413 csc := mocktls.NewMockClientSessionCache(mockCtrl) 414 var state *tls.ClientSessionState 415 receivedSessionTicket := make(chan struct{}) 416 csc.EXPECT().Get(gomock.Any()) 417 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 418 state = css 419 close(receivedSessionTicket) 420 }) 421 clientConf.ClientSessionCache = csc 422 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 423 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 424 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 425 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 426 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 427 clientConf, serverConf, 428 clientOrigRTTStats, serverOrigRTTStats, 429 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 430 false, 431 ) 432 Expect(clientErr).ToNot(HaveOccurred()) 433 Expect(serverErr).ToNot(HaveOccurred()) 434 Eventually(receivedSessionTicket).Should(BeClosed()) 435 Expect(server.ConnectionState().DidResume).To(BeFalse()) 436 Expect(client.ConnectionState().DidResume).To(BeFalse()) 437 438 csc.EXPECT().Get(gomock.Any()).Return(state, true) 439 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 440 clientRTTStats := &utils.RTTStats{} 441 serverRTTStats := &utils.RTTStats{} 442 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 443 clientConf, serverConf, 444 clientRTTStats, serverRTTStats, 445 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 446 false, 447 ) 448 Expect(clientErr).ToNot(HaveOccurred()) 449 Expect(serverErr).ToNot(HaveOccurred()) 450 Eventually(receivedSessionTicket).Should(BeClosed()) 451 Expect(server.ConnectionState().DidResume).To(BeTrue()) 452 Expect(client.ConnectionState().DidResume).To(BeTrue()) 453 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 454 if !strings.Contains(runtime.Version(), "go1.20") { 455 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 456 } 457 }) 458 459 It("doesn't use session resumption if the server disabled it", func() { 460 csc := mocktls.NewMockClientSessionCache(mockCtrl) 461 var state *tls.ClientSessionState 462 receivedSessionTicket := make(chan struct{}) 463 csc.EXPECT().Get(gomock.Any()) 464 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 465 state = css 466 close(receivedSessionTicket) 467 }) 468 clientConf.ClientSessionCache = csc 469 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 470 clientConf, serverConf, 471 &utils.RTTStats{}, &utils.RTTStats{}, 472 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 473 false, 474 ) 475 Expect(clientErr).ToNot(HaveOccurred()) 476 Expect(serverErr).ToNot(HaveOccurred()) 477 Eventually(receivedSessionTicket).Should(BeClosed()) 478 Expect(server.ConnectionState().DidResume).To(BeFalse()) 479 Expect(client.ConnectionState().DidResume).To(BeFalse()) 480 481 serverConf.SessionTicketsDisabled = true 482 csc.EXPECT().Get(gomock.Any()).Return(state, true) 483 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 484 clientConf, serverConf, 485 &utils.RTTStats{}, &utils.RTTStats{}, 486 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 487 false, 488 ) 489 Expect(clientErr).ToNot(HaveOccurred()) 490 Expect(serverErr).ToNot(HaveOccurred()) 491 Eventually(receivedSessionTicket).Should(BeClosed()) 492 Expect(server.ConnectionState().DidResume).To(BeFalse()) 493 Expect(client.ConnectionState().DidResume).To(BeFalse()) 494 }) 495 496 It("uses 0-RTT", func() { 497 csc := mocktls.NewMockClientSessionCache(mockCtrl) 498 var state *tls.ClientSessionState 499 receivedSessionTicket := make(chan struct{}) 500 csc.EXPECT().Get(gomock.Any()) 501 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 502 state = css 503 close(receivedSessionTicket) 504 }) 505 clientConf.ClientSessionCache = csc 506 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 507 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 508 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 509 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 510 const initialMaxData protocol.ByteCount = 1337 511 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 512 clientConf, serverConf, 513 clientOrigRTTStats, serverOrigRTTStats, 514 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 515 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 516 true, 517 ) 518 Expect(clientErr).ToNot(HaveOccurred()) 519 Expect(serverErr).ToNot(HaveOccurred()) 520 Eventually(receivedSessionTicket).Should(BeClosed()) 521 Expect(server.ConnectionState().DidResume).To(BeFalse()) 522 Expect(client.ConnectionState().DidResume).To(BeFalse()) 523 524 csc.EXPECT().Get(gomock.Any()).Return(state, true) 525 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 526 527 clientRTTStats := &utils.RTTStats{} 528 serverRTTStats := &utils.RTTStats{} 529 client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( 530 clientConf, serverConf, 531 clientRTTStats, serverRTTStats, 532 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 533 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 534 true, 535 ) 536 Expect(clientErr).ToNot(HaveOccurred()) 537 Expect(serverErr).ToNot(HaveOccurred()) 538 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 539 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 540 541 var tp *wire.TransportParameters 542 var clientReceived0RTTKeys bool 543 for _, ev := range clientEvents { 544 //nolint:exhaustive // only need to process a few events 545 switch ev.Kind { 546 case EventRestoredTransportParameters: 547 tp = ev.TransportParameters 548 case EventReceivedReadKeys: 549 clientReceived0RTTKeys = true 550 } 551 } 552 Expect(clientReceived0RTTKeys).To(BeTrue()) 553 Expect(tp).ToNot(BeNil()) 554 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 555 556 var serverReceived0RTTKeys bool 557 for _, ev := range serverEvents { 558 //nolint:exhaustive // only need to process a few events 559 switch ev.Kind { 560 case EventReceivedReadKeys: 561 serverReceived0RTTKeys = true 562 } 563 } 564 Expect(serverReceived0RTTKeys).To(BeTrue()) 565 566 Expect(server.ConnectionState().DidResume).To(BeTrue()) 567 Expect(client.ConnectionState().DidResume).To(BeTrue()) 568 Expect(server.ConnectionState().Used0RTT).To(BeTrue()) 569 Expect(client.ConnectionState().Used0RTT).To(BeTrue()) 570 }) 571 572 It("rejects 0-RTT, when the transport parameters changed", func() { 573 csc := mocktls.NewMockClientSessionCache(mockCtrl) 574 var state *tls.ClientSessionState 575 receivedSessionTicket := make(chan struct{}) 576 csc.EXPECT().Get(gomock.Any()) 577 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 578 state = css 579 close(receivedSessionTicket) 580 }) 581 clientConf.ClientSessionCache = csc 582 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 583 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 584 const initialMaxData protocol.ByteCount = 1337 585 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 586 clientConf, serverConf, 587 clientOrigRTTStats, &utils.RTTStats{}, 588 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 589 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 590 true, 591 ) 592 Expect(clientErr).ToNot(HaveOccurred()) 593 Expect(serverErr).ToNot(HaveOccurred()) 594 Eventually(receivedSessionTicket).Should(BeClosed()) 595 Expect(server.ConnectionState().DidResume).To(BeFalse()) 596 Expect(client.ConnectionState().DidResume).To(BeFalse()) 597 598 csc.EXPECT().Get(gomock.Any()).Return(state, true) 599 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 600 601 clientRTTStats := &utils.RTTStats{} 602 client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( 603 clientConf, serverConf, 604 clientRTTStats, &utils.RTTStats{}, 605 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 606 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, 607 true, 608 ) 609 Expect(clientErr).ToNot(HaveOccurred()) 610 Expect(serverErr).ToNot(HaveOccurred()) 611 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 612 613 var tp *wire.TransportParameters 614 var clientReceived0RTTKeys bool 615 for _, ev := range clientEvents { 616 //nolint:exhaustive // only need to process a few events 617 switch ev.Kind { 618 case EventRestoredTransportParameters: 619 tp = ev.TransportParameters 620 case EventReceivedReadKeys: 621 clientReceived0RTTKeys = true 622 } 623 } 624 Expect(clientReceived0RTTKeys).To(BeTrue()) 625 Expect(tp).ToNot(BeNil()) 626 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 627 628 Expect(server.ConnectionState().DidResume).To(BeTrue()) 629 Expect(client.ConnectionState().DidResume).To(BeTrue()) 630 Expect(server.ConnectionState().Used0RTT).To(BeFalse()) 631 Expect(client.ConnectionState().Used0RTT).To(BeFalse()) 632 }) 633 }) 634 }) 635 })