github.com/emmansun/gmsm@v0.29.1/sm2/sm2_legacy.go (about)

     1  package sm2
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	_subtle "crypto/subtle"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"math/big"
    11  	"strings"
    12  
    13  	"github.com/emmansun/gmsm/internal/subtle"
    14  	"github.com/emmansun/gmsm/sm2/sm2ec"
    15  	"github.com/emmansun/gmsm/sm3"
    16  	"golang.org/x/crypto/cryptobyte"
    17  	"golang.org/x/crypto/cryptobyte/asn1"
    18  )
    19  
    20  // This file contains a math/big implementation of SM2 DSA/Encryption that is only used for
    21  // deprecated custom curves.
    22  
    23  // A invertible implements fast inverse in GF(N).
    24  type invertible interface {
    25  	// Inverse returns the inverse of k mod Params().N.
    26  	Inverse(k *big.Int) *big.Int
    27  }
    28  
    29  // A combinedMult implements fast combined multiplication for verification.
    30  type combinedMult interface {
    31  	// CombinedMult returns [s1]G + [s2]P where G is the generator.
    32  	CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int)
    33  }
    34  
    35  // hashToInt converts a hash value to an integer. Per FIPS 186-4, Section 6.4,
    36  // we use the left-most bits of the hash to match the bit-length of the order of
    37  // the curve. This also performs Step 5 of SEC 1, Version 2.0, Section 4.1.3.
    38  func hashToInt(hash []byte, c elliptic.Curve) *big.Int {
    39  	orderBits := c.Params().N.BitLen()
    40  	orderBytes := (orderBits + 7) / 8
    41  	if len(hash) > orderBytes {
    42  		hash = hash[:orderBytes]
    43  	}
    44  
    45  	ret := new(big.Int).SetBytes(hash)
    46  	excess := len(hash)*8 - orderBits
    47  	if excess > 0 {
    48  		ret.Rsh(ret, uint(excess))
    49  	}
    50  	return ret
    51  }
    52  
    53  var errZeroParam = errors.New("zero parameter")
    54  
    55  // Sign signs a hash (which should be the result of hashing a larger message)
    56  // using the private key, priv. If the hash is longer than the bit-length of the
    57  // private key's curve order, the hash will be truncated to that length. It
    58  // returns the signature as a pair of integers. Most applications should use
    59  // SignASN1 instead of dealing directly with r, s.
    60  //
    61  // Compliance with GB/T 32918.2-2016 regardless it's SM2 curve or not.
    62  func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, err error) {
    63  	key := new(PrivateKey)
    64  	key.PrivateKey = *priv
    65  	sig, err := SignASN1(rand, key, hash, nil)
    66  	if err != nil {
    67  		return nil, nil, err
    68  	}
    69  
    70  	r, s = new(big.Int), new(big.Int)
    71  	var inner cryptobyte.String
    72  	input := cryptobyte.String(sig)
    73  	if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
    74  		!input.Empty() ||
    75  		!inner.ReadASN1Integer(r) ||
    76  		!inner.ReadASN1Integer(s) ||
    77  		!inner.Empty() {
    78  		return nil, nil, errors.New("invalid ASN.1 from SignASN1")
    79  	}
    80  	return r, s, nil
    81  }
    82  
    83  func signLegacy(priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) {
    84  	// See [NSA] 3.4.1
    85  	c := priv.PublicKey.Curve
    86  	N := c.Params().N
    87  	if N.Sign() == 0 {
    88  		return nil, errZeroParam
    89  	}
    90  	var k, r, s *big.Int
    91  	e := hashToInt(hash, c)
    92  	for {
    93  		for {
    94  			k, err = randFieldElement(c, rand)
    95  			if err != nil {
    96  				return nil, err
    97  			}
    98  
    99  			r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) // (x, y) = k*G
   100  			r.Add(r, e)                                 // r = x + e
   101  			r.Mod(r, N)                                 // r = (x + e) mod N
   102  			if r.Sign() != 0 {
   103  				t := new(big.Int).Add(r, k)
   104  				if t.Cmp(N) != 0 { // if r != 0 && (r + k) != N then ok
   105  					break
   106  				}
   107  			}
   108  		}
   109  		s = new(big.Int).Mul(priv.D, r)
   110  		s = new(big.Int).Sub(k, s)
   111  		dp1 := new(big.Int).Add(priv.D, one)
   112  
   113  		var dp1Inv *big.Int
   114  
   115  		if in, ok := priv.Curve.(invertible); ok {
   116  			dp1Inv = in.Inverse(dp1)
   117  		} else {
   118  			dp1Inv = fermatInverse(dp1, N) // N != 0
   119  		}
   120  
   121  		s.Mul(s, dp1Inv)
   122  		s.Mod(s, N) // N != 0
   123  		if s.Sign() != 0 {
   124  			break
   125  		}
   126  	}
   127  
   128  	return encodeSignature(r.Bytes(), s.Bytes())
   129  }
   130  
   131  // fermatInverse calculates the inverse of k in GF(P) using Fermat's method
   132  // (exponentiation modulo P - 2, per Euler's theorem). This has better
   133  // constant-time properties than Euclid's method (implemented in
   134  // math/big.Int.ModInverse and FIPS 186-4, Appendix C.1) although math/big
   135  // itself isn't strictly constant-time so it's not perfect.
   136  func fermatInverse(k, N *big.Int) *big.Int {
   137  	two := big.NewInt(2)
   138  	nMinus2 := new(big.Int).Sub(N, two)
   139  	return new(big.Int).Exp(k, nMinus2, N)
   140  }
   141  
   142  // SignWithSM2 follow sm2 dsa standards for hash part, compliance with GB/T 32918.2-2016.
   143  func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s *big.Int, err error) {
   144  	digest, err := CalculateSM2Hash(&priv.PublicKey, msg, uid)
   145  	if err != nil {
   146  		return nil, nil, err
   147  	}
   148  
   149  	return Sign(rand, priv, digest)
   150  }
   151  
   152  // Verify verifies the signature in r, s of hash using the public key, pub. Its
   153  // return value records whether the signature is valid. Most applications should
   154  // use VerifyASN1 instead of dealing directly with r, s.
   155  //
   156  // Compliance with GB/T 32918.2-2016 regardless it's SM2 curve or not.
   157  // Caller should make sure the hash's correctness.
   158  func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool {
   159  	if r.Sign() <= 0 || s.Sign() <= 0 {
   160  		return false
   161  	}
   162  	sig, err := encodeSignature(r.Bytes(), s.Bytes())
   163  	if err != nil {
   164  		return false
   165  	}
   166  	return VerifyASN1(pub, hash, sig)
   167  }
   168  
   169  func verifyLegacy(pub *ecdsa.PublicKey, hash, sig []byte) bool {
   170  	rBytes, sBytes, err := parseSignature(sig)
   171  	if err != nil {
   172  		return false
   173  	}
   174  	r, s := new(big.Int).SetBytes(rBytes), new(big.Int).SetBytes(sBytes)
   175  
   176  	c := pub.Curve
   177  	N := c.Params().N
   178  
   179  	if r.Sign() <= 0 || s.Sign() <= 0 {
   180  		return false
   181  	}
   182  	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
   183  		return false
   184  	}
   185  	e := hashToInt(hash, c)
   186  	t := new(big.Int).Add(r, s)
   187  	t.Mod(t, N)
   188  	if t.Sign() == 0 {
   189  		return false
   190  	}
   191  
   192  	var x *big.Int
   193  	if opt, ok := c.(combinedMult); ok {
   194  		x, _ = opt.CombinedMult(pub.X, pub.Y, s.Bytes(), t.Bytes())
   195  	} else {
   196  		x1, y1 := c.ScalarBaseMult(s.Bytes())
   197  		x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
   198  		x, _ = c.Add(x1, y1, x2, y2)
   199  	}
   200  
   201  	x.Add(x, e)
   202  	x.Mod(x, N)
   203  	return x.Cmp(r) == 0
   204  }
   205  
   206  // VerifyWithSM2 verifies the signature in r, s of raw msg and uid using the public key, pub.
   207  // It returns value records whether the signature is valid. Compliance with GB/T 32918.2-2016.
   208  func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool {
   209  	digest, err := CalculateSM2Hash(pub, msg, uid)
   210  	if err != nil {
   211  		return false
   212  	}
   213  	return Verify(pub, digest, r, s)
   214  }
   215  
   216  var (
   217  	one = new(big.Int).SetInt64(1)
   218  )
   219  
   220  // randFieldElement returns a random element of the order of the given
   221  // curve using the procedure given in FIPS 186-4, Appendix B.5.2.
   222  func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
   223  	// See randomPoint for notes on the algorithm. This has to match, or s390x
   224  	// signatures will come out different from other architectures, which will
   225  	// break TLS recorded tests.
   226  	for {
   227  		N := c.Params().N
   228  		b := make([]byte, (N.BitLen()+7)/8)
   229  		if _, err = io.ReadFull(rand, b); err != nil {
   230  			return
   231  		}
   232  		if excess := len(b)*8 - N.BitLen(); excess > 0 {
   233  			b[0] >>= excess
   234  		}
   235  		k = new(big.Int).SetBytes(b)
   236  		if k.Sign() != 0 && k.Cmp(N) < 0 {
   237  			return
   238  		}
   239  	}
   240  }
   241  
   242  func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
   243  	curve := pub.Curve
   244  	msgLen := len(msg)
   245  
   246  	var retryCount int = 0
   247  	for {
   248  		//A1, generate random k
   249  		k, err := randFieldElement(curve, random)
   250  		if err != nil {
   251  			return nil, err
   252  		}
   253  
   254  		//A2, calculate C1 = k * G
   255  		x1, y1 := curve.ScalarBaseMult(k.Bytes())
   256  		c1 := opts.pointMarshalMode.mashal(curve, x1, y1)
   257  
   258  		//A4, calculate k * P (point of Public Key)
   259  		x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
   260  
   261  		//A5, calculate t=KDF(x2||y2, klen)
   262  		c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
   263  		if subtle.ConstantTimeAllZero(c2) == 1 {
   264  			retryCount++
   265  			if retryCount > maxRetryLimit {
   266  				return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount)
   267  			}
   268  			continue
   269  		}
   270  
   271  		//A6, C2 = M + t;
   272  		subtle.XORBytes(c2, msg, c2)
   273  
   274  		//A7, C3 = hash(x2||M||y2)
   275  		c3 := calculateC3(curve, x2, y2, msg)
   276  
   277  		if opts.ciphertextEncoding == ENCODING_PLAIN {
   278  			if opts.ciphertextSplicingOrder == C1C3C2 {
   279  				// c1 || c3 || c2
   280  				return append(append(c1, c3...), c2...), nil
   281  			}
   282  			// c1 || c2 || c3
   283  			return append(append(c1, c2...), c3...), nil
   284  		}
   285  		// ASN.1 format will force C3 C2 order
   286  		return mashalASN1Ciphertext(x1, y1, c2, c3)
   287  	}
   288  }
   289  
   290  func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
   291  	md := sm3.New()
   292  	md.Write(toBytes(curve, x2))
   293  	md.Write(msg)
   294  	md.Write(toBytes(curve, y2))
   295  	return md.Sum(nil)
   296  }
   297  
   298  func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) {
   299  	var b cryptobyte.Builder
   300  	b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
   301  		b.AddASN1BigInt(x1)
   302  		b.AddASN1BigInt(y1)
   303  		b.AddASN1OctetString(c3)
   304  		b.AddASN1OctetString(c2)
   305  	})
   306  	return b.Bytes()
   307  }
   308  
   309  // ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format
   310  func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
   311  	if opts == nil {
   312  		opts = defaultEncrypterOpts
   313  	}
   314  	x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext))
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  	curve := sm2ec.P256()
   319  	c1 := opts.pointMarshalMode.mashal(curve, x1, y1)
   320  	if opts.ciphertextSplicingOrder == C1C3C2 {
   321  		// c1 || c3 || c2
   322  		return append(append(c1, c3...), c2...), nil
   323  	}
   324  	// c1 || c2 || c3
   325  	return append(append(c1, c2...), c3...), nil
   326  }
   327  
   328  // PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
   329  func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
   330  	if ciphertext[0] == 0x30 {
   331  		return nil, errors.New("sm2: invalid plain encoding ciphertext")
   332  	}
   333  	curve := sm2ec.P256()
   334  	ciphertextLen := len(ciphertext)
   335  	if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
   336  		return nil, errCiphertextTooShort
   337  	}
   338  	// get C1, and check C1
   339  	x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	var c2, c3 []byte
   345  
   346  	if from == C1C3C2 {
   347  		c2 = ciphertext[c3Start+sm3.Size:]
   348  		c3 = ciphertext[c3Start : c3Start+sm3.Size]
   349  	} else {
   350  		c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
   351  		c3 = ciphertext[ciphertextLen-sm3.Size:]
   352  	}
   353  	return mashalASN1Ciphertext(x1, y1, c2, c3)
   354  }
   355  
   356  // AdjustCiphertextSplicingOrder utility method to change c2 c3 order
   357  func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) {
   358  	curve := sm2ec.P256()
   359  	if from == to {
   360  		return ciphertext, nil
   361  	}
   362  	ciphertextLen := len(ciphertext)
   363  	if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
   364  		return nil, errCiphertextTooShort
   365  	}
   366  
   367  	// get C1, and check C1
   368  	_, _, c3Start, err := bytes2Point(curve, ciphertext)
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	var c1, c2, c3 []byte
   374  
   375  	c1 = ciphertext[:c3Start]
   376  	if from == C1C3C2 {
   377  		c2 = ciphertext[c3Start+sm3.Size:]
   378  		c3 = ciphertext[c3Start : c3Start+sm3.Size]
   379  	} else {
   380  		c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
   381  		c3 = ciphertext[ciphertextLen-sm3.Size:]
   382  	}
   383  
   384  	result := make([]byte, ciphertextLen)
   385  	copy(result, c1)
   386  	if to == C1C3C2 {
   387  		// c1 || c3 || c2
   388  		copy(result[c3Start:], c3)
   389  		copy(result[c3Start+sm3.Size:], c2)
   390  	} else {
   391  		// c1 || c2 || c3
   392  		copy(result[c3Start:], c2)
   393  		copy(result[ciphertextLen-sm3.Size:], c3)
   394  	}
   395  	return result, nil
   396  }
   397  
   398  func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
   399  	x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
   400  	if err != nil {
   401  		return nil, ErrDecryption
   402  	}
   403  	return rawDecrypt(priv, x1, y1, c2, c3)
   404  }
   405  
   406  func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) {
   407  	curve := priv.Curve
   408  	x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
   409  	msgLen := len(c2)
   410  	msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
   411  	if subtle.ConstantTimeAllZero(c2) == 1 {
   412  		return nil, ErrDecryption
   413  	}
   414  
   415  	//B5, calculate msg = c2 ^ t
   416  	subtle.XORBytes(msg, c2, msg)
   417  
   418  	u := calculateC3(curve, x2, y2, msg)
   419  	if _subtle.ConstantTimeCompare(u, c3) == 1 {
   420  		return msg, nil
   421  	}
   422  	return nil, ErrDecryption
   423  }
   424  
   425  func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
   426  	splicingOrder := C1C3C2
   427  	if opts != nil {
   428  		if opts.ciphertextEncoding == ENCODING_ASN1 {
   429  			return decryptASN1(priv, ciphertext)
   430  		}
   431  		splicingOrder = opts.cipherTextSplicingOrder
   432  	}
   433  	if ciphertext[0] == 0x30 {
   434  		return decryptASN1(priv, ciphertext)
   435  	}
   436  	ciphertextLen := len(ciphertext)
   437  	curve := priv.Curve
   438  	// B1, get C1, and check C1
   439  	x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
   440  	if err != nil {
   441  		return nil, ErrDecryption
   442  	}
   443  
   444  	//B4, calculate t=KDF(x2||y2, klen)
   445  	var c2, c3 []byte
   446  	if splicingOrder == C1C3C2 {
   447  		c2 = ciphertext[c3Start+sm3.Size:]
   448  		c3 = ciphertext[c3Start : c3Start+sm3.Size]
   449  	} else {
   450  		c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
   451  		c3 = ciphertext[ciphertextLen-sm3.Size:]
   452  	}
   453  
   454  	return rawDecrypt(priv, x1, y1, c2, c3)
   455  }
   456  
   457  func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
   458  	if len(bytes) < 1+(curve.Params().BitSize/8) {
   459  		return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes))
   460  	}
   461  	format := bytes[0]
   462  	byteLen := (curve.Params().BitSize + 7) >> 3
   463  	switch format {
   464  	case uncompressed, hybrid06, hybrid07: // what's the hybrid format purpose?
   465  		if len(bytes) < 1+byteLen*2 {
   466  			return nil, nil, 0, fmt.Errorf("sm2: invalid point uncompressed/hybrid form bytes length %d", len(bytes))
   467  		}
   468  		data := make([]byte, 1+byteLen*2)
   469  		data[0] = uncompressed
   470  		copy(data[1:], bytes[1:1+byteLen*2])
   471  		x, y := sm2ec.Unmarshal(curve, data)
   472  		if x == nil || y == nil {
   473  			return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name)
   474  		}
   475  		return x, y, 1 + byteLen*2, nil
   476  	case compressed02, compressed03:
   477  		if len(bytes) < 1+byteLen {
   478  			return nil, nil, 0, fmt.Errorf("sm2: invalid point compressed form bytes length %d", len(bytes))
   479  		}
   480  		// Make sure it's NIST curve or SM2 P-256 curve
   481  		if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, sm2ec.P256().Params().Name) {
   482  			// y² = x³ - 3x + b, prime curves
   483  			x, y := sm2ec.UnmarshalCompressed(curve, bytes[:1+byteLen])
   484  			if x == nil || y == nil {
   485  				return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name)
   486  			}
   487  			return x, y, 1 + byteLen, nil
   488  		}
   489  		return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name)
   490  	}
   491  	return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format)
   492  }
   493  
   494  func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
   495  	switch mode {
   496  	case MarshalCompressed:
   497  		return elliptic.MarshalCompressed(curve, x, y)
   498  	case MarshalHybrid:
   499  		buffer := elliptic.Marshal(curve, x, y)
   500  		buffer[0] = byte(y.Bit(0)) | hybrid06
   501  		return buffer
   502  	default:
   503  		return elliptic.Marshal(curve, x, y)
   504  	}
   505  }