github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/handshake.go (about)

     1  //go:build go1.21
     2  
     3  package terasu
     4  
     5  import (
     6  	"context"
     7  	"crypto"
     8  	"crypto/ecdh"
     9  	"crypto/tls"
    10  	"errors"
    11  	"hash"
    12  	"unsafe"
    13  )
    14  
    15  //go:linkname defaultConfig crypto/tls.defaultConfig
    16  func defaultConfig() *tls.Config
    17  
    18  type clientHelloMsg struct {
    19  	raw                              []byte
    20  	vers                             uint16
    21  	random                           []byte
    22  	sessionId                        []byte
    23  	cipherSuites                     []uint16
    24  	compressionMethods               []uint8
    25  	serverName                       string
    26  	ocspStapling                     bool
    27  	supportedCurves                  []tls.CurveID
    28  	supportedPoints                  []uint8
    29  	ticketSupported                  bool
    30  	sessionTicket                    []uint8
    31  	supportedSignatureAlgorithms     []tls.SignatureScheme
    32  	supportedSignatureAlgorithmsCert []tls.SignatureScheme
    33  	secureRenegotiationSupported     bool
    34  	secureRenegotiation              []byte
    35  	extendedMasterSecret             bool
    36  	alpnProtocols                    []string
    37  	scts                             bool
    38  	supportedVersions                []uint16
    39  	cookie                           []byte
    40  	keyShares                        []byte
    41  	earlyData                        bool
    42  }
    43  
    44  //go:linkname marshal crypto/tls.(*clientHelloMsg).marshal
    45  func marshal(m *clientHelloMsg) ([]byte, error)
    46  
    47  func (m *clientHelloMsg) marshal() ([]byte, error) {
    48  	return marshal(m)
    49  }
    50  
    51  //go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal
    52  func unmarshal(m *clientHelloMsg, data []byte) bool
    53  
    54  func (m *clientHelloMsg) unmarshal(data []byte) bool {
    55  	return unmarshal(m, data)
    56  }
    57  
    58  //go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello
    59  func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error)
    60  
    61  func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
    62  	return makeClientHello(c)
    63  }
    64  
    65  // A sessionState is a resumable session.
    66  type sessionState struct {
    67  	// Encoded as a SessionState (in the language of RFC 8446, Section 3).
    68  	//
    69  	//   enum { server(1), client(2) } SessionStateType;
    70  	//
    71  	//   opaque Certificate<1..2^24-1>;
    72  	//
    73  	//   Certificate CertificateChain<0..2^24-1>;
    74  	//
    75  	//   opaque Extra<0..2^24-1>;
    76  	//
    77  	//   struct {
    78  	//       uint16 version;
    79  	//       SessionStateType type;
    80  	//       uint16 cipher_suite;
    81  	//       uint64 created_at;
    82  	//       opaque secret<1..2^8-1>;
    83  	//       Extra extra<0..2^24-1>;
    84  	//       uint8 ext_master_secret = { 0, 1 };
    85  	//       uint8 early_data = { 0, 1 };
    86  	//       CertificateEntry certificate_list<0..2^24-1>;
    87  	//       CertificateChain verified_chains<0..2^24-1>; /* excluding leaf */
    88  	//       select (SessionState.early_data) {
    89  	//           case 0: Empty;
    90  	//           case 1: opaque alpn<1..2^8-1>;
    91  	//       };
    92  	//       select (SessionState.type) {
    93  	//           case server: Empty;
    94  	//           case client: struct {
    95  	//               select (SessionState.version) {
    96  	//                   case VersionTLS10..VersionTLS12: Empty;
    97  	//                   case VersionTLS13: struct {
    98  	//                       uint64 use_by;
    99  	//                       uint32 age_add;
   100  	//                   };
   101  	//               };
   102  	//           };
   103  	//       };
   104  	//   } SessionState;
   105  	//
   106  
   107  	// Extra is ignored by crypto/tls, but is encoded by [SessionState.Bytes]
   108  	// and parsed by [ParseSessionState].
   109  	//
   110  	// This allows [Config.UnwrapSession]/[Config.WrapSession] and
   111  	// [ClientSessionCache] implementations to store and retrieve additional
   112  	// data alongside this session.
   113  	//
   114  	// To allow different layers in a protocol stack to share this field,
   115  	// applications must only append to it, not replace it, and must use entries
   116  	// that can be recognized even if out of order (for example, by starting
   117  	// with an id and version prefix).
   118  	Extra [][]byte
   119  
   120  	// EarlyData indicates whether the ticket can be used for 0-RTT in a QUIC
   121  	// connection. The application may set this to false if it is true to
   122  	// decline to offer 0-RTT even if supported.
   123  	EarlyData bool
   124  
   125  	version     uint16
   126  	isClient    bool
   127  	cipherSuite uint16
   128  }
   129  
   130  //go:linkname loadSession crypto/tls.(*Conn).loadSession
   131  func loadSession(c *_trsconn, hello *clientHelloMsg) (
   132  	session *sessionState, earlySecret, binderKey []byte, err error,
   133  )
   134  
   135  func (c *_trsconn) loadSession(hello *clientHelloMsg) (
   136  	session *sessionState, earlySecret, binderKey []byte, err error,
   137  ) {
   138  	return loadSession(c, hello)
   139  }
   140  
   141  //go:linkname clientSessionCacheKey crypto/tls.(*Conn).clientSessionCacheKey
   142  func clientSessionCacheKey(c *_trsconn) string
   143  
   144  func (c *_trsconn) clientSessionCacheKey() string {
   145  	return clientSessionCacheKey(c)
   146  }
   147  
   148  // A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash
   149  // algorithm to be used with HKDF. See RFC 8446, Appendix B.4.
   150  type cipherSuiteTLS13 struct {
   151  	id     uint16
   152  	keyLen int
   153  	aead   func(key, fixedNonce []byte) any
   154  	hash   crypto.Hash
   155  }
   156  
   157  //go:linkname deriveSecret crypto/tls.(*cipherSuiteTLS13).deriveSecret
   158  func deriveSecret(c *cipherSuiteTLS13, secret []byte, label string, transcript hash.Hash) []byte
   159  
   160  func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
   161  	return deriveSecret(c, secret, label, transcript)
   162  }
   163  
   164  //go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID
   165  func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
   166  
   167  type handshakeMessage interface {
   168  	marshal() ([]byte, error)
   169  	unmarshal([]byte) bool
   170  }
   171  
   172  type transcriptHash interface {
   173  	Write([]byte) (int, error)
   174  }
   175  
   176  //go:linkname transcriptMsg crypto/tls.transcriptMsg
   177  func transcriptMsg(msg handshakeMessage, h transcriptHash) error
   178  
   179  const clientEarlyTrafficLabel = "c e traffic"
   180  
   181  //go:linkname quicSetWriteSecret crypto/tls.(*Conn).quicSetWriteSecret
   182  func quicSetWriteSecret(c *_trsconn, level tls.QUICEncryptionLevel, suite uint16, secret []byte)
   183  
   184  //go:linkname readHandshake crypto/tls.(*Conn).readHandshake
   185  func readHandshake(c *_trsconn, transcript transcriptHash) (any, error)
   186  
   187  func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) {
   188  	return readHandshake(c, transcript)
   189  }
   190  
   191  type serverHelloMsg struct {
   192  	raw    []byte
   193  	vers   uint16
   194  	random []byte
   195  }
   196  
   197  //go:linkname sendAlert crypto/tls.(*Conn).sendAlert
   198  func sendAlert(c *_trsconn, err alert) error
   199  
   200  func (c *_trsconn) sendAlert(err alert) error {
   201  	return sendAlert(c, err)
   202  }
   203  
   204  //go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError
   205  func unexpectedMessageError(wanted, got any) error
   206  
   207  const (
   208  	alertUnexpectedMessage alert = 10
   209  	alertIllegalParameter  alert = 47
   210  )
   211  
   212  //go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion
   213  func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error
   214  
   215  func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error {
   216  	return pickTLSVersion(c, serverHello)
   217  }
   218  
   219  //go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion
   220  func maxSupportedVersion(c *tls.Config, isClient bool) uint16
   221  
   222  const roleClient = true
   223  
   224  const (
   225  	// downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server
   226  	// random as a downgrade protection if the server would be capable of
   227  	// negotiating a higher version. See RFC 8446, Section 4.1.3.
   228  	downgradeCanaryTLS12 = "DOWNGRD\x01"
   229  	downgradeCanaryTLS11 = "DOWNGRD\x00"
   230  )
   231  
   232  type clientHandshakeStateTLS13 struct {
   233  	c           *Conn
   234  	ctx         context.Context
   235  	serverHello *serverHelloMsg
   236  	hello       *clientHelloMsg
   237  	ecdheKey    *ecdh.PrivateKey
   238  
   239  	session     *sessionState
   240  	earlySecret []byte
   241  	binderKey   []byte
   242  
   243  	certReq       *uintptr
   244  	usingPSK      bool
   245  	sentDummyCCS  bool
   246  	suite         *cipherSuiteTLS13
   247  	transcript    hash.Hash
   248  	masterSecret  []byte
   249  	trafficSecret []byte // client_application_traffic_secret_0
   250  }
   251  
   252  //go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake
   253  func handshake13(hs *clientHandshakeStateTLS13) error
   254  
   255  func (hs *clientHandshakeStateTLS13) handshake() error {
   256  	return handshake13(hs)
   257  }
   258  
   259  // A finishedHash calculates the hash of a set of handshake messages suitable
   260  // for including in a Finished message.
   261  type finishedHash struct {
   262  	client hash.Hash
   263  	server hash.Hash
   264  
   265  	// Prior to TLS 1.2, an additional MD5 hash is required.
   266  	clientMD5 hash.Hash
   267  	serverMD5 hash.Hash
   268  
   269  	// In TLS 1.2, a full buffer is sadly required.
   270  	buffer []byte
   271  
   272  	version uint16
   273  	prf     func(result, secret, label, seed []byte)
   274  }
   275  
   276  type clientHandshakeState struct {
   277  	c            *Conn
   278  	ctx          context.Context
   279  	serverHello  *serverHelloMsg
   280  	hello        *clientHelloMsg
   281  	suite        *uintptr
   282  	finishedHash finishedHash
   283  	masterSecret []byte
   284  	session      *sessionState // the session being resumed
   285  	ticket       []byte        // a fresh ticket received during this handshake
   286  }
   287  
   288  //go:linkname handshake crypto/tls.(*clientHandshakeState).handshake
   289  func handshake(hs *clientHandshakeState) error
   290  
   291  func (hs *clientHandshakeState) handshake() error {
   292  	return handshake(hs)
   293  }
   294  
   295  // writeHandshakeRecord writes a handshake message to the connection and updates
   296  // the record layer state. If transcript is non-nil the marshalled message is
   297  // written to it.
   298  func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) {
   299  	c.out.Lock()
   300  	defer c.out.Unlock()
   301  
   302  	data, err := msg.marshal()
   303  	if err != nil {
   304  		return 0, err
   305  	}
   306  	if transcript != nil {
   307  		transcript.Write(data)
   308  	}
   309  
   310  	return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data)
   311  }
   312  
   313  func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error {
   314  	return func(ctx context.Context) (err error) {
   315  		c := (*_trsconn)(unsafe.Pointer(cout))
   316  
   317  		if c.config == nil {
   318  			c.config = defaultConfig()
   319  		}
   320  
   321  		// This may be a renegotiation handshake, in which case some fields
   322  		// need to be reset.
   323  		c.didResume = false
   324  
   325  		hello, ecdheKey, err := c.makeClientHello()
   326  		if err != nil {
   327  			return err
   328  		}
   329  		c.serverName = hello.serverName
   330  
   331  		session, earlySecret, binderKey, err := c.loadSession(hello)
   332  		if err != nil {
   333  			return err
   334  		}
   335  		if session != nil {
   336  			defer func() {
   337  				// If we got a handshake failure when resuming a session, throw away
   338  				// the session ticket. See RFC 5077, Section 3.2.
   339  				//
   340  				// RFC 8446 makes no mention of dropping tickets on failure, but it
   341  				// does require servers to abort on invalid binders, so we need to
   342  				// delete tickets to recover from a corrupted PSK.
   343  				if err != nil {
   344  					if cacheKey := c.clientSessionCacheKey(); cacheKey != "" {
   345  						c.config.ClientSessionCache.Put(cacheKey, nil)
   346  					}
   347  				}
   348  			}()
   349  		}
   350  
   351  		if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil {
   352  			return err
   353  		}
   354  
   355  		if hello.earlyData {
   356  			suite := cipherSuiteTLS13ByID(session.cipherSuite)
   357  			transcript := suite.hash.New()
   358  			if err := transcriptMsg(hello, transcript); err != nil {
   359  				return err
   360  			}
   361  			earlyTrafficSecret := suite.deriveSecret(earlySecret, clientEarlyTrafficLabel, transcript)
   362  			quicSetWriteSecret(c, tls.QUICEncryptionLevelEarly, suite.id, earlyTrafficSecret)
   363  		}
   364  
   365  		// serverHelloMsg is not included in the transcript
   366  		msg, err := c.readHandshake(nil)
   367  		if err != nil {
   368  			return err
   369  		}
   370  
   371  		var serverHello *serverHelloMsg
   372  		if !isTypeEqual(msg, "*tls.serverHelloMsg") {
   373  			c.sendAlert(alertUnexpectedMessage)
   374  			return unexpectedMessageError(serverHello, msg)
   375  		}
   376  		serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)(
   377  			unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))),
   378  		))
   379  
   380  		if err := c.pickTLSVersion(serverHello); err != nil {
   381  			return err
   382  		}
   383  
   384  		// If we are negotiating a protocol version that's lower than what we
   385  		// support, check for the server downgrade canaries.
   386  		// See RFC 8446, Section 4.1.3.
   387  		maxVers := maxSupportedVersion(c.config, roleClient)
   388  		tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12
   389  		tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11
   390  		if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) ||
   391  			maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade {
   392  			c.sendAlert(alertIllegalParameter)
   393  			return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox")
   394  		}
   395  
   396  		if c.vers == tls.VersionTLS13 {
   397  			hs := &clientHandshakeStateTLS13{
   398  				c:           cout,
   399  				ctx:         ctx,
   400  				serverHello: serverHello,
   401  				hello:       hello,
   402  				ecdheKey:    ecdheKey,
   403  				session:     session,
   404  				earlySecret: earlySecret,
   405  				binderKey:   binderKey,
   406  			}
   407  
   408  			// In TLS 1.3, session tickets are delivered after the handshake.
   409  			return hs.handshake()
   410  		}
   411  
   412  		hs := &clientHandshakeState{
   413  			c:           cout,
   414  			ctx:         ctx,
   415  			serverHello: serverHello,
   416  			hello:       hello,
   417  			session:     session,
   418  		}
   419  
   420  		if err := hs.handshake(); err != nil {
   421  			return err
   422  		}
   423  
   424  		return nil
   425  	}
   426  }