github.com/hellobchain/newcryptosm@v0.0.0-20221019060107-edb949a317e9/tls/key_agreement.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package tls
    17  
    18  import (
    19  	"crypto"
    20  	"crypto/elliptic"
    21  	"crypto/md5"
    22  	"crypto/rsa"
    23  	"crypto/sha1"
    24  	"encoding/asn1"
    25  	"errors"
    26  	"github.com/hellobchain/newcryptosm"
    27  	"github.com/hellobchain/newcryptosm/ecdsa"
    28  	"github.com/hellobchain/newcryptosm/sm2"
    29  	"github.com/hellobchain/newcryptosm/sm3"
    30  	"github.com/hellobchain/newcryptosm/x509"
    31  	"io"
    32  	"math/big"
    33  
    34  	"golang.org/x/crypto/curve25519"
    35  )
    36  
    37  var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
    38  var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
    39  
    40  // rsaKeyAgreement implements the standard TLS key agreement where the client
    41  // encrypts the pre-master secret to the server's public key.
    42  type rsaKeyAgreement struct{}
    43  
    44  func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
    45  	return nil, nil
    46  }
    47  
    48  func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
    49  	if len(ckx.ciphertext) < 2 {
    50  		return nil, errClientKeyExchange
    51  	}
    52  
    53  	ciphertext := ckx.ciphertext
    54  	if version != VersionSSL30 {
    55  		ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
    56  		if ciphertextLen != len(ckx.ciphertext)-2 {
    57  			return nil, errClientKeyExchange
    58  		}
    59  		ciphertext = ckx.ciphertext[2:]
    60  	}
    61  	priv, ok := cert.PrivateKey.(crypto.Decrypter)
    62  	if !ok {
    63  		return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
    64  	}
    65  	// Perform constant time RSA PKCS#1 v1.5 decryption
    66  	preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	// We don't check the version number in the premaster secret. For one,
    71  	// by checking it, we would leak information about the validity of the
    72  	// encrypted pre-master secret. Secondly, it provides only a small
    73  	// benefit against a downgrade attack and some implementations send the
    74  	// wrong version anyway. See the discussion at the end of section
    75  	// 7.4.7.1 of RFC 4346.
    76  	return preMasterSecret, nil
    77  }
    78  
    79  func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
    80  	return errors.New("tls: unexpected ServerKeyExchange")
    81  }
    82  
    83  func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
    84  	preMasterSecret := make([]byte, 48)
    85  	preMasterSecret[0] = byte(clientHello.vers >> 8)
    86  	preMasterSecret[1] = byte(clientHello.vers)
    87  	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
    88  	if err != nil {
    89  		return nil, nil, err
    90  	}
    91  
    92  	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
    93  	if err != nil {
    94  		return nil, nil, err
    95  	}
    96  	ckx := new(clientKeyExchangeMsg)
    97  	ckx.ciphertext = make([]byte, len(encrypted)+2)
    98  	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
    99  	ckx.ciphertext[1] = byte(len(encrypted))
   100  	copy(ckx.ciphertext[2:], encrypted)
   101  	return preMasterSecret, ckx, nil
   102  }
   103  
   104  // sha1Hash calculates a SHA1 hash over the given byte slices.
   105  func sha1Hash(slices [][]byte) []byte {
   106  	hsha1 := sha1.New()
   107  	for _, slice := range slices {
   108  		hsha1.Write(slice)
   109  	}
   110  	return hsha1.Sum(nil)
   111  }
   112  
   113  // sm3Hash calculates a sm3 hash over the given byte slices.
   114  func sm3Hash(slices [][]byte) []byte {
   115  	hsm3 := sm3.New()
   116  	for _, slice := range slices {
   117  		hsm3.Write(slice)
   118  	}
   119  	return hsm3.Sum(nil)
   120  }
   121  
   122  // md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
   123  // concatenation of an MD5 and SHA1 hash.
   124  func md5SHA1Hash(slices [][]byte) []byte {
   125  	md5sha1 := make([]byte, md5.Size+sha1.Size)
   126  	hmd5 := md5.New()
   127  	for _, slice := range slices {
   128  		hmd5.Write(slice)
   129  	}
   130  	copy(md5sha1, hmd5.Sum(nil))
   131  	copy(md5sha1[md5.Size:], sha1Hash(slices))
   132  	return md5sha1
   133  }
   134  
   135  // hashForServerKeyExchange hashes the given slices and returns their digest
   136  // and the identifier of the hash function used. The sigAndHash argument is
   137  // only used for >= TLS 1.2 and precisely identifies the hash function to use.
   138  func hashForServerKeyExchange(sigAndHash signatureAndHash, version uint16, slices ...[]byte) ([]byte, newcryptosm.Hash, error) {
   139  	if version >= VersionTLS12 {
   140  		if !isSupportedSignatureAndHash(sigAndHash, supportedSignatureAlgorithms) {
   141  			return nil, newcryptosm.Hash(0), errors.New("tls: unsupported hash function used by peer")
   142  		}
   143  		hashFunc, err := lookupTLSHash(sigAndHash.hash)
   144  		if err != nil {
   145  			return nil, newcryptosm.Hash(0), err
   146  		}
   147  		h := hashFunc.New()
   148  		for _, slice := range slices {
   149  			h.Write(slice)
   150  		}
   151  		digest := h.Sum(nil)
   152  		return digest, hashFunc, nil
   153  	}
   154  	if sigAndHash.signature == signatureECDSA {
   155  		return sha1Hash(slices), newcryptosm.SHA1, nil
   156  	}
   157  	if sigAndHash.signature == signatureSM2 {
   158  		return sm3Hash(slices), newcryptosm.SM3, nil
   159  	}
   160  	return md5SHA1Hash(slices), newcryptosm.MD5SHA1, nil
   161  }
   162  
   163  // pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a
   164  // ServerKeyExchange given the signature type being used and the client's
   165  // advertised list of supported signature and hash combinations.
   166  func pickTLS12HashForSignature(sigType uint8, clientList []signatureAndHash) (uint8, error) {
   167  	if len(clientList) == 0 {
   168  		// If the client didn't specify any signature_algorithms
   169  		// extension then we can assume that it supports SHA1. See
   170  		// http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
   171  		return hashSHA1, nil
   172  	}
   173  
   174  	for _, sigAndHash := range clientList {
   175  		if sigAndHash.signature != sigType {
   176  			continue
   177  		}
   178  		if isSupportedSignatureAndHash(sigAndHash, supportedSignatureAlgorithms) {
   179  			return sigAndHash.hash, nil
   180  		}
   181  	}
   182  
   183  	return 0, errors.New("tls: client doesn't support any common hash functions")
   184  }
   185  
   186  func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
   187  	switch id {
   188  	case CurveP256:
   189  		return elliptic.P256(), true
   190  	case CurveP384:
   191  		return elliptic.P384(), true
   192  	case CurveP521:
   193  		return elliptic.P521(), true
   194  	case CureP256SM2:
   195  		return sm2.SM2(), true
   196  	default:
   197  		return nil, false
   198  	}
   199  
   200  }
   201  
   202  // ecdheRSAKeyAgreement implements a TLS key agreement where the server
   203  // generates a ephemeral EC public/private key pair and signs it. The
   204  // pre-master secret is then calculated using ECDH. The signature may
   205  // either be ECDSA or RSA.
   206  type ecdheKeyAgreement struct {
   207  	version    uint16
   208  	sigType    uint8
   209  	privateKey []byte
   210  	curveid    CurveID
   211  
   212  	// publicKey is used to store the peer's public value when X25519 is
   213  	// being used.
   214  	publicKey []byte
   215  	// x and y are used to store the peer's public value when one of the
   216  	// NIST curves is being used.
   217  	x, y *big.Int
   218  }
   219  
   220  func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
   221  	preferredCurves := config.curvePreferences()
   222  
   223  NextCandidate:
   224  	for _, candidate := range preferredCurves {
   225  		for _, c := range clientHello.supportedCurves {
   226  			if candidate == c {
   227  				ka.curveid = c
   228  				break NextCandidate
   229  			}
   230  		}
   231  	}
   232  
   233  	if ka.curveid == 0 {
   234  		return nil, errors.New("tls: no supported elliptic curves offered")
   235  	}
   236  
   237  	var sm2Public []byte
   238  
   239  	if ka.curveid == X25519 {
   240  		var scalar, public [32]byte
   241  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   242  			return nil, err
   243  		}
   244  
   245  		curve25519.ScalarBaseMult(&public, &scalar)
   246  		ka.privateKey = scalar[:]
   247  		sm2Public = public[:]
   248  	} else {
   249  		curve, ok := curveForCurveID(ka.curveid)
   250  		if !ok {
   251  			return nil, errors.New("tls: preferredCurves includes unsupported curve")
   252  		}
   253  
   254  		var x, y *big.Int
   255  		var err error
   256  		ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  		sm2Public = elliptic.Marshal(curve, x, y) //kG
   261  	}
   262  
   263  	// http://tools.ietf.org/html/rfc4492#section-5.4
   264  	serverECDHParams := make([]byte, 1+2+1+len(sm2Public))
   265  	serverECDHParams[0] = 3 // named curve
   266  	serverECDHParams[1] = byte(ka.curveid >> 8)
   267  	serverECDHParams[2] = byte(ka.curveid)
   268  	serverECDHParams[3] = byte(len(sm2Public))
   269  	copy(serverECDHParams[4:], sm2Public)
   270  
   271  	sigAndHash := signatureAndHash{signature: ka.sigType}
   272  
   273  	if ka.version >= VersionTLS12 {
   274  		var err error
   275  		if sigAndHash.hash, err = pickTLS12HashForSignature(ka.sigType, clientHello.signatureAndHashes); err != nil {
   276  			return nil, err
   277  		}
   278  	}
   279  
   280  	digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, hello.random, serverECDHParams)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  
   285  	priv, ok := cert.PrivateKey.(crypto.Signer)
   286  	if !ok {
   287  		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   288  	}
   289  	var sig []byte
   290  	switch ka.sigType {
   291  	case signatureSM2:
   292  		_, ok := priv.Public().(*ecdsa.PublicKey)
   293  		if !ok {
   294  			return nil, errors.New("tls: ECDHE SM2 requires an SM2 server key")
   295  		}
   296  	case signatureECDSA:
   297  		_, ok := priv.Public().(*ecdsa.PublicKey)
   298  		if !ok {
   299  			return nil, errors.New("tls: ECDHE ECDSA requires an ECDSA server key")
   300  		}
   301  	case signatureRSA:
   302  		_, ok := priv.Public().(*rsa.PublicKey)
   303  		if !ok {
   304  			return nil, errors.New("tls: ECDHE RSA requires a RSA server key")
   305  		}
   306  	default:
   307  		return nil, errors.New("tls: unknown ECDHE signature algorithm")
   308  	}
   309  	sig, err = priv.Sign(config.rand(), digest, hashFunc)
   310  	if err != nil {
   311  		return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
   312  	}
   313  
   314  	skx := new(serverKeyExchangeMsg)
   315  	sigAndHashLen := 0
   316  	if ka.version >= VersionTLS12 {
   317  		sigAndHashLen = 2
   318  	}
   319  	skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
   320  	copy(skx.key, serverECDHParams)
   321  	k := skx.key[len(serverECDHParams):]
   322  	if ka.version >= VersionTLS12 {
   323  		k[0] = sigAndHash.hash
   324  		k[1] = sigAndHash.signature
   325  		k = k[2:]
   326  	}
   327  	k[0] = byte(len(sig) >> 8)
   328  	k[1] = byte(len(sig))
   329  	copy(k[2:], sig)
   330  
   331  	return skx, nil
   332  }
   333  
   334  func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   335  	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
   336  		return nil, errClientKeyExchange
   337  	}
   338  
   339  	if ka.curveid == X25519 {
   340  		if len(ckx.ciphertext) != 1+32 {
   341  			return nil, errClientKeyExchange
   342  		}
   343  
   344  		var theirPublic, sharedKey, scalar [32]byte
   345  		copy(theirPublic[:], ckx.ciphertext[1:])
   346  		copy(scalar[:], ka.privateKey)
   347  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   348  		return sharedKey[:], nil
   349  	}
   350  
   351  	curve, ok := curveForCurveID(ka.curveid)
   352  	if !ok {
   353  		panic("internal error")
   354  	}
   355  	x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:])
   356  	if x == nil {
   357  		return nil, errClientKeyExchange
   358  	}
   359  	if !curve.IsOnCurve(x, y) {
   360  		return nil, errClientKeyExchange
   361  	}
   362  	x, _ = curve.ScalarMult(x, y, ka.privateKey)
   363  	preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
   364  	xBytes := x.Bytes()
   365  	copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   366  
   367  	return preMasterSecret, nil
   368  }
   369  
   370  func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   371  	if len(skx.key) < 4 {
   372  		return errServerKeyExchange
   373  	}
   374  	if skx.key[0] != 3 { // named curve
   375  		return errors.New("tls: server selected unsupported curve")
   376  	}
   377  	ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
   378  
   379  	publicLen := int(skx.key[3])
   380  	if publicLen+4 > len(skx.key) {
   381  		return errServerKeyExchange
   382  	}
   383  	serverECDHParams := skx.key[:4+publicLen]
   384  	publicKey := serverECDHParams[4:]
   385  
   386  	sig := skx.key[4+publicLen:]
   387  	if len(sig) < 2 {
   388  		return errServerKeyExchange
   389  	}
   390  
   391  	if ka.curveid == X25519 {
   392  		if len(publicKey) != 32 {
   393  			return errors.New("tls: bad X25519 public value")
   394  		}
   395  		ka.publicKey = publicKey
   396  	} else {
   397  		curve, ok := curveForCurveID(ka.curveid)
   398  		if !ok {
   399  			return errors.New("tls: server selected unsupported curve")
   400  		}
   401  
   402  		ka.x, ka.y = elliptic.Unmarshal(curve, publicKey)
   403  		if ka.x == nil {
   404  			return errServerKeyExchange
   405  		}
   406  		if !curve.IsOnCurve(ka.x, ka.y) {
   407  			return errServerKeyExchange
   408  		}
   409  	}
   410  
   411  	sigAndHash := signatureAndHash{signature: ka.sigType}
   412  	if ka.version >= VersionTLS12 {
   413  		// handle SignatureAndHashAlgorithm
   414  		sigAndHash = signatureAndHash{hash: sig[0], signature: sig[1]}
   415  		if sigAndHash.signature != ka.sigType {
   416  			return errServerKeyExchange
   417  		}
   418  		sig = sig[2:]
   419  		if len(sig) < 2 {
   420  			return errServerKeyExchange
   421  		}
   422  	}
   423  	sigLen := int(sig[0])<<8 | int(sig[1])
   424  	if sigLen+2 != len(sig) {
   425  		return errServerKeyExchange
   426  	}
   427  	sig = sig[2:]
   428  
   429  	digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, serverHello.random, serverECDHParams)
   430  	if err != nil {
   431  		return err
   432  	}
   433  	switch ka.sigType {
   434  	case signatureSM2:
   435  		pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   436  		if !ok {
   437  			return errors.New("tls: ECDHE SM2 requires a SM2 server public key")
   438  		}
   439  		sm2Sig := new(sm2Signature)
   440  		if _, err := asn1.Unmarshal(sig, sm2Sig); err != nil {
   441  			return err
   442  		}
   443  		if sm2Sig.R.Sign() <= 0 || sm2Sig.S.Sign() <= 0 {
   444  			return errors.New("tls: SM2 signature contained zero or negative values")
   445  		}
   446  		if !ecdsa.Verify(&ecdsa.PublicKey{
   447  			X:     pubKey.X,
   448  			Y:     pubKey.Y,
   449  			Curve: pubKey.Curve,
   450  		}, digest, sm2Sig.R, sm2Sig.S) {
   451  			return errors.New("tls: SM2 verification failure")
   452  		}
   453  	case signatureECDSA:
   454  		pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   455  		if !ok {
   456  			return errors.New("tls: ECDHE ECDSA requires a ECDSA server public key")
   457  		}
   458  		ecdsaSig := new(ecdsaSignature)
   459  		if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
   460  			return err
   461  		}
   462  		if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
   463  			return errors.New("tls: ECDSA signature contained zero or negative values")
   464  		}
   465  		if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
   466  			return errors.New("tls: ECDSA verification failure")
   467  		}
   468  	case signatureRSA:
   469  		pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
   470  		if !ok {
   471  			return errors.New("tls: ECDHE RSA requires a RSA server public key")
   472  		}
   473  		if err := rsa.VerifyPKCS1v15(pubKey, crypto.Hash(hashFunc), digest, sig); err != nil {
   474  			return err
   475  		}
   476  	default:
   477  		return errors.New("tls: unknown ECDHE signature algorithm")
   478  	}
   479  
   480  	return nil
   481  }
   482  
   483  func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   484  	if ka.curveid == 0 {
   485  		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
   486  	}
   487  
   488  	var serialized, preMasterSecret []byte
   489  
   490  	if ka.curveid == X25519 {
   491  		var ourPublic, theirPublic, sharedKey, scalar [32]byte
   492  
   493  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   494  			return nil, nil, err
   495  		}
   496  
   497  		copy(theirPublic[:], ka.publicKey)
   498  		curve25519.ScalarBaseMult(&ourPublic, &scalar)
   499  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   500  		serialized = ourPublic[:]
   501  		preMasterSecret = sharedKey[:]
   502  	} else {
   503  		curve, ok := curveForCurveID(ka.curveid)
   504  		if !ok {
   505  			panic("internal error")
   506  		}
   507  		priv, mx, my, err := elliptic.GenerateKey(curve, config.rand())
   508  		if err != nil {
   509  			return nil, nil, err
   510  		}
   511  		x, _ := curve.ScalarMult(ka.x, ka.y, priv)
   512  		preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
   513  		xBytes := x.Bytes()
   514  		copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   515  
   516  		serialized = elliptic.Marshal(curve, mx, my)
   517  	}
   518  
   519  	ckx := new(clientKeyExchangeMsg)
   520  	ckx.ciphertext = make([]byte, 1+len(serialized))
   521  	ckx.ciphertext[0] = byte(len(serialized))
   522  	copy(ckx.ciphertext[1:], serialized)
   523  
   524  	return preMasterSecret, ckx, nil
   525  }
   526  
   527  // sm2KeyAgreement implements a TLS key agreement where the server
   528  // generates a ephemeral SM2 public/private key pair and signs it. The
   529  // pre-master secret is then calculated using SM2. The signature may
   530  // be SM2.
   531  type sm2KeyAgreement struct {
   532  	version    uint16
   533  	sigType    uint8
   534  	privateKey []byte
   535  	curveid    CurveID
   536  
   537  	// publicKey is used to store the peer's public value when X25519 is
   538  	// being used.
   539  	publicKey []byte
   540  	// x and y are used to store the peer's public value when one of the
   541  	// NIST curves is being used.
   542  	x, y *big.Int
   543  }
   544  
   545  func (ka *sm2KeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
   546  	preferredCurves := config.curvePreferences()
   547  
   548  NextCandidate:
   549  	for _, candidate := range preferredCurves {
   550  		for _, c := range clientHello.supportedCurves {
   551  			if candidate == c {
   552  				ka.curveid = c
   553  				break NextCandidate
   554  			}
   555  		}
   556  	}
   557  
   558  	if ka.curveid == 0 {
   559  		return nil, errors.New("tls: no supported elliptic curves offered")
   560  	}
   561  
   562  	var sm2Public []byte
   563  
   564  	if ka.curveid == X25519 {
   565  		var scalar, public [32]byte
   566  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   567  			return nil, err
   568  		}
   569  
   570  		curve25519.ScalarBaseMult(&public, &scalar)
   571  		ka.privateKey = scalar[:]
   572  		sm2Public = public[:]
   573  	} else {
   574  		curve, ok := curveForCurveID(ka.curveid)
   575  		if !ok {
   576  			return nil, errors.New("tls: preferredCurves includes unsupported curve")
   577  		}
   578  
   579  		var x, y *big.Int
   580  		var err error
   581  		ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
   582  		if err != nil {
   583  			return nil, err
   584  		}
   585  		sm2Public = elliptic.Marshal(curve, x, y) //kG
   586  	}
   587  
   588  	// http://tools.ietf.org/html/rfc4492#section-5.4
   589  	serverECDHParams := make([]byte, 1+2+1+len(sm2Public))
   590  	serverECDHParams[0] = 3 // named curve
   591  	serverECDHParams[1] = byte(ka.curveid >> 8)
   592  	serverECDHParams[2] = byte(ka.curveid)
   593  	serverECDHParams[3] = byte(len(sm2Public))
   594  	copy(serverECDHParams[4:], sm2Public)
   595  
   596  	sigAndHash := signatureAndHash{signature: ka.sigType}
   597  
   598  	if ka.version >= VersionTLS12 {
   599  		var err error
   600  		if sigAndHash.hash, err = pickTLS12HashForSignature(ka.sigType, clientHello.signatureAndHashes); err != nil {
   601  			return nil, err
   602  		}
   603  	}
   604  
   605  	digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, hello.random, serverECDHParams)
   606  	if err != nil {
   607  		return nil, err
   608  	}
   609  
   610  	priv, ok := cert.PrivateKey.(crypto.Signer)
   611  	if !ok {
   612  		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   613  	}
   614  	var sig []byte
   615  	switch ka.sigType {
   616  	case signatureSM2:
   617  		_, ok := priv.Public().(*ecdsa.PublicKey)
   618  		if !ok {
   619  			return nil, errors.New("tls: ECDHE SM2 requires an SM2 server key")
   620  		}
   621  	case signatureECDSA:
   622  		_, ok := priv.Public().(*ecdsa.PublicKey)
   623  		if !ok {
   624  			return nil, errors.New("tls: ECDHE ECDSA requires an ECDSA server key")
   625  		}
   626  	case signatureRSA:
   627  		_, ok := priv.Public().(*rsa.PublicKey)
   628  		if !ok {
   629  			return nil, errors.New("tls: ECDHE RSA requires a RSA server key")
   630  		}
   631  	default:
   632  		return nil, errors.New("tls: unknown ECDHE signature algorithm")
   633  	}
   634  	sig, err = priv.Sign(config.rand(), digest, hashFunc)
   635  	if err != nil {
   636  		return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
   637  	}
   638  
   639  	skx := new(serverKeyExchangeMsg)
   640  	sigAndHashLen := 0
   641  	if ka.version >= VersionTLS12 {
   642  		sigAndHashLen = 2
   643  	}
   644  	skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
   645  	copy(skx.key, serverECDHParams)
   646  	k := skx.key[len(serverECDHParams):]
   647  	if ka.version >= VersionTLS12 {
   648  		k[0] = sigAndHash.hash
   649  		k[1] = sigAndHash.signature
   650  		k = k[2:]
   651  	}
   652  	k[0] = byte(len(sig) >> 8)
   653  	k[1] = byte(len(sig))
   654  	copy(k[2:], sig)
   655  
   656  	return skx, nil
   657  }
   658  
   659  func (ka *sm2KeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   660  	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
   661  		return nil, errClientKeyExchange
   662  	}
   663  
   664  	if ka.curveid == X25519 {
   665  		if len(ckx.ciphertext) != 1+32 {
   666  			return nil, errClientKeyExchange
   667  		}
   668  
   669  		var theirPublic, sharedKey, scalar [32]byte
   670  		copy(theirPublic[:], ckx.ciphertext[1:])
   671  		copy(scalar[:], ka.privateKey)
   672  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   673  		return sharedKey[:], nil
   674  	}
   675  
   676  	curve, ok := curveForCurveID(ka.curveid)
   677  	if !ok {
   678  		panic("internal error")
   679  	}
   680  	x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:])
   681  	if x == nil {
   682  		return nil, errClientKeyExchange
   683  	}
   684  	if !curve.IsOnCurve(x, y) {
   685  		return nil, errClientKeyExchange
   686  	}
   687  	x, _ = curve.ScalarMult(x, y, ka.privateKey)
   688  	preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
   689  	xBytes := x.Bytes()
   690  	copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   691  
   692  	return preMasterSecret, nil
   693  }
   694  
   695  func (ka *sm2KeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   696  	if len(skx.key) < 4 {
   697  		return errServerKeyExchange
   698  	}
   699  	if skx.key[0] != 3 { // named curve
   700  		return errors.New("tls: server selected unsupported curve")
   701  	}
   702  	ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
   703  
   704  	publicLen := int(skx.key[3])
   705  	if publicLen+4 > len(skx.key) {
   706  		return errServerKeyExchange
   707  	}
   708  	serverECDHParams := skx.key[:4+publicLen]
   709  	publicKey := serverECDHParams[4:]
   710  
   711  	sig := skx.key[4+publicLen:]
   712  	if len(sig) < 2 {
   713  		return errServerKeyExchange
   714  	}
   715  
   716  	if ka.curveid == X25519 {
   717  		if len(publicKey) != 32 {
   718  			return errors.New("tls: bad X25519 public value")
   719  		}
   720  		ka.publicKey = publicKey
   721  	} else {
   722  		curve, ok := curveForCurveID(ka.curveid)
   723  		if !ok {
   724  			return errors.New("tls: server selected unsupported curve")
   725  		}
   726  
   727  		ka.x, ka.y = elliptic.Unmarshal(curve, publicKey)
   728  		if ka.x == nil {
   729  			return errServerKeyExchange
   730  		}
   731  		if !curve.IsOnCurve(ka.x, ka.y) {
   732  			return errServerKeyExchange
   733  		}
   734  	}
   735  
   736  	sigAndHash := signatureAndHash{signature: ka.sigType}
   737  	if ka.version >= VersionTLS12 {
   738  		// handle SignatureAndHashAlgorithm
   739  		sigAndHash = signatureAndHash{hash: sig[0], signature: sig[1]}
   740  		if sigAndHash.signature != ka.sigType {
   741  			return errServerKeyExchange
   742  		}
   743  		sig = sig[2:]
   744  		if len(sig) < 2 {
   745  			return errServerKeyExchange
   746  		}
   747  	}
   748  	sigLen := int(sig[0])<<8 | int(sig[1])
   749  	if sigLen+2 != len(sig) {
   750  		return errServerKeyExchange
   751  	}
   752  	sig = sig[2:]
   753  
   754  	digest, hashFunc, err := hashForServerKeyExchange(sigAndHash, ka.version, clientHello.random, serverHello.random, serverECDHParams)
   755  	if err != nil {
   756  		return err
   757  	}
   758  	switch ka.sigType {
   759  	case signatureSM2:
   760  		pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   761  		if !ok {
   762  			return errors.New("tls: ECDHE SM2 requires a SM2 server public key")
   763  		}
   764  		sm2Sig := new(sm2Signature)
   765  		if _, err := asn1.Unmarshal(sig, sm2Sig); err != nil {
   766  			return err
   767  		}
   768  		if sm2Sig.R.Sign() <= 0 || sm2Sig.S.Sign() <= 0 {
   769  			return errors.New("tls: SM2 signature contained zero or negative values")
   770  		}
   771  		if !ecdsa.Verify(&ecdsa.PublicKey{
   772  			X:     pubKey.X,
   773  			Y:     pubKey.Y,
   774  			Curve: pubKey.Curve,
   775  		}, digest, sm2Sig.R, sm2Sig.S) {
   776  			return errors.New("tls: SM2 verification failure")
   777  		}
   778  	case signatureECDSA:
   779  		pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
   780  		if !ok {
   781  			return errors.New("tls: ECDHE ECDSA requires a ECDSA server public key")
   782  		}
   783  		ecdsaSig := new(ecdsaSignature)
   784  		if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
   785  			return err
   786  		}
   787  		if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
   788  			return errors.New("tls: ECDSA signature contained zero or negative values")
   789  		}
   790  		if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
   791  			return errors.New("tls: ECDSA verification failure")
   792  		}
   793  	case signatureRSA:
   794  		pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
   795  		if !ok {
   796  			return errors.New("tls: ECDHE RSA requires a RSA server public key")
   797  		}
   798  		if err := rsa.VerifyPKCS1v15(pubKey, crypto.Hash(hashFunc), digest, sig); err != nil {
   799  			return err
   800  		}
   801  	default:
   802  		return errors.New("tls: unknown ECDHE signature algorithm")
   803  	}
   804  
   805  	return nil
   806  }
   807  
   808  func (ka *sm2KeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   809  	if ka.curveid == 0 {
   810  		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
   811  	}
   812  
   813  	var serialized, preMasterSecret []byte
   814  
   815  	if ka.curveid == X25519 {
   816  		var ourPublic, theirPublic, sharedKey, scalar [32]byte
   817  
   818  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   819  			return nil, nil, err
   820  		}
   821  
   822  		copy(theirPublic[:], ka.publicKey)
   823  		curve25519.ScalarBaseMult(&ourPublic, &scalar)
   824  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   825  		serialized = ourPublic[:]
   826  		preMasterSecret = sharedKey[:]
   827  	} else {
   828  		curve, ok := curveForCurveID(ka.curveid)
   829  		if !ok {
   830  			panic("internal error")
   831  		}
   832  		priv, mx, my, err := elliptic.GenerateKey(curve, config.Rand)
   833  		if err != nil {
   834  			return nil, nil, err
   835  		}
   836  		x, _ := curve.ScalarMult(ka.x, ka.y, priv)
   837  		preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
   838  		xBytes := x.Bytes()
   839  		copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   840  
   841  		serialized = elliptic.Marshal(curve, mx, my)
   842  	}
   843  
   844  	ckx := new(clientKeyExchangeMsg)
   845  	ckx.ciphertext = make([]byte, 1+len(serialized))
   846  	ckx.ciphertext[0] = byte(len(serialized))
   847  	copy(ckx.ciphertext[1:], serialized)
   848  
   849  	return preMasterSecret, ckx, nil
   850  }