github.com/fumiama/terasu@v0.0.0-20240507144117-547a591149c0/handshake_1.20.go (about)

     1  //go:build !go1.21
     2  
     3  package terasu
     4  
     5  import (
     6  	"context"
     7  	"crypto/ecdh"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"errors"
    11  	"hash"
    12  	"time"
    13  	"unsafe"
    14  )
    15  
    16  //go:linkname defaultConfig crypto/tls.defaultConfig
    17  func defaultConfig() *tls.Config
    18  
    19  type clientHelloMsg struct {
    20  	raw                []byte
    21  	vers               uint16
    22  	random             []byte
    23  	sessionId          []byte
    24  	cipherSuites       []uint16
    25  	compressionMethods []uint8
    26  	serverName         string
    27  }
    28  
    29  //go:linkname marshal crypto/tls.(*clientHelloMsg).marshal
    30  func marshal(m *clientHelloMsg) ([]byte, error)
    31  
    32  func (m *clientHelloMsg) marshal() ([]byte, error) {
    33  	return marshal(m)
    34  }
    35  
    36  //go:linkname unmarshal crypto/tls.(*clientHelloMsg).unmarshal
    37  func unmarshal(m *clientHelloMsg, data []byte) bool
    38  
    39  func (m *clientHelloMsg) unmarshal(data []byte) bool {
    40  	return unmarshal(m, data)
    41  }
    42  
    43  //go:linkname makeClientHello crypto/tls.(*Conn).makeClientHello
    44  func makeClientHello(c *_trsconn) (*clientHelloMsg, *ecdh.PrivateKey, error)
    45  
    46  func (c *_trsconn) makeClientHello() (*clientHelloMsg, *ecdh.PrivateKey, error) {
    47  	return makeClientHello(c)
    48  }
    49  
    50  // ClientSessionState contains the state needed by clients to resume TLS
    51  // sessions.
    52  type sessionState struct {
    53  	sessionTicket      []uint8               // Encrypted ticket used for session resumption with server
    54  	vers               uint16                // TLS version negotiated for the session
    55  	cipherSuite        uint16                // Ciphersuite negotiated for the session
    56  	masterSecret       []byte                // Full handshake MasterSecret, or TLS 1.3 resumption_master_secret
    57  	serverCertificates []*x509.Certificate   // Certificate chain presented by the server
    58  	verifiedChains     [][]*x509.Certificate // Certificate chains we built for verification
    59  	receivedAt         time.Time             // When the session ticket was received from the server
    60  	ocspResponse       []byte                // Stapled OCSP response presented by the server
    61  	scts               [][]byte              // SCTs presented by the server
    62  
    63  	// TLS 1.3 fields.
    64  	nonce  []byte    // Ticket nonce sent by the server, to derive PSK
    65  	useBy  time.Time // Expiration of the ticket lifetime as set by the server
    66  	ageAdd uint32    // Random obfuscation factor for sending the ticket age
    67  }
    68  
    69  //go:linkname loadSession crypto/tls.(*Conn).loadSession
    70  func loadSession(c *_trsconn, hello *clientHelloMsg) (cacheKey string,
    71  	session *sessionState, earlySecret, binderKey []byte, err error,
    72  )
    73  
    74  func (c *_trsconn) loadSession(hello *clientHelloMsg) (cacheKey string,
    75  	session *sessionState, earlySecret, binderKey []byte, err error,
    76  ) {
    77  	return loadSession(c, hello)
    78  }
    79  
    80  type handshakeMessage interface {
    81  	marshal() ([]byte, error)
    82  	unmarshal([]byte) bool
    83  }
    84  
    85  type transcriptHash interface {
    86  	Write([]byte) (int, error)
    87  }
    88  
    89  //go:linkname transcriptMsg crypto/tls.transcriptMsg
    90  func transcriptMsg(msg handshakeMessage, h transcriptHash) error
    91  
    92  //go:linkname readHandshake crypto/tls.(*Conn).readHandshake
    93  func readHandshake(c *_trsconn, transcript transcriptHash) (any, error)
    94  
    95  func (c *_trsconn) readHandshake(transcript transcriptHash) (any, error) {
    96  	return readHandshake(c, transcript)
    97  }
    98  
    99  type serverHelloMsg struct {
   100  	raw    []byte
   101  	vers   uint16
   102  	random []byte
   103  }
   104  
   105  //go:linkname sendAlert crypto/tls.(*Conn).sendAlert
   106  func sendAlert(c *_trsconn, err alert) error
   107  
   108  func (c *_trsconn) sendAlert(err alert) error {
   109  	return sendAlert(c, err)
   110  }
   111  
   112  //go:linkname unexpectedMessageError crypto/tls.unexpectedMessageError
   113  func unexpectedMessageError(wanted, got any) error
   114  
   115  const (
   116  	alertUnexpectedMessage alert = 10
   117  	alertIllegalParameter  alert = 47
   118  )
   119  
   120  //go:linkname pickTLSVersion crypto/tls.(*Conn).pickTLSVersion
   121  func pickTLSVersion(c *_trsconn, serverHello *serverHelloMsg) error
   122  
   123  func (c *_trsconn) pickTLSVersion(serverHello *serverHelloMsg) error {
   124  	return pickTLSVersion(c, serverHello)
   125  }
   126  
   127  //go:linkname maxSupportedVersion crypto/tls.(*Config).maxSupportedVersion
   128  func maxSupportedVersion(c *tls.Config, isClient bool) uint16
   129  
   130  const roleClient = true
   131  
   132  const (
   133  	// downgradeCanaryTLS12 or downgradeCanaryTLS11 is embedded in the server
   134  	// random as a downgrade protection if the server would be capable of
   135  	// negotiating a higher version. See RFC 8446, Section 4.1.3.
   136  	downgradeCanaryTLS12 = "DOWNGRD\x01"
   137  	downgradeCanaryTLS11 = "DOWNGRD\x00"
   138  )
   139  
   140  type clientHandshakeStateTLS13 struct {
   141  	c           *Conn
   142  	ctx         context.Context
   143  	serverHello *serverHelloMsg
   144  	hello       *clientHelloMsg
   145  	ecdheKey    *ecdh.PrivateKey
   146  
   147  	session     *sessionState
   148  	earlySecret []byte
   149  	binderKey   []byte
   150  
   151  	certReq       *uintptr
   152  	usingPSK      bool
   153  	sentDummyCCS  bool
   154  	suite         *uintptr
   155  	transcript    hash.Hash
   156  	masterSecret  []byte
   157  	trafficSecret []byte // client_application_traffic_secret_0
   158  }
   159  
   160  //go:linkname handshake13 crypto/tls.(*clientHandshakeStateTLS13).handshake
   161  func handshake13(hs *clientHandshakeStateTLS13) error
   162  
   163  func (hs *clientHandshakeStateTLS13) handshake() error {
   164  	return handshake13(hs)
   165  }
   166  
   167  // A finishedHash calculates the hash of a set of handshake messages suitable
   168  // for including in a Finished message.
   169  type finishedHash struct {
   170  	client hash.Hash
   171  	server hash.Hash
   172  
   173  	// Prior to TLS 1.2, an additional MD5 hash is required.
   174  	clientMD5 hash.Hash
   175  	serverMD5 hash.Hash
   176  
   177  	// In TLS 1.2, a full buffer is sadly required.
   178  	buffer []byte
   179  
   180  	version uint16
   181  	prf     func(result, secret, label, seed []byte)
   182  }
   183  
   184  type clientHandshakeState struct {
   185  	c            *Conn
   186  	ctx          context.Context
   187  	serverHello  *serverHelloMsg
   188  	hello        *clientHelloMsg
   189  	suite        *uintptr
   190  	finishedHash finishedHash
   191  	masterSecret []byte
   192  	session      *sessionState // the session being resumed
   193  	ticket       []byte        // a fresh ticket received during this handshake
   194  }
   195  
   196  //go:linkname handshake crypto/tls.(*clientHandshakeState).handshake
   197  func handshake(hs *clientHandshakeState) error
   198  
   199  func (hs *clientHandshakeState) handshake() error {
   200  	return handshake(hs)
   201  }
   202  
   203  // writeHandshakeRecord writes a handshake message to the connection and updates
   204  // the record layer state. If transcript is non-nil the marshalled message is
   205  // written to it.
   206  func (c *_trsconn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash, firstFragmentLen uint8) (int, error) {
   207  	c.out.Lock()
   208  	defer c.out.Unlock()
   209  
   210  	data, err := msg.marshal()
   211  	if err != nil {
   212  		return 0, err
   213  	}
   214  	if transcript != nil {
   215  		transcript.Write(data)
   216  	}
   217  
   218  	return c.writeRecordLocked(recordTypeHandshake, firstFragmentLen, data)
   219  }
   220  
   221  func (cout *Conn) clientHandshake(firstFragmentLen uint8) func(context.Context) error {
   222  	return func(ctx context.Context) (err error) {
   223  		c := (*_trsconn)(unsafe.Pointer(cout))
   224  
   225  		if c.config == nil {
   226  			c.config = defaultConfig()
   227  		}
   228  
   229  		// This may be a renegotiation handshake, in which case some fields
   230  		// need to be reset.
   231  		c.didResume = false
   232  
   233  		hello, ecdheKey, err := c.makeClientHello()
   234  		if err != nil {
   235  			return err
   236  		}
   237  		c.serverName = hello.serverName
   238  
   239  		cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
   240  		if err != nil {
   241  			return err
   242  		}
   243  		if cacheKey != "" && session != nil {
   244  			defer func() {
   245  				// If we got a handshake failure when resuming a session, throw away
   246  				// the session ticket. See RFC 5077, Section 3.2.
   247  				//
   248  				// RFC 8446 makes no mention of dropping tickets on failure, but it
   249  				// does require servers to abort on invalid binders, so we need to
   250  				// delete tickets to recover from a corrupted PSK.
   251  				if err != nil {
   252  					c.config.ClientSessionCache.Put(cacheKey, nil)
   253  				}
   254  			}()
   255  		}
   256  
   257  		if _, err := c.writeHandshakeRecord(hello, nil, firstFragmentLen); err != nil {
   258  			return err
   259  		}
   260  
   261  		// serverHelloMsg is not included in the transcript
   262  		msg, err := c.readHandshake(nil)
   263  		if err != nil {
   264  			return err
   265  		}
   266  
   267  		var serverHello *serverHelloMsg
   268  		if !isTypeEqual(msg, "*tls.serverHelloMsg") {
   269  			c.sendAlert(alertUnexpectedMessage)
   270  			return unexpectedMessageError(serverHello, msg)
   271  		}
   272  		serverHello = (*serverHelloMsg)(*(*unsafe.Pointer)(
   273  			unsafe.Add(unsafe.Pointer(&msg), unsafe.Sizeof(uintptr(0))),
   274  		))
   275  
   276  		if err := c.pickTLSVersion(serverHello); err != nil {
   277  			return err
   278  		}
   279  
   280  		// If we are negotiating a protocol version that's lower than what we
   281  		// support, check for the server downgrade canaries.
   282  		// See RFC 8446, Section 4.1.3.
   283  		maxVers := maxSupportedVersion(c.config, roleClient)
   284  		tls12Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS12
   285  		tls11Downgrade := string(serverHello.random[24:]) == downgradeCanaryTLS11
   286  		if maxVers == tls.VersionTLS13 && c.vers <= tls.VersionTLS12 && (tls12Downgrade || tls11Downgrade) ||
   287  			maxVers == tls.VersionTLS12 && c.vers <= tls.VersionTLS11 && tls11Downgrade {
   288  			c.sendAlert(alertIllegalParameter)
   289  			return errors.New("tls: downgrade attempt detected, possibly due to a MitM attack or a broken middlebox")
   290  		}
   291  
   292  		if c.vers == tls.VersionTLS13 {
   293  			hs := &clientHandshakeStateTLS13{
   294  				c:           cout,
   295  				ctx:         ctx,
   296  				serverHello: serverHello,
   297  				hello:       hello,
   298  				ecdheKey:    ecdheKey,
   299  				session:     session,
   300  				earlySecret: earlySecret,
   301  				binderKey:   binderKey,
   302  			}
   303  
   304  			// In TLS 1.3, session tickets are delivered after the handshake.
   305  			return hs.handshake()
   306  		}
   307  
   308  		hs := &clientHandshakeState{
   309  			c:           cout,
   310  			ctx:         ctx,
   311  			serverHello: serverHello,
   312  			hello:       hello,
   313  			session:     session,
   314  		}
   315  
   316  		if err := hs.handshake(); err != nil {
   317  			return err
   318  		}
   319  
   320  		// If we had a successful handshake and hs.session is different from
   321  		// the one already cached - cache a new one.
   322  		if cacheKey != "" && hs.session != nil && session != hs.session {
   323  			c.config.ClientSessionCache.Put(cacheKey, (*tls.ClientSessionState)(unsafe.Pointer(hs.session)))
   324  		}
   325  
   326  		return nil
   327  	}
   328  }