github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/sm2/sm2.go (about) 1 /* 2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 package sm2 17 18 // reference to ecdsa 19 import ( 20 "bytes" 21 "crypto" 22 "crypto/elliptic" 23 "crypto/rand" 24 "encoding/asn1" 25 "encoding/binary" 26 "errors" 27 "io" 28 "math/big" 29 30 "github.com/Hyperledger-TWGC/tjfoc-gm/sm3" 31 ) 32 33 var ( 34 default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} 35 ) 36 37 type PublicKey struct { 38 elliptic.Curve 39 X, Y *big.Int 40 } 41 42 type PrivateKey struct { 43 PublicKey 44 D *big.Int 45 } 46 47 type sm2Signature struct { 48 R, S *big.Int 49 } 50 type sm2Cipher struct { 51 XCoordinate *big.Int 52 YCoordinate *big.Int 53 HASH []byte 54 CipherText []byte 55 } 56 57 // The SM2's private key contains the public key 58 func (priv *PrivateKey) Public() crypto.PublicKey { 59 return &priv.PublicKey 60 } 61 62 var errZeroParam = errors.New("zero parameter") 63 var one = new(big.Int).SetInt64(1) 64 var two = new(big.Int).SetInt64(2) 65 66 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s 67 func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) { 68 r, s, err := Sm2Sign(priv, msg, nil, random) 69 if err != nil { 70 return nil, err 71 } 72 return asn1.Marshal(sm2Signature{r, s}) 73 } 74 75 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool { 76 var sm2Sign sm2Signature 77 _, err := asn1.Unmarshal(sign, &sm2Sign) 78 if err != nil { 79 return false 80 } 81 return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S) 82 } 83 84 func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) { 85 if len(uid) == 0 { 86 uid = default_uid 87 } 88 89 za, err := ZA(pub, uid) 90 if err != nil { 91 return nil, err 92 } 93 94 e, err := msgHash(za, msg) 95 if err != nil { 96 return nil, err 97 } 98 99 return e.Bytes(), nil 100 } 101 102 //****************************Encryption algorithm****************************// 103 func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) { 104 return EncryptAsn1(pub, data, random) 105 } 106 107 func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) { 108 return DecryptAsn1(priv, data) 109 } 110 111 //**************************Key agreement algorithm**************************// 112 // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k 113 func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) { 114 return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false) 115 } 116 117 // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k 118 func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) { 119 return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true) 120 } 121 122 //****************************************************************************// 123 124 func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) { 125 digest, err := priv.PublicKey.Sm3Digest(msg, uid) 126 if err != nil { 127 return nil, nil, err 128 } 129 e := new(big.Int).SetBytes(digest) 130 c := priv.PublicKey.Curve 131 N := c.Params().N 132 if N.Sign() == 0 { 133 return nil, nil, errZeroParam 134 } 135 var k *big.Int 136 for { // 调整算法细节以实现SM2 137 for { 138 k, err = randFieldElement(c, random) 139 if err != nil { 140 r = nil 141 return 142 } 143 r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) 144 r.Add(r, e) 145 r.Mod(r, N) 146 if r.Sign() != 0 { 147 if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 { 148 break 149 } 150 } 151 152 } 153 rD := new(big.Int).Mul(priv.D, r) 154 s = new(big.Int).Sub(k, rD) 155 d1 := new(big.Int).Add(priv.D, one) 156 d1Inv := new(big.Int).ModInverse(d1, N) 157 s.Mul(s, d1Inv) 158 s.Mod(s, N) 159 if s.Sign() != 0 { 160 break 161 } 162 } 163 return 164 } 165 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool { 166 c := pub.Curve 167 N := c.Params().N 168 one := new(big.Int).SetInt64(1) 169 if r.Cmp(one) < 0 || s.Cmp(one) < 0 { 170 return false 171 } 172 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 173 return false 174 } 175 if len(uid) == 0 { 176 uid = default_uid 177 } 178 za, err := ZA(pub, uid) 179 if err != nil { 180 return false 181 } 182 e, err := msgHash(za, msg) 183 if err != nil { 184 return false 185 } 186 t := new(big.Int).Add(r, s) 187 t.Mod(t, N) 188 if t.Sign() == 0 { 189 return false 190 } 191 var x *big.Int 192 x1, y1 := c.ScalarBaseMult(s.Bytes()) 193 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 194 x, _ = c.Add(x1, y1, x2, y2) 195 196 x.Add(x, e) 197 x.Mod(x, N) 198 return x.Cmp(r) == 0 199 } 200 201 /* 202 za, err := ZA(pub, uid) 203 if err != nil { 204 return 205 } 206 e, err := msgHash(za, msg) 207 hash=e.getBytes() 208 */ 209 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { 210 c := pub.Curve 211 N := c.Params().N 212 213 if r.Sign() <= 0 || s.Sign() <= 0 { 214 return false 215 } 216 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 217 return false 218 } 219 220 // 调整算法细节以实现SM2 221 t := new(big.Int).Add(r, s) 222 t.Mod(t, N) 223 if t.Sign() == 0 { 224 return false 225 } 226 227 var x *big.Int 228 x1, y1 := c.ScalarBaseMult(s.Bytes()) 229 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 230 x, _ = c.Add(x1, y1, x2, y2) 231 232 e := new(big.Int).SetBytes(hash) 233 x.Add(x, e) 234 x.Mod(x, N) 235 return x.Cmp(r) == 0 236 } 237 238 /* 239 * sm2密文结构如下: 240 * x 241 * y 242 * hash 243 * CipherText 244 */ 245 func Encrypt(pub *PublicKey, data []byte, random io.Reader) ([]byte, error) { 246 length := len(data) 247 for { 248 c := []byte{} 249 curve := pub.Curve 250 k, err := randFieldElement(curve, random) 251 if err != nil { 252 return nil, err 253 } 254 x1, y1 := curve.ScalarBaseMult(k.Bytes()) 255 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) 256 x1Buf := x1.Bytes() 257 y1Buf := y1.Bytes() 258 x2Buf := x2.Bytes() 259 y2Buf := y2.Bytes() 260 if n := len(x1Buf); n < 32 { 261 x1Buf = append(zeroByteSlice()[:32-n], x1Buf...) 262 } 263 if n := len(y1Buf); n < 32 { 264 y1Buf = append(zeroByteSlice()[:32-n], y1Buf...) 265 } 266 if n := len(x2Buf); n < 32 { 267 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...) 268 } 269 if n := len(y2Buf); n < 32 { 270 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...) 271 } 272 c = append(c, x1Buf...) // x分量 273 c = append(c, y1Buf...) // y分量 274 tm := []byte{} 275 tm = append(tm, x2Buf...) 276 tm = append(tm, data...) 277 tm = append(tm, y2Buf...) 278 h := sm3.Sm3Sum(tm) 279 c = append(c, h...) 280 ct, ok := kdf(length, x2Buf, y2Buf) // 密文 281 if !ok { 282 continue 283 } 284 c = append(c, ct...) 285 for i := 0; i < length; i++ { 286 c[96+i] ^= data[i] 287 } 288 return append([]byte{0x04}, c...), nil 289 } 290 } 291 292 func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) { 293 data = data[1:] 294 length := len(data) - 96 295 curve := priv.Curve 296 x := new(big.Int).SetBytes(data[:32]) 297 y := new(big.Int).SetBytes(data[32:64]) 298 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes()) 299 x2Buf := x2.Bytes() 300 y2Buf := y2.Bytes() 301 if n := len(x2Buf); n < 32 { 302 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...) 303 } 304 if n := len(y2Buf); n < 32 { 305 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...) 306 } 307 c, ok := kdf(length, x2Buf, y2Buf) 308 if !ok { 309 return nil, errors.New("Decrypt: failed to decrypt") 310 } 311 for i := 0; i < length; i++ { 312 c[i] ^= data[i+96] 313 } 314 tm := []byte{} 315 tm = append(tm, x2Buf...) 316 tm = append(tm, c...) 317 tm = append(tm, y2Buf...) 318 h := sm3.Sm3Sum(tm) 319 if bytes.Compare(h, data[64:96]) != 0 { 320 return c, errors.New("Decrypt: failed to decrypt") 321 } 322 return c, nil 323 } 324 325 // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串 326 // klen: 密钥长度 327 // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识 328 // pri: 函数调用者的密钥 329 // pub: 对方的公钥 330 // rpri: 函数调用者生成的临时SM2密钥 331 // rpub: 对方发来的临时SM2公钥 332 // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false 333 // 返回 k 为klen长度的字节串 334 func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) { 335 curve := P256Sm2() 336 N := curve.Params().N 337 x2hat := keXHat(rpri.PublicKey.X) 338 x2rb := new(big.Int).Mul(x2hat, rpri.D) 339 tbt := new(big.Int).Add(pri.D, x2rb) 340 tb := new(big.Int).Mod(tbt, N) 341 if !curve.IsOnCurve(rpub.X, rpub.Y) { 342 err = errors.New("Ra not on curve") 343 return 344 } 345 x1hat := keXHat(rpub.X) 346 ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes()) 347 vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1) 348 349 vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes()) 350 pza := pub 351 if thisISA { 352 pza = &pri.PublicKey 353 } 354 za, err := ZA(pza, ida) 355 if err != nil { 356 return 357 } 358 zero := new(big.Int) 359 if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 { 360 err = errors.New("V is infinite") 361 } 362 pzb := pub 363 if !thisISA { 364 pzb = &pri.PublicKey 365 } 366 zb, err := ZA(pzb, idb) 367 k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb) 368 if !ok { 369 err = errors.New("kdf: zero key") 370 return 371 } 372 h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes()) 373 if !thisISA { 374 h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes()) 375 } 376 hash := sm3.Sm3Sum(h1) 377 h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash) 378 S1 := sm3.Sm3Sum(h2) 379 h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash) 380 S2 := sm3.Sm3Sum(h3) 381 return k, S1, S2, nil 382 } 383 384 func msgHash(za, msg []byte) (*big.Int, error) { 385 e := sm3.New() 386 e.Write(za) 387 e.Write(msg) 388 return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil 389 } 390 391 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) 392 func ZA(pub *PublicKey, uid []byte) ([]byte, error) { 393 za := sm3.New() 394 uidLen := len(uid) 395 if uidLen >= 8192 { 396 return []byte{}, errors.New("SM2: uid too large") 397 } 398 Entla := uint16(8 * uidLen) 399 za.Write([]byte{byte((Entla >> 8) & 0xFF)}) 400 za.Write([]byte{byte(Entla & 0xFF)}) 401 if uidLen > 0 { 402 za.Write(uid) 403 } 404 za.Write(sm2P256ToBig(&sm2P256.a).Bytes()) 405 za.Write(sm2P256.B.Bytes()) 406 za.Write(sm2P256.Gx.Bytes()) 407 za.Write(sm2P256.Gy.Bytes()) 408 409 xBuf := pub.X.Bytes() 410 yBuf := pub.Y.Bytes() 411 if n := len(xBuf); n < 32 { 412 xBuf = append(zeroByteSlice()[:32-n], xBuf...) 413 } 414 if n := len(yBuf); n < 32 { 415 yBuf = append(zeroByteSlice()[:32-n], yBuf...) 416 } 417 za.Write(xBuf) 418 za.Write(yBuf) 419 return za.Sum(nil)[:32], nil 420 } 421 422 // 32byte 423 func zeroByteSlice() []byte { 424 return []byte{ 425 0, 0, 0, 0, 426 0, 0, 0, 0, 427 0, 0, 0, 0, 428 0, 0, 0, 0, 429 0, 0, 0, 0, 430 0, 0, 0, 0, 431 0, 0, 0, 0, 432 0, 0, 0, 0, 433 } 434 } 435 436 /* 437 sm2加密,返回asn.1编码格式的密文内容 438 */ 439 func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) { 440 cipher, err := Encrypt(pub, data, rand) 441 if err != nil { 442 return nil, err 443 } 444 return CipherMarshal(cipher) 445 } 446 447 /* 448 sm2解密,解析asn.1编码格式的密文内容 449 */ 450 func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) { 451 cipher, err := CipherUnmarshal(data) 452 if err != nil { 453 return nil, err 454 } 455 return Decrypt(pub, cipher) 456 } 457 458 /* 459 *sm2密文转asn.1编码格式 460 *sm2密文结构如下: 461 * x 462 * y 463 * hash 464 * CipherText 465 */ 466 func CipherMarshal(data []byte) ([]byte, error) { 467 data = data[1:] 468 x := new(big.Int).SetBytes(data[:32]) 469 y := new(big.Int).SetBytes(data[32:64]) 470 hash := data[64:96] 471 cipherText := data[96:] 472 return asn1.Marshal(sm2Cipher{x, y, hash, cipherText}) 473 } 474 475 /* 476 sm2密文asn.1编码格式转C1|C3|C2拼接格式 477 */ 478 func CipherUnmarshal(data []byte) ([]byte, error) { 479 var cipher sm2Cipher 480 _, err := asn1.Unmarshal(data, &cipher) 481 if err != nil { 482 return nil, err 483 } 484 x := cipher.XCoordinate.Bytes() 485 y := cipher.YCoordinate.Bytes() 486 hash := cipher.HASH 487 if err != nil { 488 return nil, err 489 } 490 cipherText := cipher.CipherText 491 if err != nil { 492 return nil, err 493 } 494 c := []byte{} 495 c = append(c, x...) // x分量 496 c = append(c, y...) // y分 497 c = append(c, hash...) // x分量 498 c = append(c, cipherText...) // y分 499 return append([]byte{0x04}, c...), nil 500 } 501 502 // keXHat 计算 x = 2^w + (x & (2^w-1)) 503 // 密钥协商算法辅助函数 504 func keXHat(x *big.Int) (xul *big.Int) { 505 buf := x.Bytes() 506 for i := 0; i < len(buf)-16; i++ { 507 buf[i] = 0 508 } 509 if len(buf) >= 16 { 510 c := buf[len(buf)-16] 511 buf[len(buf)-16] = c & 0x7f 512 } 513 514 r := new(big.Int).SetBytes(buf) 515 _2w := new(big.Int).SetBytes([]byte{ 516 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 517 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 518 return r.Add(r, _2w) 519 } 520 521 func BytesCombine(pBytes ...[]byte) []byte { 522 len := len(pBytes) 523 s := make([][]byte, len) 524 for index := 0; index < len; index++ { 525 s[index] = pBytes[index] 526 } 527 sep := []byte("") 528 return bytes.Join(s, sep) 529 } 530 531 func intToBytes(x int) []byte { 532 var buf = make([]byte, 4) 533 534 binary.BigEndian.PutUint32(buf, uint32(x)) 535 return buf 536 } 537 538 func kdf(length int, x ...[]byte) ([]byte, bool) { 539 var c []byte 540 541 ct := 1 542 h := sm3.New() 543 for i, j := 0, (length+31)/32; i < j; i++ { 544 h.Reset() 545 for _, xx := range x { 546 h.Write(xx) 547 } 548 h.Write(intToBytes(ct)) 549 hash := h.Sum(nil) 550 if i+1 == j && length%32 != 0 { 551 c = append(c, hash[:length%32]...) 552 } else { 553 c = append(c, hash...) 554 } 555 ct++ 556 } 557 for i := 0; i < length; i++ { 558 if c[i] != 0 { 559 return c, true 560 } 561 } 562 return c, false 563 } 564 565 func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) { 566 if random == nil { 567 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. 568 } 569 params := c.Params() 570 b := make([]byte, params.BitSize/8+8) 571 _, err = io.ReadFull(random, b) 572 if err != nil { 573 return 574 } 575 k = new(big.Int).SetBytes(b) 576 n := new(big.Int).Sub(params.N, one) 577 k.Mod(k, n) 578 k.Add(k, one) 579 return 580 } 581 582 func GenerateKey(random io.Reader) (*PrivateKey, error) { 583 c := P256Sm2() 584 if random == nil { 585 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. 586 } 587 params := c.Params() 588 b := make([]byte, params.BitSize/8+8) 589 _, err := io.ReadFull(random, b) 590 if err != nil { 591 return nil, err 592 } 593 594 k := new(big.Int).SetBytes(b) 595 n := new(big.Int).Sub(params.N, two) 596 k.Mod(k, n) 597 k.Add(k, one) 598 priv := new(PrivateKey) 599 priv.PublicKey.Curve = c 600 priv.D = k 601 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 602 603 return priv, nil 604 } 605 606 type zr struct { 607 io.Reader 608 } 609 610 func (z *zr) Read(dst []byte) (n int, err error) { 611 for i := range dst { 612 dst[i] = 0 613 } 614 return len(dst), nil 615 } 616 617 var zeroReader = &zr{} 618 619 func getLastBit(a *big.Int) uint { 620 return a.Bit(0) 621 } 622 623 // crypto.Decrypter 624 func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error){ 625 return Decrypt(priv, msg) 626 } 627