github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/gmtls/key_agreement.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package gmtls
    17  
    18  import (
    19  	"crypto"
    20  	"crypto/elliptic"
    21  	"crypto/md5"
    22  	"crypto/rsa"
    23  	"crypto/sha1"
    24  	"errors"
    25  	"io"
    26  	"math/big"
    27  
    28  	"github.com/Hyperledger-TWGC/tjfoc-gm/x509"
    29  
    30  	"golang.org/x/crypto/curve25519"
    31  )
    32  
    33  var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
    34  var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
    35  
    36  // rsaKeyAgreement implements the standard TLS key agreement where the client
    37  // encrypts the pre-master secret to the server's public key.
    38  type rsaKeyAgreement struct{}
    39  
    40  func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
    41  	clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
    42  	return nil, nil
    43  }
    44  
    45  func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
    46  	if len(ckx.ciphertext) < 2 {
    47  		return nil, errClientKeyExchange
    48  	}
    49  
    50  	ciphertext := ckx.ciphertext
    51  	if version != VersionSSL30 {
    52  		ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
    53  		if ciphertextLen != len(ckx.ciphertext)-2 {
    54  			return nil, errClientKeyExchange
    55  		}
    56  		ciphertext = ckx.ciphertext[2:]
    57  	}
    58  	priv, ok := cert.PrivateKey.(crypto.Decrypter)
    59  	if !ok {
    60  		return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
    61  	}
    62  	// Perform constant time RSA PKCS#1 v1.5 decryption
    63  	preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	// We don't check the version number in the premaster secret. For one,
    68  	// by checking it, we would leak information about the validity of the
    69  	// encrypted pre-master secret. Secondly, it provides only a small
    70  	// benefit against a downgrade attack and some implementations send the
    71  	// wrong version anyway. See the discussion at the end of section
    72  	// 7.4.7.1 of RFC 4346.
    73  	return preMasterSecret, nil
    74  }
    75  
    76  func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
    77  	return errors.New("tls: unexpected ServerKeyExchange")
    78  }
    79  
    80  func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
    81  	preMasterSecret := make([]byte, 48)
    82  	preMasterSecret[0] = byte(clientHello.vers >> 8)
    83  	preMasterSecret[1] = byte(clientHello.vers)
    84  	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
    85  	if err != nil {
    86  		return nil, nil, err
    87  	}
    88  
    89  	encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
    90  	if err != nil {
    91  		return nil, nil, err
    92  	}
    93  	ckx := new(clientKeyExchangeMsg)
    94  	ckx.ciphertext = make([]byte, len(encrypted)+2)
    95  	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
    96  	ckx.ciphertext[1] = byte(len(encrypted))
    97  	copy(ckx.ciphertext[2:], encrypted)
    98  	return preMasterSecret, ckx, nil
    99  }
   100  
   101  // sha1Hash calculates a SHA1 hash over the given byte slices.
   102  func sha1Hash(slices [][]byte) []byte {
   103  	hsha1 := sha1.New()
   104  	for _, slice := range slices {
   105  		hsha1.Write(slice)
   106  	}
   107  	return hsha1.Sum(nil)
   108  }
   109  
   110  // md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the
   111  // concatenation of an MD5 and SHA1 hash.
   112  func md5SHA1Hash(slices [][]byte) []byte {
   113  	md5sha1 := make([]byte, md5.Size+sha1.Size)
   114  	hmd5 := md5.New()
   115  	for _, slice := range slices {
   116  		hmd5.Write(slice)
   117  	}
   118  	copy(md5sha1, hmd5.Sum(nil))
   119  	copy(md5sha1[md5.Size:], sha1Hash(slices))
   120  	return md5sha1
   121  }
   122  
   123  // hashForServerKeyExchange hashes the given slices and returns their digest
   124  // using the given hash function (for >= TLS 1.2) or using a default based on
   125  // the sigType (for earlier TLS versions).
   126  func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) {
   127  	if version >= VersionTLS12 {
   128  		h := hashFunc.New()
   129  		for _, slice := range slices {
   130  			h.Write(slice)
   131  		}
   132  		digest := h.Sum(nil)
   133  		return digest, nil
   134  	}
   135  	if sigType == signatureECDSA {
   136  		return sha1Hash(slices), nil
   137  	}
   138  	return md5SHA1Hash(slices), nil
   139  }
   140  
   141  func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
   142  	switch id {
   143  	case CurveP256:
   144  		return elliptic.P256(), true
   145  	case CurveP384:
   146  		return elliptic.P384(), true
   147  	case CurveP521:
   148  		return elliptic.P521(), true
   149  	default:
   150  		return nil, false
   151  	}
   152  
   153  }
   154  
   155  // ecdheKeyAgreement implements a TLS key agreement where the server
   156  // generates an ephemeral EC public/private key pair and signs it. The
   157  // pre-master secret is then calculated using ECDH. The signature may
   158  // either be ECDSA or RSA.
   159  type ecdheKeyAgreement struct {
   160  	version    uint16
   161  	isRSA      bool
   162  	privateKey []byte
   163  	curveid    CurveID
   164  
   165  	// publicKey is used to store the peer's public value when X25519 is
   166  	// being used.
   167  	publicKey []byte
   168  	// x and y are used to store the peer's public value when one of the
   169  	// NIST curves is being used.
   170  	x, y *big.Int
   171  }
   172  
   173  func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
   174  	clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
   175  	preferredCurves := config.curvePreferences()
   176  
   177  NextCandidate:
   178  	for _, candidate := range preferredCurves {
   179  		for _, c := range clientHello.supportedCurves {
   180  			if candidate == c {
   181  				ka.curveid = c
   182  				break NextCandidate
   183  			}
   184  		}
   185  	}
   186  
   187  	if ka.curveid == 0 {
   188  		return nil, errors.New("tls: no supported elliptic curves offered")
   189  	}
   190  
   191  	var ecdhePublic []byte
   192  
   193  	if ka.curveid == X25519 {
   194  		var scalar, public [32]byte
   195  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   196  			return nil, err
   197  		}
   198  
   199  		curve25519.ScalarBaseMult(&public, &scalar)
   200  		ka.privateKey = scalar[:]
   201  		ecdhePublic = public[:]
   202  	} else {
   203  		curve, ok := curveForCurveID(ka.curveid)
   204  		if !ok {
   205  			return nil, errors.New("tls: preferredCurves includes unsupported curve")
   206  		}
   207  
   208  		var x, y *big.Int
   209  		var err error
   210  		ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
   211  		if err != nil {
   212  			return nil, err
   213  		}
   214  		ecdhePublic = elliptic.Marshal(curve, x, y)
   215  	}
   216  
   217  	// https://tools.ietf.org/html/rfc4492#section-5.4
   218  	serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
   219  	serverECDHParams[0] = 3 // named curve
   220  	serverECDHParams[1] = byte(ka.curveid >> 8)
   221  	serverECDHParams[2] = byte(ka.curveid)
   222  	serverECDHParams[3] = byte(len(ecdhePublic))
   223  	copy(serverECDHParams[4:], ecdhePublic)
   224  
   225  	priv, ok := signCert.PrivateKey.(crypto.Signer)
   226  	if !ok {
   227  		return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
   228  	}
   229  
   230  	signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version)
   231  	if err != nil {
   232  		return nil, err
   233  	}
   234  	if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
   235  		return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
   236  	}
   237  
   238  	digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	signOpts := crypto.SignerOpts(hashFunc)
   244  	if sigType == signatureRSAPSS {
   245  		signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc}
   246  	}
   247  	sig, err := priv.Sign(config.rand(), digest, signOpts)
   248  	if err != nil {
   249  		return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
   250  	}
   251  
   252  	skx := new(serverKeyExchangeMsg)
   253  	sigAndHashLen := 0
   254  	if ka.version >= VersionTLS12 {
   255  		sigAndHashLen = 2
   256  	}
   257  	skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
   258  	copy(skx.key, serverECDHParams)
   259  	k := skx.key[len(serverECDHParams):]
   260  	if ka.version >= VersionTLS12 {
   261  		k[0] = byte(signatureAlgorithm >> 8)
   262  		k[1] = byte(signatureAlgorithm)
   263  		k = k[2:]
   264  	}
   265  	k[0] = byte(len(sig) >> 8)
   266  	k[1] = byte(len(sig))
   267  	copy(k[2:], sig)
   268  
   269  	return skx, nil
   270  }
   271  
   272  func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
   273  	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
   274  		return nil, errClientKeyExchange
   275  	}
   276  
   277  	if ka.curveid == X25519 {
   278  		if len(ckx.ciphertext) != 1+32 {
   279  			return nil, errClientKeyExchange
   280  		}
   281  
   282  		var theirPublic, sharedKey, scalar [32]byte
   283  		copy(theirPublic[:], ckx.ciphertext[1:])
   284  		copy(scalar[:], ka.privateKey)
   285  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   286  		return sharedKey[:], nil
   287  	}
   288  
   289  	curve, ok := curveForCurveID(ka.curveid)
   290  	if !ok {
   291  		panic("internal error")
   292  	}
   293  	x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:]) // Unmarshal also checks whether the given point is on the curve
   294  	if x == nil {
   295  		return nil, errClientKeyExchange
   296  	}
   297  	x, _ = curve.ScalarMult(x, y, ka.privateKey)
   298  	preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
   299  	xBytes := x.Bytes()
   300  	copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   301  
   302  	return preMasterSecret, nil
   303  }
   304  
   305  func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
   306  	if len(skx.key) < 4 {
   307  		return errServerKeyExchange
   308  	}
   309  	if skx.key[0] != 3 { // named curve
   310  		return errors.New("tls: server selected unsupported curve")
   311  	}
   312  	ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
   313  
   314  	publicLen := int(skx.key[3])
   315  	if publicLen+4 > len(skx.key) {
   316  		return errServerKeyExchange
   317  	}
   318  	serverECDHParams := skx.key[:4+publicLen]
   319  	publicKey := serverECDHParams[4:]
   320  
   321  	sig := skx.key[4+publicLen:]
   322  	if len(sig) < 2 {
   323  		return errServerKeyExchange
   324  	}
   325  
   326  	if ka.curveid == X25519 {
   327  		if len(publicKey) != 32 {
   328  			return errors.New("tls: bad X25519 public value")
   329  		}
   330  		ka.publicKey = publicKey
   331  	} else {
   332  		curve, ok := curveForCurveID(ka.curveid)
   333  		if !ok {
   334  			return errors.New("tls: server selected unsupported curve")
   335  		}
   336  		ka.x, ka.y = elliptic.Unmarshal(curve, publicKey) // Unmarshal also checks whether the given point is on the curve
   337  		if ka.x == nil {
   338  			return errServerKeyExchange
   339  		}
   340  	}
   341  
   342  	var signatureAlgorithm SignatureScheme
   343  	if ka.version >= VersionTLS12 {
   344  		// handle SignatureAndHashAlgorithm
   345  		signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
   346  		sig = sig[2:]
   347  		if len(sig) < 2 {
   348  			return errServerKeyExchange
   349  		}
   350  	}
   351  	_, sigType, hashFunc, err := pickSignatureAlgorithm(cert.PublicKey, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version)
   352  	if err != nil {
   353  		return err
   354  	}
   355  	if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
   356  		return errServerKeyExchange
   357  	}
   358  
   359  	sigLen := int(sig[0])<<8 | int(sig[1])
   360  	if sigLen+2 != len(sig) {
   361  		return errServerKeyExchange
   362  	}
   363  	sig = sig[2:]
   364  
   365  	digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams)
   366  	if err != nil {
   367  		return err
   368  	}
   369  	return verifyHandshakeSignature(sigType, cert.PublicKey, hashFunc, digest, sig)
   370  }
   371  
   372  func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
   373  	if ka.curveid == 0 {
   374  		return nil, nil, errors.New("tls: missing ServerKeyExchange message")
   375  	}
   376  
   377  	var serialized, preMasterSecret []byte
   378  
   379  	if ka.curveid == X25519 {
   380  		var ourPublic, theirPublic, sharedKey, scalar [32]byte
   381  
   382  		if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
   383  			return nil, nil, err
   384  		}
   385  
   386  		copy(theirPublic[:], ka.publicKey)
   387  		curve25519.ScalarBaseMult(&ourPublic, &scalar)
   388  		curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
   389  		serialized = ourPublic[:]
   390  		preMasterSecret = sharedKey[:]
   391  	} else {
   392  		curve, ok := curveForCurveID(ka.curveid)
   393  		if !ok {
   394  			panic("internal error")
   395  		}
   396  		priv, mx, my, err := elliptic.GenerateKey(curve, config.rand())
   397  		if err != nil {
   398  			return nil, nil, err
   399  		}
   400  		x, _ := curve.ScalarMult(ka.x, ka.y, priv)
   401  		preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
   402  		xBytes := x.Bytes()
   403  		copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
   404  
   405  		serialized = elliptic.Marshal(curve, mx, my)
   406  	}
   407  
   408  	ckx := new(clientKeyExchangeMsg)
   409  	ckx.ciphertext = make([]byte, 1+len(serialized))
   410  	ckx.ciphertext[0] = byte(len(serialized))
   411  	copy(ckx.ciphertext[1:], serialized)
   412  
   413  	return preMasterSecret, ckx, nil
   414  }