github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/gm_key_agreement.go (about)

     1  // Copyright 2010 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gmtls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto"
    10  	"crypto/ecdsa"
    11  	"crypto/elliptic"
    12  	"encoding/asn1"
    13  	"errors"
    14  	"io"
    15  	"math/big"
    16  
    17  	"github.com/Hyperledger-TWGC/tjfoc-gm/sm2"
    18  	"github.com/Hyperledger-TWGC/tjfoc-gm/x509"
    19  
    20  	"golang.org/x/crypto/curve25519"
    21  )
    22  
    23  //// hashForServerKeyExchange hashes the given slices and returns their digest
    24  //// using the given hash function (for >= TLS 1.2) or using a default based on
    25  //// the sigType (for earlier TLS versions).
    26  //func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) {
    27  //	if version >= VersionTLS12 {
    28  //		h := hashFunc.New()
    29  //		for _, slice := range slices {
    30  //			h.Write(slice)
    31  //		}
    32  //		digest := h.Sum(nil)
    33  //		return digest, nil
    34  //	}
    35  //	if sigType == signatureECDSA {
    36  //		return sha1Hash(slices), nil
    37  //	}
    38  //	return md5SHA1Hash(slices), nil
    39  //}
    40  //
    41  //func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
    42  //	switch id {
    43  //	case CurveP256:
    44  //		return elliptic.P256(), true
    45  //	case CurveP384:
    46  //		return elliptic.P384(), true
    47  //	case CurveP521:
    48  //		return elliptic.P521(), true
    49  //	default:
    50  //		return nil, false
    51  //	}
    52  //
    53  //}
    54  
    55  // ecdheKeyAgreementGM implements a TLS key agreement where the server
    56  // generates an ephemeral SM2 public/private key pair and signs it. The
    57  // pre-master secret is then calculated using ECDH.
    58  type ecdheKeyAgreementGM struct {
    59  	version    uint16
    60  	privateKey []byte
    61  	curveid    CurveID
    62  
    63  	// publicKey is used to store the peer's public value when X25519 is
    64  	// being used.
    65  	publicKey []byte
    66  	// x and y are used to store the peer's public value when one of the
    67  	// NIST curves is being used.
    68  	x, y *big.Int
    69  }
    70  
    71  func (ka *ecdheKeyAgreementGM) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
    72  	clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
    73  	panic("")
    74  	//	preferredCurves := config.curvePreferences()
    75  	//
    76  	//NextCandidate:
    77  	//	for _, candidate := range preferredCurves {
    78  	//		for _, c := range clientHello.supportedCurves {
    79  	//			if candidate == c {
    80  	//				ka.curveid = c
    81  	//				break NextCandidate
    82  	//			}
    83  	//		}
    84  	//	}
    85  	//
    86  	//	if ka.curveid == 0 {
    87  	//		return nil, errors.New("tls: no supported elliptic curves offered")
    88  	//	}
    89  	//
    90  	//	var ecdhePublic []byte
    91  	//
    92  	//	if ka.curveid == X25519 {
    93  	//		var scalar, public [32]byte
    94  	//		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
    95  	//			return nil, err
    96  	//		}
    97  	//
    98  	//		curve25519.ScalarBaseMult(&public, &scalar)
    99  	//		ka.privateKey = scalar[:]
   100  	//		ecdhePublic = public[:]
   101  	//	} else {
   102  	//		curve, ok := curveForCurveID(ka.curveid)
   103  	//		if !ok {
   104  	//			return nil, errors.New("tls: preferredCurves includes unsupported curve")
   105  	//		}
   106  	//
   107  	//		var x, y *big.Int
   108  	//		var err error
   109  	//		ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
   110  	//		if err != nil {
   111  	//			return nil, err
   112  	//		}
   113  	//		ecdhePublic = elliptic.Marshal(curve, x, y)
   114  	//	}
   115  	//
   116  	//	// https://tools.ietf.org/html/rfc4492#section-5.4
   117  	//	serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
   118  	//	serverECDHParams[0] = 3 // named curve
   119  	//	serverECDHParams[1] = byte(ka.curveid >> 8)
   120  	//	serverECDHParams[2] = byte(ka.curveid)
   121  	//	serverECDHParams[3] = byte(len(ecdhePublic))
   122  	//	copy(serverECDHParams[4:], ecdhePublic)
   123  	//
   124  	//	priv, ok := cert.PrivateKey.(crypto.Signer)
   125  	//	if !ok {
   126  	//		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   127  	//	}
   128  	//
   129  	//	signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version)
   130  	//	if err != nil {
   131  	//		return nil, err
   132  	//	}
   133  	//	if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
   134  	//		return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
   135  	//	}
   136  	//
   137  	//	digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams)
   138  	//	if err != nil {
   139  	//		return nil, err
   140  	//	}
   141  	//
   142  	//	signOpts := crypto.SignerOpts(hashFunc)
   143  	//	if sigType == signatureRSAPSS {
   144  	//		signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc}
   145  	//	}
   146  	//	sig, err := priv.Sign(config.rand(), digest, signOpts)
   147  	//	if err != nil {
   148  	//		return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
   149  	//	}
   150  	//
   151  	//	skx := new(serverKeyExchangeMsg)
   152  	//	sigAndHashLen := 0
   153  	//	if ka.version >= VersionTLS12 {
   154  	//		sigAndHashLen = 2
   155  	//	}
   156  	//	skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
   157  	//	copy(skx.key, serverECDHParams)
   158  	//	k := skx.key[len(serverECDHParams):]
   159  	//	if ka.version >= VersionTLS12 {
   160  	//		k[0] = byte(signatureAlgorithm >> 8)
   161  	//		k[1] = byte(signatureAlgorithm)
   162  	//		k = k[2:]
   163  	//	}
   164  	//	k[0] = byte(len(sig) >> 8)
   165  	//	k[1] = byte(len(sig))
   166  	//	copy(k[2:], sig)
   167  	//
   168  	//	return skx, nil
   169  }
   170  
   171  func (ka *ecdheKeyAgreementGM) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   172  	panic("")
   173  	//	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
   174  	//		return nil, errClientKeyExchange
   175  	//	}
   176  	//
   177  	//	if ka.curveid == X25519 {
   178  	//		if len(ckx.ciphertext) != 1+32 {
   179  	//			return nil, errClientKeyExchange
   180  	//		}
   181  	//
   182  	//		var theirPublic, sharedKey, scalar [32]byte
   183  	//		copy(theirPublic[:], ckx.ciphertext[1:])
   184  	//		copy(scalar[:], ka.privateKey)
   185  	//		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   186  	//		return sharedKey[:], nil
   187  	//	}
   188  	//
   189  	//	curve, ok := curveForCurveID(ka.curveid)
   190  	//	if !ok {
   191  	//		panic("internal error")
   192  	//	}
   193  	//	x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:]) // Unmarshal also checks whether the given point is on the curve
   194  	//	if x == nil {
   195  	//		return nil, errClientKeyExchange
   196  	//	}
   197  	//	x, _ = curve.ScalarMult(x, y, ka.privateKey)
   198  	//	preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
   199  	//	xBytes := x.Bytes()
   200  	//	copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   201  	//
   202  	//	return preMasterSecret, nil
   203  }
   204  
   205  func (ka *ecdheKeyAgreementGM) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   206  	if len(skx.key) < 4 {
   207  		return errServerKeyExchange
   208  	}
   209  	if skx.key[0] != 3 { // named curve
   210  		return errors.New("tls: server selected unsupported curve")
   211  	}
   212  	ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
   213  
   214  	publicLen := int(skx.key[3])
   215  	if publicLen+4 > len(skx.key) {
   216  		return errServerKeyExchange
   217  	}
   218  	serverECDHParams := skx.key[:4+publicLen]
   219  	publicKey := serverECDHParams[4:]
   220  
   221  	sig := skx.key[4+publicLen:]
   222  	if len(sig) < 2 {
   223  		return errServerKeyExchange
   224  	}
   225  
   226  	//according to GMT0024, we don't care about
   227  	curve := sm2.P256Sm2()
   228  	ka.x, ka.y = elliptic.Unmarshal(curve, publicKey) // Unmarshal also checks whether the given point is on the curve
   229  	if ka.x == nil {
   230  		return errServerKeyExchange
   231  	}
   232  
   233  	var signatureAlgorithm SignatureScheme
   234  	_, sigType, hashFunc, err := pickSignatureAlgorithm(cert.PublicKey, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version)
   235  
   236  	sigLen := int(sig[0])<<8 | int(sig[1])
   237  	if sigLen+2 != len(sig) {
   238  		return errServerKeyExchange
   239  	}
   240  	sig = sig[2:]
   241  
   242  	digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams)
   243  	if err != nil {
   244  		return err
   245  	}
   246  	return verifyHandshakeSignature(sigType, cert.PublicKey, hashFunc, digest, sig)
   247  }
   248  
   249  func (ka *ecdheKeyAgreementGM) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   250  	if ka.curveid == 0 {
   251  		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
   252  	}
   253  
   254  	var serialized, preMasterSecret []byte
   255  
   256  	if ka.curveid == X25519 {
   257  		var ourPublic, theirPublic, sharedKey, scalar [32]byte
   258  
   259  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   260  			return nil, nil, err
   261  		}
   262  
   263  		copy(theirPublic[:], ka.publicKey)
   264  		curve25519.ScalarBaseMult(&ourPublic, &scalar)
   265  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   266  		serialized = ourPublic[:]
   267  		preMasterSecret = sharedKey[:]
   268  	} else {
   269  		curve, ok := curveForCurveID(ka.curveid)
   270  		if !ok {
   271  			panic("internal error")
   272  		}
   273  		priv, mx, my, err := elliptic.GenerateKey(curve, config.rand())
   274  		if err != nil {
   275  			return nil, nil, err
   276  		}
   277  		x, _ := curve.ScalarMult(ka.x, ka.y, priv)
   278  		preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
   279  		xBytes := x.Bytes()
   280  		copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   281  
   282  		serialized = elliptic.Marshal(curve, mx, my)
   283  	}
   284  
   285  	ckx := new(clientKeyExchangeMsg)
   286  	ckx.ciphertext = make([]byte, 1+len(serialized))
   287  	ckx.ciphertext[0] = byte(len(serialized))
   288  	copy(ckx.ciphertext[1:], serialized)
   289  
   290  	return preMasterSecret, ckx, nil
   291  }
   292  
   293  // eccKeyAgreementGM implements a TLS key agreement where the server
   294  // generates an ephemeral SM2 public/private key pair and signs it. The
   295  // pre-master secret is then calculated using ECDH.
   296  type eccKeyAgreementGM struct {
   297  	version    uint16
   298  	privateKey []byte
   299  	curveid    CurveID
   300  
   301  	// publicKey is used to store the peer's public value when X25519 is
   302  	// being used.
   303  	publicKey []byte
   304  	// x and y are used to store the peer's public value when one of the
   305  	// NIST curves is being used.
   306  	x, y *big.Int
   307  
   308  	//cert for encipher referred to GMT0024
   309  	encipherCert *x509.Certificate
   310  }
   311  
   312  func (ka *eccKeyAgreementGM) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
   313  	clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
   314  	// mod by syl only one cert
   315  	//digest := ka.hashForServerKeyExchange(clientHello.random, hello.random, cert.Certificate[1])
   316  	digest := ka.hashForServerKeyExchange(clientHello.random, hello.random, cipherCert.Certificate[0])
   317  
   318  	priv, ok := signCert.PrivateKey.(crypto.Signer)
   319  	if !ok {
   320  		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   321  	}
   322  	sig, err := priv.Sign(config.rand(), digest, nil)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  
   327  	len := len(sig)
   328  
   329  	ske := new(serverKeyExchangeMsg)
   330  	ske.key = make([]byte, len+2)
   331  	ske.key[0] = byte(len >> 8)
   332  	ske.key[1] = byte(len)
   333  	copy(ske.key[2:], sig)
   334  
   335  	return ske, nil
   336  }
   337  
   338  func (ka *eccKeyAgreementGM) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   339  	if len(ckx.ciphertext) == 0 {
   340  		return nil, errClientKeyExchange
   341  	}
   342  
   343  	if int(ckx.ciphertext[0]<<8|ckx.ciphertext[1]) != len(ckx.ciphertext)-2 {
   344  		return nil, errClientKeyExchange
   345  	}
   346  
   347  	cipher := ckx.ciphertext[2:]
   348  
   349  	decrypter, ok := cert.PrivateKey.(crypto.Decrypter)
   350  	if !ok {
   351  		return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
   352  	}
   353  
   354  	plain, err := decrypter.Decrypt(config.rand(), cipher, nil)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	if len(plain) != 48 {
   360  		return nil, errClientKeyExchange
   361  	}
   362  
   363  	//we do not examine the version here according to openssl practice
   364  	return plain, nil
   365  }
   366  
   367  func (ka *eccKeyAgreementGM) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   368  	if len(skx.key) <= 2 {
   369  		return errServerKeyExchange
   370  	}
   371  	sigLen := int(skx.key[0]<<8 | skx.key[1])
   372  	if sigLen+2 != len(skx.key) {
   373  		return errServerKeyExchange
   374  	}
   375  	sig := skx.key[2:]
   376  	//sig := skx.key[:]
   377  
   378  	digest := ka.hashForServerKeyExchange(clientHello.random, serverHello.random, ka.encipherCert.Raw)
   379  
   380  	//verify
   381  	pubKey, _ := cert.PublicKey.(*ecdsa.PublicKey)
   382  	if pubKey.Curve != sm2.P256Sm2() {
   383  		return errors.New("tls: sm2 signing requires a sm2 public key")
   384  	}
   385  
   386  	ecdsaSig := new(ecdsaSignature)
   387  	rest, err := asn1.Unmarshal(sig, ecdsaSig)
   388  	if err != nil {
   389  		return err
   390  	}
   391  	if len(rest) != 0 {
   392  		return errors.New("tls:processServerKeyExchange: sm2 get signature failed")
   393  	}
   394  	if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
   395  		return errors.New("tls: processServerKeyExchange: sm2 signature contained zero or negative values")
   396  	}
   397  
   398  	sm2PubKey := sm2.PublicKey{
   399  		Curve: pubKey.Curve,
   400  		X:     pubKey.X,
   401  		Y:     pubKey.Y,
   402  	}
   403  
   404  	if !sm2PubKey.Verify(digest, sig) {
   405  		return errors.New("tls: processServerKeyExchange: sm2 verification failure")
   406  	}
   407  
   408  	return nil
   409  }
   410  
   411  func (ka *eccKeyAgreementGM) hashForServerKeyExchange(slices ...[]byte) []byte {
   412  	buffer := new(bytes.Buffer)
   413  	for i, slice := range slices {
   414  		if i == 2 {
   415  			buffer.Write([]byte{byte(len(slice) >> 16), byte(len(slice) >> 8), byte(len(slice))})
   416  		}
   417  		buffer.Write(slice)
   418  	}
   419  	return buffer.Bytes()
   420  }
   421  
   422  func (ka *eccKeyAgreementGM) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   423  	preMasterSecret := make([]byte, 48)
   424  	preMasterSecret[0] = byte(clientHello.vers >> 8)
   425  	preMasterSecret[1] = byte(clientHello.vers)
   426  	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
   427  	if err != nil {
   428  		return nil, nil, err
   429  	}
   430  	pubKey := ka.encipherCert.PublicKey.(*ecdsa.PublicKey)
   431  	sm2PubKey := &sm2.PublicKey{Curve: pubKey.Curve, X: pubKey.X, Y: pubKey.Y}
   432  	encrypted, err := sm2.Encrypt(sm2PubKey, preMasterSecret, config.rand())
   433  	if err != nil {
   434  		return nil, nil, err
   435  	}
   436  	ckx := new(clientKeyExchangeMsg)
   437  	ckx.ciphertext = make([]byte, len(encrypted)+2)
   438  	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
   439  	ckx.ciphertext[1] = byte(len(encrypted))
   440  	copy(ckx.ciphertext[2:], encrypted)
   441  	return preMasterSecret, ckx, nil
   442  }