github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/handshake/updatable_aead.go (about) 1 package handshake 2 3 import ( 4 "crypto" 5 "crypto/cipher" 6 "crypto/tls" 7 "encoding/binary" 8 "fmt" 9 "time" 10 11 "github.com/danielpfeifer02/quic-go-prio-packs/crypto_turnoff" 12 "github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol" 13 "github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr" 14 "github.com/danielpfeifer02/quic-go-prio-packs/internal/utils" 15 "github.com/danielpfeifer02/quic-go-prio-packs/logging" 16 ) 17 18 // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. 19 // It's a package-level variable to allow modifying it for testing purposes. 20 var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval 21 22 // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update. 23 // It's a package-level variable to allow modifying it for testing purposes. 24 var FirstKeyUpdateInterval uint64 = 100 25 26 type updatableAEAD struct { 27 suite *cipherSuite 28 29 keyPhase protocol.KeyPhase 30 largestAcked protocol.PacketNumber 31 firstPacketNumber protocol.PacketNumber 32 handshakeConfirmed bool 33 34 invalidPacketLimit uint64 35 invalidPacketCount uint64 36 37 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). 38 prevRcvAEADExpiry time.Time 39 prevRcvAEAD cipher.AEAD 40 41 firstRcvdWithCurrentKey protocol.PacketNumber 42 firstSentWithCurrentKey protocol.PacketNumber 43 highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) 44 numRcvdWithCurrentKey uint64 45 numSentWithCurrentKey uint64 46 rcvAEAD cipher.AEAD 47 sendAEAD cipher.AEAD 48 // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). 49 aeadOverhead int 50 51 nextRcvAEAD cipher.AEAD 52 nextSendAEAD cipher.AEAD 53 nextRcvTrafficSecret []byte 54 nextSendTrafficSecret []byte 55 56 headerDecrypter headerProtector 57 headerEncrypter headerProtector 58 59 rttStats *utils.RTTStats 60 61 tracer *logging.ConnectionTracer 62 logger utils.Logger 63 version protocol.Version 64 65 // use a single slice to avoid allocations 66 nonceBuf []byte 67 } 68 69 var ( 70 _ ShortHeaderOpener = &updatableAEAD{} 71 _ ShortHeaderSealer = &updatableAEAD{} 72 ) 73 74 func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD { 75 return &updatableAEAD{ 76 firstPacketNumber: protocol.InvalidPacketNumber, 77 largestAcked: protocol.InvalidPacketNumber, 78 firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, 79 firstSentWithCurrentKey: protocol.InvalidPacketNumber, 80 rttStats: rttStats, 81 tracer: tracer, 82 logger: logger, 83 version: version, 84 } 85 } 86 87 func (a *updatableAEAD) rollKeys() { 88 if a.prevRcvAEAD != nil { 89 a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) 90 if a.tracer != nil && a.tracer.DroppedKey != nil { 91 a.tracer.DroppedKey(a.keyPhase - 1) 92 } 93 a.prevRcvAEADExpiry = time.Time{} 94 } 95 96 a.keyPhase++ 97 a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber 98 a.firstSentWithCurrentKey = protocol.InvalidPacketNumber 99 a.numRcvdWithCurrentKey = 0 100 a.numSentWithCurrentKey = 0 101 a.prevRcvAEAD = a.rcvAEAD 102 a.rcvAEAD = a.nextRcvAEAD 103 a.sendAEAD = a.nextSendAEAD 104 105 a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) 106 a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) 107 a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version) 108 a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version) 109 } 110 111 func (a *updatableAEAD) startKeyDropTimer(now time.Time) { 112 d := 3 * a.rttStats.PTO(true) 113 a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) 114 a.prevRcvAEADExpiry = now.Add(d) 115 } 116 117 func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { 118 return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) 119 } 120 121 // SetReadKey sets the read key. 122 // For the client, this function is called before SetWriteKey. 123 // For the server, this function is called after SetWriteKey. 124 func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) { 125 a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) 126 a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) 127 if a.suite == nil { 128 a.setAEADParameters(a.rcvAEAD, suite) 129 } 130 131 a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 132 a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version) 133 } 134 135 // SetWriteKey sets the write key. 136 // For the client, this function is called after SetReadKey. 137 // For the server, this function is called before SetReadKey. 138 func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { 139 a.sendAEAD = createAEAD(suite, trafficSecret, a.version) 140 a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) 141 if a.suite == nil { 142 a.setAEADParameters(a.sendAEAD, suite) 143 } 144 145 a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 146 a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) 147 } 148 149 func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) { 150 a.nonceBuf = make([]byte, aead.NonceSize()) 151 a.aeadOverhead = aead.Overhead() 152 a.suite = suite 153 switch suite.ID { 154 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: 155 a.invalidPacketLimit = protocol.InvalidPacketLimitAES 156 case tls.TLS_CHACHA20_POLY1305_SHA256: 157 a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha 158 159 // NO_CRYPTO_TAG 160 // based on https://pkg.go.dev/crypto/tls#pkg-constants 0x0000 is not used for any other cipher suite 161 case 0x0000: 162 print("no AEAD parameters need to be set for 0x0000\n") 163 164 default: 165 panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) 166 } 167 } 168 169 func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { 170 return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) 171 } 172 173 func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 174 175 // NO_CRYPTO_TAG 176 if crypto_turnoff.CRYPTO_TURNED_OFF { 177 return src, nil 178 } 179 180 dec, err := a.open(dst, src, rcvTime, pn, kp, ad) 181 if err == ErrDecryptionFailed { 182 a.invalidPacketCount++ 183 if a.invalidPacketCount >= a.invalidPacketLimit { 184 return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} 185 } 186 } 187 if err == nil { 188 a.highestRcvdPN = max(a.highestRcvdPN, pn) 189 } 190 return dec, err 191 } 192 193 func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 194 if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { 195 a.prevRcvAEAD = nil 196 a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) 197 a.prevRcvAEADExpiry = time.Time{} 198 if a.tracer != nil && a.tracer.DroppedKey != nil { 199 a.tracer.DroppedKey(a.keyPhase - 1) 200 } 201 } 202 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 203 if kp != a.keyPhase.Bit() { 204 if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { 205 if a.prevRcvAEAD == nil { 206 return nil, ErrKeysDropped 207 } 208 // we updated the key, but the peer hasn't updated yet 209 dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) 210 if err != nil { 211 err = ErrDecryptionFailed 212 } 213 return dec, err 214 } 215 // try opening the packet with the next key phase 216 dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) 217 if err != nil { 218 return nil, ErrDecryptionFailed 219 } 220 // Opening succeeded. Check if the peer was allowed to update. 221 if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 222 return nil, &qerr.TransportError{ 223 ErrorCode: qerr.KeyUpdateError, 224 ErrorMessage: "keys updated too quickly", 225 } 226 } 227 a.rollKeys() 228 a.logger.Debugf("Peer updated keys to %d", a.keyPhase) 229 // The peer initiated this key update. It's safe to drop the keys for the previous generation now. 230 // Start a timer to drop the previous key generation. 231 a.startKeyDropTimer(rcvTime) 232 if a.tracer != nil && a.tracer.UpdatedKey != nil { 233 a.tracer.UpdatedKey(a.keyPhase, true) 234 } 235 a.firstRcvdWithCurrentKey = pn 236 return dec, err 237 } 238 // The AEAD we're using here will be the qtls.aeadAESGCM13. 239 // It uses the nonce provided here and XOR it with the IV. 240 dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) 241 if err != nil { 242 return dec, ErrDecryptionFailed 243 } 244 a.numRcvdWithCurrentKey++ 245 if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { 246 // We initiated the key updated, and now we received the first packet protected with the new key phase. 247 // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. 248 if a.keyPhase > 0 { 249 a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) 250 a.startKeyDropTimer(rcvTime) 251 } 252 a.firstRcvdWithCurrentKey = pn 253 } 254 return dec, err 255 } 256 257 func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { 258 if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 259 a.firstSentWithCurrentKey = pn 260 } 261 if a.firstPacketNumber == protocol.InvalidPacketNumber { 262 a.firstPacketNumber = pn 263 } 264 a.numSentWithCurrentKey++ 265 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 266 // The AEAD we're using here will be the qtls.aeadAESGCM13. 267 // It uses the nonce provided here and XOR it with the IV. 268 return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) 269 } 270 271 func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { 272 if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 273 pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { 274 return &qerr.TransportError{ 275 ErrorCode: qerr.KeyUpdateError, 276 ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), 277 } 278 } 279 a.largestAcked = pn 280 return nil 281 } 282 283 func (a *updatableAEAD) SetHandshakeConfirmed() { 284 a.handshakeConfirmed = true 285 } 286 287 func (a *updatableAEAD) updateAllowed() bool { 288 if !a.handshakeConfirmed { 289 return false 290 } 291 // the first key update is allowed as soon as the handshake is confirmed 292 return a.keyPhase == 0 || 293 // subsequent key updates as soon as a packet sent with that key phase has been acknowledged 294 (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 295 a.largestAcked != protocol.InvalidPacketNumber && 296 a.largestAcked >= a.firstSentWithCurrentKey) 297 } 298 299 func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { 300 if !a.updateAllowed() { 301 return false 302 } 303 // Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism. 304 if a.keyPhase == 0 { 305 if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval { 306 return true 307 } 308 } 309 if a.numRcvdWithCurrentKey >= KeyUpdateInterval { 310 a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) 311 return true 312 } 313 if a.numSentWithCurrentKey >= KeyUpdateInterval { 314 a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) 315 return true 316 } 317 return false 318 } 319 320 func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { 321 if a.shouldInitiateKeyUpdate() { 322 a.rollKeys() 323 a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) 324 if a.tracer != nil && a.tracer.UpdatedKey != nil { 325 a.tracer.UpdatedKey(a.keyPhase, false) 326 } 327 } 328 return a.keyPhase.Bit() 329 } 330 331 func (a *updatableAEAD) Overhead() int { 332 return a.aeadOverhead 333 } 334 335 func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 336 a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) 337 } 338 339 func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 340 a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) 341 } 342 343 func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { 344 return a.firstPacketNumber 345 }