github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/internal/handshake/updatable_aead.go (about) 1 package handshake 2 3 import ( 4 "crypto" 5 "crypto/cipher" 6 "crypto/tls" 7 "encoding/binary" 8 "fmt" 9 "time" 10 11 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol" 12 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr" 13 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qtls" 14 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils" 15 "github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging" 16 ) 17 18 // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update. 19 // It's a package-level variable to allow modifying it for testing purposes. 20 var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval 21 22 type updatableAEAD struct { 23 suite *qtls.CipherSuiteTLS13 24 25 keyPhase protocol.KeyPhase 26 largestAcked protocol.PacketNumber 27 firstPacketNumber protocol.PacketNumber 28 handshakeConfirmed bool 29 30 keyUpdateInterval uint64 31 invalidPacketLimit uint64 32 invalidPacketCount uint64 33 34 // Time when the keys should be dropped. Keys are dropped on the next call to Open(). 35 prevRcvAEADExpiry time.Time 36 prevRcvAEAD cipher.AEAD 37 38 firstRcvdWithCurrentKey protocol.PacketNumber 39 firstSentWithCurrentKey protocol.PacketNumber 40 highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected) 41 numRcvdWithCurrentKey uint64 42 numSentWithCurrentKey uint64 43 rcvAEAD cipher.AEAD 44 sendAEAD cipher.AEAD 45 // caches cipher.AEAD.Overhead(). This speeds up calls to Overhead(). 46 aeadOverhead int 47 48 nextRcvAEAD cipher.AEAD 49 nextSendAEAD cipher.AEAD 50 nextRcvTrafficSecret []byte 51 nextSendTrafficSecret []byte 52 53 headerDecrypter headerProtector 54 headerEncrypter headerProtector 55 56 rttStats *utils.RTTStats 57 58 tracer logging.ConnectionTracer 59 logger utils.Logger 60 61 // use a single slice to avoid allocations 62 nonceBuf []byte 63 } 64 65 var ( 66 _ ShortHeaderOpener = &updatableAEAD{} 67 _ ShortHeaderSealer = &updatableAEAD{} 68 ) 69 70 func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD { 71 return &updatableAEAD{ 72 firstPacketNumber: protocol.InvalidPacketNumber, 73 largestAcked: protocol.InvalidPacketNumber, 74 firstRcvdWithCurrentKey: protocol.InvalidPacketNumber, 75 firstSentWithCurrentKey: protocol.InvalidPacketNumber, 76 keyUpdateInterval: KeyUpdateInterval, 77 rttStats: rttStats, 78 tracer: tracer, 79 logger: logger, 80 } 81 } 82 83 func (a *updatableAEAD) rollKeys() { 84 if a.prevRcvAEAD != nil { 85 a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry) 86 if a.tracer != nil { 87 a.tracer.DroppedKey(a.keyPhase - 1) 88 } 89 a.prevRcvAEADExpiry = time.Time{} 90 } 91 92 a.keyPhase++ 93 a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber 94 a.firstSentWithCurrentKey = protocol.InvalidPacketNumber 95 a.numRcvdWithCurrentKey = 0 96 a.numSentWithCurrentKey = 0 97 a.prevRcvAEAD = a.rcvAEAD 98 a.rcvAEAD = a.nextRcvAEAD 99 a.sendAEAD = a.nextSendAEAD 100 101 a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret) 102 a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret) 103 a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret) 104 a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret) 105 } 106 107 func (a *updatableAEAD) startKeyDropTimer(now time.Time) { 108 d := 3 * a.rttStats.PTO(true) 109 a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d) 110 a.prevRcvAEADExpiry = now.Add(d) 111 } 112 113 func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte { 114 return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size()) 115 } 116 117 // For the client, this function is called before SetWriteKey. 118 // For the server, this function is called after SetWriteKey. 119 func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { 120 a.rcvAEAD = createAEAD(suite, trafficSecret) 121 a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false) 122 if a.suite == nil { 123 a.setAEADParameters(a.rcvAEAD, suite) 124 } 125 126 a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 127 a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret) 128 } 129 130 // For the client, this function is called after SetReadKey. 131 // For the server, this function is called before SetWriteKey. 132 func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { 133 a.sendAEAD = createAEAD(suite, trafficSecret) 134 a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false) 135 if a.suite == nil { 136 a.setAEADParameters(a.sendAEAD, suite) 137 } 138 139 a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret) 140 a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret) 141 } 142 143 func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { 144 a.nonceBuf = make([]byte, aead.NonceSize()) 145 a.aeadOverhead = aead.Overhead() 146 a.suite = suite 147 switch suite.ID { 148 case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: 149 a.invalidPacketLimit = protocol.InvalidPacketLimitAES 150 case tls.TLS_CHACHA20_POLY1305_SHA256: 151 a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha 152 default: 153 panic(fmt.Sprintf("unknown cipher suite %d", suite.ID)) 154 } 155 } 156 157 func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber { 158 return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN) 159 } 160 161 func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 162 dec, err := a.open(dst, src, rcvTime, pn, kp, ad) 163 if err == ErrDecryptionFailed { 164 a.invalidPacketCount++ 165 if a.invalidPacketCount >= a.invalidPacketLimit { 166 return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached} 167 } 168 } 169 if err == nil { 170 a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn) 171 } 172 return dec, err 173 } 174 175 func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) { 176 if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) { 177 a.prevRcvAEAD = nil 178 a.logger.Debugf("Dropping key phase %d", a.keyPhase-1) 179 a.prevRcvAEADExpiry = time.Time{} 180 if a.tracer != nil { 181 a.tracer.DroppedKey(a.keyPhase - 1) 182 } 183 } 184 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 185 if kp != a.keyPhase.Bit() { 186 if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey { 187 if a.prevRcvAEAD == nil { 188 return nil, ErrKeysDropped 189 } 190 // we updated the key, but the peer hasn't updated yet 191 dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad) 192 if err != nil { 193 err = ErrDecryptionFailed 194 } 195 return dec, err 196 } 197 // try opening the packet with the next key phase 198 dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad) 199 if err != nil { 200 return nil, ErrDecryptionFailed 201 } 202 // Opening succeeded. Check if the peer was allowed to update. 203 if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 204 return nil, &qerr.TransportError{ 205 ErrorCode: qerr.KeyUpdateError, 206 ErrorMessage: "keys updated too quickly", 207 } 208 } 209 a.rollKeys() 210 a.logger.Debugf("Peer updated keys to %d", a.keyPhase) 211 // The peer initiated this key update. It's safe to drop the keys for the previous generation now. 212 // Start a timer to drop the previous key generation. 213 a.startKeyDropTimer(rcvTime) 214 if a.tracer != nil { 215 a.tracer.UpdatedKey(a.keyPhase, true) 216 } 217 a.firstRcvdWithCurrentKey = pn 218 return dec, err 219 } 220 // The AEAD we're using here will be the qtls.aeadAESGCM13. 221 // It uses the nonce provided here and XOR it with the IV. 222 dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad) 223 if err != nil { 224 return dec, ErrDecryptionFailed 225 } 226 a.numRcvdWithCurrentKey++ 227 if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber { 228 // We initiated the key updated, and now we received the first packet protected with the new key phase. 229 // Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys. 230 if a.keyPhase > 0 { 231 a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase) 232 a.startKeyDropTimer(rcvTime) 233 } 234 a.firstRcvdWithCurrentKey = pn 235 } 236 return dec, err 237 } 238 239 func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte { 240 if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber { 241 a.firstSentWithCurrentKey = pn 242 } 243 if a.firstPacketNumber == protocol.InvalidPacketNumber { 244 a.firstPacketNumber = pn 245 } 246 a.numSentWithCurrentKey++ 247 binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn)) 248 // The AEAD we're using here will be the qtls.aeadAESGCM13. 249 // It uses the nonce provided here and XOR it with the IV. 250 return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad) 251 } 252 253 func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error { 254 if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 255 pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 { 256 return &qerr.TransportError{ 257 ErrorCode: qerr.KeyUpdateError, 258 ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase), 259 } 260 } 261 a.largestAcked = pn 262 return nil 263 } 264 265 func (a *updatableAEAD) SetHandshakeConfirmed() { 266 a.handshakeConfirmed = true 267 } 268 269 func (a *updatableAEAD) updateAllowed() bool { 270 if !a.handshakeConfirmed { 271 return false 272 } 273 // the first key update is allowed as soon as the handshake is confirmed 274 return a.keyPhase == 0 || 275 // subsequent key updates as soon as a packet sent with that key phase has been acknowledged 276 (a.firstSentWithCurrentKey != protocol.InvalidPacketNumber && 277 a.largestAcked != protocol.InvalidPacketNumber && 278 a.largestAcked >= a.firstSentWithCurrentKey) 279 } 280 281 func (a *updatableAEAD) shouldInitiateKeyUpdate() bool { 282 if !a.updateAllowed() { 283 return false 284 } 285 if a.numRcvdWithCurrentKey >= a.keyUpdateInterval { 286 a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1) 287 return true 288 } 289 if a.numSentWithCurrentKey >= a.keyUpdateInterval { 290 a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1) 291 return true 292 } 293 return false 294 } 295 296 func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit { 297 if a.shouldInitiateKeyUpdate() { 298 a.rollKeys() 299 a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase) 300 if a.tracer != nil { 301 a.tracer.UpdatedKey(a.keyPhase, false) 302 } 303 } 304 return a.keyPhase.Bit() 305 } 306 307 func (a *updatableAEAD) Overhead() int { 308 return a.aeadOverhead 309 } 310 311 func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 312 a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes) 313 } 314 315 func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) { 316 a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes) 317 } 318 319 func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber { 320 return a.firstPacketNumber 321 }