github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/internal/handshake/updatable_aead.go (about)

     1  package handshake
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/cipher"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    12  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr"
    13  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qtls"
    14  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    15  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging"
    16  )
    17  
    18  // KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
    19  // It's a package-level variable to allow modifying it for testing purposes.
    20  var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
    21  
    22  type updatableAEAD struct {
    23  	suite *qtls.CipherSuiteTLS13
    24  
    25  	keyPhase           protocol.KeyPhase
    26  	largestAcked       protocol.PacketNumber
    27  	firstPacketNumber  protocol.PacketNumber
    28  	handshakeConfirmed bool
    29  
    30  	keyUpdateInterval  uint64
    31  	invalidPacketLimit uint64
    32  	invalidPacketCount uint64
    33  
    34  	// Time when the keys should be dropped. Keys are dropped on the next call to Open().
    35  	prevRcvAEADExpiry time.Time
    36  	prevRcvAEAD       cipher.AEAD
    37  
    38  	firstRcvdWithCurrentKey protocol.PacketNumber
    39  	firstSentWithCurrentKey protocol.PacketNumber
    40  	highestRcvdPN           protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
    41  	numRcvdWithCurrentKey   uint64
    42  	numSentWithCurrentKey   uint64
    43  	rcvAEAD                 cipher.AEAD
    44  	sendAEAD                cipher.AEAD
    45  	// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
    46  	aeadOverhead int
    47  
    48  	nextRcvAEAD           cipher.AEAD
    49  	nextSendAEAD          cipher.AEAD
    50  	nextRcvTrafficSecret  []byte
    51  	nextSendTrafficSecret []byte
    52  
    53  	headerDecrypter headerProtector
    54  	headerEncrypter headerProtector
    55  
    56  	rttStats *utils.RTTStats
    57  
    58  	tracer logging.ConnectionTracer
    59  	logger utils.Logger
    60  
    61  	// use a single slice to avoid allocations
    62  	nonceBuf []byte
    63  }
    64  
    65  var (
    66  	_ ShortHeaderOpener = &updatableAEAD{}
    67  	_ ShortHeaderSealer = &updatableAEAD{}
    68  )
    69  
    70  func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD {
    71  	return &updatableAEAD{
    72  		firstPacketNumber:       protocol.InvalidPacketNumber,
    73  		largestAcked:            protocol.InvalidPacketNumber,
    74  		firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
    75  		firstSentWithCurrentKey: protocol.InvalidPacketNumber,
    76  		keyUpdateInterval:       KeyUpdateInterval,
    77  		rttStats:                rttStats,
    78  		tracer:                  tracer,
    79  		logger:                  logger,
    80  	}
    81  }
    82  
    83  func (a *updatableAEAD) rollKeys() {
    84  	if a.prevRcvAEAD != nil {
    85  		a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
    86  		if a.tracer != nil {
    87  			a.tracer.DroppedKey(a.keyPhase - 1)
    88  		}
    89  		a.prevRcvAEADExpiry = time.Time{}
    90  	}
    91  
    92  	a.keyPhase++
    93  	a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
    94  	a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
    95  	a.numRcvdWithCurrentKey = 0
    96  	a.numSentWithCurrentKey = 0
    97  	a.prevRcvAEAD = a.rcvAEAD
    98  	a.rcvAEAD = a.nextRcvAEAD
    99  	a.sendAEAD = a.nextSendAEAD
   100  
   101  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
   102  	a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
   103  	a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret)
   104  	a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
   105  }
   106  
   107  func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
   108  	d := 3 * a.rttStats.PTO(true)
   109  	a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
   110  	a.prevRcvAEADExpiry = now.Add(d)
   111  }
   112  
   113  func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
   114  	return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
   115  }
   116  
   117  // For the client, this function is called before SetWriteKey.
   118  // For the server, this function is called after SetWriteKey.
   119  func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
   120  	a.rcvAEAD = createAEAD(suite, trafficSecret)
   121  	a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false)
   122  	if a.suite == nil {
   123  		a.setAEADParameters(a.rcvAEAD, suite)
   124  	}
   125  
   126  	a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   127  	a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret)
   128  }
   129  
   130  // For the client, this function is called after SetReadKey.
   131  // For the server, this function is called before SetWriteKey.
   132  func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
   133  	a.sendAEAD = createAEAD(suite, trafficSecret)
   134  	a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false)
   135  	if a.suite == nil {
   136  		a.setAEADParameters(a.sendAEAD, suite)
   137  	}
   138  
   139  	a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
   140  	a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
   141  }
   142  
   143  func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
   144  	a.nonceBuf = make([]byte, aead.NonceSize())
   145  	a.aeadOverhead = aead.Overhead()
   146  	a.suite = suite
   147  	switch suite.ID {
   148  	case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
   149  		a.invalidPacketLimit = protocol.InvalidPacketLimitAES
   150  	case tls.TLS_CHACHA20_POLY1305_SHA256:
   151  		a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
   152  	default:
   153  		panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
   154  	}
   155  }
   156  
   157  func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
   158  	return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
   159  }
   160  
   161  func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   162  	dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
   163  	if err == ErrDecryptionFailed {
   164  		a.invalidPacketCount++
   165  		if a.invalidPacketCount >= a.invalidPacketLimit {
   166  			return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
   167  		}
   168  	}
   169  	if err == nil {
   170  		a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
   171  	}
   172  	return dec, err
   173  }
   174  
   175  func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
   176  	if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
   177  		a.prevRcvAEAD = nil
   178  		a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
   179  		a.prevRcvAEADExpiry = time.Time{}
   180  		if a.tracer != nil {
   181  			a.tracer.DroppedKey(a.keyPhase - 1)
   182  		}
   183  	}
   184  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   185  	if kp != a.keyPhase.Bit() {
   186  		if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
   187  			if a.prevRcvAEAD == nil {
   188  				return nil, ErrKeysDropped
   189  			}
   190  			// we updated the key, but the peer hasn't updated yet
   191  			dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   192  			if err != nil {
   193  				err = ErrDecryptionFailed
   194  			}
   195  			return dec, err
   196  		}
   197  		// try opening the packet with the next key phase
   198  		dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
   199  		if err != nil {
   200  			return nil, ErrDecryptionFailed
   201  		}
   202  		// Opening succeeded. Check if the peer was allowed to update.
   203  		if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   204  			return nil, &qerr.TransportError{
   205  				ErrorCode:    qerr.KeyUpdateError,
   206  				ErrorMessage: "keys updated too quickly",
   207  			}
   208  		}
   209  		a.rollKeys()
   210  		a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
   211  		// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
   212  		// Start a timer to drop the previous key generation.
   213  		a.startKeyDropTimer(rcvTime)
   214  		if a.tracer != nil {
   215  			a.tracer.UpdatedKey(a.keyPhase, true)
   216  		}
   217  		a.firstRcvdWithCurrentKey = pn
   218  		return dec, err
   219  	}
   220  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   221  	// It uses the nonce provided here and XOR it with the IV.
   222  	dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
   223  	if err != nil {
   224  		return dec, ErrDecryptionFailed
   225  	}
   226  	a.numRcvdWithCurrentKey++
   227  	if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
   228  		// We initiated the key updated, and now we received the first packet protected with the new key phase.
   229  		// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
   230  		if a.keyPhase > 0 {
   231  			a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
   232  			a.startKeyDropTimer(rcvTime)
   233  		}
   234  		a.firstRcvdWithCurrentKey = pn
   235  	}
   236  	return dec, err
   237  }
   238  
   239  func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
   240  	if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
   241  		a.firstSentWithCurrentKey = pn
   242  	}
   243  	if a.firstPacketNumber == protocol.InvalidPacketNumber {
   244  		a.firstPacketNumber = pn
   245  	}
   246  	a.numSentWithCurrentKey++
   247  	binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
   248  	// The AEAD we're using here will be the qtls.aeadAESGCM13.
   249  	// It uses the nonce provided here and XOR it with the IV.
   250  	return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
   251  }
   252  
   253  func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
   254  	if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   255  		pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
   256  		return &qerr.TransportError{
   257  			ErrorCode:    qerr.KeyUpdateError,
   258  			ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
   259  		}
   260  	}
   261  	a.largestAcked = pn
   262  	return nil
   263  }
   264  
   265  func (a *updatableAEAD) SetHandshakeConfirmed() {
   266  	a.handshakeConfirmed = true
   267  }
   268  
   269  func (a *updatableAEAD) updateAllowed() bool {
   270  	if !a.handshakeConfirmed {
   271  		return false
   272  	}
   273  	// the first key update is allowed as soon as the handshake is confirmed
   274  	return a.keyPhase == 0 ||
   275  		// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
   276  		(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
   277  			a.largestAcked != protocol.InvalidPacketNumber &&
   278  			a.largestAcked >= a.firstSentWithCurrentKey)
   279  }
   280  
   281  func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
   282  	if !a.updateAllowed() {
   283  		return false
   284  	}
   285  	if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
   286  		a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
   287  		return true
   288  	}
   289  	if a.numSentWithCurrentKey >= a.keyUpdateInterval {
   290  		a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
   291  		return true
   292  	}
   293  	return false
   294  }
   295  
   296  func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
   297  	if a.shouldInitiateKeyUpdate() {
   298  		a.rollKeys()
   299  		a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
   300  		if a.tracer != nil {
   301  			a.tracer.UpdatedKey(a.keyPhase, false)
   302  		}
   303  	}
   304  	return a.keyPhase.Bit()
   305  }
   306  
   307  func (a *updatableAEAD) Overhead() int {
   308  	return a.aeadOverhead
   309  }
   310  
   311  func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   312  	a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
   313  }
   314  
   315  func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
   316  	a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
   317  }
   318  
   319  func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
   320  	return a.firstPacketNumber
   321  }