github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/internal/handshake/crypto_setup.go (about)

     1  package handshake
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/ooni/psiphon/tunnel-core/psiphon/common/prng"
    14  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    15  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr"
    16  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qtls"
    17  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    18  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire"
    19  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/logging"
    20  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/quicvarint"
    21  )
    22  
    23  // TLS unexpected_message alert
    24  const alertUnexpectedMessage uint8 = 10
    25  
    26  type messageType uint8
    27  
    28  // TLS handshake message types.
    29  const (
    30  	typeClientHello         messageType = 1
    31  	typeServerHello         messageType = 2
    32  	typeNewSessionTicket    messageType = 4
    33  	typeEncryptedExtensions messageType = 8
    34  	typeCertificate         messageType = 11
    35  	typeCertificateRequest  messageType = 13
    36  	typeCertificateVerify   messageType = 15
    37  	typeFinished            messageType = 20
    38  )
    39  
    40  func (m messageType) String() string {
    41  	switch m {
    42  	case typeClientHello:
    43  		return "ClientHello"
    44  	case typeServerHello:
    45  		return "ServerHello"
    46  	case typeNewSessionTicket:
    47  		return "NewSessionTicket"
    48  	case typeEncryptedExtensions:
    49  		return "EncryptedExtensions"
    50  	case typeCertificate:
    51  		return "Certificate"
    52  	case typeCertificateRequest:
    53  		return "CertificateRequest"
    54  	case typeCertificateVerify:
    55  		return "CertificateVerify"
    56  	case typeFinished:
    57  		return "Finished"
    58  	default:
    59  		return fmt.Sprintf("unknown message type: %d", m)
    60  	}
    61  }
    62  
    63  const clientSessionStateRevision = 3
    64  
    65  type conn struct {
    66  	localAddr, remoteAddr net.Addr
    67  	version               protocol.VersionNumber
    68  }
    69  
    70  var _ ConnWithVersion = &conn{}
    71  
    72  func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
    73  	return &conn{
    74  		localAddr:  local,
    75  		remoteAddr: remote,
    76  		version:    version,
    77  	}
    78  }
    79  
    80  var _ net.Conn = &conn{}
    81  
    82  func (c *conn) Read([]byte) (int, error)               { return 0, nil }
    83  func (c *conn) Write([]byte) (int, error)              { return 0, nil }
    84  func (c *conn) Close() error                           { return nil }
    85  func (c *conn) RemoteAddr() net.Addr                   { return c.remoteAddr }
    86  func (c *conn) LocalAddr() net.Addr                    { return c.localAddr }
    87  func (c *conn) SetReadDeadline(time.Time) error        { return nil }
    88  func (c *conn) SetWriteDeadline(time.Time) error       { return nil }
    89  func (c *conn) SetDeadline(time.Time) error            { return nil }
    90  func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
    91  
    92  type cryptoSetup struct {
    93  	tlsConf   *tls.Config
    94  	extraConf *qtls.ExtraConfig
    95  	conn      *qtls.Conn
    96  
    97  	version protocol.VersionNumber
    98  
    99  	messageChan               chan []byte
   100  	isReadingHandshakeMessage chan struct{}
   101  	readFirstHandshakeMessage bool
   102  
   103  	ourParams  *wire.TransportParameters
   104  	peerParams *wire.TransportParameters
   105  	paramsChan <-chan []byte
   106  
   107  	runner handshakeRunner
   108  
   109  	alertChan chan uint8
   110  	// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
   111  	handshakeDone chan struct{}
   112  	// is closed when Close() is called
   113  	closeChan chan struct{}
   114  
   115  	zeroRTTParameters      *wire.TransportParameters
   116  	clientHelloWritten     bool
   117  	clientHelloWrittenChan chan *wire.TransportParameters
   118  
   119  	rttStats *utils.RTTStats
   120  
   121  	tracer logging.ConnectionTracer
   122  	logger utils.Logger
   123  
   124  	perspective protocol.Perspective
   125  
   126  	mutex sync.Mutex // protects all members below
   127  
   128  	handshakeCompleteTime time.Time
   129  
   130  	readEncLevel  protocol.EncryptionLevel
   131  	writeEncLevel protocol.EncryptionLevel
   132  
   133  	zeroRTTOpener LongHeaderOpener // only set for the server
   134  	zeroRTTSealer LongHeaderSealer // only set for the client
   135  
   136  	initialStream io.Writer
   137  	initialOpener LongHeaderOpener
   138  	initialSealer LongHeaderSealer
   139  
   140  	handshakeStream io.Writer
   141  	handshakeOpener LongHeaderOpener
   142  	handshakeSealer LongHeaderSealer
   143  
   144  	aead          *updatableAEAD
   145  	has1RTTSealer bool
   146  	has1RTTOpener bool
   147  }
   148  
   149  var (
   150  	_ qtls.RecordLayer = &cryptoSetup{}
   151  	_ CryptoSetup      = &cryptoSetup{}
   152  )
   153  
   154  // NewCryptoSetupClient creates a new crypto setup for the client
   155  func NewCryptoSetupClient(
   156  	initialStream io.Writer,
   157  	handshakeStream io.Writer,
   158  	connID protocol.ConnectionID,
   159  	localAddr net.Addr,
   160  	remoteAddr net.Addr,
   161  	tp *wire.TransportParameters,
   162  	runner handshakeRunner,
   163  	tlsConf *tls.Config,
   164  	clientHelloSeed *prng.Seed,
   165  	getClientHelloRandom func() ([]byte, error),
   166  	enable0RTT bool,
   167  	rttStats *utils.RTTStats,
   168  	tracer logging.ConnectionTracer,
   169  	logger utils.Logger,
   170  	version protocol.VersionNumber,
   171  ) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
   172  
   173  	// [Psiphon]
   174  	// Instantiate the PRNG here as it's used in sequence in two places:
   175  	// TransportParameters.Marshal, for the quic_transport_parameters extension;
   176  	// and then in qtls.clientHelloMsg.marshal.
   177  	var clientHelloPRNG *prng.PRNG
   178  	if clientHelloSeed != nil {
   179  		clientHelloPRNG = prng.NewPRNGWithSeed(clientHelloSeed)
   180  	}
   181  
   182  	cs, clientHelloWritten := newCryptoSetup(
   183  		initialStream,
   184  		handshakeStream,
   185  		connID,
   186  		tp,
   187  		runner,
   188  		tlsConf,
   189  		enable0RTT,
   190  		rttStats,
   191  		tracer,
   192  		logger,
   193  		protocol.PerspectiveClient,
   194  		version,
   195  
   196  		// [Psiphon]
   197  		clientHelloPRNG,
   198  	)
   199  
   200  	// [Psiphon]
   201  	cs.extraConf.ClientHelloPRNG = clientHelloPRNG
   202  	cs.extraConf.GetClientHelloRandom = getClientHelloRandom
   203  
   204  	cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
   205  	return cs, clientHelloWritten
   206  }
   207  
   208  // NewCryptoSetupServer creates a new crypto setup for the server
   209  func NewCryptoSetupServer(
   210  	initialStream io.Writer,
   211  	handshakeStream io.Writer,
   212  	connID protocol.ConnectionID,
   213  	localAddr net.Addr,
   214  	remoteAddr net.Addr,
   215  	tp *wire.TransportParameters,
   216  	runner handshakeRunner,
   217  	tlsConf *tls.Config,
   218  	enable0RTT bool,
   219  	rttStats *utils.RTTStats,
   220  	tracer logging.ConnectionTracer,
   221  	logger utils.Logger,
   222  	version protocol.VersionNumber,
   223  ) CryptoSetup {
   224  	cs, _ := newCryptoSetup(
   225  		initialStream,
   226  		handshakeStream,
   227  		connID,
   228  		tp,
   229  		runner,
   230  		tlsConf,
   231  		enable0RTT,
   232  		rttStats,
   233  		tracer,
   234  		logger,
   235  		protocol.PerspectiveServer,
   236  		version,
   237  
   238  		// [Psiphon]
   239  		nil,
   240  	)
   241  	cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
   242  	return cs
   243  }
   244  
   245  func newCryptoSetup(
   246  	initialStream io.Writer,
   247  	handshakeStream io.Writer,
   248  	connID protocol.ConnectionID,
   249  	tp *wire.TransportParameters,
   250  	runner handshakeRunner,
   251  	tlsConf *tls.Config,
   252  	enable0RTT bool,
   253  	rttStats *utils.RTTStats,
   254  	tracer logging.ConnectionTracer,
   255  	logger utils.Logger,
   256  	perspective protocol.Perspective,
   257  	version protocol.VersionNumber,
   258  	clientHelloPRNG *prng.PRNG,
   259  ) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
   260  	initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
   261  	if tracer != nil {
   262  		tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
   263  		tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
   264  	}
   265  	extHandler := newExtensionHandler(tp.Marshal(perspective, clientHelloPRNG), perspective, version)
   266  	cs := &cryptoSetup{
   267  		tlsConf:                   tlsConf,
   268  		initialStream:             initialStream,
   269  		initialSealer:             initialSealer,
   270  		initialOpener:             initialOpener,
   271  		handshakeStream:           handshakeStream,
   272  		aead:                      newUpdatableAEAD(rttStats, tracer, logger),
   273  		readEncLevel:              protocol.EncryptionInitial,
   274  		writeEncLevel:             protocol.EncryptionInitial,
   275  		runner:                    runner,
   276  		ourParams:                 tp,
   277  		paramsChan:                extHandler.TransportParameters(),
   278  		rttStats:                  rttStats,
   279  		tracer:                    tracer,
   280  		logger:                    logger,
   281  		perspective:               perspective,
   282  		handshakeDone:             make(chan struct{}),
   283  		alertChan:                 make(chan uint8),
   284  		clientHelloWrittenChan:    make(chan *wire.TransportParameters, 1),
   285  		messageChan:               make(chan []byte, 100),
   286  		isReadingHandshakeMessage: make(chan struct{}),
   287  		closeChan:                 make(chan struct{}),
   288  		version:                   version,
   289  	}
   290  	var maxEarlyData uint32
   291  	if enable0RTT {
   292  		maxEarlyData = 0xffffffff
   293  	}
   294  	cs.extraConf = &qtls.ExtraConfig{
   295  		GetExtensions:              extHandler.GetExtensions,
   296  		ReceivedExtensions:         extHandler.ReceivedExtensions,
   297  		AlternativeRecordLayer:     cs,
   298  		EnforceNextProtoSelection:  true,
   299  		MaxEarlyData:               maxEarlyData,
   300  		Accept0RTT:                 cs.accept0RTT,
   301  		Rejected0RTT:               cs.rejected0RTT,
   302  		Enable0RTT:                 enable0RTT,
   303  		GetAppDataForSessionState:  cs.marshalDataForSessionState,
   304  		SetAppDataFromSessionState: cs.handleDataFromSessionState,
   305  	}
   306  	return cs, cs.clientHelloWrittenChan
   307  }
   308  
   309  func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
   310  	initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
   311  	h.initialSealer = initialSealer
   312  	h.initialOpener = initialOpener
   313  	if h.tracer != nil {
   314  		h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
   315  		h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
   316  	}
   317  }
   318  
   319  func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
   320  	return h.aead.SetLargestAcked(pn)
   321  }
   322  
   323  func (h *cryptoSetup) RunHandshake() {
   324  	// Handle errors that might occur when HandleData() is called.
   325  	handshakeComplete := make(chan struct{})
   326  	handshakeErrChan := make(chan error, 1)
   327  	go func() {
   328  		defer close(h.handshakeDone)
   329  		if err := h.conn.Handshake(); err != nil {
   330  			handshakeErrChan <- err
   331  			return
   332  		}
   333  		close(handshakeComplete)
   334  	}()
   335  
   336  	select {
   337  	case <-handshakeComplete: // return when the handshake is done
   338  		h.mutex.Lock()
   339  		h.handshakeCompleteTime = time.Now()
   340  		h.mutex.Unlock()
   341  		h.runner.OnHandshakeComplete()
   342  	case <-h.closeChan:
   343  		// wait until the Handshake() go routine has returned
   344  		<-h.handshakeDone
   345  	case alert := <-h.alertChan:
   346  		handshakeErr := <-handshakeErrChan
   347  		h.onError(alert, handshakeErr.Error())
   348  	}
   349  }
   350  
   351  func (h *cryptoSetup) onError(alert uint8, message string) {
   352  	h.runner.OnError(qerr.NewCryptoError(alert, message))
   353  }
   354  
   355  // Close closes the crypto setup.
   356  // It aborts the handshake, if it is still running.
   357  // It must only be called once.
   358  func (h *cryptoSetup) Close() error {
   359  	close(h.closeChan)
   360  	// wait until qtls.Handshake() actually returned
   361  	<-h.handshakeDone
   362  	return nil
   363  }
   364  
   365  // handleMessage handles a TLS handshake message.
   366  // It is called by the crypto streams when a new message is available.
   367  // It returns if it is done with messages on the same encryption level.
   368  func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
   369  	msgType := messageType(data[0])
   370  	h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
   371  	if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
   372  		h.onError(alertUnexpectedMessage, err.Error())
   373  		return false
   374  	}
   375  	h.messageChan <- data
   376  	if encLevel == protocol.Encryption1RTT {
   377  		h.handlePostHandshakeMessage()
   378  		return false
   379  	}
   380  readLoop:
   381  	for {
   382  		select {
   383  		case data := <-h.paramsChan:
   384  			if data == nil {
   385  				h.onError(0x6d, "missing quic_transport_parameters extension")
   386  			} else {
   387  				h.handleTransportParameters(data)
   388  			}
   389  		case <-h.isReadingHandshakeMessage:
   390  			break readLoop
   391  		case <-h.handshakeDone:
   392  			break readLoop
   393  		case <-h.closeChan:
   394  			break readLoop
   395  		}
   396  	}
   397  	// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
   398  	// but only if a handshake opener and sealer was created.
   399  	// Otherwise, a HelloRetryRequest was performed.
   400  	// We're done with the Handshake encryption level after processing the Finished message.
   401  	return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
   402  		msgType == typeFinished
   403  }
   404  
   405  func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
   406  	var expected protocol.EncryptionLevel
   407  	switch msgType {
   408  	case typeClientHello,
   409  		typeServerHello:
   410  		expected = protocol.EncryptionInitial
   411  	case typeEncryptedExtensions,
   412  		typeCertificate,
   413  		typeCertificateRequest,
   414  		typeCertificateVerify,
   415  		typeFinished:
   416  		expected = protocol.EncryptionHandshake
   417  	case typeNewSessionTicket:
   418  		expected = protocol.Encryption1RTT
   419  	default:
   420  		return fmt.Errorf("unexpected handshake message: %d", msgType)
   421  	}
   422  	if encLevel != expected {
   423  		return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
   424  	}
   425  	return nil
   426  }
   427  
   428  func (h *cryptoSetup) handleTransportParameters(data []byte) {
   429  	var tp wire.TransportParameters
   430  	if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
   431  		h.runner.OnError(&qerr.TransportError{
   432  			ErrorCode:    qerr.TransportParameterError,
   433  			ErrorMessage: err.Error(),
   434  		})
   435  	}
   436  	h.peerParams = &tp
   437  	h.runner.OnReceivedParams(h.peerParams)
   438  }
   439  
   440  // must be called after receiving the transport parameters
   441  func (h *cryptoSetup) marshalDataForSessionState() []byte {
   442  	buf := &bytes.Buffer{}
   443  	quicvarint.Write(buf, clientSessionStateRevision)
   444  	quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
   445  	h.peerParams.MarshalForSessionTicket(buf)
   446  	return buf.Bytes()
   447  }
   448  
   449  func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
   450  	tp, err := h.handleDataFromSessionStateImpl(data)
   451  	if err != nil {
   452  		h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
   453  		return
   454  	}
   455  	h.zeroRTTParameters = tp
   456  }
   457  
   458  func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
   459  	r := bytes.NewReader(data)
   460  	ver, err := quicvarint.Read(r)
   461  	if err != nil {
   462  		return nil, err
   463  	}
   464  	if ver != clientSessionStateRevision {
   465  		return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
   466  	}
   467  	rtt, err := quicvarint.Read(r)
   468  	if err != nil {
   469  		return nil, err
   470  	}
   471  	h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
   472  	var tp wire.TransportParameters
   473  	if err := tp.UnmarshalFromSessionTicket(r); err != nil {
   474  		return nil, err
   475  	}
   476  	return &tp, nil
   477  }
   478  
   479  // only valid for the server
   480  func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
   481  	var appData []byte
   482  	// Save transport parameters to the session ticket if we're allowing 0-RTT.
   483  	if h.extraConf.MaxEarlyData > 0 {
   484  		appData = (&sessionTicket{
   485  			Parameters: h.ourParams,
   486  			RTT:        h.rttStats.SmoothedRTT(),
   487  		}).Marshal()
   488  	}
   489  	return h.conn.GetSessionTicket(appData)
   490  }
   491  
   492  // accept0RTT is called for the server when receiving the client's session ticket.
   493  // It decides whether to accept 0-RTT.
   494  func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
   495  	var t sessionTicket
   496  	if err := t.Unmarshal(sessionTicketData); err != nil {
   497  		h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
   498  		return false
   499  	}
   500  	valid := h.ourParams.ValidFor0RTT(t.Parameters)
   501  	if valid {
   502  		h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
   503  		h.rttStats.SetInitialRTT(t.RTT)
   504  	} else {
   505  		h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
   506  	}
   507  	return valid
   508  }
   509  
   510  // rejected0RTT is called for the client when the server rejects 0-RTT.
   511  func (h *cryptoSetup) rejected0RTT() {
   512  	h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
   513  
   514  	h.mutex.Lock()
   515  	had0RTTKeys := h.zeroRTTSealer != nil
   516  	h.zeroRTTSealer = nil
   517  	h.mutex.Unlock()
   518  
   519  	if had0RTTKeys {
   520  		h.runner.DropKeys(protocol.Encryption0RTT)
   521  	}
   522  }
   523  
   524  func (h *cryptoSetup) handlePostHandshakeMessage() {
   525  	// make sure the handshake has already completed
   526  	<-h.handshakeDone
   527  
   528  	done := make(chan struct{})
   529  	defer close(done)
   530  
   531  	// h.alertChan is an unbuffered channel.
   532  	// If an error occurs during conn.HandlePostHandshakeMessage,
   533  	// it will be sent on this channel.
   534  	// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
   535  	alertChan := make(chan uint8, 1)
   536  	go func() {
   537  		<-h.isReadingHandshakeMessage
   538  		select {
   539  		case alert := <-h.alertChan:
   540  			alertChan <- alert
   541  		case <-done:
   542  		}
   543  	}()
   544  
   545  	if err := h.conn.HandlePostHandshakeMessage(); err != nil {
   546  		select {
   547  		case <-h.closeChan:
   548  		case alert := <-alertChan:
   549  			h.onError(alert, err.Error())
   550  		}
   551  	}
   552  }
   553  
   554  // ReadHandshakeMessage is called by TLS.
   555  // It blocks until a new handshake message is available.
   556  func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
   557  	if !h.readFirstHandshakeMessage {
   558  		h.readFirstHandshakeMessage = true
   559  	} else {
   560  		select {
   561  		case h.isReadingHandshakeMessage <- struct{}{}:
   562  		case <-h.closeChan:
   563  			return nil, errors.New("error while handling the handshake message")
   564  		}
   565  	}
   566  	select {
   567  	case msg := <-h.messageChan:
   568  		return msg, nil
   569  	case <-h.closeChan:
   570  		return nil, errors.New("error while handling the handshake message")
   571  	}
   572  }
   573  
   574  func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
   575  	h.mutex.Lock()
   576  	switch encLevel {
   577  	case qtls.Encryption0RTT:
   578  		if h.perspective == protocol.PerspectiveClient {
   579  			panic("Received 0-RTT read key for the client")
   580  		}
   581  		h.zeroRTTOpener = newLongHeaderOpener(
   582  			createAEAD(suite, trafficSecret),
   583  			newHeaderProtector(suite, trafficSecret, true),
   584  		)
   585  		h.mutex.Unlock()
   586  		h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
   587  		if h.tracer != nil {
   588  			h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
   589  		}
   590  		return
   591  	case qtls.EncryptionHandshake:
   592  		h.readEncLevel = protocol.EncryptionHandshake
   593  		h.handshakeOpener = newHandshakeOpener(
   594  			createAEAD(suite, trafficSecret),
   595  			newHeaderProtector(suite, trafficSecret, true),
   596  			h.dropInitialKeys,
   597  			h.perspective,
   598  		)
   599  		h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
   600  	case qtls.EncryptionApplication:
   601  		h.readEncLevel = protocol.Encryption1RTT
   602  		h.aead.SetReadKey(suite, trafficSecret)
   603  		h.has1RTTOpener = true
   604  		h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
   605  	default:
   606  		panic("unexpected read encryption level")
   607  	}
   608  	h.mutex.Unlock()
   609  	if h.tracer != nil {
   610  		h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
   611  	}
   612  }
   613  
   614  func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
   615  	h.mutex.Lock()
   616  	switch encLevel {
   617  	case qtls.Encryption0RTT:
   618  		if h.perspective == protocol.PerspectiveServer {
   619  			panic("Received 0-RTT write key for the server")
   620  		}
   621  		h.zeroRTTSealer = newLongHeaderSealer(
   622  			createAEAD(suite, trafficSecret),
   623  			newHeaderProtector(suite, trafficSecret, true),
   624  		)
   625  		h.mutex.Unlock()
   626  		h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
   627  		if h.tracer != nil {
   628  			h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
   629  		}
   630  		return
   631  	case qtls.EncryptionHandshake:
   632  		h.writeEncLevel = protocol.EncryptionHandshake
   633  		h.handshakeSealer = newHandshakeSealer(
   634  			createAEAD(suite, trafficSecret),
   635  			newHeaderProtector(suite, trafficSecret, true),
   636  			h.dropInitialKeys,
   637  			h.perspective,
   638  		)
   639  		h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
   640  	case qtls.EncryptionApplication:
   641  		h.writeEncLevel = protocol.Encryption1RTT
   642  		h.aead.SetWriteKey(suite, trafficSecret)
   643  		h.has1RTTSealer = true
   644  		h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
   645  		if h.zeroRTTSealer != nil {
   646  			h.zeroRTTSealer = nil
   647  			h.logger.Debugf("Dropping 0-RTT keys.")
   648  			if h.tracer != nil {
   649  				h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
   650  			}
   651  		}
   652  	default:
   653  		panic("unexpected write encryption level")
   654  	}
   655  	h.mutex.Unlock()
   656  	if h.tracer != nil {
   657  		h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
   658  	}
   659  }
   660  
   661  // WriteRecord is called when TLS writes data
   662  func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
   663  	h.mutex.Lock()
   664  	defer h.mutex.Unlock()
   665  
   666  	//nolint:exhaustive // LS records can only be written for Initial and Handshake.
   667  	switch h.writeEncLevel {
   668  	case protocol.EncryptionInitial:
   669  		// assume that the first WriteRecord call contains the ClientHello
   670  		n, err := h.initialStream.Write(p)
   671  		if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
   672  			h.clientHelloWritten = true
   673  			if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
   674  				h.logger.Debugf("Doing 0-RTT.")
   675  				h.clientHelloWrittenChan <- h.zeroRTTParameters
   676  			} else {
   677  				h.logger.Debugf("Not doing 0-RTT.")
   678  				h.clientHelloWrittenChan <- nil
   679  			}
   680  		}
   681  		return n, err
   682  	case protocol.EncryptionHandshake:
   683  		return h.handshakeStream.Write(p)
   684  	default:
   685  		panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
   686  	}
   687  }
   688  
   689  func (h *cryptoSetup) SendAlert(alert uint8) {
   690  	select {
   691  	case h.alertChan <- alert:
   692  	case <-h.closeChan:
   693  		// no need to send an alert when we've already closed
   694  	}
   695  }
   696  
   697  // used a callback in the handshakeSealer and handshakeOpener
   698  func (h *cryptoSetup) dropInitialKeys() {
   699  	h.mutex.Lock()
   700  	h.initialOpener = nil
   701  	h.initialSealer = nil
   702  	h.mutex.Unlock()
   703  	h.runner.DropKeys(protocol.EncryptionInitial)
   704  	h.logger.Debugf("Dropping Initial keys.")
   705  }
   706  
   707  func (h *cryptoSetup) SetHandshakeConfirmed() {
   708  	h.aead.SetHandshakeConfirmed()
   709  	// drop Handshake keys
   710  	var dropped bool
   711  	h.mutex.Lock()
   712  	if h.handshakeOpener != nil {
   713  		h.handshakeOpener = nil
   714  		h.handshakeSealer = nil
   715  		dropped = true
   716  	}
   717  	h.mutex.Unlock()
   718  	if dropped {
   719  		h.runner.DropKeys(protocol.EncryptionHandshake)
   720  		h.logger.Debugf("Dropping Handshake keys.")
   721  	}
   722  }
   723  
   724  func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
   725  	h.mutex.Lock()
   726  	defer h.mutex.Unlock()
   727  
   728  	if h.initialSealer == nil {
   729  		return nil, ErrKeysDropped
   730  	}
   731  	return h.initialSealer, nil
   732  }
   733  
   734  func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
   735  	h.mutex.Lock()
   736  	defer h.mutex.Unlock()
   737  
   738  	if h.zeroRTTSealer == nil {
   739  		return nil, ErrKeysDropped
   740  	}
   741  	return h.zeroRTTSealer, nil
   742  }
   743  
   744  func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
   745  	h.mutex.Lock()
   746  	defer h.mutex.Unlock()
   747  
   748  	if h.handshakeSealer == nil {
   749  		if h.initialSealer == nil {
   750  			return nil, ErrKeysDropped
   751  		}
   752  		return nil, ErrKeysNotYetAvailable
   753  	}
   754  	return h.handshakeSealer, nil
   755  }
   756  
   757  func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
   758  	h.mutex.Lock()
   759  	defer h.mutex.Unlock()
   760  
   761  	if !h.has1RTTSealer {
   762  		return nil, ErrKeysNotYetAvailable
   763  	}
   764  	return h.aead, nil
   765  }
   766  
   767  func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
   768  	h.mutex.Lock()
   769  	defer h.mutex.Unlock()
   770  
   771  	if h.initialOpener == nil {
   772  		return nil, ErrKeysDropped
   773  	}
   774  	return h.initialOpener, nil
   775  }
   776  
   777  func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
   778  	h.mutex.Lock()
   779  	defer h.mutex.Unlock()
   780  
   781  	if h.zeroRTTOpener == nil {
   782  		if h.initialOpener != nil {
   783  			return nil, ErrKeysNotYetAvailable
   784  		}
   785  		// if the initial opener is also not available, the keys were already dropped
   786  		return nil, ErrKeysDropped
   787  	}
   788  	return h.zeroRTTOpener, nil
   789  }
   790  
   791  func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
   792  	h.mutex.Lock()
   793  	defer h.mutex.Unlock()
   794  
   795  	if h.handshakeOpener == nil {
   796  		if h.initialOpener != nil {
   797  			return nil, ErrKeysNotYetAvailable
   798  		}
   799  		// if the initial opener is also not available, the keys were already dropped
   800  		return nil, ErrKeysDropped
   801  	}
   802  	return h.handshakeOpener, nil
   803  }
   804  
   805  func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
   806  	h.mutex.Lock()
   807  	defer h.mutex.Unlock()
   808  
   809  	if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
   810  		h.zeroRTTOpener = nil
   811  		h.logger.Debugf("Dropping 0-RTT keys.")
   812  		if h.tracer != nil {
   813  			h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
   814  		}
   815  	}
   816  
   817  	if !h.has1RTTOpener {
   818  		return nil, ErrKeysNotYetAvailable
   819  	}
   820  	return h.aead, nil
   821  }
   822  
   823  func (h *cryptoSetup) ConnectionState() ConnectionState {
   824  	return qtls.GetConnectionState(h.conn)
   825  }