github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/internal/handshake/updatable_aead.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/cipher"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/metacubex/quic-go/internal/protocol"
    12  	"github.com/metacubex/quic-go/internal/qerr"
    13  	"github.com/metacubex/quic-go/internal/utils"
    14  	"github.com/metacubex/quic-go/logging"
    15  )
    16  
    17  // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
    18  // It's a package-level variable to allow modifying it for testing purposes.
    19  var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
    20  
    21  // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
    22  // It's a package-level variable to allow modifying it for testing purposes.
    23  var FirstKeyUpdateInterval uint64 = 100
    24  
    25  type updatableAEAD struct {
    26  	suite *cipherSuite
    27  
    28  	keyPhase           protocol.KeyPhase
    29  	largestAcked       protocol.PacketNumber
    30  	firstPacketNumber  protocol.PacketNumber
    31  	handshakeConfirmed bool
    32  
    33  	invalidPacketLimit uint64
    34  	invalidPacketCount uint64
    35  
    36  	// Time when the keys should be dropped. Keys are dropped on the next call to Open().
    37  	prevRcvAEADExpiry time.Time
    38  	prevRcvAEAD       cipher.AEAD
    39  
    40  	firstRcvdWithCurrentKey protocol.PacketNumber
    41  	firstSentWithCurrentKey protocol.PacketNumber
    42  	highestRcvdPN           protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
    43  	numRcvdWithCurrentKey   uint64
    44  	numSentWithCurrentKey   uint64
    45  	rcvAEAD                 cipher.AEAD
    46  	sendAEAD                cipher.AEAD
    47  	// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
    48  	aeadOverhead int
    49  
    50  	nextRcvAEAD           cipher.AEAD
    51  	nextSendAEAD          cipher.AEAD
    52  	nextRcvTrafficSecret  []byte
    53  	nextSendTrafficSecret []byte
    54  
    55  	headerDecrypter headerProtector
    56  	headerEncrypter headerProtector
    57  
    58  	rttStats *utils.RTTStats
    59  
    60  	tracer  *logging.ConnectionTracer
    61  	logger  utils.Logger
    62  	version protocol.Version
    63  
    64  	// use a single slice to avoid allocations
    65  	nonceBuf []byte
    66  }
    67  
    68  var (
    69  	_ ShortHeaderOpener = &updatableAEAD{}
    70  	_ ShortHeaderSealer = &updatableAEAD{}
    71  )
    72  
    73  func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
    74  	return &updatableAEAD{
    75  		firstPacketNumber:       protocol.InvalidPacketNumber,
    76  		largestAcked:            protocol.InvalidPacketNumber,
    77  		firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
    78  		firstSentWithCurrentKey: protocol.InvalidPacketNumber,
    79  		rttStats:                rttStats,
    80  		tracer:                  tracer,
    81  		logger:                  logger,
    82  		version:                 version,
    83  	}
    84  }
    85  
    86  func (a *updatableAEAD) rollKeys() {
    87  	if a.prevRcvAEAD != nil {
    88  		a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
    89  		if a.tracer != nil && a.tracer.DroppedKey != nil {
    90  			a.tracer.DroppedKey(a.keyPhase - 1)
    91  		}
    92  		a.prevRcvAEADExpiry = time.Time{}
    93  	}
    94  
    95  	a.keyPhase++
    96  	a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
    97  	a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
    98  	a.numRcvdWithCurrentKey = 0
    99  	a.numSentWithCurrentKey = 0
   100  	a.prevRcvAEAD = a.rcvAEAD
   101  	a.rcvAEAD = a.nextRcvAEAD
   102  	a.sendAEAD = a.nextSendAEAD
   103  
   104  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
   105  	a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
   106  	a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
   107  	a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
   108  }
   109  
   110  func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
   111  	d := 3 * a.rttStats.PTO(true)
   112  	a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
   113  	a.prevRcvAEADExpiry = now.Add(d)
   114  }
   115  
   116  func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
   117  	return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
   118  }
   119  
   120  // SetReadKey sets the read key.
   121  // For the client, this function is called before SetWriteKey.
   122  // For the server, this function is called after SetWriteKey.
   123  func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
   124  	a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
   125  	a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
   126  	if a.suite == nil {
   127  		a.setAEADParameters(a.rcvAEAD, suite)
   128  	}
   129  
   130  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   131  	a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
   132  }
   133  
   134  // SetWriteKey sets the write key.
   135  // For the client, this function is called after SetReadKey.
   136  // For the server, this function is called before SetReadKey.
   137  func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
   138  	a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
   139  	a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
   140  	if a.suite == nil {
   141  		a.setAEADParameters(a.sendAEAD, suite)
   142  	}
   143  
   144  	a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   145  	a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
   146  }
   147  
   148  func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
   149  	a.nonceBuf = make([]byte, aead.NonceSize())
   150  	a.aeadOverhead = aead.Overhead()
   151  	a.suite = suite
   152  	switch suite.ID {
   153  	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
   154  		a.invalidPacketLimit = protocol.InvalidPacketLimitAES
   155  	case tls.TLS_CHACHA20_POLY1305_SHA256:
   156  		a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
   157  	default:
   158  		panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
   159  	}
   160  }
   161  
   162  func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
   163  	return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
   164  }
   165  
   166  func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   167  	dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
   168  	if err == ErrDecryptionFailed {
   169  		a.invalidPacketCount++
   170  		if a.invalidPacketCount >= a.invalidPacketLimit {
   171  			return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
   172  		}
   173  	}
   174  	if err == nil {
   175  		a.highestRcvdPN = utils.Max(a.highestRcvdPN, pn)
   176  	}
   177  	return dec, err
   178  }
   179  
   180  func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   181  	if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
   182  		a.prevRcvAEAD = nil
   183  		a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
   184  		a.prevRcvAEADExpiry = time.Time{}
   185  		if a.tracer != nil && a.tracer.DroppedKey != nil {
   186  			a.tracer.DroppedKey(a.keyPhase - 1)
   187  		}
   188  	}
   189  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   190  	if kp != a.keyPhase.Bit() {
   191  		if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
   192  			if a.prevRcvAEAD == nil {
   193  				return nil, ErrKeysDropped
   194  			}
   195  			// we updated the key, but the peer hasn't updated yet
   196  			dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   197  			if err != nil {
   198  				err = ErrDecryptionFailed
   199  			}
   200  			return dec, err
   201  		}
   202  		// try opening the packet with the next key phase
   203  		dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   204  		if err != nil {
   205  			return nil, ErrDecryptionFailed
   206  		}
   207  		// Opening succeeded. Check if the peer was allowed to update.
   208  		if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   209  			return nil, &qerr.TransportError{
   210  				ErrorCode:    qerr.KeyUpdateError,
   211  				ErrorMessage: "keys updated too quickly",
   212  			}
   213  		}
   214  		a.rollKeys()
   215  		a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
   216  		// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
   217  		// Start a timer to drop the previous key generation.
   218  		a.startKeyDropTimer(rcvTime)
   219  		if a.tracer != nil && a.tracer.UpdatedKey != nil {
   220  			a.tracer.UpdatedKey(a.keyPhase, true)
   221  		}
   222  		a.firstRcvdWithCurrentKey = pn
   223  		return dec, err
   224  	}
   225  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   226  	// It uses the nonce provided here and XOR it with the IV.
   227  	dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
   228  	if err != nil {
   229  		return dec, ErrDecryptionFailed
   230  	}
   231  	a.numRcvdWithCurrentKey++
   232  	if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
   233  		// We initiated the key updated, and now we received the first packet protected with the new key phase.
   234  		// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
   235  		if a.keyPhase > 0 {
   236  			a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
   237  			a.startKeyDropTimer(rcvTime)
   238  		}
   239  		a.firstRcvdWithCurrentKey = pn
   240  	}
   241  	return dec, err
   242  }
   243  
   244  func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
   245  	if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   246  		a.firstSentWithCurrentKey = pn
   247  	}
   248  	if a.firstPacketNumber == protocol.InvalidPacketNumber {
   249  		a.firstPacketNumber = pn
   250  	}
   251  	a.numSentWithCurrentKey++
   252  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   253  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   254  	// It uses the nonce provided here and XOR it with the IV.
   255  	return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
   256  }
   257  
   258  func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
   259  	if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   260  		pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
   261  		return &qerr.TransportError{
   262  			ErrorCode:    qerr.KeyUpdateError,
   263  			ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
   264  		}
   265  	}
   266  	a.largestAcked = pn
   267  	return nil
   268  }
   269  
   270  func (a *updatableAEAD) SetHandshakeConfirmed() {
   271  	a.handshakeConfirmed = true
   272  }
   273  
   274  func (a *updatableAEAD) updateAllowed() bool {
   275  	if !a.handshakeConfirmed {
   276  		return false
   277  	}
   278  	// the first key update is allowed as soon as the handshake is confirmed
   279  	return a.keyPhase == 0 ||
   280  		// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
   281  		(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   282  			a.largestAcked != protocol.InvalidPacketNumber &&
   283  			a.largestAcked >= a.firstSentWithCurrentKey)
   284  }
   285  
   286  func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
   287  	if !a.updateAllowed() {
   288  		return false
   289  	}
   290  	// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
   291  	if a.keyPhase == 0 {
   292  		if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
   293  			return true
   294  		}
   295  	}
   296  	if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
   297  		a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
   298  		return true
   299  	}
   300  	if a.numSentWithCurrentKey >= KeyUpdateInterval {
   301  		a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
   302  		return true
   303  	}
   304  	return false
   305  }
   306  
   307  func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
   308  	if a.shouldInitiateKeyUpdate() {
   309  		a.rollKeys()
   310  		a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
   311  		if a.tracer != nil && a.tracer.UpdatedKey != nil {
   312  			a.tracer.UpdatedKey(a.keyPhase, false)
   313  		}
   314  	}
   315  	return a.keyPhase.Bit()
   316  }
   317  
   318  func (a *updatableAEAD) Overhead() int {
   319  	return a.aeadOverhead
   320  }
   321  
   322  func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   323  	a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
   324  }
   325  
   326  func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   327  	a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
   328  }
   329  
   330  func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
   331  	return a.firstPacketNumber
   332  }