github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/handshake/updatable_aead.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/cipher"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/crypto_turnoff"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/utils"
    15  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    16  )
    17  
    18  // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
    19  // It's a package-level variable to allow modifying it for testing purposes.
    20  var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
    21  
    22  // FirstKeyUpdateInterval is the maximum number of packets we send or receive before initiating the first key update.
    23  // It's a package-level variable to allow modifying it for testing purposes.
    24  var FirstKeyUpdateInterval uint64 = 100
    25  
    26  type updatableAEAD struct {
    27  	suite *cipherSuite
    28  
    29  	keyPhase           protocol.KeyPhase
    30  	largestAcked       protocol.PacketNumber
    31  	firstPacketNumber  protocol.PacketNumber
    32  	handshakeConfirmed bool
    33  
    34  	invalidPacketLimit uint64
    35  	invalidPacketCount uint64
    36  
    37  	// Time when the keys should be dropped. Keys are dropped on the next call to Open().
    38  	prevRcvAEADExpiry time.Time
    39  	prevRcvAEAD       cipher.AEAD
    40  
    41  	firstRcvdWithCurrentKey protocol.PacketNumber
    42  	firstSentWithCurrentKey protocol.PacketNumber
    43  	highestRcvdPN           protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
    44  	numRcvdWithCurrentKey   uint64
    45  	numSentWithCurrentKey   uint64
    46  	rcvAEAD                 cipher.AEAD
    47  	sendAEAD                cipher.AEAD
    48  	// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
    49  	aeadOverhead int
    50  
    51  	nextRcvAEAD           cipher.AEAD
    52  	nextSendAEAD          cipher.AEAD
    53  	nextRcvTrafficSecret  []byte
    54  	nextSendTrafficSecret []byte
    55  
    56  	headerDecrypter headerProtector
    57  	headerEncrypter headerProtector
    58  
    59  	rttStats *utils.RTTStats
    60  
    61  	tracer  *logging.ConnectionTracer
    62  	logger  utils.Logger
    63  	version protocol.Version
    64  
    65  	// use a single slice to avoid allocations
    66  	nonceBuf []byte
    67  }
    68  
    69  var (
    70  	_ ShortHeaderOpener = &updatableAEAD{}
    71  	_ ShortHeaderSealer = &updatableAEAD{}
    72  )
    73  
    74  func newUpdatableAEAD(rttStats *utils.RTTStats, tracer *logging.ConnectionTracer, logger utils.Logger, version protocol.Version) *updatableAEAD {
    75  	return &updatableAEAD{
    76  		firstPacketNumber:       protocol.InvalidPacketNumber,
    77  		largestAcked:            protocol.InvalidPacketNumber,
    78  		firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
    79  		firstSentWithCurrentKey: protocol.InvalidPacketNumber,
    80  		rttStats:                rttStats,
    81  		tracer:                  tracer,
    82  		logger:                  logger,
    83  		version:                 version,
    84  	}
    85  }
    86  
    87  func (a *updatableAEAD) rollKeys() {
    88  	if a.prevRcvAEAD != nil {
    89  		a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
    90  		if a.tracer != nil && a.tracer.DroppedKey != nil {
    91  			a.tracer.DroppedKey(a.keyPhase - 1)
    92  		}
    93  		a.prevRcvAEADExpiry = time.Time{}
    94  	}
    95  
    96  	a.keyPhase++
    97  	a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
    98  	a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
    99  	a.numRcvdWithCurrentKey = 0
   100  	a.numSentWithCurrentKey = 0
   101  	a.prevRcvAEAD = a.rcvAEAD
   102  	a.rcvAEAD = a.nextRcvAEAD
   103  	a.sendAEAD = a.nextSendAEAD
   104  
   105  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
   106  	a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
   107  	a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret, a.version)
   108  	a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret, a.version)
   109  }
   110  
   111  func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
   112  	d := 3 * a.rttStats.PTO(true)
   113  	a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
   114  	a.prevRcvAEADExpiry = now.Add(d)
   115  }
   116  
   117  func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
   118  	return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
   119  }
   120  
   121  // SetReadKey sets the read key.
   122  // For the client, this function is called before SetWriteKey.
   123  // For the server, this function is called after SetWriteKey.
   124  func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) {
   125  	a.rcvAEAD = createAEAD(suite, trafficSecret, a.version)
   126  	a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
   127  	if a.suite == nil {
   128  		a.setAEADParameters(a.rcvAEAD, suite)
   129  	}
   130  
   131  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   132  	a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret, a.version)
   133  }
   134  
   135  // SetWriteKey sets the write key.
   136  // For the client, this function is called after SetReadKey.
   137  // For the server, this function is called before SetReadKey.
   138  func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) {
   139  	a.sendAEAD = createAEAD(suite, trafficSecret, a.version)
   140  	a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version)
   141  	if a.suite == nil {
   142  		a.setAEADParameters(a.sendAEAD, suite)
   143  	}
   144  
   145  	a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   146  	a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version)
   147  }
   148  
   149  func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) {
   150  	a.nonceBuf = make([]byte, aead.NonceSize())
   151  	a.aeadOverhead = aead.Overhead()
   152  	a.suite = suite
   153  	switch suite.ID {
   154  	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
   155  		a.invalidPacketLimit = protocol.InvalidPacketLimitAES
   156  	case tls.TLS_CHACHA20_POLY1305_SHA256:
   157  		a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
   158  
   159  	// NO_CRYPTO_TAG
   160  	// based on https://pkg.go.dev/crypto/tls#pkg-constants 0x0000 is not used for any other cipher suite
   161  	case 0x0000:
   162  		print("no AEAD parameters need to be set for 0x0000\n")
   163  
   164  	default:
   165  		panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
   166  	}
   167  }
   168  
   169  func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
   170  	return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
   171  }
   172  
   173  func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   174  
   175  	// NO_CRYPTO_TAG
   176  	if crypto_turnoff.CRYPTO_TURNED_OFF {
   177  		return src, nil
   178  	}
   179  
   180  	dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
   181  	if err == ErrDecryptionFailed {
   182  		a.invalidPacketCount++
   183  		if a.invalidPacketCount >= a.invalidPacketLimit {
   184  			return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
   185  		}
   186  	}
   187  	if err == nil {
   188  		a.highestRcvdPN = max(a.highestRcvdPN, pn)
   189  	}
   190  	return dec, err
   191  }
   192  
   193  func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   194  	if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
   195  		a.prevRcvAEAD = nil
   196  		a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
   197  		a.prevRcvAEADExpiry = time.Time{}
   198  		if a.tracer != nil && a.tracer.DroppedKey != nil {
   199  			a.tracer.DroppedKey(a.keyPhase - 1)
   200  		}
   201  	}
   202  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   203  	if kp != a.keyPhase.Bit() {
   204  		if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
   205  			if a.prevRcvAEAD == nil {
   206  				return nil, ErrKeysDropped
   207  			}
   208  			// we updated the key, but the peer hasn't updated yet
   209  			dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   210  			if err != nil {
   211  				err = ErrDecryptionFailed
   212  			}
   213  			return dec, err
   214  		}
   215  		// try opening the packet with the next key phase
   216  		dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   217  		if err != nil {
   218  			return nil, ErrDecryptionFailed
   219  		}
   220  		// Opening succeeded. Check if the peer was allowed to update.
   221  		if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   222  			return nil, &qerr.TransportError{
   223  				ErrorCode:    qerr.KeyUpdateError,
   224  				ErrorMessage: "keys updated too quickly",
   225  			}
   226  		}
   227  		a.rollKeys()
   228  		a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
   229  		// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
   230  		// Start a timer to drop the previous key generation.
   231  		a.startKeyDropTimer(rcvTime)
   232  		if a.tracer != nil && a.tracer.UpdatedKey != nil {
   233  			a.tracer.UpdatedKey(a.keyPhase, true)
   234  		}
   235  		a.firstRcvdWithCurrentKey = pn
   236  		return dec, err
   237  	}
   238  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   239  	// It uses the nonce provided here and XOR it with the IV.
   240  	dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
   241  	if err != nil {
   242  		return dec, ErrDecryptionFailed
   243  	}
   244  	a.numRcvdWithCurrentKey++
   245  	if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
   246  		// We initiated the key updated, and now we received the first packet protected with the new key phase.
   247  		// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
   248  		if a.keyPhase > 0 {
   249  			a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
   250  			a.startKeyDropTimer(rcvTime)
   251  		}
   252  		a.firstRcvdWithCurrentKey = pn
   253  	}
   254  	return dec, err
   255  }
   256  
   257  func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
   258  	if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   259  		a.firstSentWithCurrentKey = pn
   260  	}
   261  	if a.firstPacketNumber == protocol.InvalidPacketNumber {
   262  		a.firstPacketNumber = pn
   263  	}
   264  	a.numSentWithCurrentKey++
   265  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   266  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   267  	// It uses the nonce provided here and XOR it with the IV.
   268  	return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
   269  }
   270  
   271  func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
   272  	if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   273  		pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
   274  		return &qerr.TransportError{
   275  			ErrorCode:    qerr.KeyUpdateError,
   276  			ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
   277  		}
   278  	}
   279  	a.largestAcked = pn
   280  	return nil
   281  }
   282  
   283  func (a *updatableAEAD) SetHandshakeConfirmed() {
   284  	a.handshakeConfirmed = true
   285  }
   286  
   287  func (a *updatableAEAD) updateAllowed() bool {
   288  	if !a.handshakeConfirmed {
   289  		return false
   290  	}
   291  	// the first key update is allowed as soon as the handshake is confirmed
   292  	return a.keyPhase == 0 ||
   293  		// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
   294  		(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   295  			a.largestAcked != protocol.InvalidPacketNumber &&
   296  			a.largestAcked >= a.firstSentWithCurrentKey)
   297  }
   298  
   299  func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
   300  	if !a.updateAllowed() {
   301  		return false
   302  	}
   303  	// Initiate the first key update shortly after the handshake, in order to exercise the key update mechanism.
   304  	if a.keyPhase == 0 {
   305  		if a.numRcvdWithCurrentKey >= FirstKeyUpdateInterval || a.numSentWithCurrentKey >= FirstKeyUpdateInterval {
   306  			return true
   307  		}
   308  	}
   309  	if a.numRcvdWithCurrentKey >= KeyUpdateInterval {
   310  		a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
   311  		return true
   312  	}
   313  	if a.numSentWithCurrentKey >= KeyUpdateInterval {
   314  		a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
   315  		return true
   316  	}
   317  	return false
   318  }
   319  
   320  func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
   321  	if a.shouldInitiateKeyUpdate() {
   322  		a.rollKeys()
   323  		a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
   324  		if a.tracer != nil && a.tracer.UpdatedKey != nil {
   325  			a.tracer.UpdatedKey(a.keyPhase, false)
   326  		}
   327  	}
   328  	return a.keyPhase.Bit()
   329  }
   330  
   331  func (a *updatableAEAD) Overhead() int {
   332  	return a.aeadOverhead
   333  }
   334  
   335  func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   336  	a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
   337  }
   338  
   339  func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   340  	a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
   341  }
   342  
   343  func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
   344  	return a.firstPacketNumber
   345  }