github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/internal/handshake/updatable_aead_test.go (about) 1 package handshake 2 3 import ( 4 "crypto/rand" 5 "crypto/tls" 6 "fmt" 7 "testing" 8 "time" 9 10 mocklogging "github.com/mikelsr/quic-go/internal/mocks/logging" 11 "github.com/mikelsr/quic-go/internal/protocol" 12 "github.com/mikelsr/quic-go/internal/qerr" 13 "github.com/mikelsr/quic-go/internal/utils" 14 15 "github.com/golang/mock/gomock" 16 . "github.com/onsi/ginkgo/v2" 17 . "github.com/onsi/gomega" 18 ) 19 20 var _ = Describe("Updatable AEAD", func() { 21 DescribeTable("ChaCha test vector", 22 func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) { 23 secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b") 24 aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v) 25 chacha := cipherSuites[2] 26 Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256)) 27 aead.SetWriteKey(chacha, secret) 28 const pnOffset = 1 29 header := splitHexString("4200bff4") 30 payloadOffset := len(header) 31 plaintext := splitHexString("01") 32 payload := aead.Seal(nil, plaintext, 654360564, header) 33 Expect(payload).To(Equal(expectedPayload)) 34 packet := append(header, payload...) 35 aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset]) 36 Expect(packet).To(Equal(expectedPacket)) 37 }, 38 Entry("QUIC v1", 39 protocol.Version1, 40 splitHexString("655e5cd55c41f69080575d7999c25a5bfb"), 41 splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"), 42 ), 43 Entry("QUIC v2", 44 protocol.Version2, 45 splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"), 46 splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"), 47 ), 48 ) 49 50 for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} { 51 v := ver 52 53 Context(fmt.Sprintf("using version %s", v), func() { 54 for i := range cipherSuites { 55 cs := cipherSuites[i] 56 57 Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() { 58 var ( 59 client, server *updatableAEAD 60 serverTracer *mocklogging.MockConnectionTracer 61 rttStats *utils.RTTStats 62 ) 63 64 BeforeEach(func() { 65 serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl) 66 trafficSecret1 := make([]byte, 16) 67 trafficSecret2 := make([]byte, 16) 68 rand.Read(trafficSecret1) 69 rand.Read(trafficSecret2) 70 71 rttStats = utils.NewRTTStats() 72 client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v) 73 server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v) 74 client.SetReadKey(cs, trafficSecret2) 75 client.SetWriteKey(cs, trafficSecret1) 76 server.SetReadKey(cs, trafficSecret1) 77 server.SetWriteKey(cs, trafficSecret2) 78 }) 79 80 Context("header protection", func() { 81 It("encrypts and decrypts the header", func() { 82 var lastFiveBitsDifferent int 83 for i := 0; i < 100; i++ { 84 sample := make([]byte, 16) 85 rand.Read(sample) 86 header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef} 87 client.EncryptHeader(sample, &header[0], header[9:13]) 88 if header[0]&0x1f != 0xb5&0x1f { 89 lastFiveBitsDifferent++ 90 } 91 Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0))) 92 Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8})) 93 Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) 94 server.DecryptHeader(sample, &header[0], header[9:13]) 95 Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef})) 96 } 97 Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75)) 98 }) 99 }) 100 101 Context("message encryption", func() { 102 var msg, ad []byte 103 104 BeforeEach(func() { 105 msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") 106 ad = []byte("Donec in velit neque.") 107 }) 108 109 It("encrypts and decrypts a message", func() { 110 encrypted := server.Seal(nil, msg, 0x1337, ad) 111 opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 112 Expect(err).ToNot(HaveOccurred()) 113 Expect(opened).To(Equal(msg)) 114 }) 115 116 It("saves the first packet number", func() { 117 client.Seal(nil, msg, 0x1337, ad) 118 Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) 119 client.Seal(nil, msg, 0x1338, ad) 120 Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337))) 121 }) 122 123 It("fails to open a message if the associated data is not the same", func() { 124 encrypted := client.Seal(nil, msg, 0x1337, ad) 125 _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad")) 126 Expect(err).To(MatchError(ErrDecryptionFailed)) 127 }) 128 129 It("fails to open a message if the packet number is not the same", func() { 130 encrypted := server.Seal(nil, msg, 0x1337, ad) 131 _, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad) 132 Expect(err).To(MatchError(ErrDecryptionFailed)) 133 }) 134 135 It("decodes the packet number", func() { 136 encrypted := server.Seal(nil, msg, 0x1337, ad) 137 _, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 138 Expect(err).ToNot(HaveOccurred()) 139 Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338)) 140 }) 141 142 It("ignores packets it can't decrypt for packet number derivation", func() { 143 encrypted := server.Seal(nil, msg, 0x1337, ad) 144 _, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad) 145 Expect(err).To(HaveOccurred()) 146 Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38)) 147 }) 148 149 It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() { 150 client.invalidPacketLimit = 10 151 for i := 0; i < 9; i++ { 152 _, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad")) 153 Expect(err).To(MatchError(ErrDecryptionFailed)) 154 } 155 _, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad")) 156 Expect(err).To(HaveOccurred()) 157 Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) 158 Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached)) 159 }) 160 161 Context("key updates", func() { 162 Context("receiving key updates", func() { 163 It("updates keys", func() { 164 now := time.Now() 165 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 166 encrypted0 := server.Seal(nil, msg, 0x1337, ad) 167 server.rollKeys() 168 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 169 encrypted1 := server.Seal(nil, msg, 0x1337, ad) 170 Expect(encrypted0).ToNot(Equal(encrypted1)) 171 // expect opening to fail. The client didn't roll keys yet 172 _, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad) 173 Expect(err).To(MatchError(ErrDecryptionFailed)) 174 client.rollKeys() 175 decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad) 176 Expect(err).ToNot(HaveOccurred()) 177 Expect(decrypted).To(Equal(msg)) 178 }) 179 180 It("updates the keys when receiving a packet with the next key phase", func() { 181 now := time.Now() 182 // receive the first packet at key phase zero 183 encrypted0 := client.Seal(nil, msg, 0x42, ad) 184 decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad) 185 Expect(err).ToNot(HaveOccurred()) 186 Expect(decrypted).To(Equal(msg)) 187 // send one packet at key phase zero 188 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 189 _ = server.Seal(nil, msg, 0x1, ad) 190 // now received a message at key phase one 191 client.rollKeys() 192 encrypted1 := client.Seal(nil, msg, 0x43, ad) 193 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 194 decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad) 195 Expect(err).ToNot(HaveOccurred()) 196 Expect(decrypted).To(Equal(msg)) 197 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 198 }) 199 200 It("opens a reordered packet with the old keys after an update", func() { 201 now := time.Now() 202 encrypted01 := client.Seal(nil, msg, 0x42, ad) 203 encrypted02 := client.Seal(nil, msg, 0x43, ad) 204 // receive the first packet with key phase 0 205 _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) 206 Expect(err).ToNot(HaveOccurred()) 207 // send one packet at key phase zero 208 _ = server.Seal(nil, msg, 0x1, ad) 209 // now receive a packet with key phase 1 210 client.rollKeys() 211 encrypted1 := client.Seal(nil, msg, 0x44, ad) 212 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 213 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 214 _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) 215 Expect(err).ToNot(HaveOccurred()) 216 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 217 // now receive a reordered packet with key phase 0 218 decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad) 219 Expect(err).ToNot(HaveOccurred()) 220 Expect(decrypted).To(Equal(msg)) 221 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 222 }) 223 224 It("drops keys 3 PTOs after a key update", func() { 225 now := time.Now() 226 rttStats.UpdateRTT(10*time.Millisecond, 0, now) 227 pto := rttStats.PTO(true) 228 encrypted01 := client.Seal(nil, msg, 0x42, ad) 229 encrypted02 := client.Seal(nil, msg, 0x43, ad) 230 // receive the first packet with key phase 0 231 _, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad) 232 Expect(err).ToNot(HaveOccurred()) 233 // send one packet at key phase zero 234 _ = server.Seal(nil, msg, 0x1, ad) 235 // now receive a packet with key phase 1 236 client.rollKeys() 237 encrypted1 := client.Seal(nil, msg, 0x44, ad) 238 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 239 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 240 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 241 _, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad) 242 Expect(err).ToNot(HaveOccurred()) 243 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 244 // now receive a reordered packet with key phase 0 245 _, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad) 246 Expect(err).To(MatchError(ErrKeysDropped)) 247 }) 248 249 It("allows the first key update immediately", func() { 250 // receive a packet at key phase one, before having sent or received any packets at key phase 0 251 client.rollKeys() 252 encrypted1 := client.Seal(nil, msg, 0x1337, ad) 253 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 254 _, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) 255 Expect(err).ToNot(HaveOccurred()) 256 }) 257 258 It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() { 259 client.rollKeys() 260 encrypted := client.Seal(nil, msg, 0x1337, ad) 261 encrypted = encrypted[:len(encrypted)-1] 262 _, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad) 263 Expect(err).To(MatchError(ErrDecryptionFailed)) 264 }) 265 266 It("errors when the peer updates keys too frequently", func() { 267 server.rollKeys() 268 client.rollKeys() 269 // receive the first packet at key phase one 270 encrypted0 := client.Seal(nil, msg, 0x42, ad) 271 _, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad) 272 Expect(err).ToNot(HaveOccurred()) 273 // now receive a packet at key phase two, before having sent any packets 274 client.rollKeys() 275 encrypted1 := client.Seal(nil, msg, 0x42, ad) 276 _, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad) 277 Expect(err).To(MatchError(&qerr.TransportError{ 278 ErrorCode: qerr.KeyUpdateError, 279 ErrorMessage: "keys updated too quickly", 280 })) 281 }) 282 }) 283 284 Context("initiating key updates", func() { 285 const firstKeyUpdateInterval = 5 286 const keyUpdateInterval = 20 287 var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64 288 289 BeforeEach(func() { 290 origKeyUpdateInterval = KeyUpdateInterval 291 origFirstKeyUpdateInterval = FirstKeyUpdateInterval 292 KeyUpdateInterval = keyUpdateInterval 293 FirstKeyUpdateInterval = firstKeyUpdateInterval 294 server.SetHandshakeConfirmed() 295 }) 296 297 AfterEach(func() { 298 KeyUpdateInterval = origKeyUpdateInterval 299 FirstKeyUpdateInterval = origFirstKeyUpdateInterval 300 }) 301 302 It("initiates a key update after sealing the maximum number of packets, for the first update", func() { 303 for i := 0; i < firstKeyUpdateInterval; i++ { 304 pn := protocol.PacketNumber(i) 305 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 306 server.Seal(nil, msg, pn, ad) 307 } 308 // the first update is allowed without receiving an acknowledgement 309 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 310 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 311 }) 312 313 It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() { 314 server.rollKeys() 315 client.rollKeys() 316 for i := 0; i < keyUpdateInterval; i++ { 317 pn := protocol.PacketNumber(i) 318 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 319 server.Seal(nil, msg, pn, ad) 320 } 321 // no update allowed before receiving an acknowledgement for the current key phase 322 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 323 // receive an ACK for a packet sent in key phase 0 324 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 325 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad")) 326 Expect(err).ToNot(HaveOccurred()) 327 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 328 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 329 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) 330 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 331 }) 332 333 It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() { 334 // First make sure that we update our keys. 335 for i := 0; i < firstKeyUpdateInterval; i++ { 336 pn := protocol.PacketNumber(i) 337 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 338 server.Seal(nil, msg, pn, ad) 339 } 340 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 341 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 342 // Now that our keys are updated, send a packet using the new keys. 343 const nextPN = firstKeyUpdateInterval + 1 344 server.Seal(nil, msg, nextPN, ad) 345 // We haven't decrypted any packet in the new key phase yet. 346 // This means that the ACK must have been sent in the old key phase. 347 Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{ 348 ErrorCode: qerr.KeyUpdateError, 349 ErrorMessage: "received ACK for key phase 1, but peer didn't update keys", 350 })) 351 }) 352 353 It("doesn't error before actually sending a packet in the new key phase", func() { 354 // First make sure that we update our keys. 355 for i := 0; i < firstKeyUpdateInterval; i++ { 356 pn := protocol.PacketNumber(i) 357 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 358 server.Seal(nil, msg, pn, ad) 359 } 360 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 361 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 362 Expect(err).ToNot(HaveOccurred()) 363 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 364 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 365 // Now that our keys are updated, send a packet using the new keys. 366 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 367 // We haven't decrypted any packet in the new key phase yet. 368 // This means that the ACK must have been sent in the old key phase. 369 Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred()) 370 }) 371 372 It("initiates a key update after opening the maximum number of packets, for the first update", func() { 373 for i := 0; i < firstKeyUpdateInterval; i++ { 374 pn := protocol.PacketNumber(i) 375 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 376 encrypted := client.Seal(nil, msg, pn, ad) 377 _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad) 378 Expect(err).ToNot(HaveOccurred()) 379 } 380 // the first update is allowed without receiving an acknowledgement 381 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 382 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 383 }) 384 385 It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() { 386 server.rollKeys() 387 client.rollKeys() 388 for i := 0; i < keyUpdateInterval; i++ { 389 pn := protocol.PacketNumber(i) 390 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 391 encrypted := client.Seal(nil, msg, pn, ad) 392 _, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad) 393 Expect(err).ToNot(HaveOccurred()) 394 } 395 // no update allowed before receiving an acknowledgement for the current key phase 396 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 397 server.Seal(nil, msg, 1, ad) 398 Expect(server.SetLargestAcked(1)).To(Succeed()) 399 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 400 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false) 401 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 402 }) 403 404 It("drops keys 3 PTOs after a key update", func() { 405 now := time.Now() 406 for i := 0; i < firstKeyUpdateInterval; i++ { 407 pn := protocol.PacketNumber(i) 408 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 409 server.Seal(nil, msg, pn, ad) 410 } 411 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 412 _, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad")) 413 Expect(err).ToNot(HaveOccurred()) 414 Expect(server.SetLargestAcked(0)).To(Succeed()) 415 // Now we've initiated the first key update. 416 // Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there 417 threePTO := 3 * rttStats.PTO(false) 418 dataKeyPhaseZero := client.Seal(nil, msg, 1, ad) 419 _, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad) 420 Expect(err).ToNot(HaveOccurred()) 421 // Now receive a packet with key phase 1. 422 // This should start the timer to drop the keys after 3 PTOs. 423 client.rollKeys() 424 dataKeyPhaseOne := client.Seal(nil, msg, 10, ad) 425 t := now.Add(threePTO).Add(time.Second) 426 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true) 427 _, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad) 428 Expect(err).ToNot(HaveOccurred()) 429 // Make sure the keys are still here. 430 _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad) 431 Expect(err).ToNot(HaveOccurred()) 432 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)) 433 _, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad) 434 Expect(err).To(MatchError(ErrKeysDropped)) 435 }) 436 437 It("doesn't drop the first key generation too early", func() { 438 now := time.Now() 439 data1 := client.Seal(nil, msg, 1, ad) 440 _, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad) 441 Expect(err).ToNot(HaveOccurred()) 442 for i := 0; i < firstKeyUpdateInterval; i++ { 443 pn := protocol.PacketNumber(i) 444 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 445 server.Seal(nil, msg, pn, ad) 446 Expect(server.SetLargestAcked(pn)).To(Succeed()) 447 } 448 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 449 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 450 // The server never received a packet at key phase 1. 451 // Make sure the key phase 0 is still there at a much later point. 452 data2 := client.Seal(nil, msg, 1, ad) 453 _, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad) 454 Expect(err).ToNot(HaveOccurred()) 455 }) 456 457 It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() { 458 for i := 0; i < firstKeyUpdateInterval; i++ { 459 pn := protocol.PacketNumber(i) 460 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 461 server.Seal(nil, msg, pn, ad) 462 } 463 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 464 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 465 Expect(err).ToNot(HaveOccurred()) 466 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 467 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 468 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 469 const nextPN = keyUpdateInterval + 1 470 // Send and receive an acknowledgement for a packet in key phase 1. 471 // We are now running a timer to drop the keys with 3 PTO. 472 server.Seal(nil, msg, nextPN, ad) 473 client.rollKeys() 474 dataKeyPhaseOne := client.Seal(nil, msg, 2, ad) 475 now := time.Now() 476 _, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad) 477 Expect(err).ToNot(HaveOccurred()) 478 Expect(server.SetLargestAcked(nextPN)) 479 // Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over. 480 // This mean that we need to drop the keys for key phase 0 immediately. 481 client.rollKeys() 482 dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad) 483 gomock.InOrder( 484 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), 485 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true), 486 ) 487 _, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad) 488 Expect(err).ToNot(HaveOccurred()) 489 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 490 }) 491 492 It("drops keys early when we initiate another key update within the 3 PTO period", func() { 493 server.SetHandshakeConfirmed() 494 // send so many packets that we initiate the first key update 495 for i := 0; i < firstKeyUpdateInterval; i++ { 496 pn := protocol.PacketNumber(i) 497 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 498 server.Seal(nil, msg, pn, ad) 499 } 500 b := client.Seal(nil, []byte("foobar"), 1, []byte("ad")) 501 _, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad")) 502 Expect(err).ToNot(HaveOccurred()) 503 ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed()) 504 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false) 505 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 506 // send so many packets that we initiate the next key update 507 for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ { 508 pn := protocol.PacketNumber(i) 509 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne)) 510 server.Seal(nil, msg, pn, ad) 511 } 512 client.rollKeys() 513 b = client.Seal(nil, []byte("foobar"), 2, []byte("ad")) 514 now := time.Now() 515 _, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad")) 516 Expect(err).ToNot(HaveOccurred()) 517 ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed()) 518 gomock.InOrder( 519 serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)), 520 serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false), 521 ) 522 Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero)) 523 // We haven't received an ACK for a packet sent in key phase 2 yet. 524 // Make sure we canceled the timer to drop the previous key phase. 525 b = client.Seal(nil, []byte("foobar"), 3, []byte("ad")) 526 _, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad")) 527 Expect(err).ToNot(HaveOccurred()) 528 }) 529 }) 530 }) 531 }) 532 }) 533 } 534 }) 535 } 536 }) 537 538 func getClientAndServer() (client, server *updatableAEAD) { 539 trafficSecret1 := make([]byte, 16) 540 trafficSecret2 := make([]byte, 16) 541 rand.Read(trafficSecret1) 542 rand.Read(trafficSecret2) 543 544 cs := cipherSuites[0] 545 rttStats := utils.NewRTTStats() 546 client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) 547 server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1) 548 client.SetReadKey(cs, trafficSecret2) 549 client.SetWriteKey(cs, trafficSecret1) 550 server.SetReadKey(cs, trafficSecret1) 551 server.SetWriteKey(cs, trafficSecret2) 552 return 553 } 554 555 func BenchmarkPacketEncryption(b *testing.B) { 556 client, _ := getClientAndServer() 557 const l = 1200 558 src := make([]byte, l) 559 rand.Read(src) 560 ad := make([]byte, 32) 561 rand.Read(ad) 562 563 for i := 0; i < b.N; i++ { 564 src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad) 565 } 566 } 567 568 func BenchmarkPacketDecryption(b *testing.B) { 569 client, server := getClientAndServer() 570 const l = 1200 571 src := make([]byte, l) 572 dst := make([]byte, l) 573 rand.Read(src) 574 ad := make([]byte, 32) 575 rand.Read(ad) 576 src = client.Seal(src[:0], src[:l], 1337, ad) 577 578 for i := 0; i < b.N; i++ { 579 if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil { 580 b.Fatalf("opening failed: %v", err) 581 } 582 } 583 } 584 585 func BenchmarkRollKeys(b *testing.B) { 586 client, _ := getClientAndServer() 587 for i := 0; i < b.N; i++ { 588 client.rollKeys() 589 } 590 if int(client.keyPhase) != b.N { 591 b.Fatal("didn't roll keys often enough") 592 } 593 }