github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/internal/handshake/updatable_aead_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	mocklogging "github.com/metacubex/quic-go/internal/mocks/logging"
    11  	"github.com/metacubex/quic-go/internal/protocol"
    12  	"github.com/metacubex/quic-go/internal/qerr"
    13  	"github.com/metacubex/quic-go/internal/utils"
    14  	"github.com/metacubex/quic-go/logging"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  	"go.uber.org/mock/gomock"
    19  )
    20  
    21  var _ = Describe("Updatable AEAD", func() {
    22  	DescribeTable("ChaCha test vector",
    23  		func(v protocol.Version, expectedPayload, expectedPacket []byte) {
    24  			secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
    25  			aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v)
    26  			chacha := cipherSuites[2]
    27  			Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256))
    28  			aead.SetWriteKey(chacha, secret)
    29  			const pnOffset = 1
    30  			header := splitHexString("4200bff4")
    31  			payloadOffset := len(header)
    32  			plaintext := splitHexString("01")
    33  			payload := aead.Seal(nil, plaintext, 654360564, header)
    34  			Expect(payload).To(Equal(expectedPayload))
    35  			packet := append(header, payload...)
    36  			aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset])
    37  			Expect(packet).To(Equal(expectedPacket))
    38  		},
    39  		Entry("QUIC v1",
    40  			protocol.Version1,
    41  			splitHexString("655e5cd55c41f69080575d7999c25a5bfb"),
    42  			splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"),
    43  		),
    44  		Entry("QUIC v2",
    45  			protocol.Version2,
    46  			splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"),
    47  			splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"),
    48  		),
    49  	)
    50  
    51  	for _, ver := range []protocol.Version{protocol.Version1, protocol.Version2} {
    52  		v := ver
    53  
    54  		Context(fmt.Sprintf("using version %s", v), func() {
    55  			for i := range cipherSuites {
    56  				cs := cipherSuites[i]
    57  
    58  				Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() {
    59  					var (
    60  						client, server *updatableAEAD
    61  						serverTracer   *mocklogging.MockConnectionTracer
    62  						rttStats       *utils.RTTStats
    63  					)
    64  
    65  					BeforeEach(func() {
    66  						var tr *logging.ConnectionTracer
    67  						tr, serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
    68  						trafficSecret1 := make([]byte, 16)
    69  						trafficSecret2 := make([]byte, 16)
    70  						rand.Read(trafficSecret1)
    71  						rand.Read(trafficSecret2)
    72  
    73  						rttStats = utils.NewRTTStats()
    74  						client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v)
    75  						server = newUpdatableAEAD(rttStats, tr, utils.DefaultLogger, v)
    76  						client.SetReadKey(cs, trafficSecret2)
    77  						client.SetWriteKey(cs, trafficSecret1)
    78  						server.SetReadKey(cs, trafficSecret1)
    79  						server.SetWriteKey(cs, trafficSecret2)
    80  					})
    81  
    82  					Context("header protection", func() {
    83  						It("encrypts and decrypts the header", func() {
    84  							var lastFiveBitsDifferent int
    85  							for i := 0; i < 100; i++ {
    86  								sample := make([]byte, 16)
    87  								rand.Read(sample)
    88  								header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
    89  								client.EncryptHeader(sample, &header[0], header[9:13])
    90  								if header[0]&0x1f != 0xb5&0x1f {
    91  									lastFiveBitsDifferent++
    92  								}
    93  								Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
    94  								Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
    95  								Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
    96  								server.DecryptHeader(sample, &header[0], header[9:13])
    97  								Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
    98  							}
    99  							Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
   100  						})
   101  					})
   102  
   103  					Context("message encryption", func() {
   104  						var msg, ad []byte
   105  
   106  						BeforeEach(func() {
   107  							msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
   108  							ad = []byte("Donec in velit neque.")
   109  						})
   110  
   111  						It("encrypts and decrypts a message", func() {
   112  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   113  							opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   114  							Expect(err).ToNot(HaveOccurred())
   115  							Expect(opened).To(Equal(msg))
   116  						})
   117  
   118  						It("saves the first packet number", func() {
   119  							client.Seal(nil, msg, 0x1337, ad)
   120  							Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
   121  							client.Seal(nil, msg, 0x1338, ad)
   122  							Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
   123  						})
   124  
   125  						It("fails to open a message if the associated data is not the same", func() {
   126  							encrypted := client.Seal(nil, msg, 0x1337, ad)
   127  							_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
   128  							Expect(err).To(MatchError(ErrDecryptionFailed))
   129  						})
   130  
   131  						It("fails to open a message if the packet number is not the same", func() {
   132  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   133  							_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
   134  							Expect(err).To(MatchError(ErrDecryptionFailed))
   135  						})
   136  
   137  						It("decodes the packet number", func() {
   138  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   139  							_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   140  							Expect(err).ToNot(HaveOccurred())
   141  							Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
   142  						})
   143  
   144  						It("ignores packets it can't decrypt for packet number derivation", func() {
   145  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   146  							_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   147  							Expect(err).To(HaveOccurred())
   148  							Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
   149  						})
   150  
   151  						It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
   152  							client.invalidPacketLimit = 10
   153  							for i := 0; i < 9; i++ {
   154  								_, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad"))
   155  								Expect(err).To(MatchError(ErrDecryptionFailed))
   156  							}
   157  							_, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad"))
   158  							Expect(err).To(HaveOccurred())
   159  							Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   160  							Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached))
   161  						})
   162  
   163  						Context("key updates", func() {
   164  							Context("receiving key updates", func() {
   165  								It("updates keys", func() {
   166  									now := time.Now()
   167  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   168  									encrypted0 := server.Seal(nil, msg, 0x1337, ad)
   169  									server.rollKeys()
   170  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   171  									encrypted1 := server.Seal(nil, msg, 0x1337, ad)
   172  									Expect(encrypted0).ToNot(Equal(encrypted1))
   173  									// expect opening to fail. The client didn't roll keys yet
   174  									_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
   175  									Expect(err).To(MatchError(ErrDecryptionFailed))
   176  									client.rollKeys()
   177  									decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
   178  									Expect(err).ToNot(HaveOccurred())
   179  									Expect(decrypted).To(Equal(msg))
   180  								})
   181  
   182  								It("updates the keys when receiving a packet with the next key phase", func() {
   183  									now := time.Now()
   184  									// receive the first packet at key phase zero
   185  									encrypted0 := client.Seal(nil, msg, 0x42, ad)
   186  									decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
   187  									Expect(err).ToNot(HaveOccurred())
   188  									Expect(decrypted).To(Equal(msg))
   189  									// send one packet at key phase zero
   190  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   191  									_ = server.Seal(nil, msg, 0x1, ad)
   192  									// now received a message at key phase one
   193  									client.rollKeys()
   194  									encrypted1 := client.Seal(nil, msg, 0x43, ad)
   195  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   196  									decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
   197  									Expect(err).ToNot(HaveOccurred())
   198  									Expect(decrypted).To(Equal(msg))
   199  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   200  								})
   201  
   202  								It("opens a reordered packet with the old keys after an update", func() {
   203  									now := time.Now()
   204  									encrypted01 := client.Seal(nil, msg, 0x42, ad)
   205  									encrypted02 := client.Seal(nil, msg, 0x43, ad)
   206  									// receive the first packet with key phase 0
   207  									_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
   208  									Expect(err).ToNot(HaveOccurred())
   209  									// send one packet at key phase zero
   210  									_ = server.Seal(nil, msg, 0x1, ad)
   211  									// now receive a packet with key phase 1
   212  									client.rollKeys()
   213  									encrypted1 := client.Seal(nil, msg, 0x44, ad)
   214  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   215  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   216  									_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
   217  									Expect(err).ToNot(HaveOccurred())
   218  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   219  									// now receive a reordered packet with key phase 0
   220  									decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
   221  									Expect(err).ToNot(HaveOccurred())
   222  									Expect(decrypted).To(Equal(msg))
   223  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   224  								})
   225  
   226  								It("drops keys 3 PTOs after a key update", func() {
   227  									now := time.Now()
   228  									rttStats.UpdateRTT(10*time.Millisecond, 0, now)
   229  									pto := rttStats.PTO(true)
   230  									encrypted01 := client.Seal(nil, msg, 0x42, ad)
   231  									encrypted02 := client.Seal(nil, msg, 0x43, ad)
   232  									// receive the first packet with key phase 0
   233  									_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
   234  									Expect(err).ToNot(HaveOccurred())
   235  									// send one packet at key phase zero
   236  									_ = server.Seal(nil, msg, 0x1, ad)
   237  									// now receive a packet with key phase 1
   238  									client.rollKeys()
   239  									encrypted1 := client.Seal(nil, msg, 0x44, ad)
   240  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   241  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   242  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   243  									_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
   244  									Expect(err).ToNot(HaveOccurred())
   245  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   246  									// now receive a reordered packet with key phase 0
   247  									_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
   248  									Expect(err).To(MatchError(ErrKeysDropped))
   249  								})
   250  
   251  								It("allows the first key update immediately", func() {
   252  									// receive a packet at key phase one, before having sent or received any packets at key phase 0
   253  									client.rollKeys()
   254  									encrypted1 := client.Seal(nil, msg, 0x1337, ad)
   255  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   256  									_, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
   257  									Expect(err).ToNot(HaveOccurred())
   258  								})
   259  
   260  								It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
   261  									client.rollKeys()
   262  									encrypted := client.Seal(nil, msg, 0x1337, ad)
   263  									encrypted = encrypted[:len(encrypted)-1]
   264  									_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
   265  									Expect(err).To(MatchError(ErrDecryptionFailed))
   266  								})
   267  
   268  								It("errors when the peer updates keys too frequently", func() {
   269  									server.rollKeys()
   270  									client.rollKeys()
   271  									// receive the first packet at key phase one
   272  									encrypted0 := client.Seal(nil, msg, 0x42, ad)
   273  									_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
   274  									Expect(err).ToNot(HaveOccurred())
   275  									// now receive a packet at key phase two, before having sent any packets
   276  									client.rollKeys()
   277  									encrypted1 := client.Seal(nil, msg, 0x42, ad)
   278  									_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
   279  									Expect(err).To(MatchError(&qerr.TransportError{
   280  										ErrorCode:    qerr.KeyUpdateError,
   281  										ErrorMessage: "keys updated too quickly",
   282  									}))
   283  								})
   284  							})
   285  
   286  							Context("initiating key updates", func() {
   287  								const firstKeyUpdateInterval = 5
   288  								const keyUpdateInterval = 20
   289  								var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
   290  
   291  								BeforeEach(func() {
   292  									origKeyUpdateInterval = KeyUpdateInterval
   293  									origFirstKeyUpdateInterval = FirstKeyUpdateInterval
   294  									KeyUpdateInterval = keyUpdateInterval
   295  									FirstKeyUpdateInterval = firstKeyUpdateInterval
   296  									server.SetHandshakeConfirmed()
   297  								})
   298  
   299  								AfterEach(func() {
   300  									KeyUpdateInterval = origKeyUpdateInterval
   301  									FirstKeyUpdateInterval = origFirstKeyUpdateInterval
   302  								})
   303  
   304  								It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
   305  									for i := 0; i < firstKeyUpdateInterval; i++ {
   306  										pn := protocol.PacketNumber(i)
   307  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   308  										server.Seal(nil, msg, pn, ad)
   309  									}
   310  									// the first update is allowed without receiving an acknowledgement
   311  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   312  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   313  								})
   314  
   315  								It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
   316  									server.rollKeys()
   317  									client.rollKeys()
   318  									for i := 0; i < keyUpdateInterval; i++ {
   319  										pn := protocol.PacketNumber(i)
   320  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   321  										server.Seal(nil, msg, pn, ad)
   322  									}
   323  									// no update allowed before receiving an acknowledgement for the current key phase
   324  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   325  									// receive an ACK for a packet sent in key phase 0
   326  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   327  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
   328  									Expect(err).ToNot(HaveOccurred())
   329  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   330  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   331  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
   332  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   333  								})
   334  
   335  								It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
   336  									// First make sure that we update our keys.
   337  									for i := 0; i < firstKeyUpdateInterval; i++ {
   338  										pn := protocol.PacketNumber(i)
   339  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   340  										server.Seal(nil, msg, pn, ad)
   341  									}
   342  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   343  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   344  									// Now that our keys are updated, send a packet using the new keys.
   345  									const nextPN = firstKeyUpdateInterval + 1
   346  									server.Seal(nil, msg, nextPN, ad)
   347  									// We haven't decrypted any packet in the new key phase yet.
   348  									// This means that the ACK must have been sent in the old key phase.
   349  									Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{
   350  										ErrorCode:    qerr.KeyUpdateError,
   351  										ErrorMessage: "received ACK for key phase 1, but peer didn't update keys",
   352  									}))
   353  								})
   354  
   355  								It("doesn't error before actually sending a packet in the new key phase", func() {
   356  									// First make sure that we update our keys.
   357  									for i := 0; i < firstKeyUpdateInterval; i++ {
   358  										pn := protocol.PacketNumber(i)
   359  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   360  										server.Seal(nil, msg, pn, ad)
   361  									}
   362  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   363  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   364  									Expect(err).ToNot(HaveOccurred())
   365  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   366  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   367  									// Now that our keys are updated, send a packet using the new keys.
   368  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   369  									// We haven't decrypted any packet in the new key phase yet.
   370  									// This means that the ACK must have been sent in the old key phase.
   371  									Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
   372  								})
   373  
   374  								It("initiates a key update after opening the maximum number of packets, for the first update", func() {
   375  									for i := 0; i < firstKeyUpdateInterval; i++ {
   376  										pn := protocol.PacketNumber(i)
   377  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   378  										encrypted := client.Seal(nil, msg, pn, ad)
   379  										_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
   380  										Expect(err).ToNot(HaveOccurred())
   381  									}
   382  									// the first update is allowed without receiving an acknowledgement
   383  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   384  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   385  								})
   386  
   387  								It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
   388  									server.rollKeys()
   389  									client.rollKeys()
   390  									for i := 0; i < keyUpdateInterval; i++ {
   391  										pn := protocol.PacketNumber(i)
   392  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   393  										encrypted := client.Seal(nil, msg, pn, ad)
   394  										_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
   395  										Expect(err).ToNot(HaveOccurred())
   396  									}
   397  									// no update allowed before receiving an acknowledgement for the current key phase
   398  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   399  									server.Seal(nil, msg, 1, ad)
   400  									Expect(server.SetLargestAcked(1)).To(Succeed())
   401  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   402  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
   403  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   404  								})
   405  
   406  								It("drops keys 3 PTOs after a key update", func() {
   407  									now := time.Now()
   408  									for i := 0; i < firstKeyUpdateInterval; i++ {
   409  										pn := protocol.PacketNumber(i)
   410  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   411  										server.Seal(nil, msg, pn, ad)
   412  									}
   413  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   414  									_, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad"))
   415  									Expect(err).ToNot(HaveOccurred())
   416  									Expect(server.SetLargestAcked(0)).To(Succeed())
   417  									// Now we've initiated the first key update.
   418  									// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
   419  									threePTO := 3 * rttStats.PTO(false)
   420  									dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
   421  									_, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
   422  									Expect(err).ToNot(HaveOccurred())
   423  									// Now receive a packet with key phase 1.
   424  									// This should start the timer to drop the keys after 3 PTOs.
   425  									client.rollKeys()
   426  									dataKeyPhaseOne := client.Seal(nil, msg, 10, ad)
   427  									t := now.Add(threePTO).Add(time.Second)
   428  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   429  									_, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad)
   430  									Expect(err).ToNot(HaveOccurred())
   431  									// Make sure the keys are still here.
   432  									_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad)
   433  									Expect(err).ToNot(HaveOccurred())
   434  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   435  									_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad)
   436  									Expect(err).To(MatchError(ErrKeysDropped))
   437  								})
   438  
   439  								It("doesn't drop the first key generation too early", func() {
   440  									now := time.Now()
   441  									data1 := client.Seal(nil, msg, 1, ad)
   442  									_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
   443  									Expect(err).ToNot(HaveOccurred())
   444  									for i := 0; i < firstKeyUpdateInterval; i++ {
   445  										pn := protocol.PacketNumber(i)
   446  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   447  										server.Seal(nil, msg, pn, ad)
   448  										Expect(server.SetLargestAcked(pn)).To(Succeed())
   449  									}
   450  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   451  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   452  									// The server never received a packet at key phase 1.
   453  									// Make sure the key phase 0 is still there at a much later point.
   454  									data2 := client.Seal(nil, msg, 1, ad)
   455  									_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad)
   456  									Expect(err).ToNot(HaveOccurred())
   457  								})
   458  
   459  								It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
   460  									for i := 0; i < firstKeyUpdateInterval; i++ {
   461  										pn := protocol.PacketNumber(i)
   462  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   463  										server.Seal(nil, msg, pn, ad)
   464  									}
   465  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   466  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   467  									Expect(err).ToNot(HaveOccurred())
   468  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   469  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   470  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   471  									const nextPN = keyUpdateInterval + 1
   472  									// Send and receive an acknowledgement for a packet in key phase 1.
   473  									// We are now running a timer to drop the keys with 3 PTO.
   474  									server.Seal(nil, msg, nextPN, ad)
   475  									client.rollKeys()
   476  									dataKeyPhaseOne := client.Seal(nil, msg, 2, ad)
   477  									now := time.Now()
   478  									_, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad)
   479  									Expect(err).ToNot(HaveOccurred())
   480  									Expect(server.SetLargestAcked(nextPN))
   481  									// Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over.
   482  									// This mean that we need to drop the keys for key phase 0 immediately.
   483  									client.rollKeys()
   484  									dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad)
   485  									gomock.InOrder(
   486  										serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
   487  										serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true),
   488  									)
   489  									_, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad)
   490  									Expect(err).ToNot(HaveOccurred())
   491  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   492  								})
   493  
   494  								It("drops keys early when we initiate another key update within the 3 PTO period", func() {
   495  									server.SetHandshakeConfirmed()
   496  									// send so many packets that we initiate the first key update
   497  									for i := 0; i < firstKeyUpdateInterval; i++ {
   498  										pn := protocol.PacketNumber(i)
   499  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   500  										server.Seal(nil, msg, pn, ad)
   501  									}
   502  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   503  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   504  									Expect(err).ToNot(HaveOccurred())
   505  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   506  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   507  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   508  									// send so many packets that we initiate the next key update
   509  									for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ {
   510  										pn := protocol.PacketNumber(i)
   511  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   512  										server.Seal(nil, msg, pn, ad)
   513  									}
   514  									client.rollKeys()
   515  									b = client.Seal(nil, []byte("foobar"), 2, []byte("ad"))
   516  									now := time.Now()
   517  									_, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad"))
   518  									Expect(err).ToNot(HaveOccurred())
   519  									ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed())
   520  									gomock.InOrder(
   521  										serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
   522  										serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false),
   523  									)
   524  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   525  									// We haven't received an ACK for a packet sent in key phase 2 yet.
   526  									// Make sure we canceled the timer to drop the previous key phase.
   527  									b = client.Seal(nil, []byte("foobar"), 3, []byte("ad"))
   528  									_, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad"))
   529  									Expect(err).ToNot(HaveOccurred())
   530  								})
   531  							})
   532  						})
   533  					})
   534  				})
   535  			}
   536  		})
   537  	}
   538  })
   539  
   540  func getClientAndServer() (client, server *updatableAEAD) {
   541  	trafficSecret1 := make([]byte, 16)
   542  	trafficSecret2 := make([]byte, 16)
   543  	rand.Read(trafficSecret1)
   544  	rand.Read(trafficSecret2)
   545  
   546  	cs := cipherSuites[0]
   547  	rttStats := utils.NewRTTStats()
   548  	client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
   549  	server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
   550  	client.SetReadKey(cs, trafficSecret2)
   551  	client.SetWriteKey(cs, trafficSecret1)
   552  	server.SetReadKey(cs, trafficSecret1)
   553  	server.SetWriteKey(cs, trafficSecret2)
   554  	return
   555  }
   556  
   557  func BenchmarkPacketEncryption(b *testing.B) {
   558  	client, _ := getClientAndServer()
   559  	const l = 1200
   560  	src := make([]byte, l)
   561  	rand.Read(src)
   562  	ad := make([]byte, 32)
   563  	rand.Read(ad)
   564  
   565  	for i := 0; i < b.N; i++ {
   566  		src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad)
   567  	}
   568  }
   569  
   570  func BenchmarkPacketDecryption(b *testing.B) {
   571  	client, server := getClientAndServer()
   572  	const l = 1200
   573  	src := make([]byte, l)
   574  	dst := make([]byte, l)
   575  	rand.Read(src)
   576  	ad := make([]byte, 32)
   577  	rand.Read(ad)
   578  	src = client.Seal(src[:0], src[:l], 1337, ad)
   579  
   580  	for i := 0; i < b.N; i++ {
   581  		if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil {
   582  			b.Fatalf("opening failed: %v", err)
   583  		}
   584  	}
   585  }
   586  
   587  func BenchmarkRollKeys(b *testing.B) {
   588  	client, _ := getClientAndServer()
   589  	for i := 0; i < b.N; i++ {
   590  		client.rollKeys()
   591  	}
   592  	if int(client.keyPhase) != b.N {
   593  		b.Fatal("didn't roll keys often enough")
   594  	}
   595  }