github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/internal/handshake/crypto_setup_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "crypto/x509/pkix" 9 "math/big" 10 "net" 11 "reflect" 12 "time" 13 14 mocktls "github.com/daeuniverse/quic-go/internal/mocks/tls" 15 "github.com/daeuniverse/quic-go/internal/protocol" 16 "github.com/daeuniverse/quic-go/internal/qerr" 17 "github.com/daeuniverse/quic-go/internal/testdata" 18 "github.com/daeuniverse/quic-go/internal/utils" 19 "github.com/daeuniverse/quic-go/internal/wire" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 "go.uber.org/mock/gomock" 24 ) 25 26 const ( 27 typeClientHello = 1 28 typeNewSessionTicket = 4 29 ) 30 31 var _ = Describe("Crypto Setup TLS", func() { 32 generateCert := func() tls.Certificate { 33 priv, err := rsa.GenerateKey(rand.Reader, 2048) 34 Expect(err).ToNot(HaveOccurred()) 35 tmpl := &x509.Certificate{ 36 SerialNumber: big.NewInt(1), 37 Subject: pkix.Name{}, 38 SignatureAlgorithm: x509.SHA256WithRSA, 39 NotBefore: time.Now(), 40 NotAfter: time.Now().Add(time.Hour), // valid for an hour 41 BasicConstraintsValid: true, 42 } 43 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) 44 Expect(err).ToNot(HaveOccurred()) 45 return tls.Certificate{ 46 PrivateKey: priv, 47 Certificate: [][]byte{certDER}, 48 } 49 } 50 51 var clientConf, serverConf *tls.Config 52 53 BeforeEach(func() { 54 serverConf = testdata.GetTLSConfig() 55 serverConf.NextProtos = []string{"crypto-setup"} 56 clientConf = &tls.Config{ 57 ServerName: "localhost", 58 RootCAs: testdata.GetRootCA(), 59 NextProtos: []string{"crypto-setup"}, 60 } 61 }) 62 63 It("handles qtls errors occurring before during ClientHello generation", func() { 64 tlsConf := testdata.GetTLSConfig() 65 tlsConf.InsecureSkipVerify = true 66 tlsConf.NextProtos = []string{""} 67 cl := NewCryptoSetupClient( 68 protocol.ConnectionID{}, 69 &wire.TransportParameters{}, 70 tlsConf, 71 false, 72 &utils.RTTStats{}, 73 nil, 74 utils.DefaultLogger.WithPrefix("client"), 75 protocol.Version1, 76 ) 77 78 Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ 79 ErrorCode: qerr.InternalError, 80 ErrorMessage: "tls: invalid NextProtos value", 81 })) 82 }) 83 84 It("errors when a message is received at the wrong encryption level", func() { 85 var token protocol.StatelessResetToken 86 server := NewCryptoSetupServer( 87 protocol.ConnectionID{}, 88 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 89 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 90 &wire.TransportParameters{StatelessResetToken: &token}, 91 testdata.GetTLSConfig(), 92 false, 93 &utils.RTTStats{}, 94 nil, 95 utils.DefaultLogger.WithPrefix("server"), 96 protocol.Version1, 97 ) 98 99 Expect(server.StartHandshake()).To(Succeed()) 100 101 fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) 102 // wrong encryption level 103 err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) 104 Expect(err).To(HaveOccurred()) 105 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 106 }) 107 108 Context("filling in a net.Conn in tls.ClientHelloInfo", func() { 109 var ( 110 local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} 111 remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 112 ) 113 114 It("wraps GetCertificate", func() { 115 var localAddr, remoteAddr net.Addr 116 tlsConf := &tls.Config{ 117 GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 118 localAddr = info.Conn.LocalAddr() 119 remoteAddr = info.Conn.RemoteAddr() 120 cert := generateCert() 121 return &cert, nil 122 }, 123 } 124 addConnToClientHelloInfo(tlsConf, local, remote) 125 _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) 126 Expect(err).ToNot(HaveOccurred()) 127 Expect(localAddr).To(Equal(local)) 128 Expect(remoteAddr).To(Equal(remote)) 129 }) 130 131 It("wraps GetConfigForClient", func() { 132 var localAddr, remoteAddr net.Addr 133 tlsConf := &tls.Config{ 134 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 135 localAddr = info.Conn.LocalAddr() 136 remoteAddr = info.Conn.RemoteAddr() 137 return &tls.Config{}, nil 138 }, 139 } 140 addConnToClientHelloInfo(tlsConf, local, remote) 141 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 142 Expect(err).ToNot(HaveOccurred()) 143 Expect(localAddr).To(Equal(local)) 144 Expect(remoteAddr).To(Equal(remote)) 145 Expect(conf).ToNot(BeNil()) 146 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 147 }) 148 149 It("wraps GetConfigForClient, recursively", func() { 150 var localAddr, remoteAddr net.Addr 151 tlsConf := &tls.Config{} 152 var innerConf *tls.Config 153 getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam 154 localAddr = info.Conn.LocalAddr() 155 remoteAddr = info.Conn.RemoteAddr() 156 cert := generateCert() 157 return &cert, nil 158 } 159 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 160 innerConf = tlsConf.Clone() 161 // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config 162 innerConf.MaxVersion = tls.VersionTLS12 163 innerConf.GetCertificate = getCert 164 return innerConf, nil 165 } 166 addConnToClientHelloInfo(tlsConf, local, remote) 167 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 168 Expect(err).ToNot(HaveOccurred()) 169 Expect(conf).ToNot(BeNil()) 170 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 171 _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) 172 Expect(err).ToNot(HaveOccurred()) 173 Expect(localAddr).To(Equal(local)) 174 Expect(remoteAddr).To(Equal(remote)) 175 // make sure that the tls.Config returned by GetConfigForClient isn't modified 176 Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) 177 Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) 178 }) 179 }) 180 181 Context("doing the handshake", func() { 182 newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { 183 rttStats := &utils.RTTStats{} 184 rttStats.UpdateRTT(rtt, 0, time.Now()) 185 ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) 186 return rttStats 187 } 188 189 // The clientEvents and serverEvents contain all events that were not processed by the function, 190 // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. 191 handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { 192 Expect(client.StartHandshake()).To(Succeed()) 193 Expect(server.StartHandshake()).To(Succeed()) 194 195 var clientHandshakeComplete, serverHandshakeComplete bool 196 197 for { 198 clientLoop: 199 for { 200 ev := client.NextEvent() 201 //nolint:exhaustive // only need to process a few events 202 switch ev.Kind { 203 case EventNoEvent: 204 break clientLoop 205 case EventWriteInitialData: 206 if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 207 serverErr = err 208 return 209 } 210 case EventWriteHandshakeData: 211 if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 212 serverErr = err 213 return 214 } 215 case EventHandshakeComplete: 216 clientHandshakeComplete = true 217 default: 218 clientEvents = append(clientEvents, ev) 219 } 220 } 221 222 serverLoop: 223 for { 224 ev := server.NextEvent() 225 //nolint:exhaustive // only need to process a few events 226 switch ev.Kind { 227 case EventNoEvent: 228 break serverLoop 229 case EventWriteInitialData: 230 if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 231 clientErr = err 232 return 233 } 234 case EventWriteHandshakeData: 235 if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 236 clientErr = err 237 return 238 } 239 case EventHandshakeComplete: 240 serverHandshakeComplete = true 241 ticket, err := server.GetSessionTicket() 242 Expect(err).ToNot(HaveOccurred()) 243 if ticket != nil { 244 Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) 245 } 246 default: 247 serverEvents = append(serverEvents, ev) 248 } 249 } 250 251 if clientHandshakeComplete && serverHandshakeComplete { 252 break 253 } 254 } 255 return 256 } 257 258 handshakeWithTLSConf := func( 259 clientConf, serverConf *tls.Config, 260 clientRTTStats, serverRTTStats *utils.RTTStats, 261 clientTransportParameters, serverTransportParameters *wire.TransportParameters, 262 enable0RTT bool, 263 ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ 264 CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ 265 ) { 266 client := NewCryptoSetupClient( 267 protocol.ConnectionID{}, 268 clientTransportParameters, 269 clientConf, 270 enable0RTT, 271 clientRTTStats, 272 nil, 273 utils.DefaultLogger.WithPrefix("client"), 274 protocol.Version1, 275 ) 276 277 if serverTransportParameters.StatelessResetToken == nil { 278 var token protocol.StatelessResetToken 279 serverTransportParameters.StatelessResetToken = &token 280 } 281 server := NewCryptoSetupServer( 282 protocol.ConnectionID{}, 283 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 284 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 285 serverTransportParameters, 286 serverConf, 287 enable0RTT, 288 serverRTTStats, 289 nil, 290 utils.DefaultLogger.WithPrefix("server"), 291 protocol.Version1, 292 ) 293 cEvents, cErr, sEvents, sErr := handshake(client, server) 294 return client, cEvents, cErr, server, sEvents, sErr 295 } 296 297 It("handshakes", func() { 298 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 299 clientConf, serverConf, 300 &utils.RTTStats{}, &utils.RTTStats{}, 301 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 302 false, 303 ) 304 Expect(clientErr).ToNot(HaveOccurred()) 305 Expect(serverErr).ToNot(HaveOccurred()) 306 }) 307 308 It("performs a HelloRetryRequst", func() { 309 serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} 310 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 311 clientConf, serverConf, 312 &utils.RTTStats{}, &utils.RTTStats{}, 313 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 314 false, 315 ) 316 Expect(clientErr).ToNot(HaveOccurred()) 317 Expect(serverErr).ToNot(HaveOccurred()) 318 }) 319 320 It("handshakes with client auth", func() { 321 clientConf.Certificates = []tls.Certificate{generateCert()} 322 serverConf.ClientAuth = tls.RequireAnyClientCert 323 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 324 clientConf, serverConf, 325 &utils.RTTStats{}, &utils.RTTStats{}, 326 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 327 false, 328 ) 329 Expect(clientErr).ToNot(HaveOccurred()) 330 Expect(serverErr).ToNot(HaveOccurred()) 331 }) 332 333 It("receives transport parameters", func() { 334 cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} 335 client := NewCryptoSetupClient( 336 protocol.ConnectionID{}, 337 cTransportParameters, 338 clientConf, 339 false, 340 &utils.RTTStats{}, 341 nil, 342 utils.DefaultLogger.WithPrefix("client"), 343 protocol.Version1, 344 ) 345 346 var token protocol.StatelessResetToken 347 sTransportParameters := &wire.TransportParameters{ 348 MaxIdleTimeout: 1337 * time.Second, 349 StatelessResetToken: &token, 350 ActiveConnectionIDLimit: 2, 351 } 352 server := NewCryptoSetupServer( 353 protocol.ConnectionID{}, 354 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 355 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 356 sTransportParameters, 357 serverConf, 358 false, 359 &utils.RTTStats{}, 360 nil, 361 utils.DefaultLogger.WithPrefix("server"), 362 protocol.Version1, 363 ) 364 365 clientEvents, cErr, serverEvents, sErr := handshake(client, server) 366 Expect(cErr).ToNot(HaveOccurred()) 367 Expect(sErr).ToNot(HaveOccurred()) 368 var clientReceivedTransportParameters *wire.TransportParameters 369 for _, ev := range clientEvents { 370 if ev.Kind == EventReceivedTransportParameters { 371 clientReceivedTransportParameters = ev.TransportParameters 372 } 373 } 374 Expect(clientReceivedTransportParameters).ToNot(BeNil()) 375 Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second)) 376 377 var serverReceivedTransportParameters *wire.TransportParameters 378 for _, ev := range serverEvents { 379 if ev.Kind == EventReceivedTransportParameters { 380 serverReceivedTransportParameters = ev.TransportParameters 381 } 382 } 383 Expect(serverReceivedTransportParameters).ToNot(BeNil()) 384 Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second)) 385 }) 386 387 Context("with session tickets", func() { 388 It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { 389 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 390 clientConf, serverConf, 391 &utils.RTTStats{}, &utils.RTTStats{}, 392 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 393 false, 394 ) 395 Expect(clientErr).ToNot(HaveOccurred()) 396 Expect(serverErr).ToNot(HaveOccurred()) 397 398 // inject an invalid session ticket 399 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 400 err := client.HandleMessage(b, protocol.EncryptionHandshake) 401 Expect(err).To(HaveOccurred()) 402 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 403 }) 404 405 It("errors when handling the NewSessionTicket fails", func() { 406 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 407 clientConf, serverConf, 408 &utils.RTTStats{}, &utils.RTTStats{}, 409 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 410 false, 411 ) 412 Expect(clientErr).ToNot(HaveOccurred()) 413 Expect(serverErr).ToNot(HaveOccurred()) 414 415 // inject an invalid session ticket 416 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 417 err := client.HandleMessage(b, protocol.Encryption1RTT) 418 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 419 Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) 420 }) 421 422 It("uses session resumption", func() { 423 csc := mocktls.NewMockClientSessionCache(mockCtrl) 424 var state *tls.ClientSessionState 425 receivedSessionTicket := make(chan struct{}) 426 csc.EXPECT().Get(gomock.Any()) 427 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 428 state = css 429 close(receivedSessionTicket) 430 }) 431 clientConf.ClientSessionCache = csc 432 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 433 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 434 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 435 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 436 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 437 clientConf, serverConf, 438 clientOrigRTTStats, serverOrigRTTStats, 439 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 440 false, 441 ) 442 Expect(clientErr).ToNot(HaveOccurred()) 443 Expect(serverErr).ToNot(HaveOccurred()) 444 Eventually(receivedSessionTicket).Should(BeClosed()) 445 Expect(server.ConnectionState().DidResume).To(BeFalse()) 446 Expect(client.ConnectionState().DidResume).To(BeFalse()) 447 448 csc.EXPECT().Get(gomock.Any()).Return(state, true) 449 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 450 clientRTTStats := &utils.RTTStats{} 451 serverRTTStats := &utils.RTTStats{} 452 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 453 clientConf, serverConf, 454 clientRTTStats, serverRTTStats, 455 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 456 false, 457 ) 458 Expect(clientErr).ToNot(HaveOccurred()) 459 Expect(serverErr).ToNot(HaveOccurred()) 460 Eventually(receivedSessionTicket).Should(BeClosed()) 461 Expect(server.ConnectionState().DidResume).To(BeTrue()) 462 Expect(client.ConnectionState().DidResume).To(BeTrue()) 463 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 464 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 465 }) 466 467 It("doesn't use session resumption if the server disabled it", func() { 468 csc := mocktls.NewMockClientSessionCache(mockCtrl) 469 var state *tls.ClientSessionState 470 receivedSessionTicket := make(chan struct{}) 471 csc.EXPECT().Get(gomock.Any()) 472 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 473 state = css 474 close(receivedSessionTicket) 475 }) 476 clientConf.ClientSessionCache = csc 477 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 478 clientConf, serverConf, 479 &utils.RTTStats{}, &utils.RTTStats{}, 480 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 481 false, 482 ) 483 Expect(clientErr).ToNot(HaveOccurred()) 484 Expect(serverErr).ToNot(HaveOccurred()) 485 Eventually(receivedSessionTicket).Should(BeClosed()) 486 Expect(server.ConnectionState().DidResume).To(BeFalse()) 487 Expect(client.ConnectionState().DidResume).To(BeFalse()) 488 489 serverConf.SessionTicketsDisabled = true 490 csc.EXPECT().Get(gomock.Any()).Return(state, true) 491 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 492 clientConf, serverConf, 493 &utils.RTTStats{}, &utils.RTTStats{}, 494 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 495 false, 496 ) 497 Expect(clientErr).ToNot(HaveOccurred()) 498 Expect(serverErr).ToNot(HaveOccurred()) 499 Eventually(receivedSessionTicket).Should(BeClosed()) 500 Expect(server.ConnectionState().DidResume).To(BeFalse()) 501 Expect(client.ConnectionState().DidResume).To(BeFalse()) 502 }) 503 504 It("uses 0-RTT", func() { 505 csc := mocktls.NewMockClientSessionCache(mockCtrl) 506 var state *tls.ClientSessionState 507 receivedSessionTicket := make(chan struct{}) 508 csc.EXPECT().Get(gomock.Any()) 509 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 510 state = css 511 close(receivedSessionTicket) 512 }) 513 clientConf.ClientSessionCache = csc 514 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 515 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 516 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 517 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 518 const initialMaxData protocol.ByteCount = 1337 519 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 520 clientConf, serverConf, 521 clientOrigRTTStats, serverOrigRTTStats, 522 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 523 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 524 true, 525 ) 526 Expect(clientErr).ToNot(HaveOccurred()) 527 Expect(serverErr).ToNot(HaveOccurred()) 528 Eventually(receivedSessionTicket).Should(BeClosed()) 529 Expect(server.ConnectionState().DidResume).To(BeFalse()) 530 Expect(client.ConnectionState().DidResume).To(BeFalse()) 531 532 csc.EXPECT().Get(gomock.Any()).Return(state, true) 533 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 534 535 clientRTTStats := &utils.RTTStats{} 536 serverRTTStats := &utils.RTTStats{} 537 client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( 538 clientConf, serverConf, 539 clientRTTStats, serverRTTStats, 540 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 541 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 542 true, 543 ) 544 Expect(clientErr).ToNot(HaveOccurred()) 545 Expect(serverErr).ToNot(HaveOccurred()) 546 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 547 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 548 549 var tp *wire.TransportParameters 550 var clientReceived0RTTKeys bool 551 for _, ev := range clientEvents { 552 //nolint:exhaustive // only need to process a few events 553 switch ev.Kind { 554 case EventRestoredTransportParameters: 555 tp = ev.TransportParameters 556 case EventReceivedReadKeys: 557 clientReceived0RTTKeys = true 558 } 559 } 560 Expect(clientReceived0RTTKeys).To(BeTrue()) 561 Expect(tp).ToNot(BeNil()) 562 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 563 564 var serverReceived0RTTKeys bool 565 for _, ev := range serverEvents { 566 //nolint:exhaustive // only need to process a few events 567 switch ev.Kind { 568 case EventReceivedReadKeys: 569 serverReceived0RTTKeys = true 570 } 571 } 572 Expect(serverReceived0RTTKeys).To(BeTrue()) 573 574 Expect(server.ConnectionState().DidResume).To(BeTrue()) 575 Expect(client.ConnectionState().DidResume).To(BeTrue()) 576 Expect(server.ConnectionState().Used0RTT).To(BeTrue()) 577 Expect(client.ConnectionState().Used0RTT).To(BeTrue()) 578 }) 579 580 It("rejects 0-RTT, when the transport parameters changed", func() { 581 csc := mocktls.NewMockClientSessionCache(mockCtrl) 582 var state *tls.ClientSessionState 583 receivedSessionTicket := make(chan struct{}) 584 csc.EXPECT().Get(gomock.Any()) 585 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 586 state = css 587 close(receivedSessionTicket) 588 }) 589 clientConf.ClientSessionCache = csc 590 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 591 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 592 const initialMaxData protocol.ByteCount = 1337 593 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 594 clientConf, serverConf, 595 clientOrigRTTStats, &utils.RTTStats{}, 596 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 597 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 598 true, 599 ) 600 Expect(clientErr).ToNot(HaveOccurred()) 601 Expect(serverErr).ToNot(HaveOccurred()) 602 Eventually(receivedSessionTicket).Should(BeClosed()) 603 Expect(server.ConnectionState().DidResume).To(BeFalse()) 604 Expect(client.ConnectionState().DidResume).To(BeFalse()) 605 606 csc.EXPECT().Get(gomock.Any()).Return(state, true) 607 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 608 609 clientRTTStats := &utils.RTTStats{} 610 client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( 611 clientConf, serverConf, 612 clientRTTStats, &utils.RTTStats{}, 613 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 614 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, 615 true, 616 ) 617 Expect(clientErr).ToNot(HaveOccurred()) 618 Expect(serverErr).ToNot(HaveOccurred()) 619 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 620 621 var tp *wire.TransportParameters 622 var clientReceived0RTTKeys bool 623 for _, ev := range clientEvents { 624 //nolint:exhaustive // only need to process a few events 625 switch ev.Kind { 626 case EventRestoredTransportParameters: 627 tp = ev.TransportParameters 628 case EventReceivedReadKeys: 629 clientReceived0RTTKeys = true 630 } 631 } 632 Expect(clientReceived0RTTKeys).To(BeTrue()) 633 Expect(tp).ToNot(BeNil()) 634 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 635 636 Expect(server.ConnectionState().DidResume).To(BeTrue()) 637 Expect(client.ConnectionState().DidResume).To(BeTrue()) 638 Expect(server.ConnectionState().Used0RTT).To(BeFalse()) 639 Expect(client.ConnectionState().Used0RTT).To(BeFalse()) 640 }) 641 }) 642 }) 643 })