github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/internal/handshake/updatable_aead_test.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	mocklogging "github.com/mikelsr/quic-go/internal/mocks/logging"
    11  	"github.com/mikelsr/quic-go/internal/protocol"
    12  	"github.com/mikelsr/quic-go/internal/qerr"
    13  	"github.com/mikelsr/quic-go/internal/utils"
    14  
    15  	"github.com/golang/mock/gomock"
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  )
    19  
    20  var _ = Describe("Updatable AEAD", func() {
    21  	DescribeTable("ChaCha test vector",
    22  		func(v protocol.VersionNumber, expectedPayload, expectedPacket []byte) {
    23  			secret := splitHexString("9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b")
    24  			aead := newUpdatableAEAD(&utils.RTTStats{}, nil, nil, v)
    25  			chacha := cipherSuites[2]
    26  			Expect(chacha.ID).To(Equal(tls.TLS_CHACHA20_POLY1305_SHA256))
    27  			aead.SetWriteKey(chacha, secret)
    28  			const pnOffset = 1
    29  			header := splitHexString("4200bff4")
    30  			payloadOffset := len(header)
    31  			plaintext := splitHexString("01")
    32  			payload := aead.Seal(nil, plaintext, 654360564, header)
    33  			Expect(payload).To(Equal(expectedPayload))
    34  			packet := append(header, payload...)
    35  			aead.EncryptHeader(packet[pnOffset+4:pnOffset+4+16], &packet[0], packet[pnOffset:payloadOffset])
    36  			Expect(packet).To(Equal(expectedPacket))
    37  		},
    38  		Entry("QUIC v1",
    39  			protocol.Version1,
    40  			splitHexString("655e5cd55c41f69080575d7999c25a5bfb"),
    41  			splitHexString("4cfe4189655e5cd55c41f69080575d7999c25a5bfb"),
    42  		),
    43  		Entry("QUIC v2",
    44  			protocol.Version2,
    45  			splitHexString("0ae7b6b932bc27d786f4bc2bb20f2162ba"),
    46  			splitHexString("5558b1c60ae7b6b932bc27d786f4bc2bb20f2162ba"),
    47  		),
    48  	)
    49  
    50  	for _, ver := range []protocol.VersionNumber{protocol.Version1, protocol.Version2} {
    51  		v := ver
    52  
    53  		Context(fmt.Sprintf("using version %s", v), func() {
    54  			for i := range cipherSuites {
    55  				cs := cipherSuites[i]
    56  
    57  				Context(fmt.Sprintf("using %s", tls.CipherSuiteName(cs.ID)), func() {
    58  					var (
    59  						client, server *updatableAEAD
    60  						serverTracer   *mocklogging.MockConnectionTracer
    61  						rttStats       *utils.RTTStats
    62  					)
    63  
    64  					BeforeEach(func() {
    65  						serverTracer = mocklogging.NewMockConnectionTracer(mockCtrl)
    66  						trafficSecret1 := make([]byte, 16)
    67  						trafficSecret2 := make([]byte, 16)
    68  						rand.Read(trafficSecret1)
    69  						rand.Read(trafficSecret2)
    70  
    71  						rttStats = utils.NewRTTStats()
    72  						client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, v)
    73  						server = newUpdatableAEAD(rttStats, serverTracer, utils.DefaultLogger, v)
    74  						client.SetReadKey(cs, trafficSecret2)
    75  						client.SetWriteKey(cs, trafficSecret1)
    76  						server.SetReadKey(cs, trafficSecret1)
    77  						server.SetWriteKey(cs, trafficSecret2)
    78  					})
    79  
    80  					Context("header protection", func() {
    81  						It("encrypts and decrypts the header", func() {
    82  							var lastFiveBitsDifferent int
    83  							for i := 0; i < 100; i++ {
    84  								sample := make([]byte, 16)
    85  								rand.Read(sample)
    86  								header := []byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}
    87  								client.EncryptHeader(sample, &header[0], header[9:13])
    88  								if header[0]&0x1f != 0xb5&0x1f {
    89  									lastFiveBitsDifferent++
    90  								}
    91  								Expect(header[0] & 0xe0).To(Equal(byte(0xb5 & 0xe0)))
    92  								Expect(header[1:9]).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7, 8}))
    93  								Expect(header[9:13]).ToNot(Equal([]byte{0xde, 0xad, 0xbe, 0xef}))
    94  								server.DecryptHeader(sample, &header[0], header[9:13])
    95  								Expect(header).To(Equal([]byte{0xb5, 1, 2, 3, 4, 5, 6, 7, 8, 0xde, 0xad, 0xbe, 0xef}))
    96  							}
    97  							Expect(lastFiveBitsDifferent).To(BeNumerically(">", 75))
    98  						})
    99  					})
   100  
   101  					Context("message encryption", func() {
   102  						var msg, ad []byte
   103  
   104  						BeforeEach(func() {
   105  							msg = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
   106  							ad = []byte("Donec in velit neque.")
   107  						})
   108  
   109  						It("encrypts and decrypts a message", func() {
   110  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   111  							opened, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   112  							Expect(err).ToNot(HaveOccurred())
   113  							Expect(opened).To(Equal(msg))
   114  						})
   115  
   116  						It("saves the first packet number", func() {
   117  							client.Seal(nil, msg, 0x1337, ad)
   118  							Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
   119  							client.Seal(nil, msg, 0x1338, ad)
   120  							Expect(client.FirstPacketNumber()).To(Equal(protocol.PacketNumber(0x1337)))
   121  						})
   122  
   123  						It("fails to open a message if the associated data is not the same", func() {
   124  							encrypted := client.Seal(nil, msg, 0x1337, ad)
   125  							_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, []byte("wrong ad"))
   126  							Expect(err).To(MatchError(ErrDecryptionFailed))
   127  						})
   128  
   129  						It("fails to open a message if the packet number is not the same", func() {
   130  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   131  							_, err := client.Open(nil, encrypted, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
   132  							Expect(err).To(MatchError(ErrDecryptionFailed))
   133  						})
   134  
   135  						It("decodes the packet number", func() {
   136  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   137  							_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   138  							Expect(err).ToNot(HaveOccurred())
   139  							Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
   140  						})
   141  
   142  						It("ignores packets it can't decrypt for packet number derivation", func() {
   143  							encrypted := server.Seal(nil, msg, 0x1337, ad)
   144  							_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
   145  							Expect(err).To(HaveOccurred())
   146  							Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
   147  						})
   148  
   149  						It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
   150  							client.invalidPacketLimit = 10
   151  							for i := 0; i < 9; i++ {
   152  								_, err := client.Open(nil, []byte("foobar"), time.Now(), protocol.PacketNumber(i), protocol.KeyPhaseZero, []byte("ad"))
   153  								Expect(err).To(MatchError(ErrDecryptionFailed))
   154  							}
   155  							_, err := client.Open(nil, []byte("foobar"), time.Now(), 10, protocol.KeyPhaseZero, []byte("ad"))
   156  							Expect(err).To(HaveOccurred())
   157  							Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{}))
   158  							Expect(err.(*qerr.TransportError).ErrorCode).To(Equal(qerr.AEADLimitReached))
   159  						})
   160  
   161  						Context("key updates", func() {
   162  							Context("receiving key updates", func() {
   163  								It("updates keys", func() {
   164  									now := time.Now()
   165  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   166  									encrypted0 := server.Seal(nil, msg, 0x1337, ad)
   167  									server.rollKeys()
   168  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   169  									encrypted1 := server.Seal(nil, msg, 0x1337, ad)
   170  									Expect(encrypted0).ToNot(Equal(encrypted1))
   171  									// expect opening to fail. The client didn't roll keys yet
   172  									_, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseZero, ad)
   173  									Expect(err).To(MatchError(ErrDecryptionFailed))
   174  									client.rollKeys()
   175  									decrypted, err := client.Open(nil, encrypted1, now, 0x1337, protocol.KeyPhaseOne, ad)
   176  									Expect(err).ToNot(HaveOccurred())
   177  									Expect(decrypted).To(Equal(msg))
   178  								})
   179  
   180  								It("updates the keys when receiving a packet with the next key phase", func() {
   181  									now := time.Now()
   182  									// receive the first packet at key phase zero
   183  									encrypted0 := client.Seal(nil, msg, 0x42, ad)
   184  									decrypted, err := server.Open(nil, encrypted0, now, 0x42, protocol.KeyPhaseZero, ad)
   185  									Expect(err).ToNot(HaveOccurred())
   186  									Expect(decrypted).To(Equal(msg))
   187  									// send one packet at key phase zero
   188  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   189  									_ = server.Seal(nil, msg, 0x1, ad)
   190  									// now received a message at key phase one
   191  									client.rollKeys()
   192  									encrypted1 := client.Seal(nil, msg, 0x43, ad)
   193  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   194  									decrypted, err = server.Open(nil, encrypted1, now, 0x43, protocol.KeyPhaseOne, ad)
   195  									Expect(err).ToNot(HaveOccurred())
   196  									Expect(decrypted).To(Equal(msg))
   197  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   198  								})
   199  
   200  								It("opens a reordered packet with the old keys after an update", func() {
   201  									now := time.Now()
   202  									encrypted01 := client.Seal(nil, msg, 0x42, ad)
   203  									encrypted02 := client.Seal(nil, msg, 0x43, ad)
   204  									// receive the first packet with key phase 0
   205  									_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
   206  									Expect(err).ToNot(HaveOccurred())
   207  									// send one packet at key phase zero
   208  									_ = server.Seal(nil, msg, 0x1, ad)
   209  									// now receive a packet with key phase 1
   210  									client.rollKeys()
   211  									encrypted1 := client.Seal(nil, msg, 0x44, ad)
   212  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   213  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   214  									_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
   215  									Expect(err).ToNot(HaveOccurred())
   216  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   217  									// now receive a reordered packet with key phase 0
   218  									decrypted, err := server.Open(nil, encrypted02, now, 0x43, protocol.KeyPhaseZero, ad)
   219  									Expect(err).ToNot(HaveOccurred())
   220  									Expect(decrypted).To(Equal(msg))
   221  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   222  								})
   223  
   224  								It("drops keys 3 PTOs after a key update", func() {
   225  									now := time.Now()
   226  									rttStats.UpdateRTT(10*time.Millisecond, 0, now)
   227  									pto := rttStats.PTO(true)
   228  									encrypted01 := client.Seal(nil, msg, 0x42, ad)
   229  									encrypted02 := client.Seal(nil, msg, 0x43, ad)
   230  									// receive the first packet with key phase 0
   231  									_, err := server.Open(nil, encrypted01, now, 0x42, protocol.KeyPhaseZero, ad)
   232  									Expect(err).ToNot(HaveOccurred())
   233  									// send one packet at key phase zero
   234  									_ = server.Seal(nil, msg, 0x1, ad)
   235  									// now receive a packet with key phase 1
   236  									client.rollKeys()
   237  									encrypted1 := client.Seal(nil, msg, 0x44, ad)
   238  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   239  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   240  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   241  									_, err = server.Open(nil, encrypted1, now, 0x44, protocol.KeyPhaseOne, ad)
   242  									Expect(err).ToNot(HaveOccurred())
   243  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   244  									// now receive a reordered packet with key phase 0
   245  									_, err = server.Open(nil, encrypted02, now.Add(3*pto).Add(time.Nanosecond), 0x43, protocol.KeyPhaseZero, ad)
   246  									Expect(err).To(MatchError(ErrKeysDropped))
   247  								})
   248  
   249  								It("allows the first key update immediately", func() {
   250  									// receive a packet at key phase one, before having sent or received any packets at key phase 0
   251  									client.rollKeys()
   252  									encrypted1 := client.Seal(nil, msg, 0x1337, ad)
   253  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   254  									_, err := server.Open(nil, encrypted1, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
   255  									Expect(err).ToNot(HaveOccurred())
   256  								})
   257  
   258  								It("only errors when the peer starts with key phase 1 if decrypting the packet succeeds", func() {
   259  									client.rollKeys()
   260  									encrypted := client.Seal(nil, msg, 0x1337, ad)
   261  									encrypted = encrypted[:len(encrypted)-1]
   262  									_, err := server.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseOne, ad)
   263  									Expect(err).To(MatchError(ErrDecryptionFailed))
   264  								})
   265  
   266  								It("errors when the peer updates keys too frequently", func() {
   267  									server.rollKeys()
   268  									client.rollKeys()
   269  									// receive the first packet at key phase one
   270  									encrypted0 := client.Seal(nil, msg, 0x42, ad)
   271  									_, err := server.Open(nil, encrypted0, time.Now(), 0x42, protocol.KeyPhaseOne, ad)
   272  									Expect(err).ToNot(HaveOccurred())
   273  									// now receive a packet at key phase two, before having sent any packets
   274  									client.rollKeys()
   275  									encrypted1 := client.Seal(nil, msg, 0x42, ad)
   276  									_, err = server.Open(nil, encrypted1, time.Now(), 0x42, protocol.KeyPhaseZero, ad)
   277  									Expect(err).To(MatchError(&qerr.TransportError{
   278  										ErrorCode:    qerr.KeyUpdateError,
   279  										ErrorMessage: "keys updated too quickly",
   280  									}))
   281  								})
   282  							})
   283  
   284  							Context("initiating key updates", func() {
   285  								const firstKeyUpdateInterval = 5
   286  								const keyUpdateInterval = 20
   287  								var origKeyUpdateInterval, origFirstKeyUpdateInterval uint64
   288  
   289  								BeforeEach(func() {
   290  									origKeyUpdateInterval = KeyUpdateInterval
   291  									origFirstKeyUpdateInterval = FirstKeyUpdateInterval
   292  									KeyUpdateInterval = keyUpdateInterval
   293  									FirstKeyUpdateInterval = firstKeyUpdateInterval
   294  									server.SetHandshakeConfirmed()
   295  								})
   296  
   297  								AfterEach(func() {
   298  									KeyUpdateInterval = origKeyUpdateInterval
   299  									FirstKeyUpdateInterval = origFirstKeyUpdateInterval
   300  								})
   301  
   302  								It("initiates a key update after sealing the maximum number of packets, for the first update", func() {
   303  									for i := 0; i < firstKeyUpdateInterval; i++ {
   304  										pn := protocol.PacketNumber(i)
   305  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   306  										server.Seal(nil, msg, pn, ad)
   307  									}
   308  									// the first update is allowed without receiving an acknowledgement
   309  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   310  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   311  								})
   312  
   313  								It("initiates a key update after sealing the maximum number of packets, for subsequent updates", func() {
   314  									server.rollKeys()
   315  									client.rollKeys()
   316  									for i := 0; i < keyUpdateInterval; i++ {
   317  										pn := protocol.PacketNumber(i)
   318  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   319  										server.Seal(nil, msg, pn, ad)
   320  									}
   321  									// no update allowed before receiving an acknowledgement for the current key phase
   322  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   323  									// receive an ACK for a packet sent in key phase 0
   324  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   325  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseOne, []byte("ad"))
   326  									Expect(err).ToNot(HaveOccurred())
   327  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   328  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   329  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
   330  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   331  								})
   332  
   333  								It("errors if the peer acknowledges a packet sent in the next key phase using the old key phase", func() {
   334  									// First make sure that we update our keys.
   335  									for i := 0; i < firstKeyUpdateInterval; i++ {
   336  										pn := protocol.PacketNumber(i)
   337  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   338  										server.Seal(nil, msg, pn, ad)
   339  									}
   340  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   341  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   342  									// Now that our keys are updated, send a packet using the new keys.
   343  									const nextPN = firstKeyUpdateInterval + 1
   344  									server.Seal(nil, msg, nextPN, ad)
   345  									// We haven't decrypted any packet in the new key phase yet.
   346  									// This means that the ACK must have been sent in the old key phase.
   347  									Expect(server.SetLargestAcked(nextPN)).To(MatchError(&qerr.TransportError{
   348  										ErrorCode:    qerr.KeyUpdateError,
   349  										ErrorMessage: "received ACK for key phase 1, but peer didn't update keys",
   350  									}))
   351  								})
   352  
   353  								It("doesn't error before actually sending a packet in the new key phase", func() {
   354  									// First make sure that we update our keys.
   355  									for i := 0; i < firstKeyUpdateInterval; i++ {
   356  										pn := protocol.PacketNumber(i)
   357  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   358  										server.Seal(nil, msg, pn, ad)
   359  									}
   360  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   361  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   362  									Expect(err).ToNot(HaveOccurred())
   363  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   364  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   365  									// Now that our keys are updated, send a packet using the new keys.
   366  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   367  									// We haven't decrypted any packet in the new key phase yet.
   368  									// This means that the ACK must have been sent in the old key phase.
   369  									Expect(server.SetLargestAcked(1)).ToNot(HaveOccurred())
   370  								})
   371  
   372  								It("initiates a key update after opening the maximum number of packets, for the first update", func() {
   373  									for i := 0; i < firstKeyUpdateInterval; i++ {
   374  										pn := protocol.PacketNumber(i)
   375  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   376  										encrypted := client.Seal(nil, msg, pn, ad)
   377  										_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseZero, ad)
   378  										Expect(err).ToNot(HaveOccurred())
   379  									}
   380  									// the first update is allowed without receiving an acknowledgement
   381  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   382  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   383  								})
   384  
   385  								It("initiates a key update after opening the maximum number of packets, for subsequent updates", func() {
   386  									server.rollKeys()
   387  									client.rollKeys()
   388  									for i := 0; i < keyUpdateInterval; i++ {
   389  										pn := protocol.PacketNumber(i)
   390  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   391  										encrypted := client.Seal(nil, msg, pn, ad)
   392  										_, err := server.Open(nil, encrypted, time.Now(), pn, protocol.KeyPhaseOne, ad)
   393  										Expect(err).ToNot(HaveOccurred())
   394  									}
   395  									// no update allowed before receiving an acknowledgement for the current key phase
   396  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   397  									server.Seal(nil, msg, 1, ad)
   398  									Expect(server.SetLargestAcked(1)).To(Succeed())
   399  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   400  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false)
   401  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   402  								})
   403  
   404  								It("drops keys 3 PTOs after a key update", func() {
   405  									now := time.Now()
   406  									for i := 0; i < firstKeyUpdateInterval; i++ {
   407  										pn := protocol.PacketNumber(i)
   408  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   409  										server.Seal(nil, msg, pn, ad)
   410  									}
   411  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   412  									_, err := server.Open(nil, b, now, 1, protocol.KeyPhaseZero, []byte("ad"))
   413  									Expect(err).ToNot(HaveOccurred())
   414  									Expect(server.SetLargestAcked(0)).To(Succeed())
   415  									// Now we've initiated the first key update.
   416  									// Decrypt a message sent from the client more than 3 PTO later to make sure the key is still there
   417  									threePTO := 3 * rttStats.PTO(false)
   418  									dataKeyPhaseZero := client.Seal(nil, msg, 1, ad)
   419  									_, err = server.Open(nil, dataKeyPhaseZero, now.Add(threePTO).Add(time.Second), 1, protocol.KeyPhaseZero, ad)
   420  									Expect(err).ToNot(HaveOccurred())
   421  									// Now receive a packet with key phase 1.
   422  									// This should start the timer to drop the keys after 3 PTOs.
   423  									client.rollKeys()
   424  									dataKeyPhaseOne := client.Seal(nil, msg, 10, ad)
   425  									t := now.Add(threePTO).Add(time.Second)
   426  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), true)
   427  									_, err = server.Open(nil, dataKeyPhaseOne, t, 10, protocol.KeyPhaseOne, ad)
   428  									Expect(err).ToNot(HaveOccurred())
   429  									// Make sure the keys are still here.
   430  									_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO*9/10), 1, protocol.KeyPhaseZero, ad)
   431  									Expect(err).ToNot(HaveOccurred())
   432  									serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0))
   433  									_, err = server.Open(nil, dataKeyPhaseZero, t.Add(threePTO).Add(time.Nanosecond), 1, protocol.KeyPhaseZero, ad)
   434  									Expect(err).To(MatchError(ErrKeysDropped))
   435  								})
   436  
   437  								It("doesn't drop the first key generation too early", func() {
   438  									now := time.Now()
   439  									data1 := client.Seal(nil, msg, 1, ad)
   440  									_, err := server.Open(nil, data1, now, 1, protocol.KeyPhaseZero, ad)
   441  									Expect(err).ToNot(HaveOccurred())
   442  									for i := 0; i < firstKeyUpdateInterval; i++ {
   443  										pn := protocol.PacketNumber(i)
   444  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   445  										server.Seal(nil, msg, pn, ad)
   446  										Expect(server.SetLargestAcked(pn)).To(Succeed())
   447  									}
   448  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   449  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   450  									// The server never received a packet at key phase 1.
   451  									// Make sure the key phase 0 is still there at a much later point.
   452  									data2 := client.Seal(nil, msg, 1, ad)
   453  									_, err = server.Open(nil, data2, now.Add(10*rttStats.PTO(true)), 1, protocol.KeyPhaseZero, ad)
   454  									Expect(err).ToNot(HaveOccurred())
   455  								})
   456  
   457  								It("drops keys early when the peer forces initiates a key update within the 3 PTO period", func() {
   458  									for i := 0; i < firstKeyUpdateInterval; i++ {
   459  										pn := protocol.PacketNumber(i)
   460  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   461  										server.Seal(nil, msg, pn, ad)
   462  									}
   463  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   464  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   465  									Expect(err).ToNot(HaveOccurred())
   466  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   467  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   468  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   469  									const nextPN = keyUpdateInterval + 1
   470  									// Send and receive an acknowledgement for a packet in key phase 1.
   471  									// We are now running a timer to drop the keys with 3 PTO.
   472  									server.Seal(nil, msg, nextPN, ad)
   473  									client.rollKeys()
   474  									dataKeyPhaseOne := client.Seal(nil, msg, 2, ad)
   475  									now := time.Now()
   476  									_, err = server.Open(nil, dataKeyPhaseOne, now, 2, protocol.KeyPhaseOne, ad)
   477  									Expect(err).ToNot(HaveOccurred())
   478  									Expect(server.SetLargestAcked(nextPN))
   479  									// Now the client sends us a packet in key phase 2, forcing us to update keys before the 3 PTO period is over.
   480  									// This mean that we need to drop the keys for key phase 0 immediately.
   481  									client.rollKeys()
   482  									dataKeyPhaseTwo := client.Seal(nil, msg, 3, ad)
   483  									gomock.InOrder(
   484  										serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
   485  										serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), true),
   486  									)
   487  									_, err = server.Open(nil, dataKeyPhaseTwo, now, 3, protocol.KeyPhaseZero, ad)
   488  									Expect(err).ToNot(HaveOccurred())
   489  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   490  								})
   491  
   492  								It("drops keys early when we initiate another key update within the 3 PTO period", func() {
   493  									server.SetHandshakeConfirmed()
   494  									// send so many packets that we initiate the first key update
   495  									for i := 0; i < firstKeyUpdateInterval; i++ {
   496  										pn := protocol.PacketNumber(i)
   497  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   498  										server.Seal(nil, msg, pn, ad)
   499  									}
   500  									b := client.Seal(nil, []byte("foobar"), 1, []byte("ad"))
   501  									_, err := server.Open(nil, b, time.Now(), 1, protocol.KeyPhaseZero, []byte("ad"))
   502  									Expect(err).ToNot(HaveOccurred())
   503  									ExpectWithOffset(1, server.SetLargestAcked(0)).To(Succeed())
   504  									serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(1), false)
   505  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   506  									// send so many packets that we initiate the next key update
   507  									for i := keyUpdateInterval; i < 2*keyUpdateInterval; i++ {
   508  										pn := protocol.PacketNumber(i)
   509  										Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseOne))
   510  										server.Seal(nil, msg, pn, ad)
   511  									}
   512  									client.rollKeys()
   513  									b = client.Seal(nil, []byte("foobar"), 2, []byte("ad"))
   514  									now := time.Now()
   515  									_, err = server.Open(nil, b, now, 2, protocol.KeyPhaseOne, []byte("ad"))
   516  									Expect(err).ToNot(HaveOccurred())
   517  									ExpectWithOffset(1, server.SetLargestAcked(keyUpdateInterval)).To(Succeed())
   518  									gomock.InOrder(
   519  										serverTracer.EXPECT().DroppedKey(protocol.KeyPhase(0)),
   520  										serverTracer.EXPECT().UpdatedKey(protocol.KeyPhase(2), false),
   521  									)
   522  									Expect(server.KeyPhase()).To(Equal(protocol.KeyPhaseZero))
   523  									// We haven't received an ACK for a packet sent in key phase 2 yet.
   524  									// Make sure we canceled the timer to drop the previous key phase.
   525  									b = client.Seal(nil, []byte("foobar"), 3, []byte("ad"))
   526  									_, err = server.Open(nil, b, now.Add(10*rttStats.PTO(true)), 3, protocol.KeyPhaseOne, []byte("ad"))
   527  									Expect(err).ToNot(HaveOccurred())
   528  								})
   529  							})
   530  						})
   531  					})
   532  				})
   533  			}
   534  		})
   535  	}
   536  })
   537  
   538  func getClientAndServer() (client, server *updatableAEAD) {
   539  	trafficSecret1 := make([]byte, 16)
   540  	trafficSecret2 := make([]byte, 16)
   541  	rand.Read(trafficSecret1)
   542  	rand.Read(trafficSecret2)
   543  
   544  	cs := cipherSuites[0]
   545  	rttStats := utils.NewRTTStats()
   546  	client = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
   547  	server = newUpdatableAEAD(rttStats, nil, utils.DefaultLogger, protocol.Version1)
   548  	client.SetReadKey(cs, trafficSecret2)
   549  	client.SetWriteKey(cs, trafficSecret1)
   550  	server.SetReadKey(cs, trafficSecret1)
   551  	server.SetWriteKey(cs, trafficSecret2)
   552  	return
   553  }
   554  
   555  func BenchmarkPacketEncryption(b *testing.B) {
   556  	client, _ := getClientAndServer()
   557  	const l = 1200
   558  	src := make([]byte, l)
   559  	rand.Read(src)
   560  	ad := make([]byte, 32)
   561  	rand.Read(ad)
   562  
   563  	for i := 0; i < b.N; i++ {
   564  		src = client.Seal(src[:0], src[:l], protocol.PacketNumber(i), ad)
   565  	}
   566  }
   567  
   568  func BenchmarkPacketDecryption(b *testing.B) {
   569  	client, server := getClientAndServer()
   570  	const l = 1200
   571  	src := make([]byte, l)
   572  	dst := make([]byte, l)
   573  	rand.Read(src)
   574  	ad := make([]byte, 32)
   575  	rand.Read(ad)
   576  	src = client.Seal(src[:0], src[:l], 1337, ad)
   577  
   578  	for i := 0; i < b.N; i++ {
   579  		if _, err := server.Open(dst[:0], src, time.Time{}, 1337, protocol.KeyPhaseZero, ad); err != nil {
   580  			b.Fatalf("opening failed: %v", err)
   581  		}
   582  	}
   583  }
   584  
   585  func BenchmarkRollKeys(b *testing.B) {
   586  	client, _ := getClientAndServer()
   587  	for i := 0; i < b.N; i++ {
   588  		client.rollKeys()
   589  	}
   590  	if int(client.keyPhase) != b.N {
   591  		b.Fatal("didn't roll keys often enough")
   592  	}
   593  }