github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/internal/handshake/updatable_aead_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/tls" 6 "fmt" 7 "testing" 8 "time" 9 10 mocklogging "github.com/daeuniverse/quic-go/internal/mocks/logging" 11 "github.com/daeuniverse/quic-go/internal/protocol" 12 "github.com/daeuniverse/quic-go/internal/qerr" 13 "github.com/daeuniverse/quic-go/internal/utils" 14 "github.com/daeuniverse/quic-go/logging" 15 16 . "github.com/onsi/ginkgo/v2" 17 . "github.com/onsi/gomega" 18 "go.uber.org/mock/gomock" 19 ) 20 21 var _ = Describe("Updatable AEAD", func() { 22 DescribeTable("ChaCha test vector", 23 func(v protocol.Version, expectedPayload, expectedPacket []byte) { 24 secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") 25 aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) 26 chacha := cipherSuites[2] 27 Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) 28 aead.SetWriteKey(chacha, secret) 29 const pnOffset = 1 30 header := splitHexString("4200bff4") 31 payloadOffset := len(header) 32 plaintext := splitHexString("01") 33 payload := aead.Seal(nil, plaintext, 654360564, header) 34 Expect(payload).To(Equal(expectedPayload)) 35 packet := append(header, payload...) 36 aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) 37 Expect(packet).To(Equal(expectedPacket)) 38 }, 39 Entry("QUIC v1", 40 protocol.Version1, 41 splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), 42 splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), 43 ), 44 Entry("QUIC v2", 45 protocol.Version2, 46 splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), 47 splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), 48 ), 49 ) 50 51 for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} { 52 v := ver 53 54 Context(fmt.Sprintf("using version %s", v), func() { 55 for i := range cipherSuites { 56 cs := cipherSuites[i] 57 58 Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { 59 var ( 60 client, server *updatableAEAD 61 serverTracer *mocklogging.MockConnectionTracer 62 rttStats *utils.RTTStats 63 ) 64 65 BeforeEach(func() { 66 var tr *logging.ConnectionTracer 67 tr, serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) 68 trafficSecret1 := make([]byte, 16) 69 trafficSecret2 := make([]byte, 16) 70 rand.Read(trafficSecret1) 71 rand.Read(trafficSecret2) 72 73 rttStats = utils.NewRTTStats() 74 client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) 75 server = newUpdatableAEAD(rttStats, tr, utils.DefaultLogger, v) 76 client.SetReadKey(cs, trafficSecret2) 77 client.SetWriteKey(cs, trafficSecret1) 78 server.SetReadKey(cs, trafficSecret1) 79 server.SetWriteKey(cs, trafficSecret2) 80 }) 81 82 Context("header protection", func() { 83 It("encrypts and decrypts the header", func() { 84 var lastFiveBitsDifferent int 85 for i := 0; i < 100; i++ { 86 sample := make([]byte, 16) 87 rand.Read(sample) 88 header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} 89 client.EncryptHeader(sample, &header[0], header[9:13]) 90 if header[0]&0x1f != 0xb5&0x1f { 91 lastFiveBitsDifferent++ 92 } 93 Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) 94 Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 95 Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) 96 server.DecryptHeader(sample, &header[0], header[9:13]) 97 Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) 98 } 99 Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) 100 }) 101 }) 102 103 Context("message encryption", func() { 104 var msg, ad []byte 105 106 BeforeEach(func() { 107 msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") 108 ad = []byte("Donec in velit neque.") 109 }) 110 111 It("encrypts and decrypts a message", func() { 112 encrypted := server.Seal(nil, msg, 0x1337, ad) 113 opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 114 Expect(err).ToNot(HaveOccurred()) 115 Expect(opened).To(Equal(msg)) 116 }) 117 118 It("saves the first packet number", func() { 119 client.Seal(nil, msg, 0x1337, ad) 120 Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) 121 client.Seal(nil, msg, 0x1338, ad) 122 Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) 123 }) 124 125 It("fails to open a message if the associated data is not the same", func() { 126 encrypted := client.Seal(nil, msg, 0x1337, ad) 127 _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) 128 Expect(err).To(MatchError(ErrDecryptionFailed)) 129 }) 130 131 It("fails to open a message if the packet number is not the same", func() { 132 encrypted := server.Seal(nil, msg, 0x1337, ad) 133 _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) 134 Expect(err).To(MatchError(ErrDecryptionFailed)) 135 }) 136 137 It("decodes the packet number", func() { 138 encrypted := server.Seal(nil, msg, 0x1337, ad) 139 _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 140 Expect(err).ToNot(HaveOccurred()) 141 Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) 142 }) 143 144 It("ignores packets it can't decrypt for packet number derivation", func() { 145 encrypted := server.Seal(nil, msg, 0x1337, ad) 146 _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 147 Expect(err).To(HaveOccurred()) 148 Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) 149 }) 150 151 It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { 152 client.invalidPacketLimit = 10 153 for i := 0; i < 9; i++ { 154 _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) 155 Expect(err).To(MatchError(ErrDecryptionFailed)) 156 } 157 _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) 158 Expect(err).To(HaveOccurred()) 159 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 160 Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) 161 }) 162 163 Context("key updates", func() { 164 Context("receiving key updates", func() { 165 It("updates keys", func() { 166 now := time.Now() 167 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 168 encrypted0 := server.Seal(nil, msg, 0x1337, ad) 169 server.rollKeys() 170 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 171 encrypted1 := server.Seal(nil, msg, 0x1337, ad) 172 Expect(encrypted0).ToNot(Equal(encrypted1)) 173 // expect opening to fail. The client didn't roll keys yet 174 _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) 175 Expect(err).To(MatchError(ErrDecryptionFailed)) 176 client.rollKeys() 177 decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) 178 Expect(err).ToNot(HaveOccurred()) 179 Expect(decrypted).To(Equal(msg)) 180 }) 181 182 It("updates the keys when receiving a packet with the next key phase", func() { 183 now := time.Now() 184 // receive the first packet at key phase zero 185 encrypted0 := client.Seal(nil, msg, 0x42, ad) 186 decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) 187 Expect(err).ToNot(HaveOccurred()) 188 Expect(decrypted).To(Equal(msg)) 189 // send one packet at key phase zero 190 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 191 _ = server.Seal(nil, msg, 0x1, ad) 192 // now received a message at key phase one 193 client.rollKeys() 194 encrypted1 := client.Seal(nil, msg, 0x43, ad) 195 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 196 decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) 197 Expect(err).ToNot(HaveOccurred()) 198 Expect(decrypted).To(Equal(msg)) 199 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 200 }) 201 202 It("opens a reordered packet with the old keys after an update", func() { 203 now := time.Now() 204 encrypted01 := client.Seal(nil, msg, 0x42, ad) 205 encrypted02 := client.Seal(nil, msg, 0x43, ad) 206 // receive the first packet with key phase 0 207 _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) 208 Expect(err).ToNot(HaveOccurred()) 209 // send one packet at key phase zero 210 _ = server.Seal(nil, msg, 0x1, ad) 211 // now receive a packet with key phase 1 212 client.rollKeys() 213 encrypted1 := client.Seal(nil, msg, 0x44, ad) 214 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 215 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 216 _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) 217 Expect(err).ToNot(HaveOccurred()) 218 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 219 // now receive a reordered packet with key phase 0 220 decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) 221 Expect(err).ToNot(HaveOccurred()) 222 Expect(decrypted).To(Equal(msg)) 223 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 224 }) 225 226 It("drops keys 3 PTOs after a key update", func() { 227 now := time.Now() 228 rttStats.UpdateRTT(10*time.Millisecond, 0, now) 229 pto := rttStats.PTO(true) 230 encrypted01 := client.Seal(nil, msg, 0x42, ad) 231 encrypted02 := client.Seal(nil, msg, 0x43, ad) 232 // receive the first packet with key phase 0 233 _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) 234 Expect(err).ToNot(HaveOccurred()) 235 // send one packet at key phase zero 236 _ = server.Seal(nil, msg, 0x1, ad) 237 // now receive a packet with key phase 1 238 client.rollKeys() 239 encrypted1 := client.Seal(nil, msg, 0x44, ad) 240 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 241 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 242 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 243 _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) 244 Expect(err).ToNot(HaveOccurred()) 245 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 246 // now receive a reordered packet with key phase 0 247 _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) 248 Expect(err).To(MatchError(ErrKeysDropped)) 249 }) 250 251 It("allows the first key update immediately", func() { 252 // receive a packet at key phase one, before having sent or received any packets at key phase 0 253 client.rollKeys() 254 encrypted1 := client.Seal(nil, msg, 0x1337, ad) 255 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 256 _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) 257 Expect(err).ToNot(HaveOccurred()) 258 }) 259 260 It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { 261 client.rollKeys() 262 encrypted := client.Seal(nil, msg, 0x1337, ad) 263 encrypted = encrypted[:len(encrypted)-1] 264 _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) 265 Expect(err).To(MatchError(ErrDecryptionFailed)) 266 }) 267 268 It("errors when the peer updates keys too frequently", func() { 269 server.rollKeys() 270 client.rollKeys() 271 // receive the first packet at key phase one 272 encrypted0 := client.Seal(nil, msg, 0x42, ad) 273 _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) 274 Expect(err).ToNot(HaveOccurred()) 275 // now receive a packet at key phase two, before having sent any packets 276 client.rollKeys() 277 encrypted1 := client.Seal(nil, msg, 0x42, ad) 278 _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) 279 Expect(err).To(MatchError(&qerr.TransportError{ 280 ErrorCode: qerr.KeyUpdateError, 281 ErrorMessage: "keys updated too quickly", 282 })) 283 }) 284 }) 285 286 Context("initiating key updates", func() { 287 const firstKeyUpdateInterval = 5 288 const keyUpdateInterval = 20 289 var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64 290 291 BeforeEach(func() { 292 origKeyUpdateInterval = KeyUpdateInterval 293 origFirstKeyUpdateInterval = FirstKeyUpdateInterval 294 KeyUpdateInterval = keyUpdateInterval 295 FirstKeyUpdateInterval = firstKeyUpdateInterval 296 server.SetHandshakeConfirmed() 297 }) 298 299 AfterEach(func() { 300 KeyUpdateInterval = origKeyUpdateInterval 301 FirstKeyUpdateInterval = origFirstKeyUpdateInterval 302 }) 303 304 It("initiates a key update after sealing the maximum number of packets, for the first update", func() { 305 for i := 0; i < firstKeyUpdateInterval; i++ { 306 pn := protocol.PacketNumber(i) 307 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 308 server.Seal(nil, msg, pn, ad) 309 } 310 // the first update is allowed without receiving an acknowledgement 311 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 312 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 313 }) 314 315 It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { 316 server.rollKeys() 317 client.rollKeys() 318 for i := 0; i < keyUpdateInterval; i++ { 319 pn := protocol.PacketNumber(i) 320 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 321 server.Seal(nil, msg, pn, ad) 322 } 323 // no update allowed before receiving an acknowledgement for the current key phase 324 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 325 // receive an ACK for a packet sent in key phase 0 326 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 327 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) 328 Expect(err).ToNot(HaveOccurred()) 329 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 330 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 331 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) 332 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 333 }) 334 335 It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { 336 // First make sure that we update our keys. 337 for i := 0; i < firstKeyUpdateInterval; i++ { 338 pn := protocol.PacketNumber(i) 339 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 340 server.Seal(nil, msg, pn, ad) 341 } 342 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 343 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 344 // Now that our keys are updated, send a packet using the new keys. 345 const nextPN = firstKeyUpdateInterval + 1 346 server.Seal(nil, msg, nextPN, ad) 347 // We haven't decrypted any packet in the new key phase yet. 348 // This means that the ACK must have been sent in the old key phase. 349 Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ 350 ErrorCode: qerr.KeyUpdateError, 351 ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", 352 })) 353 }) 354 355 It("doesn't error before actually sending a packet in the new key phase", func() { 356 // First make sure that we update our keys. 357 for i := 0; i < firstKeyUpdateInterval; i++ { 358 pn := protocol.PacketNumber(i) 359 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 360 server.Seal(nil, msg, pn, ad) 361 } 362 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 363 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 364 Expect(err).ToNot(HaveOccurred()) 365 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 366 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 367 // Now that our keys are updated, send a packet using the new keys. 368 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 369 // We haven't decrypted any packet in the new key phase yet. 370 // This means that the ACK must have been sent in the old key phase. 371 Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) 372 }) 373 374 It("initiates a key update after opening the maximum number of packets, for the first update", func() { 375 for i := 0; i < firstKeyUpdateInterval; i++ { 376 pn := protocol.PacketNumber(i) 377 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 378 encrypted := client.Seal(nil, msg, pn, ad) 379 _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) 380 Expect(err).ToNot(HaveOccurred()) 381 } 382 // the first update is allowed without receiving an acknowledgement 383 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 384 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 385 }) 386 387 It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { 388 server.rollKeys() 389 client.rollKeys() 390 for i := 0; i < keyUpdateInterval; i++ { 391 pn := protocol.PacketNumber(i) 392 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 393 encrypted := client.Seal(nil, msg, pn, ad) 394 _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) 395 Expect(err).ToNot(HaveOccurred()) 396 } 397 // no update allowed before receiving an acknowledgement for the current key phase 398 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 399 server.Seal(nil, msg, 1, ad) 400 Expect(server.SetLargestAcked(1)).To(Succeed()) 401 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 402 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) 403 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 404 }) 405 406 It("drops keys 3 PTOs after a key update", func() { 407 now := time.Now() 408 for i := 0; i < firstKeyUpdateInterval; i++ { 409 pn := protocol.PacketNumber(i) 410 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 411 server.Seal(nil, msg, pn, ad) 412 } 413 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 414 _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) 415 Expect(err).ToNot(HaveOccurred()) 416 Expect(server.SetLargestAcked(0)).To(Succeed()) 417 // Now we've initiated the first key update. 418 // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there 419 threePTO := 3 * rttStats.PTO(false) 420 dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) 421 _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) 422 Expect(err).ToNot(HaveOccurred()) 423 // Now receive a packet with key phase 1. 424 // This should start the timer to drop the keys after 3 PTOs. 425 client.rollKeys() 426 dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) 427 t := now.Add(threePTO).Add(time.Second) 428 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 429 _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) 430 Expect(err).ToNot(HaveOccurred()) 431 // Make sure the keys are still here. 432 _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) 433 Expect(err).ToNot(HaveOccurred()) 434 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 435 _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) 436 Expect(err).To(MatchError(ErrKeysDropped)) 437 }) 438 439 It("doesn't drop the first key generation too early", func() { 440 now := time.Now() 441 data1 := client.Seal(nil, msg, 1, ad) 442 _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) 443 Expect(err).ToNot(HaveOccurred()) 444 for i := 0; i < firstKeyUpdateInterval; i++ { 445 pn := protocol.PacketNumber(i) 446 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 447 server.Seal(nil, msg, pn, ad) 448 Expect(server.SetLargestAcked(pn)).To(Succeed()) 449 } 450 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 451 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 452 // The server never received a packet at key phase 1. 453 // Make sure the key phase 0 is still there at a much later point. 454 data2 := client.Seal(nil, msg, 1, ad) 455 _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) 456 Expect(err).ToNot(HaveOccurred()) 457 }) 458 459 It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { 460 for i := 0; i < firstKeyUpdateInterval; i++ { 461 pn := protocol.PacketNumber(i) 462 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 463 server.Seal(nil, msg, pn, ad) 464 } 465 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 466 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 467 Expect(err).ToNot(HaveOccurred()) 468 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 469 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 470 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 471 const nextPN = keyUpdateInterval + 1 472 // Send and receive an acknowledgement for a packet in key phase 1. 473 // We are now running a timer to drop the keys with 3 PTO. 474 server.Seal(nil, msg, nextPN, ad) 475 client.rollKeys() 476 dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) 477 now := time.Now() 478 _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) 479 Expect(err).ToNot(HaveOccurred()) 480 Expect(server.SetLargestAcked(nextPN)) 481 // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. 482 // This mean that we need to drop the keys for key phase 0 immediately. 483 client.rollKeys() 484 dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) 485 gomock.InOrder( 486 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), 487 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), 488 ) 489 _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) 490 Expect(err).ToNot(HaveOccurred()) 491 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 492 }) 493 494 It("drops keys early when we initiate another key update within the 3 PTO period", func() { 495 server.SetHandshakeConfirmed() 496 // send so many packets that we initiate the first key update 497 for i := 0; i < firstKeyUpdateInterval; i++ { 498 pn := protocol.PacketNumber(i) 499 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 500 server.Seal(nil, msg, pn, ad) 501 } 502 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 503 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 504 Expect(err).ToNot(HaveOccurred()) 505 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 506 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 507 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 508 // send so many packets that we initiate the next key update 509 for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { 510 pn := protocol.PacketNumber(i) 511 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 512 server.Seal(nil, msg, pn, ad) 513 } 514 client.rollKeys() 515 b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) 516 now := time.Now() 517 _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) 518 Expect(err).ToNot(HaveOccurred()) 519 ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) 520 gomock.InOrder( 521 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), 522 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), 523 ) 524 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 525 // We haven't received an ACK for a packet sent in key phase 2 yet. 526 // Make sure we canceled the timer to drop the previous key phase. 527 b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) 528 _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) 529 Expect(err).ToNot(HaveOccurred()) 530 }) 531 }) 532 }) 533 }) 534 }) 535 } 536 }) 537 } 538 }) 539 540 func getClientAndServer() (client, server *updatableAEAD) { 541 trafficSecret1 := make([]byte, 16) 542 trafficSecret2 := make([]byte, 16) 543 rand.Read(trafficSecret1) 544 rand.Read(trafficSecret2) 545 546 cs := cipherSuites[0] 547 rttStats := utils.NewRTTStats() 548 client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) 549 server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) 550 client.SetReadKey(cs, trafficSecret2) 551 client.SetWriteKey(cs, trafficSecret1) 552 server.SetReadKey(cs, trafficSecret1) 553 server.SetWriteKey(cs, trafficSecret2) 554 return 555 } 556 557 func BenchmarkPacketEncryption(b *testing.B) { 558 client, _ := getClientAndServer() 559 const l = 1200 560 src := make([]byte, l) 561 rand.Read(src) 562 ad := make([]byte, 32) 563 rand.Read(ad) 564 565 for i := 0; i < b.N; i++ { 566 src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad) 567 } 568 } 569 570 func BenchmarkPacketDecryption(b *testing.B) { 571 client, server := getClientAndServer() 572 const l = 1200 573 src := make([]byte, l) 574 dst := make([]byte, l) 575 rand.Read(src) 576 ad := make([]byte, 32) 577 rand.Read(ad) 578 src = client.Seal(src[:0], src[:l], 1337, ad) 579 580 for i := 0; i < b.N; i++ { 581 if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil { 582 b.Fatalf("opening failed: %v", err) 583 } 584 } 585 } 586 587 func BenchmarkRollKeys(b *testing.B) { 588 client, _ := getClientAndServer() 589 for i := 0; i < b.N; i++ { 590 client.rollKeys() 591 } 592 if int(client.keyPhase) != b.N { 593 b.Fatal("didn't roll keys often enough") 594 } 595 }