github.com/emmansun/gmsm@v0.29.1/sm2/sm2_legacy.go (about) 1 package sm2 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/elliptic" 6 _subtle "crypto/subtle" 7 "errors" 8 "fmt" 9 "io" 10 "math/big" 11 "strings" 12 13 "github.com/emmansun/gmsm/internal/subtle" 14 "github.com/emmansun/gmsm/sm2/sm2ec" 15 "github.com/emmansun/gmsm/sm3" 16 "golang.org/x/crypto/cryptobyte" 17 "golang.org/x/crypto/cryptobyte/asn1" 18 ) 19 20 // This file contains a math/big implementation of SM2 DSA/Encryption that is only used for 21 // deprecated custom curves. 22 23 // A invertible implements fast inverse in GF(N). 24 type invertible interface { 25 // Inverse returns the inverse of k mod Params().N. 26 Inverse(k *big.Int) *big.Int 27 } 28 29 // A combinedMult implements fast combined multiplication for verification. 30 type combinedMult interface { 31 // CombinedMult returns [s1]G + [s2]P where G is the generator. 32 CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int) 33 } 34 35 // hashToInt converts a hash value to an integer. Per FIPS 186-4, Section 6.4, 36 // we use the left-most bits of the hash to match the bit-length of the order of 37 // the curve. This also performs Step 5 of SEC 1, Version 2.0, Section 4.1.3. 38 func hashToInt(hash []byte, c elliptic.Curve) *big.Int { 39 orderBits := c.Params().N.BitLen() 40 orderBytes := (orderBits + 7) / 8 41 if len(hash) > orderBytes { 42 hash = hash[:orderBytes] 43 } 44 45 ret := new(big.Int).SetBytes(hash) 46 excess := len(hash)*8 - orderBits 47 if excess > 0 { 48 ret.Rsh(ret, uint(excess)) 49 } 50 return ret 51 } 52 53 var errZeroParam = errors.New("zero parameter") 54 55 // Sign signs a hash (which should be the result of hashing a larger message) 56 // using the private key, priv. If the hash is longer than the bit-length of the 57 // private key's curve order, the hash will be truncated to that length. It 58 // returns the signature as a pair of integers. Most applications should use 59 // SignASN1 instead of dealing directly with r, s. 60 // 61 // Compliance with GB/T 32918.2-2016 regardless it's SM2 curve or not. 62 func Sign(rand io.Reader, priv *ecdsa.PrivateKey, hash []byte) (r, s *big.Int, err error) { 63 key := new(PrivateKey) 64 key.PrivateKey = *priv 65 sig, err := SignASN1(rand, key, hash, nil) 66 if err != nil { 67 return nil, nil, err 68 } 69 70 r, s = new(big.Int), new(big.Int) 71 var inner cryptobyte.String 72 input := cryptobyte.String(sig) 73 if !input.ReadASN1(&inner, asn1.SEQUENCE) || 74 !input.Empty() || 75 !inner.ReadASN1Integer(r) || 76 !inner.ReadASN1Integer(s) || 77 !inner.Empty() { 78 return nil, nil, errors.New("invalid ASN.1 from SignASN1") 79 } 80 return r, s, nil 81 } 82 83 func signLegacy(priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) { 84 // See [NSA] 3.4.1 85 c := priv.PublicKey.Curve 86 N := c.Params().N 87 if N.Sign() == 0 { 88 return nil, errZeroParam 89 } 90 var k, r, s *big.Int 91 e := hashToInt(hash, c) 92 for { 93 for { 94 k, err = randFieldElement(c, rand) 95 if err != nil { 96 return nil, err 97 } 98 99 r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) // (x, y) = k*G 100 r.Add(r, e) // r = x + e 101 r.Mod(r, N) // r = (x + e) mod N 102 if r.Sign() != 0 { 103 t := new(big.Int).Add(r, k) 104 if t.Cmp(N) != 0 { // if r != 0 && (r + k) != N then ok 105 break 106 } 107 } 108 } 109 s = new(big.Int).Mul(priv.D, r) 110 s = new(big.Int).Sub(k, s) 111 dp1 := new(big.Int).Add(priv.D, one) 112 113 var dp1Inv *big.Int 114 115 if in, ok := priv.Curve.(invertible); ok { 116 dp1Inv = in.Inverse(dp1) 117 } else { 118 dp1Inv = fermatInverse(dp1, N) // N != 0 119 } 120 121 s.Mul(s, dp1Inv) 122 s.Mod(s, N) // N != 0 123 if s.Sign() != 0 { 124 break 125 } 126 } 127 128 return encodeSignature(r.Bytes(), s.Bytes()) 129 } 130 131 // fermatInverse calculates the inverse of k in GF(P) using Fermat's method 132 // (exponentiation modulo P - 2, per Euler's theorem). This has better 133 // constant-time properties than Euclid's method (implemented in 134 // math/big.Int.ModInverse and FIPS 186-4, Appendix C.1) although math/big 135 // itself isn't strictly constant-time so it's not perfect. 136 func fermatInverse(k, N *big.Int) *big.Int { 137 two := big.NewInt(2) 138 nMinus2 := new(big.Int).Sub(N, two) 139 return new(big.Int).Exp(k, nMinus2, N) 140 } 141 142 // SignWithSM2 follow sm2 dsa standards for hash part, compliance with GB/T 32918.2-2016. 143 func SignWithSM2(rand io.Reader, priv *ecdsa.PrivateKey, uid, msg []byte) (r, s *big.Int, err error) { 144 digest, err := CalculateSM2Hash(&priv.PublicKey, msg, uid) 145 if err != nil { 146 return nil, nil, err 147 } 148 149 return Sign(rand, priv, digest) 150 } 151 152 // Verify verifies the signature in r, s of hash using the public key, pub. Its 153 // return value records whether the signature is valid. Most applications should 154 // use VerifyASN1 instead of dealing directly with r, s. 155 // 156 // Compliance with GB/T 32918.2-2016 regardless it's SM2 curve or not. 157 // Caller should make sure the hash's correctness. 158 func Verify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) bool { 159 if r.Sign() <= 0 || s.Sign() <= 0 { 160 return false 161 } 162 sig, err := encodeSignature(r.Bytes(), s.Bytes()) 163 if err != nil { 164 return false 165 } 166 return VerifyASN1(pub, hash, sig) 167 } 168 169 func verifyLegacy(pub *ecdsa.PublicKey, hash, sig []byte) bool { 170 rBytes, sBytes, err := parseSignature(sig) 171 if err != nil { 172 return false 173 } 174 r, s := new(big.Int).SetBytes(rBytes), new(big.Int).SetBytes(sBytes) 175 176 c := pub.Curve 177 N := c.Params().N 178 179 if r.Sign() <= 0 || s.Sign() <= 0 { 180 return false 181 } 182 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 183 return false 184 } 185 e := hashToInt(hash, c) 186 t := new(big.Int).Add(r, s) 187 t.Mod(t, N) 188 if t.Sign() == 0 { 189 return false 190 } 191 192 var x *big.Int 193 if opt, ok := c.(combinedMult); ok { 194 x, _ = opt.CombinedMult(pub.X, pub.Y, s.Bytes(), t.Bytes()) 195 } else { 196 x1, y1 := c.ScalarBaseMult(s.Bytes()) 197 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 198 x, _ = c.Add(x1, y1, x2, y2) 199 } 200 201 x.Add(x, e) 202 x.Mod(x, N) 203 return x.Cmp(r) == 0 204 } 205 206 // VerifyWithSM2 verifies the signature in r, s of raw msg and uid using the public key, pub. 207 // It returns value records whether the signature is valid. Compliance with GB/T 32918.2-2016. 208 func VerifyWithSM2(pub *ecdsa.PublicKey, uid, msg []byte, r, s *big.Int) bool { 209 digest, err := CalculateSM2Hash(pub, msg, uid) 210 if err != nil { 211 return false 212 } 213 return Verify(pub, digest, r, s) 214 } 215 216 var ( 217 one = new(big.Int).SetInt64(1) 218 ) 219 220 // randFieldElement returns a random element of the order of the given 221 // curve using the procedure given in FIPS 186-4, Appendix B.5.2. 222 func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) { 223 // See randomPoint for notes on the algorithm. This has to match, or s390x 224 // signatures will come out different from other architectures, which will 225 // break TLS recorded tests. 226 for { 227 N := c.Params().N 228 b := make([]byte, (N.BitLen()+7)/8) 229 if _, err = io.ReadFull(rand, b); err != nil { 230 return 231 } 232 if excess := len(b)*8 - N.BitLen(); excess > 0 { 233 b[0] >>= excess 234 } 235 k = new(big.Int).SetBytes(b) 236 if k.Sign() != 0 && k.Cmp(N) < 0 { 237 return 238 } 239 } 240 } 241 242 func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) { 243 curve := pub.Curve 244 msgLen := len(msg) 245 246 var retryCount int = 0 247 for { 248 //A1, generate random k 249 k, err := randFieldElement(curve, random) 250 if err != nil { 251 return nil, err 252 } 253 254 //A2, calculate C1 = k * G 255 x1, y1 := curve.ScalarBaseMult(k.Bytes()) 256 c1 := opts.pointMarshalMode.mashal(curve, x1, y1) 257 258 //A4, calculate k * P (point of Public Key) 259 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) 260 261 //A5, calculate t=KDF(x2||y2, klen) 262 c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) 263 if subtle.ConstantTimeAllZero(c2) == 1 { 264 retryCount++ 265 if retryCount > maxRetryLimit { 266 return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount) 267 } 268 continue 269 } 270 271 //A6, C2 = M + t; 272 subtle.XORBytes(c2, msg, c2) 273 274 //A7, C3 = hash(x2||M||y2) 275 c3 := calculateC3(curve, x2, y2, msg) 276 277 if opts.ciphertextEncoding == ENCODING_PLAIN { 278 if opts.ciphertextSplicingOrder == C1C3C2 { 279 // c1 || c3 || c2 280 return append(append(c1, c3...), c2...), nil 281 } 282 // c1 || c2 || c3 283 return append(append(c1, c2...), c3...), nil 284 } 285 // ASN.1 format will force C3 C2 order 286 return mashalASN1Ciphertext(x1, y1, c2, c3) 287 } 288 } 289 290 func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { 291 md := sm3.New() 292 md.Write(toBytes(curve, x2)) 293 md.Write(msg) 294 md.Write(toBytes(curve, y2)) 295 return md.Sum(nil) 296 } 297 298 func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) { 299 var b cryptobyte.Builder 300 b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { 301 b.AddASN1BigInt(x1) 302 b.AddASN1BigInt(y1) 303 b.AddASN1OctetString(c3) 304 b.AddASN1OctetString(c2) 305 }) 306 return b.Bytes() 307 } 308 309 // ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format 310 func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { 311 if opts == nil { 312 opts = defaultEncrypterOpts 313 } 314 x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext)) 315 if err != nil { 316 return nil, err 317 } 318 curve := sm2ec.P256() 319 c1 := opts.pointMarshalMode.mashal(curve, x1, y1) 320 if opts.ciphertextSplicingOrder == C1C3C2 { 321 // c1 || c3 || c2 322 return append(append(c1, c3...), c2...), nil 323 } 324 // c1 || c2 || c3 325 return append(append(c1, c2...), c3...), nil 326 } 327 328 // PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format 329 func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { 330 if ciphertext[0] == 0x30 { 331 return nil, errors.New("sm2: invalid plain encoding ciphertext") 332 } 333 curve := sm2ec.P256() 334 ciphertextLen := len(ciphertext) 335 if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { 336 return nil, errCiphertextTooShort 337 } 338 // get C1, and check C1 339 x1, y1, c3Start, err := bytes2Point(curve, ciphertext) 340 if err != nil { 341 return nil, err 342 } 343 344 var c2, c3 []byte 345 346 if from == C1C3C2 { 347 c2 = ciphertext[c3Start+sm3.Size:] 348 c3 = ciphertext[c3Start : c3Start+sm3.Size] 349 } else { 350 c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] 351 c3 = ciphertext[ciphertextLen-sm3.Size:] 352 } 353 return mashalASN1Ciphertext(x1, y1, c2, c3) 354 } 355 356 // AdjustCiphertextSplicingOrder utility method to change c2 c3 order 357 func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) { 358 curve := sm2ec.P256() 359 if from == to { 360 return ciphertext, nil 361 } 362 ciphertextLen := len(ciphertext) 363 if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { 364 return nil, errCiphertextTooShort 365 } 366 367 // get C1, and check C1 368 _, _, c3Start, err := bytes2Point(curve, ciphertext) 369 if err != nil { 370 return nil, err 371 } 372 373 var c1, c2, c3 []byte 374 375 c1 = ciphertext[:c3Start] 376 if from == C1C3C2 { 377 c2 = ciphertext[c3Start+sm3.Size:] 378 c3 = ciphertext[c3Start : c3Start+sm3.Size] 379 } else { 380 c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] 381 c3 = ciphertext[ciphertextLen-sm3.Size:] 382 } 383 384 result := make([]byte, ciphertextLen) 385 copy(result, c1) 386 if to == C1C3C2 { 387 // c1 || c3 || c2 388 copy(result[c3Start:], c3) 389 copy(result[c3Start+sm3.Size:], c2) 390 } else { 391 // c1 || c2 || c3 392 copy(result[c3Start:], c2) 393 copy(result[ciphertextLen-sm3.Size:], c3) 394 } 395 return result, nil 396 } 397 398 func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) { 399 x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) 400 if err != nil { 401 return nil, ErrDecryption 402 } 403 return rawDecrypt(priv, x1, y1, c2, c3) 404 } 405 406 func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) { 407 curve := priv.Curve 408 x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) 409 msgLen := len(c2) 410 msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) 411 if subtle.ConstantTimeAllZero(c2) == 1 { 412 return nil, ErrDecryption 413 } 414 415 //B5, calculate msg = c2 ^ t 416 subtle.XORBytes(msg, c2, msg) 417 418 u := calculateC3(curve, x2, y2, msg) 419 if _subtle.ConstantTimeCompare(u, c3) == 1 { 420 return msg, nil 421 } 422 return nil, ErrDecryption 423 } 424 425 func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { 426 splicingOrder := C1C3C2 427 if opts != nil { 428 if opts.ciphertextEncoding == ENCODING_ASN1 { 429 return decryptASN1(priv, ciphertext) 430 } 431 splicingOrder = opts.cipherTextSplicingOrder 432 } 433 if ciphertext[0] == 0x30 { 434 return decryptASN1(priv, ciphertext) 435 } 436 ciphertextLen := len(ciphertext) 437 curve := priv.Curve 438 // B1, get C1, and check C1 439 x1, y1, c3Start, err := bytes2Point(curve, ciphertext) 440 if err != nil { 441 return nil, ErrDecryption 442 } 443 444 //B4, calculate t=KDF(x2||y2, klen) 445 var c2, c3 []byte 446 if splicingOrder == C1C3C2 { 447 c2 = ciphertext[c3Start+sm3.Size:] 448 c3 = ciphertext[c3Start : c3Start+sm3.Size] 449 } else { 450 c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] 451 c3 = ciphertext[ciphertextLen-sm3.Size:] 452 } 453 454 return rawDecrypt(priv, x1, y1, c2, c3) 455 } 456 457 func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { 458 if len(bytes) < 1+(curve.Params().BitSize/8) { 459 return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes)) 460 } 461 format := bytes[0] 462 byteLen := (curve.Params().BitSize + 7) >> 3 463 switch format { 464 case uncompressed, hybrid06, hybrid07: // what's the hybrid format purpose? 465 if len(bytes) < 1+byteLen*2 { 466 return nil, nil, 0, fmt.Errorf("sm2: invalid point uncompressed/hybrid form bytes length %d", len(bytes)) 467 } 468 data := make([]byte, 1+byteLen*2) 469 data[0] = uncompressed 470 copy(data[1:], bytes[1:1+byteLen*2]) 471 x, y := sm2ec.Unmarshal(curve, data) 472 if x == nil || y == nil { 473 return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) 474 } 475 return x, y, 1 + byteLen*2, nil 476 case compressed02, compressed03: 477 if len(bytes) < 1+byteLen { 478 return nil, nil, 0, fmt.Errorf("sm2: invalid point compressed form bytes length %d", len(bytes)) 479 } 480 // Make sure it's NIST curve or SM2 P-256 curve 481 if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, sm2ec.P256().Params().Name) { 482 // y² = x³ - 3x + b, prime curves 483 x, y := sm2ec.UnmarshalCompressed(curve, bytes[:1+byteLen]) 484 if x == nil || y == nil { 485 return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) 486 } 487 return x, y, 1 + byteLen, nil 488 } 489 return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name) 490 } 491 return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format) 492 } 493 494 func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { 495 switch mode { 496 case MarshalCompressed: 497 return elliptic.MarshalCompressed(curve, x, y) 498 case MarshalHybrid: 499 buffer := elliptic.Marshal(curve, x, y) 500 buffer[0] = byte(y.Bit(0)) | hybrid06 501 return buffer 502 default: 503 return elliptic.Marshal(curve, x, y) 504 } 505 }