github.com/emmansun/gmsm@v0.29.1/sm9/sm9.go (about)

     1  // Package sm9 implements ShangMi(SM) sm9 digital signature, encryption and key exchange algorithms.
     2  package sm9
     3  
     4  import (
     5  	"crypto"
     6  	goSubtle "crypto/subtle"
     7  	"encoding/binary"
     8  	"errors"
     9  	"io"
    10  	"math/big"
    11  
    12  	"github.com/emmansun/gmsm/internal/bigmod"
    13  	"github.com/emmansun/gmsm/internal/randutil"
    14  	"github.com/emmansun/gmsm/internal/subtle"
    15  	"github.com/emmansun/gmsm/sm3"
    16  	"github.com/emmansun/gmsm/sm9/bn256"
    17  	"golang.org/x/crypto/cryptobyte"
    18  	"golang.org/x/crypto/cryptobyte/asn1"
    19  )
    20  
    21  // SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
    22  
    23  var orderNat, _ = bigmod.NewModulusFromBig(bn256.Order)
    24  var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
    25  var bigOne = big.NewInt(1)
    26  var bigOneNat *bigmod.Nat
    27  var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
    28  
    29  func init() {
    30  	bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
    31  }
    32  
    33  type hashMode byte
    34  
    35  const (
    36  	// hashmode used in h1: 0x01
    37  	H1 hashMode = 1 + iota
    38  	// hashmode used in h2: 0x02
    39  	H2
    40  )
    41  
    42  type encryptType byte
    43  
    44  const (
    45  	ENC_TYPE_XOR encryptType = 0
    46  	ENC_TYPE_ECB encryptType = 1
    47  	ENC_TYPE_CBC encryptType = 2
    48  	ENC_TYPE_OFB encryptType = 4
    49  	ENC_TYPE_CFB encryptType = 8
    50  )
    51  
    52  // hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
    53  func hash(z []byte, h hashMode) *bigmod.Nat {
    54  	md := sm3.New()
    55  	var ha [64]byte
    56  	var countBytes [4]byte
    57  	var ct uint32 = 1
    58  
    59  	binary.BigEndian.PutUint32(countBytes[:], ct)
    60  	md.Write([]byte{byte(h)})
    61  	md.Write(z)
    62  	md.Write(countBytes[:])
    63  	copy(ha[:], md.Sum(nil))
    64  	ct++
    65  	md.Reset()
    66  
    67  	binary.BigEndian.PutUint32(countBytes[:], ct)
    68  	md.Write([]byte{byte(h)})
    69  	md.Write(z)
    70  	md.Write(countBytes[:])
    71  	copy(ha[sm3.Size:], md.Sum(nil))
    72  
    73  	k := new(big.Int).SetBytes(ha[:40])
    74  	kNat := bigmod.NewNat().SetBig(k)
    75  	kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
    76  	kNat.Add(bigOneNat, orderNat)
    77  	return kNat
    78  }
    79  
    80  func hashH1(z []byte) *bigmod.Nat {
    81  	return hash(z, H1)
    82  }
    83  
    84  func hashH2(z []byte) *bigmod.Nat {
    85  	return hash(z, H2)
    86  }
    87  
    88  func randomScalar(rand io.Reader) (k *bigmod.Nat, err error) {
    89  	k = bigmod.NewNat()
    90  	for {
    91  		b := make([]byte, orderNat.Size())
    92  		if _, err = io.ReadFull(rand, b); err != nil {
    93  			return
    94  		}
    95  
    96  		// Mask off any excess bits to increase the chance of hitting a value in
    97  		// (0, N). These are the most dangerous lines in the package and maybe in
    98  		// the library: a single bit of bias in the selection of nonces would likely
    99  		// lead to key recovery, but no tests would fail. Look but DO NOT TOUCH.
   100  		if excess := len(b)*8 - orderNat.BitLen(); excess > 0 {
   101  			// Just to be safe, assert that this only happens for the one curve that
   102  			// doesn't have a round number of bits.
   103  			if excess != 0 {
   104  				panic("sm9: internal error: unexpectedly masking off bits")
   105  			}
   106  			b[0] >>= excess
   107  		}
   108  
   109  		// FIPS 186-4 makes us check k <= N - 2 and then add one.
   110  		// Checking 0 < k <= N - 1 is strictly equivalent.
   111  		// None of this matters anyway because the chance of selecting
   112  		// zero is cryptographically negligible.
   113  		if _, err = k.SetBytes(b, orderNat); err == nil && k.IsZero() == 0 {
   114  			break
   115  		}
   116  	}
   117  	return
   118  }
   119  
   120  // Sign signs a hash (which should be the result of hashing a larger message)
   121  // using the user dsa key. It returns the signature as a pair of h and s.
   122  // Please use SignASN1 instead.
   123  //
   124  // The signature is randomized. Most applications should use [crypto/rand.Reader]
   125  // as rand. Note that the returned signature does not depend deterministically on
   126  // the bytes read from rand, and may change between calls and/or between versions.
   127  func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn256.G1, err error) {
   128  	sig, err := SignASN1(rand, priv, hash)
   129  	if err != nil {
   130  		return nil, nil, err
   131  	}
   132  	return parseSignatureLegacy(sig)
   133  }
   134  
   135  // Sign signs digest with user's DSA key, reading randomness from rand. The opts argument
   136  // is not currently used but, in keeping with the crypto.Signer interface.
   137  // The result is SM9Signature ASN.1 format.
   138  //
   139  // The signature is randomized. Most applications should use [crypto/rand.Reader]
   140  // as rand. Note that the returned signature does not depend deterministically on
   141  // the bytes read from rand, and may change between calls and/or between versions.
   142  func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.SignerOpts) ([]byte, error) {
   143  	return SignASN1(rand, priv, hash)
   144  }
   145  
   146  // SignASN1 signs a hash (which should be the result of hashing a larger message)
   147  // using the private key, priv. It returns the ASN.1 encoded signature of type SM9Signature.
   148  //
   149  // The signature is randomized. Most applications should use [crypto/rand.Reader]
   150  // as rand. Note that the returned signature does not depend deterministically on
   151  // the bytes read from rand, and may change between calls and/or between versions.
   152  func SignASN1(rand io.Reader, priv *SignPrivateKey, hash []byte) ([]byte, error) {
   153  	var (
   154  		hNat *bigmod.Nat
   155  		s    *bn256.G1
   156  	)
   157  	randutil.MaybeReadByte(rand)
   158  	for {
   159  		r, err := randomScalar(rand)
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  
   164  		w, err := priv.SignMasterPublicKey.ScalarBaseMult(r.Bytes(orderNat))
   165  		if err != nil {
   166  			return nil, err
   167  		}
   168  
   169  		var buffer []byte
   170  		buffer = append(buffer, hash...)
   171  		buffer = append(buffer, w.Marshal()...)
   172  
   173  		hNat = hashH2(buffer)
   174  		r.Sub(hNat, orderNat)
   175  
   176  		if r.IsZero() == 0 {
   177  			s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
   178  			if err != nil {
   179  				return nil, err
   180  			}
   181  			break
   182  		}
   183  	}
   184  
   185  	return encodeSignature(hNat.Bytes(orderNat), s)
   186  }
   187  
   188  // Verify verifies the signature in h, s of hash using the master dsa public key and user id, uid and hid.
   189  // Its return value records whether the signature is valid. Please use VerifyASN1 instead.
   190  func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.Int, s *bn256.G1) bool {
   191  	if h.Sign() <= 0 {
   192  		return false
   193  	}
   194  	sig, err := encodeSignature(h.Bytes(), s)
   195  	if err != nil {
   196  		return false
   197  	}
   198  	return VerifyASN1(pub, uid, hid, hash, sig)
   199  }
   200  
   201  func encodeSignature(hBytes []byte, s *bn256.G1) ([]byte, error) {
   202  	var b cryptobyte.Builder
   203  	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
   204  		b.AddASN1OctetString(hBytes)
   205  		b.AddASN1BitString(s.MarshalUncompressed())
   206  	})
   207  	return b.Bytes()
   208  }
   209  
   210  func parseSignature(sig []byte) ([]byte, *bn256.G1, error) {
   211  	var (
   212  		hBytes []byte
   213  		sBytes []byte
   214  		inner  cryptobyte.String
   215  	)
   216  	input := cryptobyte.String(sig)
   217  	if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
   218  		!input.Empty() ||
   219  		!inner.ReadASN1Bytes(&hBytes, asn1.OCTET_STRING) ||
   220  		!inner.ReadASN1BitStringAsBytes(&sBytes) ||
   221  		!inner.Empty() {
   222  		return nil, nil, errors.New("invalid ASN.1")
   223  	}
   224  	if sBytes[0] != 4 {
   225  		return nil, nil, errors.New("sm9: invalid point format")
   226  	}
   227  	s := new(bn256.G1)
   228  	_, err := s.Unmarshal(sBytes[1:])
   229  	if err != nil {
   230  		return nil, nil, err
   231  	}
   232  	return hBytes, s, nil
   233  }
   234  
   235  func parseSignatureLegacy(sig []byte) (*big.Int, *bn256.G1, error) {
   236  	hBytes, s, err := parseSignature(sig)
   237  	if err != nil {
   238  		return nil, nil, err
   239  	}
   240  	return new(big.Int).SetBytes(hBytes), s, nil
   241  }
   242  
   243  // VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the
   244  // public key, pub. Its return value records whether the signature is valid.
   245  func VerifyASN1(pub *SignMasterPublicKey, uid []byte, hid byte, hash, sig []byte) bool {
   246  	h, s, err := parseSignature(sig)
   247  	if err != nil {
   248  		return false
   249  	}
   250  	if !s.IsOnCurve() {
   251  		return false
   252  	}
   253  
   254  	hNat, err := bigmod.NewNat().SetBytes(h, orderNat)
   255  	if err != nil {
   256  		return false
   257  	}
   258  	if hNat.IsZero() == 1 {
   259  		return false
   260  	}
   261  
   262  	t, err := pub.ScalarBaseMult(hNat.Bytes(orderNat))
   263  	if err != nil {
   264  		return false
   265  	}
   266  
   267  	// user sign public key p generation
   268  	p := pub.GenerateUserPublicKey(uid, hid)
   269  
   270  	u := bn256.Pair(s, p)
   271  	w := new(bn256.GT).Add(u, t)
   272  
   273  	var buffer []byte
   274  	buffer = append(buffer, hash...)
   275  	buffer = append(buffer, w.Marshal()...)
   276  	h2 := hashH2(buffer)
   277  
   278  	return h2.Equal(hNat) == 1
   279  }
   280  
   281  // Verify verifies the ASN.1 encoded signature, sig, of hash using the
   282  // public key, pub. Its return value records whether the signature is valid.
   283  func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) bool {
   284  	return VerifyASN1(pub, uid, hid, hash, sig)
   285  }
   286  
   287  // WrapKey generates and wraps key with reciever's uid and system hid, returns generated key and cipher.
   288  //
   289  // The rand parameter is used as a source of entropy to ensure that
   290  // calls this function twice doesn't result in the same key.
   291  // Most applications should use [crypto/rand.Reader] as random.
   292  func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) {
   293  	q := pub.GenerateUserPublicKey(uid, hid)
   294  	var (
   295  		r *bigmod.Nat
   296  		w *bn256.GT
   297  	)
   298  	for {
   299  		r, err = randomScalar(rand)
   300  		if err != nil {
   301  			return
   302  		}
   303  
   304  		rBytes := r.Bytes(orderNat)
   305  		cipher, err = new(bn256.G1).ScalarMult(q, rBytes)
   306  		if err != nil {
   307  			return
   308  		}
   309  
   310  		w, err = pub.ScalarBaseMult(rBytes)
   311  		if err != nil {
   312  			return
   313  		}
   314  		var buffer []byte
   315  		buffer = append(buffer, cipher.Marshal()...)
   316  		buffer = append(buffer, w.Marshal()...)
   317  		buffer = append(buffer, uid...)
   318  
   319  		key = sm3.Kdf(buffer, kLen)
   320  		if subtle.ConstantTimeAllZero(key) == 0 {
   321  			break
   322  		}
   323  	}
   324  	return
   325  }
   326  
   327  // WrapKey wraps key and converts the cipher as ASN1 format, SM9PublicKey1 definition.
   328  //
   329  // The rand parameter is used as a source of entropy to ensure that
   330  // calls this function twice doesn't result in the same key.
   331  // Most applications should use [crypto/rand.Reader] as random.
   332  func (pub *EncryptMasterPublicKey) WrapKey(rand io.Reader, uid []byte, hid byte, kLen int) ([]byte, []byte, error) {
   333  	key, cipher, err := WrapKey(rand, pub, uid, hid, kLen)
   334  	if err != nil {
   335  		return nil, nil, err
   336  	}
   337  	var b cryptobyte.Builder
   338  	b.AddASN1BitString(cipher.MarshalUncompressed())
   339  	cipherASN1, err := b.Bytes()
   340  
   341  	return key, cipherASN1, err
   342  }
   343  
   344  // WrapKeyASN1 wraps key and converts the result of SM9KeyPackage as ASN1 format. according
   345  // SM9 cryptographic algorithm application specification, SM9KeyPackage defnition.
   346  //
   347  // The rand parameter is used as a source of entropy to ensure that
   348  // calls this function twice doesn't result in the same key.
   349  // Most applications should use [crypto/rand.Reader] as random.
   350  func (pub *EncryptMasterPublicKey) WrapKeyASN1(rand io.Reader, uid []byte, hid byte, kLen int) ([]byte, error) {
   351  	key, cipher, err := WrapKey(rand, pub, uid, hid, kLen)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  	var b cryptobyte.Builder
   356  	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
   357  		b.AddASN1OctetString(key)
   358  		b.AddASN1BitString(cipher.MarshalUncompressed())
   359  	})
   360  	return b.Bytes()
   361  }
   362  
   363  // UnmarshalSM9KeyPackage is an utility to unmarshal SM9KeyPackage
   364  func UnmarshalSM9KeyPackage(der []byte) ([]byte, *bn256.G1, error) {
   365  	input := cryptobyte.String(der)
   366  	var (
   367  		key         []byte
   368  		cipherBytes []byte
   369  		inner       cryptobyte.String
   370  	)
   371  	if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
   372  		!input.Empty() ||
   373  		!inner.ReadASN1Bytes(&key, asn1.OCTET_STRING) ||
   374  		!inner.ReadASN1BitStringAsBytes(&cipherBytes) ||
   375  		!inner.Empty() {
   376  		return nil, nil, errors.New("sm9: invalid SM9KeyPackage asn.1 data")
   377  	}
   378  	g, err := unmarshalG1(cipherBytes)
   379  	if err != nil {
   380  		return nil, nil, err
   381  	}
   382  	return key, g, nil
   383  }
   384  
   385  // ErrDecryption represents a failure to decrypt a message.
   386  // It is deliberately vague to avoid adaptive attacks.
   387  var ErrDecryption = errors.New("sm9: decryption error")
   388  
   389  // ErrEmptyPlaintext represents a failure to encrypt an empty message.
   390  var ErrEmptyPlaintext = errors.New("sm9: empty plaintext")
   391  
   392  // UnwrapKey unwraps key from cipher, user id and aligned key length
   393  func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int) ([]byte, error) {
   394  	if !cipher.IsOnCurve() {
   395  		return nil, ErrDecryption
   396  	}
   397  
   398  	w := bn256.Pair(cipher, priv.PrivateKey)
   399  
   400  	var buffer []byte
   401  	buffer = append(buffer, cipher.Marshal()...)
   402  	buffer = append(buffer, w.Marshal()...)
   403  	buffer = append(buffer, uid...)
   404  
   405  	key := sm3.Kdf(buffer, kLen)
   406  	if subtle.ConstantTimeAllZero(key) == 1 {
   407  		return nil, ErrDecryption
   408  	}
   409  	return key, nil
   410  }
   411  
   412  // UnwrapKey unwraps key from cipherDer, user id and aligned key length.
   413  // cipherDer is SM9PublicKey1 format according SM9 cryptographic algorithm application specification.
   414  func (priv *EncryptPrivateKey) UnwrapKey(uid, cipherDer []byte, kLen int) ([]byte, error) {
   415  	var bytes []byte
   416  	input := cryptobyte.String(cipherDer)
   417  	if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
   418  		return nil, ErrDecryption
   419  	}
   420  	g, err := unmarshalG1(bytes)
   421  	if err != nil {
   422  		return nil, ErrDecryption
   423  	}
   424  	return UnwrapKey(priv, uid, g, kLen)
   425  }
   426  
   427  // Encrypt encrypts plaintext, returns ciphertext with format C1||C3||C2.
   428  func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts EncrypterOpts) ([]byte, error) {
   429  	c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts)
   430  	if err != nil {
   431  		return nil, err
   432  	}
   433  	ciphertext := append(c1.Marshal(), c3...)
   434  	ciphertext = append(ciphertext, c2...)
   435  	return ciphertext, nil
   436  }
   437  
   438  func encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts EncrypterOpts) (c1 *bn256.G1, c2, c3 []byte, err error) {
   439  	if opts == nil {
   440  		opts = DefaultEncrypterOpts
   441  	}
   442  	if len(plaintext) == 0 {
   443  		return nil, nil, nil, ErrEmptyPlaintext
   444  	}
   445  	key1Len := opts.GetKeySize(plaintext)
   446  	key, c1, err := WrapKey(rand, pub, uid, hid, key1Len+sm3.Size)
   447  	if err != nil {
   448  		return nil, nil, nil, err
   449  	}
   450  	c2, err = opts.Encrypt(rand, key[:key1Len], plaintext)
   451  	if err != nil {
   452  		return nil, nil, nil, err
   453  	}
   454  
   455  	hash := sm3.New()
   456  	hash.Write(c2)
   457  	hash.Write(key[key1Len:])
   458  	c3 = hash.Sum(nil)
   459  
   460  	return
   461  }
   462  
   463  // EncryptASN1 encrypts plaintext and returns ciphertext with ASN.1 format according
   464  // SM9 cryptographic algorithm application specification, SM9Cipher definition.
   465  func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts EncrypterOpts) ([]byte, error) {
   466  	return pub.Encrypt(rand, uid, hid, plaintext, opts)
   467  }
   468  
   469  // Encrypt encrypts plaintext and returns ciphertext with ASN.1 format according
   470  // SM9 cryptographic algorithm application specification, SM9Cipher definition.
   471  func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte, opts EncrypterOpts) ([]byte, error) {
   472  	if opts == nil {
   473  		opts = DefaultEncrypterOpts
   474  	}
   475  	c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts)
   476  	if err != nil {
   477  		return nil, err
   478  	}
   479  
   480  	var b cryptobyte.Builder
   481  	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
   482  		b.AddASN1Int64(int64(opts.GetEncryptType()))
   483  		b.AddASN1BitString(c1.MarshalUncompressed())
   484  		b.AddASN1OctetString(c3)
   485  		b.AddASN1OctetString(c2)
   486  	})
   487  	return b.Bytes()
   488  }
   489  
   490  // Decrypt decrypts chipher, the ciphertext should be with format C1||C3||C2
   491  func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts) ([]byte, error) {
   492  	if opts == nil {
   493  		opts = DefaultEncrypterOpts
   494  	}
   495  
   496  	c := &bn256.G1{}
   497  	c3c2, err := c.Unmarshal(ciphertext)
   498  	if err != nil {
   499  		return nil, ErrDecryption
   500  	}
   501  
   502  	_ = c3c2[sm3.Size] // bounds check elimination hint
   503  	c3 := c3c2[:sm3.Size]
   504  	c2 := c3c2[sm3.Size:]
   505  	key1Len := opts.GetKeySize(c2)
   506  
   507  	key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size)
   508  	if err != nil {
   509  		return nil, err
   510  	}
   511  	_ = key[key1Len] // bounds check elimination hint
   512  	return decrypt(c, key[:key1Len], key[key1Len:], c2, c3, opts)
   513  }
   514  
   515  func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) {
   516  	hash := sm3.New()
   517  	hash.Write(c2)
   518  	hash.Write(key2)
   519  	c32 := hash.Sum(nil)
   520  
   521  	if goSubtle.ConstantTimeCompare(c3, c32) != 1 {
   522  		return nil, ErrDecryption
   523  	}
   524  
   525  	return opts.Decrypt(key1, c2)
   526  }
   527  
   528  // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according
   529  // SM9 cryptographic algorithm application specification, SM9Cipher definition.
   530  func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) {
   531  	if len(ciphertext) <= 32+65 {
   532  		return nil, errors.New("sm9: ciphertext too short")
   533  	}
   534  	var (
   535  		encType int
   536  		c3Bytes []byte
   537  		c1Bytes []byte
   538  		c2Bytes []byte
   539  		inner   cryptobyte.String
   540  	)
   541  	input := cryptobyte.String(ciphertext)
   542  	if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
   543  		!input.Empty() ||
   544  		!inner.ReadASN1Integer(&encType) ||
   545  		!inner.ReadASN1BitStringAsBytes(&c1Bytes) ||
   546  		!inner.ReadASN1Bytes(&c3Bytes, asn1.OCTET_STRING) ||
   547  		!inner.ReadASN1Bytes(&c2Bytes, asn1.OCTET_STRING) ||
   548  		!inner.Empty() {
   549  		return nil, errors.New("sm9: invalid ciphertext asn.1 data")
   550  	}
   551  	// We just make assumption block cipher is SM4 and padding scheme is pkcs7
   552  	opts := shangMiEncrypterOpts(encryptType(encType))
   553  	if opts == nil {
   554  		return nil, ErrDecryption
   555  	}
   556  	c, err := unmarshalG1(c1Bytes)
   557  	if err != nil {
   558  		return nil, ErrDecryption
   559  	}
   560  
   561  	key1Len := opts.GetKeySize(c2Bytes)
   562  	key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size)
   563  	if err != nil {
   564  		return nil, err
   565  	}
   566  
   567  	_ = key[key1Len] // bounds check elimination hint
   568  	return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts)
   569  }
   570  
   571  // Decrypt decrypts chipher, the ciphertext should be with format C1||C3||C2
   572  func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte, opts EncrypterOpts) ([]byte, error) {
   573  	return Decrypt(priv, uid, ciphertext, opts)
   574  }
   575  
   576  // DecryptASN1 decrypts chipher, the ciphertext should be with ASN.1 format according
   577  // SM9 cryptographic algorithm application specification, SM9Cipher definition.
   578  func (priv *EncryptPrivateKey) DecryptASN1(uid, ciphertext []byte) ([]byte, error) {
   579  	return DecryptASN1(priv, uid, ciphertext)
   580  }
   581  
   582  // KeyExchange represents key exchange struct, include internal stat in whole key exchange flow.
   583  // Initiator's flow will be: NewKeyExchange -> InitKeyExchange -> transmission -> ConfirmResponder
   584  // Responder's flow will be: NewKeyExchange -> waiting ... -> RepondKeyExchange -> transmission -> ConfirmInitiator
   585  type KeyExchange struct {
   586  	genSignature bool               // control the optional sign/verify step triggered by responsder
   587  	keyLength    int                // key length
   588  	privateKey   *EncryptPrivateKey // owner's encryption private key
   589  	uid          []byte             // owner uid
   590  	peerUID      []byte             // peer uid
   591  	r            *bigmod.Nat        // random which will be used to compute secret
   592  	secret       *bn256.G1          // generated secret which will be passed to peer
   593  	peerSecret   *bn256.G1          // received peer's secret
   594  	g1           *bn256.GT          // internal state which will be used when compute the key and signature
   595  	g2           *bn256.GT          // internal state which will be used when compute the key and signature
   596  	g3           *bn256.GT          // internal state which will be used when compute the key and signature
   597  }
   598  
   599  // NewKeyExchange creates one new KeyExchange object
   600  func NewKeyExchange(priv *EncryptPrivateKey, uid, peerUID []byte, keyLen int, genSignature bool) *KeyExchange {
   601  	ke := &KeyExchange{}
   602  	ke.genSignature = genSignature
   603  	ke.keyLength = keyLen
   604  	ke.privateKey = priv
   605  	ke.uid = uid
   606  	ke.peerUID = peerUID
   607  	return ke
   608  }
   609  
   610  // Destroy clears all internal state and Ephemeral private/public keys
   611  func (ke *KeyExchange) Destroy() {
   612  	if ke.r != nil {
   613  		ke.r.SetBytes([]byte{0}, orderNat)
   614  	}
   615  	if ke.g1 != nil {
   616  		ke.g1.SetOne()
   617  	}
   618  	if ke.g2 != nil {
   619  		ke.g2.SetOne()
   620  	}
   621  	if ke.g3 != nil {
   622  		ke.g3.SetOne()
   623  	}
   624  }
   625  
   626  func initKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat) {
   627  	pubB := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid)
   628  	ke.r = r
   629  	rA, err := new(bn256.G1).ScalarMult(pubB, ke.r.Bytes(orderNat))
   630  	if err != nil {
   631  		panic(err)
   632  	}
   633  	ke.secret = rA
   634  }
   635  
   636  // InitKeyExchange generates random with responder uid, for initiator's step A1-A4
   637  func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, error) {
   638  	r, err := randomScalar(rand)
   639  	if err != nil {
   640  		return nil, err
   641  	}
   642  	initKeyExchange(ke, hid, r)
   643  	return ke.secret, nil
   644  }
   645  
   646  func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
   647  	var buffer []byte
   648  	hash := sm3.New()
   649  	hash.Write(ke.g2.Marshal())
   650  	hash.Write(ke.g3.Marshal())
   651  	if isResponder {
   652  		hash.Write(ke.peerUID)
   653  		hash.Write(ke.uid)
   654  		hash.Write(ke.peerSecret.Marshal())
   655  		hash.Write(ke.secret.Marshal())
   656  	} else {
   657  		hash.Write(ke.uid)
   658  		hash.Write(ke.peerUID)
   659  		hash.Write(ke.secret.Marshal())
   660  		hash.Write(ke.peerSecret.Marshal())
   661  	}
   662  	buffer = hash.Sum(nil)
   663  	hash.Reset()
   664  	hash.Write([]byte{prefix})
   665  	hash.Write(ke.g1.Marshal())
   666  	hash.Write(buffer)
   667  	return hash.Sum(nil)
   668  }
   669  
   670  func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
   671  	var buffer []byte
   672  	if isResponder {
   673  		buffer = append(buffer, ke.peerUID...)
   674  		buffer = append(buffer, ke.uid...)
   675  		buffer = append(buffer, ke.peerSecret.Marshal()...)
   676  		buffer = append(buffer, ke.secret.Marshal()...)
   677  	} else {
   678  		buffer = append(buffer, ke.uid...)
   679  		buffer = append(buffer, ke.peerUID...)
   680  		buffer = append(buffer, ke.secret.Marshal()...)
   681  		buffer = append(buffer, ke.peerSecret.Marshal()...)
   682  	}
   683  	buffer = append(buffer, ke.g1.Marshal()...)
   684  	buffer = append(buffer, ke.g2.Marshal()...)
   685  	buffer = append(buffer, ke.g3.Marshal()...)
   686  
   687  	return sm3.Kdf(buffer, ke.keyLength), nil
   688  }
   689  
   690  func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {
   691  	if !rA.IsOnCurve() {
   692  		return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key")
   693  	}
   694  	ke.peerSecret = rA
   695  	pubA := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid)
   696  	ke.r = r
   697  	rBytes := r.Bytes(orderNat)
   698  	rB, err := new(bn256.G1).ScalarMult(pubA, rBytes)
   699  	if err != nil {
   700  		return nil, nil, err
   701  	}
   702  	ke.secret = rB
   703  
   704  	ke.g1 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey)
   705  	ke.g3 = &bn256.GT{}
   706  	g3, err := bn256.ScalarMultGT(ke.g1, rBytes)
   707  	if err != nil {
   708  		return nil, nil, err
   709  	}
   710  	ke.g3 = g3
   711  
   712  	g2, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(rBytes)
   713  	if err != nil {
   714  		return nil, nil, err
   715  	}
   716  	ke.g2 = g2
   717  
   718  	if !ke.genSignature {
   719  		return ke.secret, nil, nil
   720  	}
   721  
   722  	return ke.secret, ke.sign(true, 0x82), nil
   723  }
   724  
   725  // RepondKeyExchange when responder receive rA, for responder's step B1-B7
   726  func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1) (*bn256.G1, []byte, error) {
   727  	r, err := randomScalar(rand)
   728  	if err != nil {
   729  		return nil, nil, err
   730  	}
   731  	return respondKeyExchange(ke, hid, r, rA)
   732  }
   733  
   734  // ConfirmResponder for initiator's step A5-A7
   735  func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, []byte, error) {
   736  	if !rB.IsOnCurve() {
   737  		return nil, nil, errors.New("sm9: invalid responder's ephemeral public key")
   738  	}
   739  	// step 5
   740  	ke.peerSecret = rB
   741  	g1, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r.Bytes(orderNat))
   742  	if err != nil {
   743  		return nil, nil, err
   744  	}
   745  	ke.g1 = g1
   746  	ke.g2 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey)
   747  	ke.g3 = &bn256.GT{}
   748  	g3, err := bn256.ScalarMultGT(ke.g2, ke.r.Bytes(orderNat))
   749  	if err != nil {
   750  		return nil, nil, err
   751  	}
   752  	ke.g3 = g3
   753  	// step 6, verify signature
   754  	if len(sB) > 0 {
   755  		signature := ke.sign(false, 0x82)
   756  		if goSubtle.ConstantTimeCompare(signature, sB) != 1 {
   757  			return nil, nil, errors.New("sm9: invalid responder's signature")
   758  		}
   759  	}
   760  	key, err := ke.generateSharedKey(false)
   761  	if err != nil {
   762  		return nil, nil, err
   763  	}
   764  	if !ke.genSignature {
   765  		return key, nil, nil
   766  	}
   767  	return key, ke.sign(false, 0x83), nil
   768  }
   769  
   770  // ConfirmInitiator for responder's step B8
   771  func (ke *KeyExchange) ConfirmInitiator(s1 []byte) ([]byte, error) {
   772  	if s1 != nil {
   773  		buffer := ke.sign(true, 0x83)
   774  		if goSubtle.ConstantTimeCompare(buffer, s1) != 1 {
   775  			return nil, errors.New("sm9: invalid initiator's signature")
   776  		}
   777  	}
   778  	return ke.generateSharedKey(true)
   779  }