github.com/epfl-dcsl/gotee@v0.0.0-20200909122901-014b35f5e5e9/src/crypto/tls/key_agreement.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/ecdsa"
    10  	"crypto/elliptic"
    11  	"crypto/md5"
    12  	"crypto/rsa"
    13  	"crypto/sha1"
    14  	"crypto/x509"
    15  	"encoding/asn1"
    16  	"errors"
    17  	"io"
    18  	"math/big"
    19  	"teecomm"
    20  
    21  	"golang_org/x/crypto/curve25519"
    22  )
    23  
    24  var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
    25  var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
    26  
    27  // rsaKeyAgreement implements the standard TLS key agreement where the client
    28  // encrypts the pre-master secret to the server's public key.
    29  type rsaKeyAgreement struct{}
    30  
    31  func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
    32  	return nil, nil
    33  }
    34  
    35  func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
    36  	if len(ckx.ciphertext) < 2 {
    37  		return nil, errClientKeyExchange
    38  	}
    39  
    40  	ciphertext := ckx.ciphertext
    41  	if version != VersionSSL30 {
    42  		ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
    43  		if ciphertextLen != len(ckx.ciphertext)-2 {
    44  			return nil, errClientKeyExchange
    45  		}
    46  		ciphertext = ckx.ciphertext[2:]
    47  	}
    48  	priv, ok := cert.PrivateKey.(crypto.Decrypter)
    49  	if !ok {
    50  		return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
    51  	}
    52  	// Perform constant time RSA PKCS#1 v1.5 decryption
    53  	var preMasterSecret []byte
    54  	var err error
    55  	if cert.DecrChan == nil {
    56  		preMasterSecret, err = priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  	} else {
    61  		preMasterSecret = make([]byte, 48)
    62  		done := make(chan bool)
    63  		key, ok := cert.PrivateKey.(*rsa.PrivateKey)
    64  		if !ok {
    65  			panic("Unable to type cast crypto.PrivateKey to rsa.PrivateKey")
    66  		}
    67  		req := teecomm.DecrRequestMsg{
    68  			key, ciphertext,
    69  			&rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48},
    70  			preMasterSecret, done}
    71  		cert.DecrChan <- req
    72  		_ = <-done
    73  	}
    74  
    75  	// We don't check the version number in the premaster secret. For one,
    76  	// by checking it, we would leak information about the validity of the
    77  	// encrypted pre-master secret. Secondly, it provides only a small
    78  	// benefit against a downgrade attack and some implementations send the
    79  	// wrong version anyway. See the discussion at the end of section
    80  	// 7.4.7.1 of RFC 4346.
    81  	return preMasterSecret, nil
    82  }
    83  
    84  func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
    85  	return errors.New("tls: unexpected ServerKeyExchange")
    86  }
    87  
    88  func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
    89  	preMasterSecret := make([]byte, 48)
    90  	preMasterSecret[0] = byte(clientHello.vers >> 8)
    91  	preMasterSecret[1] = byte(clientHello.vers)
    92  	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
    93  	if err != nil {
    94  		return nil, nil, err
    95  	}
    96  
    97  	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
    98  	if err != nil {
    99  		return nil, nil, err
   100  	}
   101  	ckx := new(clientKeyExchangeMsg)
   102  	ckx.ciphertext = make([]byte, len(encrypted)+2)
   103  	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
   104  	ckx.ciphertext[1] = byte(len(encrypted))
   105  	copy(ckx.ciphertext[2:], encrypted)
   106  	return preMasterSecret, ckx, nil
   107  }
   108  
   109  // sha1Hash calculates a SHA1 hash over the given byte slices.
   110  func sha1Hash(slices [][]byte) []byte {
   111  	hsha1 := sha1.New()
   112  	for _, slice := range slices {
   113  		hsha1.Write(slice)
   114  	}
   115  	return hsha1.Sum(nil)
   116  }
   117  
   118  // md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
   119  // concatenation of an MD5 and SHA1 hash.
   120  func md5SHA1Hash(slices [][]byte) []byte {
   121  	md5sha1 := make([]byte, md5.Size+sha1.Size)
   122  	hmd5 := md5.New()
   123  	for _, slice := range slices {
   124  		hmd5.Write(slice)
   125  	}
   126  	copy(md5sha1, hmd5.Sum(nil))
   127  	copy(md5sha1[md5.Size:], sha1Hash(slices))
   128  	return md5sha1
   129  }
   130  
   131  // hashForServerKeyExchange hashes the given slices and returns their digest
   132  // and the identifier of the hash function used. The signatureAlgorithm argument
   133  // is only used for >= TLS 1.2 and identifies the hash function to use.
   134  func hashForServerKeyExchange(sigType uint8, signatureAlgorithm SignatureScheme, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) {
   135  	if version >= VersionTLS12 {
   136  		if !isSupportedSignatureAlgorithm(signatureAlgorithm, supportedSignatureAlgorithms) {
   137  			return nil, crypto.Hash(0), errors.New("tls: unsupported hash function used by peer")
   138  		}
   139  		hashFunc, err := lookupTLSHash(signatureAlgorithm)
   140  		if err != nil {
   141  			return nil, crypto.Hash(0), err
   142  		}
   143  		h := hashFunc.New()
   144  		for _, slice := range slices {
   145  			h.Write(slice)
   146  		}
   147  		digest := h.Sum(nil)
   148  		return digest, hashFunc, nil
   149  	}
   150  	if sigType == signatureECDSA {
   151  		return sha1Hash(slices), crypto.SHA1, nil
   152  	}
   153  	return md5SHA1Hash(slices), crypto.MD5SHA1, nil
   154  }
   155  
   156  // pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a
   157  // ServerKeyExchange given the signature type being used and the client's
   158  // advertised list of supported signature and hash combinations.
   159  func pickTLS12HashForSignature(sigType uint8, clientList []SignatureScheme) (SignatureScheme, error) {
   160  	if len(clientList) == 0 {
   161  		// If the client didn't specify any signature_algorithms
   162  		// extension then we can assume that it supports SHA1. See
   163  		// http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   164  		switch sigType {
   165  		case signatureRSA:
   166  			return PKCS1WithSHA1, nil
   167  		case signatureECDSA:
   168  			return ECDSAWithSHA1, nil
   169  		default:
   170  			return 0, errors.New("tls: unknown signature algorithm")
   171  		}
   172  	}
   173  
   174  	for _, sigAlg := range clientList {
   175  		if signatureFromSignatureScheme(sigAlg) != sigType {
   176  			continue
   177  		}
   178  		if isSupportedSignatureAlgorithm(sigAlg, supportedSignatureAlgorithms) {
   179  			return sigAlg, nil
   180  		}
   181  	}
   182  
   183  	return 0, errors.New("tls: client doesn't support any common hash functions")
   184  }
   185  
   186  func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
   187  	switch id {
   188  	case CurveP256:
   189  		return elliptic.P256(), true
   190  	case CurveP384:
   191  		return elliptic.P384(), true
   192  	case CurveP521:
   193  		return elliptic.P521(), true
   194  	default:
   195  		return nil, false
   196  	}
   197  
   198  }
   199  
   200  // ecdheRSAKeyAgreement implements a TLS key agreement where the server
   201  // generates an ephemeral EC public/private key pair and signs it. The
   202  // pre-master secret is then calculated using ECDH. The signature may
   203  // either be ECDSA or RSA.
   204  type ecdheKeyAgreement struct {
   205  	version    uint16
   206  	sigType    uint8
   207  	privateKey []byte
   208  	curveid    CurveID
   209  
   210  	// publicKey is used to store the peer's public value when X25519 is
   211  	// being used.
   212  	publicKey []byte
   213  	// x and y are used to store the peer's public value when one of the
   214  	// NIST curves is being used.
   215  	x, y *big.Int
   216  }
   217  
   218  func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
   219  	preferredCurves := config.curvePreferences()
   220  
   221  NextCandidate:
   222  	for _, candidate := range preferredCurves {
   223  		for _, c := range clientHello.supportedCurves {
   224  			if candidate == c {
   225  				ka.curveid = c
   226  				break NextCandidate
   227  			}
   228  		}
   229  	}
   230  
   231  	if ka.curveid == 0 {
   232  		return nil, errors.New("tls: no supported elliptic curves offered")
   233  	}
   234  
   235  	var ecdhePublic []byte
   236  
   237  	if ka.curveid == X25519 {
   238  		var scalar, public [32]byte
   239  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   240  			return nil, err
   241  		}
   242  
   243  		curve25519.ScalarBaseMult(&public, &scalar)
   244  		ka.privateKey = scalar[:]
   245  		ecdhePublic = public[:]
   246  	} else {
   247  		curve, ok := curveForCurveID(ka.curveid)
   248  		if !ok {
   249  			return nil, errors.New("tls: preferredCurves includes unsupported curve")
   250  		}
   251  
   252  		var x, y *big.Int
   253  		var err error
   254  		ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
   255  		if err != nil {
   256  			return nil, err
   257  		}
   258  		ecdhePublic = elliptic.Marshal(curve, x, y)
   259  	}
   260  
   261  	// http://tools.ietf.org/html/rfc4492#section-5.4
   262  	serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
   263  	serverECDHParams[0] = 3 // named curve
   264  	serverECDHParams[1] = byte(ka.curveid >> 8)
   265  	serverECDHParams[2] = byte(ka.curveid)
   266  	serverECDHParams[3] = byte(len(ecdhePublic))
   267  	copy(serverECDHParams[4:], ecdhePublic)
   268  
   269  	var signatureAlgorithm SignatureScheme
   270  
   271  	if ka.version >= VersionTLS12 {
   272  		var err error
   273  		signatureAlgorithm, err = pickTLS12HashForSignature(ka.sigType, clientHello.supportedSignatureAlgorithms)
   274  		if err != nil {
   275  			return nil, err
   276  		}
   277  	}
   278  
   279  	digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, hello.random, serverECDHParams)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  
   284  	priv, ok := cert.PrivateKey.(crypto.Signer)
   285  	if !ok {
   286  		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   287  	}
   288  	var sig []byte
   289  	switch ka.sigType {
   290  	case signatureECDSA:
   291  		_, ok := priv.Public().(*ecdsa.PublicKey)
   292  		if !ok {
   293  			return nil, errors.New("tls: ECDHE ECDSA requires an ECDSA server key")
   294  		}
   295  	case signatureRSA:
   296  		_, ok := priv.Public().(*rsa.PublicKey)
   297  		if !ok {
   298  			return nil, errors.New("tls: ECDHE RSA requires a RSA server key")
   299  		}
   300  	default:
   301  		return nil, errors.New("tls: unknown ECDHE signature algorithm")
   302  	}
   303  	sig, err = priv.Sign(config.rand(), digest, hashFunc)
   304  	if err != nil {
   305  		return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
   306  	}
   307  
   308  	skx := new(serverKeyExchangeMsg)
   309  	sigAndHashLen := 0
   310  	if ka.version >= VersionTLS12 {
   311  		sigAndHashLen = 2
   312  	}
   313  	skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
   314  	copy(skx.key, serverECDHParams)
   315  	k := skx.key[len(serverECDHParams):]
   316  	if ka.version >= VersionTLS12 {
   317  		k[0] = byte(signatureAlgorithm >> 8)
   318  		k[1] = byte(signatureAlgorithm)
   319  		k = k[2:]
   320  	}
   321  	k[0] = byte(len(sig) >> 8)
   322  	k[1] = byte(len(sig))
   323  	copy(k[2:], sig)
   324  
   325  	return skx, nil
   326  }
   327  
   328  func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   329  	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
   330  		return nil, errClientKeyExchange
   331  	}
   332  
   333  	if ka.curveid == X25519 {
   334  		if len(ckx.ciphertext) != 1+32 {
   335  			return nil, errClientKeyExchange
   336  		}
   337  
   338  		var theirPublic, sharedKey, scalar [32]byte
   339  		copy(theirPublic[:], ckx.ciphertext[1:])
   340  		copy(scalar[:], ka.privateKey)
   341  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   342  		return sharedKey[:], nil
   343  	}
   344  
   345  	curve, ok := curveForCurveID(ka.curveid)
   346  	if !ok {
   347  		panic("internal error")
   348  	}
   349  	x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:]) // Unmarshal also checks whether the given point is on the curve
   350  	if x == nil {
   351  		return nil, errClientKeyExchange
   352  	}
   353  	x, _ = curve.ScalarMult(x, y, ka.privateKey)
   354  	preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
   355  	xBytes := x.Bytes()
   356  	copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   357  
   358  	return preMasterSecret, nil
   359  }
   360  
   361  func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   362  	if len(skx.key) < 4 {
   363  		return errServerKeyExchange
   364  	}
   365  	if skx.key[0] != 3 { // named curve
   366  		return errors.New("tls: server selected unsupported curve")
   367  	}
   368  	ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
   369  
   370  	publicLen := int(skx.key[3])
   371  	if publicLen+4 > len(skx.key) {
   372  		return errServerKeyExchange
   373  	}
   374  	serverECDHParams := skx.key[:4+publicLen]
   375  	publicKey := serverECDHParams[4:]
   376  
   377  	sig := skx.key[4+publicLen:]
   378  	if len(sig) < 2 {
   379  		return errServerKeyExchange
   380  	}
   381  
   382  	if ka.curveid == X25519 {
   383  		if len(publicKey) != 32 {
   384  			return errors.New("tls: bad X25519 public value")
   385  		}
   386  		ka.publicKey = publicKey
   387  	} else {
   388  		curve, ok := curveForCurveID(ka.curveid)
   389  		if !ok {
   390  			return errors.New("tls: server selected unsupported curve")
   391  		}
   392  		ka.x, ka.y = elliptic.Unmarshal(curve, publicKey) // Unmarshal also checks whether the given point is on the curve
   393  		if ka.x == nil {
   394  			return errServerKeyExchange
   395  		}
   396  	}
   397  
   398  	var signatureAlgorithm SignatureScheme
   399  	if ka.version >= VersionTLS12 {
   400  		// handle SignatureAndHashAlgorithm
   401  		signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
   402  		if signatureFromSignatureScheme(signatureAlgorithm) != ka.sigType {
   403  			return errServerKeyExchange
   404  		}
   405  		sig = sig[2:]
   406  		if len(sig) < 2 {
   407  			return errServerKeyExchange
   408  		}
   409  	}
   410  	sigLen := int(sig[0])<<8 | int(sig[1])
   411  	if sigLen+2 != len(sig) {
   412  		return errServerKeyExchange
   413  	}
   414  	sig = sig[2:]
   415  
   416  	digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, serverHello.random, serverECDHParams)
   417  	if err != nil {
   418  		return err
   419  	}
   420  	switch ka.sigType {
   421  	case signatureECDSA:
   422  		pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   423  		if !ok {
   424  			return errors.New("tls: ECDHE ECDSA requires a ECDSA server public key")
   425  		}
   426  		ecdsaSig := new(ecdsaSignature)
   427  		if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
   428  			return err
   429  		}
   430  		if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
   431  			return errors.New("tls: ECDSA signature contained zero or negative values")
   432  		}
   433  		if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
   434  			return errors.New("tls: ECDSA verification failure")
   435  		}
   436  	case signatureRSA:
   437  		pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
   438  		if !ok {
   439  			return errors.New("tls: ECDHE RSA requires a RSA server public key")
   440  		}
   441  		if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil {
   442  			return err
   443  		}
   444  	default:
   445  		return errors.New("tls: unknown ECDHE signature algorithm")
   446  	}
   447  
   448  	return nil
   449  }
   450  
   451  func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   452  	if ka.curveid == 0 {
   453  		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
   454  	}
   455  
   456  	var serialized, preMasterSecret []byte
   457  
   458  	if ka.curveid == X25519 {
   459  		var ourPublic, theirPublic, sharedKey, scalar [32]byte
   460  
   461  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   462  			return nil, nil, err
   463  		}
   464  
   465  		copy(theirPublic[:], ka.publicKey)
   466  		curve25519.ScalarBaseMult(&ourPublic, &scalar)
   467  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   468  		serialized = ourPublic[:]
   469  		preMasterSecret = sharedKey[:]
   470  	} else {
   471  		curve, ok := curveForCurveID(ka.curveid)
   472  		if !ok {
   473  			panic("internal error")
   474  		}
   475  		priv, mx, my, err := elliptic.GenerateKey(curve, config.rand())
   476  		if err != nil {
   477  			return nil, nil, err
   478  		}
   479  		x, _ := curve.ScalarMult(ka.x, ka.y, priv)
   480  		preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
   481  		xBytes := x.Bytes()
   482  		copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   483  
   484  		serialized = elliptic.Marshal(curve, mx, my)
   485  	}
   486  
   487  	ckx := new(clientKeyExchangeMsg)
   488  	ckx.ciphertext = make([]byte, 1+len(serialized))
   489  	ckx.ciphertext[0] = byte(len(serialized))
   490  	copy(ckx.ciphertext[1:], serialized)
   491  
   492  	return preMasterSecret, ckx, nil
   493  }