github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/internal/handshake/crypto_setup_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/rsa" 6 "crypto/tls" 7 "crypto/x509" 8 "crypto/x509/pkix" 9 "math/big" 10 "net" 11 "reflect" 12 "runtime" 13 "strings" 14 "time" 15 16 mocktls "github.com/metacubex/quic-go/internal/mocks/tls" 17 "github.com/metacubex/quic-go/internal/protocol" 18 "github.com/metacubex/quic-go/internal/qerr" 19 "github.com/metacubex/quic-go/internal/testdata" 20 "github.com/metacubex/quic-go/internal/utils" 21 "github.com/metacubex/quic-go/internal/wire" 22 23 . "github.com/onsi/ginkgo/v2" 24 . "github.com/onsi/gomega" 25 "go.uber.org/mock/gomock" 26 ) 27 28 const ( 29 typeClientHello = 1 30 typeNewSessionTicket = 4 31 ) 32 33 var _ = Describe("Crypto Setup TLS", func() { 34 generateCert := func() tls.Certificate { 35 priv, err := rsa.GenerateKey(rand.Reader, 2048) 36 Expect(err).ToNot(HaveOccurred()) 37 tmpl := &x509.Certificate{ 38 SerialNumber: big.NewInt(1), 39 Subject: pkix.Name{}, 40 SignatureAlgorithm: x509.SHA256WithRSA, 41 NotBefore: time.Now(), 42 NotAfter: time.Now().Add(time.Hour), // valid for an hour 43 BasicConstraintsValid: true, 44 } 45 certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv) 46 Expect(err).ToNot(HaveOccurred()) 47 return tls.Certificate{ 48 PrivateKey: priv, 49 Certificate: [][]byte{certDER}, 50 } 51 } 52 53 var clientConf, serverConf *tls.Config 54 55 BeforeEach(func() { 56 serverConf = testdata.GetTLSConfig() 57 serverConf.NextProtos = []string{"crypto-setup"} 58 clientConf = &tls.Config{ 59 ServerName: "localhost", 60 RootCAs: testdata.GetRootCA(), 61 NextProtos: []string{"crypto-setup"}, 62 } 63 }) 64 65 It("handles qtls errors occurring before during ClientHello generation", func() { 66 tlsConf := testdata.GetTLSConfig() 67 tlsConf.InsecureSkipVerify = true 68 tlsConf.NextProtos = []string{""} 69 cl := NewCryptoSetupClient( 70 protocol.ConnectionID{}, 71 &wire.TransportParameters{}, 72 tlsConf, 73 false, 74 &utils.RTTStats{}, 75 nil, 76 utils.DefaultLogger.WithPrefix("client"), 77 protocol.Version1, 78 ) 79 80 Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ 81 ErrorCode: qerr.InternalError, 82 ErrorMessage: "tls: invalid NextProtos value", 83 })) 84 }) 85 86 It("errors when a message is received at the wrong encryption level", func() { 87 var token protocol.StatelessResetToken 88 server := NewCryptoSetupServer( 89 protocol.ConnectionID{}, 90 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 91 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 92 &wire.TransportParameters{StatelessResetToken: &token}, 93 testdata.GetTLSConfig(), 94 false, 95 &utils.RTTStats{}, 96 nil, 97 utils.DefaultLogger.WithPrefix("server"), 98 protocol.Version1, 99 ) 100 101 Expect(server.StartHandshake()).To(Succeed()) 102 103 fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) 104 // wrong encryption level 105 err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) 106 Expect(err).To(HaveOccurred()) 107 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 108 }) 109 110 Context("filling in a net.Conn in tls.ClientHelloInfo", func() { 111 var ( 112 local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42} 113 remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} 114 ) 115 116 It("wraps GetCertificate", func() { 117 var localAddr, remoteAddr net.Addr 118 tlsConf := &tls.Config{ 119 GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { 120 localAddr = info.Conn.LocalAddr() 121 remoteAddr = info.Conn.RemoteAddr() 122 cert := generateCert() 123 return &cert, nil 124 }, 125 } 126 addConnToClientHelloInfo(tlsConf, local, remote) 127 _, err := tlsConf.GetCertificate(&tls.ClientHelloInfo{}) 128 Expect(err).ToNot(HaveOccurred()) 129 Expect(localAddr).To(Equal(local)) 130 Expect(remoteAddr).To(Equal(remote)) 131 }) 132 133 It("wraps GetConfigForClient", func() { 134 var localAddr, remoteAddr net.Addr 135 tlsConf := &tls.Config{ 136 GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) { 137 localAddr = info.Conn.LocalAddr() 138 remoteAddr = info.Conn.RemoteAddr() 139 return &tls.Config{}, nil 140 }, 141 } 142 addConnToClientHelloInfo(tlsConf, local, remote) 143 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 144 Expect(err).ToNot(HaveOccurred()) 145 Expect(localAddr).To(Equal(local)) 146 Expect(remoteAddr).To(Equal(remote)) 147 Expect(conf).ToNot(BeNil()) 148 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 149 }) 150 151 It("wraps GetConfigForClient, recursively", func() { 152 var localAddr, remoteAddr net.Addr 153 tlsConf := &tls.Config{} 154 var innerConf *tls.Config 155 getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam 156 localAddr = info.Conn.LocalAddr() 157 remoteAddr = info.Conn.RemoteAddr() 158 cert := generateCert() 159 return &cert, nil 160 } 161 tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) { 162 innerConf = tlsConf.Clone() 163 // set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config 164 innerConf.MaxVersion = tls.VersionTLS12 165 innerConf.GetCertificate = getCert 166 return innerConf, nil 167 } 168 addConnToClientHelloInfo(tlsConf, local, remote) 169 conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) 170 Expect(err).ToNot(HaveOccurred()) 171 Expect(conf).ToNot(BeNil()) 172 Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13)) 173 _, err = conf.GetCertificate(&tls.ClientHelloInfo{}) 174 Expect(err).ToNot(HaveOccurred()) 175 Expect(localAddr).To(Equal(local)) 176 Expect(remoteAddr).To(Equal(remote)) 177 // make sure that the tls.Config returned by GetConfigForClient isn't modified 178 Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue()) 179 Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12)) 180 }) 181 }) 182 183 Context("doing the handshake", func() { 184 newRTTStatsWithRTT := func(rtt time.Duration) *utils.RTTStats { 185 rttStats := &utils.RTTStats{} 186 rttStats.UpdateRTT(rtt, 0, time.Now()) 187 ExpectWithOffset(1, rttStats.SmoothedRTT()).To(Equal(rtt)) 188 return rttStats 189 } 190 191 // The clientEvents and serverEvents contain all events that were not processed by the function, 192 // i.e. not EventWriteInitialData, EventWriteHandshakeData, EventHandshakeComplete. 193 handshake := func(client, server CryptoSetup) (clientEvents []Event, clientErr error, serverEvents []Event, serverErr error) { 194 Expect(client.StartHandshake()).To(Succeed()) 195 Expect(server.StartHandshake()).To(Succeed()) 196 197 var clientHandshakeComplete, serverHandshakeComplete bool 198 199 for { 200 clientLoop: 201 for { 202 ev := client.NextEvent() 203 switch ev.Kind { 204 case EventNoEvent: 205 break clientLoop 206 case EventWriteInitialData: 207 if err := server.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 208 serverErr = err 209 return 210 } 211 case EventWriteHandshakeData: 212 if err := server.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 213 serverErr = err 214 return 215 } 216 case EventHandshakeComplete: 217 clientHandshakeComplete = true 218 default: 219 clientEvents = append(clientEvents, ev) 220 } 221 } 222 223 serverLoop: 224 for { 225 ev := server.NextEvent() 226 switch ev.Kind { 227 case EventNoEvent: 228 break serverLoop 229 case EventWriteInitialData: 230 if err := client.HandleMessage(ev.Data, protocol.EncryptionInitial); err != nil { 231 clientErr = err 232 return 233 } 234 case EventWriteHandshakeData: 235 if err := client.HandleMessage(ev.Data, protocol.EncryptionHandshake); err != nil { 236 clientErr = err 237 return 238 } 239 case EventHandshakeComplete: 240 serverHandshakeComplete = true 241 ticket, err := server.GetSessionTicket() 242 Expect(err).ToNot(HaveOccurred()) 243 if ticket != nil { 244 Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) 245 } 246 default: 247 serverEvents = append(serverEvents, ev) 248 } 249 } 250 251 if clientHandshakeComplete && serverHandshakeComplete { 252 break 253 } 254 } 255 return 256 } 257 258 handshakeWithTLSConf := func( 259 clientConf, serverConf *tls.Config, 260 clientRTTStats, serverRTTStats *utils.RTTStats, 261 clientTransportParameters, serverTransportParameters *wire.TransportParameters, 262 enable0RTT bool, 263 ) (CryptoSetup /* client */, []Event /* more client events */, error, /* client error */ 264 CryptoSetup /* server */, []Event /* more server events */, error, /* server error */ 265 ) { 266 client := NewCryptoSetupClient( 267 protocol.ConnectionID{}, 268 clientTransportParameters, 269 clientConf, 270 enable0RTT, 271 clientRTTStats, 272 nil, 273 utils.DefaultLogger.WithPrefix("client"), 274 protocol.Version1, 275 ) 276 277 if serverTransportParameters.StatelessResetToken == nil { 278 var token protocol.StatelessResetToken 279 serverTransportParameters.StatelessResetToken = &token 280 } 281 server := NewCryptoSetupServer( 282 protocol.ConnectionID{}, 283 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 284 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 285 serverTransportParameters, 286 serverConf, 287 enable0RTT, 288 serverRTTStats, 289 nil, 290 utils.DefaultLogger.WithPrefix("server"), 291 protocol.Version1, 292 ) 293 cEvents, cErr, sEvents, sErr := handshake(client, server) 294 return client, cEvents, cErr, server, sEvents, sErr 295 } 296 297 It("handshakes", func() { 298 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 299 clientConf, serverConf, 300 &utils.RTTStats{}, &utils.RTTStats{}, 301 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 302 false, 303 ) 304 Expect(clientErr).ToNot(HaveOccurred()) 305 Expect(serverErr).ToNot(HaveOccurred()) 306 }) 307 308 It("performs a HelloRetryRequst", func() { 309 serverConf.CurvePreferences = []tls.CurveID{tls.CurveP384} 310 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 311 clientConf, serverConf, 312 &utils.RTTStats{}, &utils.RTTStats{}, 313 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 314 false, 315 ) 316 Expect(clientErr).ToNot(HaveOccurred()) 317 Expect(serverErr).ToNot(HaveOccurred()) 318 }) 319 320 It("handshakes with client auth", func() { 321 clientConf.Certificates = []tls.Certificate{generateCert()} 322 serverConf.ClientAuth = tls.RequireAnyClientCert 323 _, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 324 clientConf, serverConf, 325 &utils.RTTStats{}, &utils.RTTStats{}, 326 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 327 false, 328 ) 329 Expect(clientErr).ToNot(HaveOccurred()) 330 Expect(serverErr).ToNot(HaveOccurred()) 331 }) 332 333 It("receives transport parameters", func() { 334 cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 42 * time.Second} 335 client := NewCryptoSetupClient( 336 protocol.ConnectionID{}, 337 cTransportParameters, 338 clientConf, 339 false, 340 &utils.RTTStats{}, 341 nil, 342 utils.DefaultLogger.WithPrefix("client"), 343 protocol.Version1, 344 ) 345 346 var token protocol.StatelessResetToken 347 sTransportParameters := &wire.TransportParameters{ 348 MaxIdleTimeout: 1337 * time.Second, 349 StatelessResetToken: &token, 350 ActiveConnectionIDLimit: 2, 351 } 352 server := NewCryptoSetupServer( 353 protocol.ConnectionID{}, 354 &net.UDPAddr{IP: net.IPv6loopback, Port: 1234}, 355 &net.UDPAddr{IP: net.IPv6loopback, Port: 4321}, 356 sTransportParameters, 357 serverConf, 358 false, 359 &utils.RTTStats{}, 360 nil, 361 utils.DefaultLogger.WithPrefix("server"), 362 protocol.Version1, 363 ) 364 365 clientEvents, cErr, serverEvents, sErr := handshake(client, server) 366 Expect(cErr).ToNot(HaveOccurred()) 367 Expect(sErr).ToNot(HaveOccurred()) 368 var clientReceivedTransportParameters *wire.TransportParameters 369 for _, ev := range clientEvents { 370 if ev.Kind == EventReceivedTransportParameters { 371 clientReceivedTransportParameters = ev.TransportParameters 372 } 373 } 374 Expect(clientReceivedTransportParameters).ToNot(BeNil()) 375 Expect(clientReceivedTransportParameters.MaxIdleTimeout).To(Equal(1337 * time.Second)) 376 377 var serverReceivedTransportParameters *wire.TransportParameters 378 for _, ev := range serverEvents { 379 if ev.Kind == EventReceivedTransportParameters { 380 serverReceivedTransportParameters = ev.TransportParameters 381 } 382 } 383 Expect(serverReceivedTransportParameters).ToNot(BeNil()) 384 Expect(serverReceivedTransportParameters.MaxIdleTimeout).To(Equal(42 * time.Second)) 385 }) 386 387 Context("with session tickets", func() { 388 It("errors when the NewSessionTicket is sent at the wrong encryption level", func() { 389 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 390 clientConf, serverConf, 391 &utils.RTTStats{}, &utils.RTTStats{}, 392 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 393 false, 394 ) 395 Expect(clientErr).ToNot(HaveOccurred()) 396 Expect(serverErr).ToNot(HaveOccurred()) 397 398 // inject an invalid session ticket 399 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 400 err := client.HandleMessage(b, protocol.EncryptionHandshake) 401 Expect(err).To(HaveOccurred()) 402 Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) 403 }) 404 405 It("errors when handling the NewSessionTicket fails", func() { 406 client, _, clientErr, _, _, serverErr := handshakeWithTLSConf( 407 clientConf, serverConf, 408 &utils.RTTStats{}, &utils.RTTStats{}, 409 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 410 false, 411 ) 412 Expect(clientErr).ToNot(HaveOccurred()) 413 Expect(serverErr).ToNot(HaveOccurred()) 414 415 // inject an invalid session ticket 416 b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) 417 err := client.HandleMessage(b, protocol.Encryption1RTT) 418 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 419 Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) 420 }) 421 422 It("uses session resumption", func() { 423 csc := mocktls.NewMockClientSessionCache(mockCtrl) 424 var state *tls.ClientSessionState 425 receivedSessionTicket := make(chan struct{}) 426 csc.EXPECT().Get(gomock.Any()) 427 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 428 state = css 429 close(receivedSessionTicket) 430 }) 431 clientConf.ClientSessionCache = csc 432 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 433 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 434 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 435 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 436 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 437 clientConf, serverConf, 438 clientOrigRTTStats, serverOrigRTTStats, 439 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 440 false, 441 ) 442 Expect(clientErr).ToNot(HaveOccurred()) 443 Expect(serverErr).ToNot(HaveOccurred()) 444 Eventually(receivedSessionTicket).Should(BeClosed()) 445 Expect(server.ConnectionState().DidResume).To(BeFalse()) 446 Expect(client.ConnectionState().DidResume).To(BeFalse()) 447 448 csc.EXPECT().Get(gomock.Any()).Return(state, true) 449 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 450 clientRTTStats := &utils.RTTStats{} 451 serverRTTStats := &utils.RTTStats{} 452 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 453 clientConf, serverConf, 454 clientRTTStats, serverRTTStats, 455 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 456 false, 457 ) 458 Expect(clientErr).ToNot(HaveOccurred()) 459 Expect(serverErr).ToNot(HaveOccurred()) 460 Eventually(receivedSessionTicket).Should(BeClosed()) 461 Expect(server.ConnectionState().DidResume).To(BeTrue()) 462 Expect(client.ConnectionState().DidResume).To(BeTrue()) 463 if !strings.Contains(runtime.Version(), "go1.20") { 464 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 465 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 466 } 467 }) 468 469 It("doesn't use session resumption if the server disabled it", func() { 470 csc := mocktls.NewMockClientSessionCache(mockCtrl) 471 var state *tls.ClientSessionState 472 receivedSessionTicket := make(chan struct{}) 473 csc.EXPECT().Get(gomock.Any()) 474 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 475 state = css 476 close(receivedSessionTicket) 477 }) 478 clientConf.ClientSessionCache = csc 479 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 480 clientConf, serverConf, 481 &utils.RTTStats{}, &utils.RTTStats{}, 482 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 483 false, 484 ) 485 Expect(clientErr).ToNot(HaveOccurred()) 486 Expect(serverErr).ToNot(HaveOccurred()) 487 Eventually(receivedSessionTicket).Should(BeClosed()) 488 Expect(server.ConnectionState().DidResume).To(BeFalse()) 489 Expect(client.ConnectionState().DidResume).To(BeFalse()) 490 491 serverConf.SessionTicketsDisabled = true 492 csc.EXPECT().Get(gomock.Any()).Return(state, true) 493 client, _, clientErr, server, _, serverErr = handshakeWithTLSConf( 494 clientConf, serverConf, 495 &utils.RTTStats{}, &utils.RTTStats{}, 496 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 497 false, 498 ) 499 Expect(clientErr).ToNot(HaveOccurred()) 500 Expect(serverErr).ToNot(HaveOccurred()) 501 Eventually(receivedSessionTicket).Should(BeClosed()) 502 Expect(server.ConnectionState().DidResume).To(BeFalse()) 503 Expect(client.ConnectionState().DidResume).To(BeFalse()) 504 }) 505 506 It("uses 0-RTT", func() { 507 csc := mocktls.NewMockClientSessionCache(mockCtrl) 508 var state *tls.ClientSessionState 509 receivedSessionTicket := make(chan struct{}) 510 csc.EXPECT().Get(gomock.Any()) 511 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 512 state = css 513 close(receivedSessionTicket) 514 }) 515 clientConf.ClientSessionCache = csc 516 const serverRTT = 25 * time.Millisecond // RTT as measured by the server. Should be restored. 517 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 518 serverOrigRTTStats := newRTTStatsWithRTT(serverRTT) 519 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 520 const initialMaxData protocol.ByteCount = 1337 521 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 522 clientConf, serverConf, 523 clientOrigRTTStats, serverOrigRTTStats, 524 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 525 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 526 true, 527 ) 528 Expect(clientErr).ToNot(HaveOccurred()) 529 Expect(serverErr).ToNot(HaveOccurred()) 530 Eventually(receivedSessionTicket).Should(BeClosed()) 531 Expect(server.ConnectionState().DidResume).To(BeFalse()) 532 Expect(client.ConnectionState().DidResume).To(BeFalse()) 533 534 csc.EXPECT().Get(gomock.Any()).Return(state, true) 535 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 536 537 clientRTTStats := &utils.RTTStats{} 538 serverRTTStats := &utils.RTTStats{} 539 client, clientEvents, clientErr, server, serverEvents, serverErr := handshakeWithTLSConf( 540 clientConf, serverConf, 541 clientRTTStats, serverRTTStats, 542 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 543 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 544 true, 545 ) 546 Expect(clientErr).ToNot(HaveOccurred()) 547 Expect(serverErr).ToNot(HaveOccurred()) 548 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 549 Expect(serverRTTStats.SmoothedRTT()).To(Equal(serverRTT)) 550 551 var tp *wire.TransportParameters 552 var clientReceived0RTTKeys bool 553 for _, ev := range clientEvents { 554 switch ev.Kind { 555 case EventRestoredTransportParameters: 556 tp = ev.TransportParameters 557 case EventReceivedReadKeys: 558 clientReceived0RTTKeys = true 559 } 560 } 561 Expect(clientReceived0RTTKeys).To(BeTrue()) 562 Expect(tp).ToNot(BeNil()) 563 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 564 565 var serverReceived0RTTKeys bool 566 for _, ev := range serverEvents { 567 switch ev.Kind { 568 case EventReceivedReadKeys: 569 serverReceived0RTTKeys = true 570 } 571 } 572 Expect(serverReceived0RTTKeys).To(BeTrue()) 573 574 Expect(server.ConnectionState().DidResume).To(BeTrue()) 575 Expect(client.ConnectionState().DidResume).To(BeTrue()) 576 Expect(server.ConnectionState().Used0RTT).To(BeTrue()) 577 Expect(client.ConnectionState().Used0RTT).To(BeTrue()) 578 }) 579 580 It("rejects 0-RTT, when the transport parameters changed", func() { 581 csc := mocktls.NewMockClientSessionCache(mockCtrl) 582 var state *tls.ClientSessionState 583 receivedSessionTicket := make(chan struct{}) 584 csc.EXPECT().Get(gomock.Any()) 585 csc.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, css *tls.ClientSessionState) { 586 state = css 587 close(receivedSessionTicket) 588 }) 589 clientConf.ClientSessionCache = csc 590 const clientRTT = 30 * time.Millisecond // RTT as measured by the client. Should be restored. 591 clientOrigRTTStats := newRTTStatsWithRTT(clientRTT) 592 const initialMaxData protocol.ByteCount = 1337 593 client, _, clientErr, server, _, serverErr := handshakeWithTLSConf( 594 clientConf, serverConf, 595 clientOrigRTTStats, &utils.RTTStats{}, 596 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 597 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData}, 598 true, 599 ) 600 Expect(clientErr).ToNot(HaveOccurred()) 601 Expect(serverErr).ToNot(HaveOccurred()) 602 Eventually(receivedSessionTicket).Should(BeClosed()) 603 Expect(server.ConnectionState().DidResume).To(BeFalse()) 604 Expect(client.ConnectionState().DidResume).To(BeFalse()) 605 606 csc.EXPECT().Get(gomock.Any()).Return(state, true) 607 csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) 608 609 clientRTTStats := &utils.RTTStats{} 610 client, clientEvents, clientErr, server, _, serverErr := handshakeWithTLSConf( 611 clientConf, serverConf, 612 clientRTTStats, &utils.RTTStats{}, 613 &wire.TransportParameters{ActiveConnectionIDLimit: 2}, 614 &wire.TransportParameters{ActiveConnectionIDLimit: 2, InitialMaxData: initialMaxData - 1}, 615 true, 616 ) 617 Expect(clientErr).ToNot(HaveOccurred()) 618 Expect(serverErr).ToNot(HaveOccurred()) 619 Expect(clientRTTStats.SmoothedRTT()).To(Equal(clientRTT)) 620 621 var tp *wire.TransportParameters 622 var clientReceived0RTTKeys bool 623 for _, ev := range clientEvents { 624 switch ev.Kind { 625 case EventRestoredTransportParameters: 626 tp = ev.TransportParameters 627 case EventReceivedReadKeys: 628 clientReceived0RTTKeys = true 629 } 630 } 631 Expect(clientReceived0RTTKeys).To(BeTrue()) 632 Expect(tp).ToNot(BeNil()) 633 Expect(tp.InitialMaxData).To(Equal(initialMaxData)) 634 635 Expect(server.ConnectionState().DidResume).To(BeTrue()) 636 Expect(client.ConnectionState().DidResume).To(BeTrue()) 637 Expect(server.ConnectionState().Used0RTT).To(BeFalse()) 638 Expect(client.ConnectionState().Used0RTT).To(BeFalse()) 639 }) 640 }) 641 }) 642 })