github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/internal/handshake/updatable_aead.go (about) 1 package handshake 2 3 import ( 4 "crypto" 5 "crypto/cipher" 6 "crypto/tls" 7 "encoding/binary" 8 "fmt" 9 "time" 10 11 "github.com/apernet/quic-go/internal/protocol" 12 "github.com/apernet/quic-go/internal/qerr" 13 "github.com/apernet/quic-go/internal/utils" 14 "github.com/apernet/quic-go/logging" 15 ) 16 17 // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. 18 // It's a package-level variable to allow modifying it for testing purposes. 19 var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval 20 21 // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. 22 // It's a package-level variable to allow modifying it for testing purposes. 23 var FirstKeyUpdateInterval uint64 = 100 24 25 type updatableAEAD struct { 26 suite *cipherSuite 27 28 keyPhase protocol.KeyPhase 29 largestAcked protocol.PacketNumber 30 firstPacketNumber protocol.PacketNumber 31 handshakeConfirmed bool 32 33 invalidPacketLimit uint64 34 invalidPacketCount uint64 35 36 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). 37 prevRcvAEADExpiry time.Time 38 prevRcvAEAD cipher.AEAD 39 40 firstRcvdWithCurrentKey protocol.PacketNumber 41 firstSentWithCurrentKey protocol.PacketNumber 42 highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) 43 numRcvdWithCurrentKey uint64 44 numSentWithCurrentKey uint64 45 rcvAEAD cipher.AEAD 46 sendAEAD cipher.AEAD 47 // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). 48 aeadOverhead int 49 50 nextRcvAEAD cipher.AEAD 51 nextSendAEAD cipher.AEAD 52 nextRcvTrafficSecret []byte 53 nextSendTrafficSecret []byte 54 55 headerDecrypter headerProtector 56 headerEncrypter headerProtector 57 58 rttStats *utils.RTTStats 59 60 tracer *logging.ConnectionTracer 61 logger utils.Logger 62 version protocol.Version 63 64 // use a single slice to avoid allocations 65 nonceBuf []byte 66 } 67 68 var ( 69 _ ShortHeaderOpener = &updatableAEAD{} 70 _ ShortHeaderSealer = &updatableAEAD{} 71 ) 72 73 func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD { 74 return &updatableAEAD{ 75 firstPacketNumber: protocol.InvalidPacketNumber, 76 largestAcked: protocol.InvalidPacketNumber, 77 firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, 78 firstSentWithCurrentKey: protocol.InvalidPacketNumber, 79 rttStats: rttStats, 80 tracer: tracer, 81 logger: logger, 82 version: version, 83 } 84 } 85 86 func (a *updatableAEAD) rollKeys() { 87 if a.prevRcvAEAD != nil { 88 a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) 89 if a.tracer != nil && a.tracer.DroppedKey != nil { 90 a.tracer.DroppedKey(a.keyPhase - 1) 91 } 92 a.prevRcvAEADExpiry = time.Time{} 93 } 94 95 a.keyPhase++ 96 a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber 97 a.firstSentWithCurrentKey = protocol.InvalidPacketNumber 98 a.numRcvdWithCurrentKey = 0 99 a.numSentWithCurrentKey = 0 100 a.prevRcvAEAD = a.rcvAEAD 101 a.rcvAEAD = a.nextRcvAEAD 102 a.sendAEAD = a.nextSendAEAD 103 104 a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) 105 a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) 106 a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) 107 a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) 108 } 109 110 func (a *updatableAEAD) startKeyDropTimer(now time.Time) { 111 d := 3 * a.rttStats.PTO(true) 112 a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) 113 a.prevRcvAEADExpiry = now.Add(d) 114 } 115 116 func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { 117 return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) 118 } 119 120 // SetReadKey sets the read key. 121 // For the client, this function is called before SetWriteKey. 122 // For the server, this function is called after SetWriteKey. 123 func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) { 124 a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) 125 a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) 126 if a.suite == nil { 127 a.setAEADParameters(a.rcvAEAD, suite) 128 } 129 130 a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 131 a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) 132 } 133 134 // SetWriteKey sets the write key. 135 // For the client, this function is called after SetReadKey. 136 // For the server, this function is called before SetReadKey. 137 func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { 138 a.sendAEAD = createAEAD(suite, trafficSecret, a.version) 139 a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) 140 if a.suite == nil { 141 a.setAEADParameters(a.sendAEAD, suite) 142 } 143 144 a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 145 a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) 146 } 147 148 func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) { 149 a.nonceBuf = make([]byte, aead.NonceSize()) 150 a.aeadOverhead = aead.Overhead() 151 a.suite = suite 152 switch suite.ID { 153 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: 154 a.invalidPacketLimit = protocol.InvalidPacketLimitAES 155 case tls.TLS_CHACHA20_POLY1305_SHA256: 156 a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha 157 default: 158 panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) 159 } 160 } 161 162 func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { 163 return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) 164 } 165 166 func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 167 dec, err := a.open(dst, src, rcvTime, pn, kp, ad) 168 if err == ErrDecryptionFailed { 169 a.invalidPacketCount++ 170 if a.invalidPacketCount >= a.invalidPacketLimit { 171 return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} 172 } 173 } 174 if err == nil { 175 a.highestRcvdPN = max(a.highestRcvdPN, pn) 176 } 177 return dec, err 178 } 179 180 func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 181 if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { 182 a.prevRcvAEAD = nil 183 a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) 184 a.prevRcvAEADExpiry = time.Time{} 185 if a.tracer != nil && a.tracer.DroppedKey != nil { 186 a.tracer.DroppedKey(a.keyPhase - 1) 187 } 188 } 189 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 190 if kp != a.keyPhase.Bit() { 191 if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { 192 if a.prevRcvAEAD == nil { 193 return nil, ErrKeysDropped 194 } 195 // we updated the key, but the peer hasn't updated yet 196 dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) 197 if err != nil { 198 err = ErrDecryptionFailed 199 } 200 return dec, err 201 } 202 // try opening the packet with the next key phase 203 dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) 204 if err != nil { 205 return nil, ErrDecryptionFailed 206 } 207 // Opening succeeded. Check if the peer was allowed to update. 208 if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 209 return nil, &qerr.TransportError{ 210 ErrorCode: qerr.KeyUpdateError, 211 ErrorMessage: "keys updated too quickly", 212 } 213 } 214 a.rollKeys() 215 a.logger.Debugf("Peer updated keys to %d", a.keyPhase) 216 // The peer initiated this key update. It's safe to drop the keys for the previous generation now. 217 // Start a timer to drop the previous key generation. 218 a.startKeyDropTimer(rcvTime) 219 if a.tracer != nil && a.tracer.UpdatedKey != nil { 220 a.tracer.UpdatedKey(a.keyPhase, true) 221 } 222 a.firstRcvdWithCurrentKey = pn 223 return dec, err 224 } 225 // The AEAD we're using here will be the qtls.aeadAESGCM13. 226 // It uses the nonce provided here and XOR it with the IV. 227 dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) 228 if err != nil { 229 return dec, ErrDecryptionFailed 230 } 231 a.numRcvdWithCurrentKey++ 232 if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { 233 // We initiated the key updated, and now we received the first packet protected with the new key phase. 234 // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. 235 if a.keyPhase > 0 { 236 a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) 237 a.startKeyDropTimer(rcvTime) 238 } 239 a.firstRcvdWithCurrentKey = pn 240 } 241 return dec, err 242 } 243 244 func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { 245 if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 246 a.firstSentWithCurrentKey = pn 247 } 248 if a.firstPacketNumber == protocol.InvalidPacketNumber { 249 a.firstPacketNumber = pn 250 } 251 a.numSentWithCurrentKey++ 252 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 253 // The AEAD we're using here will be the qtls.aeadAESGCM13. 254 // It uses the nonce provided here and XOR it with the IV. 255 return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) 256 } 257 258 func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { 259 if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 260 pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { 261 return &qerr.TransportError{ 262 ErrorCode: qerr.KeyUpdateError, 263 ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), 264 } 265 } 266 a.largestAcked = pn 267 return nil 268 } 269 270 func (a *updatableAEAD) SetHandshakeConfirmed() { 271 a.handshakeConfirmed = true 272 } 273 274 func (a *updatableAEAD) updateAllowed() bool { 275 if !a.handshakeConfirmed { 276 return false 277 } 278 // the first key update is allowed as soon as the handshake is confirmed 279 return a.keyPhase == 0 || 280 // subsequent key updates as soon as a packet sent with that key phase has been acknowledged 281 (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 282 a.largestAcked != protocol.InvalidPacketNumber && 283 a.largestAcked >= a.firstSentWithCurrentKey) 284 } 285 286 func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { 287 if !a.updateAllowed() { 288 return false 289 } 290 // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. 291 if a.keyPhase == 0 { 292 if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { 293 return true 294 } 295 } 296 if a.numRcvdWithCurrentKey >= KeyUpdateInterval { 297 a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) 298 return true 299 } 300 if a.numSentWithCurrentKey >= KeyUpdateInterval { 301 a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) 302 return true 303 } 304 return false 305 } 306 307 func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { 308 if a.shouldInitiateKeyUpdate() { 309 a.rollKeys() 310 a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) 311 if a.tracer != nil && a.tracer.UpdatedKey != nil { 312 a.tracer.UpdatedKey(a.keyPhase, false) 313 } 314 } 315 return a.keyPhase.Bit() 316 } 317 318 func (a *updatableAEAD) Overhead() int { 319 return a.aeadOverhead 320 } 321 322 func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 323 a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) 324 } 325 326 func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 327 a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) 328 } 329 330 func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { 331 return a.firstPacketNumber 332 }