github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm2soft/sm2.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 /* 4 sm2soft 是sm2的纯软实现,基于tjfoc国密算法库`tjfoc/gmsm`做了少量修改。 5 对应版权声明: thrid_licenses/github.com/tjfoc/gmsm/版权声明 6 */ 7 8 package sm2soft 9 10 // reference to ecdsa 11 import ( 12 "bytes" 13 "crypto" 14 "crypto/elliptic" 15 "crypto/rand" 16 "encoding/asn1" 17 "encoding/binary" 18 "errors" 19 "io" 20 "math/big" 21 22 "github.com/hxx258456/ccgo/sm3" 23 "golang.org/x/crypto/cryptobyte" 24 cbasn1 "golang.org/x/crypto/cryptobyte/asn1" 25 ) 26 27 var ( 28 defaultUid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} 29 C1C3C2 = 0 30 C1C2C3 = 1 31 ) 32 33 // PublicKey SM2公钥结构体 34 type PublicKey struct { 35 elliptic.Curve // 椭圆曲线 36 X, Y *big.Int // 公钥座标 37 } 38 39 func (pub *PublicKey) Equal(x crypto.PublicKey) bool { 40 xx, ok := x.(*PublicKey) 41 if !ok { 42 return false 43 } 44 return pub.X.Cmp(xx.X) == 0 && pub.Y.Cmp(xx.Y) == 0 && 45 // Standard library Curve implementations are singletons, so this check 46 // will work for those. Other Curves might be equivalent even if not 47 // singletons, but there is no definitive way to check for that, and 48 // better to err on the side of safety. 49 pub.Curve == xx.Curve 50 } 51 52 // PrivateKey SM2私钥结构体 53 type PrivateKey struct { 54 PublicKey // 公钥 55 D *big.Int // 私钥,[1,n-1]区间的随机数 56 } 57 58 type sm2Cipher struct { 59 XCoordinate *big.Int 60 YCoordinate *big.Int 61 HASH []byte 62 CipherText []byte 63 } 64 65 // Public The SM2's private key contains the public key 66 func (priv *PrivateKey) Public() crypto.PublicKey { 67 return &priv.PublicKey 68 } 69 70 func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool { 71 xx, ok := x.(*PrivateKey) 72 if !ok { 73 return false 74 } 75 return priv.PublicKey.Equal(&xx.PublicKey) && priv.D.Cmp(xx.D) == 0 76 } 77 78 var errZeroParam = errors.New("zero parameter") 79 var one = new(big.Int).SetInt64(1) 80 var two = new(big.Int).SetInt64(2) 81 82 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s 83 84 // Sign 使用priv私钥对签名内容摘要做SM2签名,参数signer目前没有使用,但仍要求传入外部对明文消息做摘要的散列算法。 85 // 使用priv私钥对签名内容摘要做SM2签名,参数signer目前没有使用,但仍要求传入外部对明文消息做摘要的散列算法。 86 // 返回的签名为DER字节数组,对(r,s)做了asn1编码。 87 // - random : 随机数获取用, 如 rand.Reader 88 // - signContentDigest : 签名内容摘要(散列值) 89 // - signer : 外部对签名内容进行摘要计算使用的散列函数 90 // 91 //goland:noinspection GoUnusedParameter 92 func (priv *PrivateKey) Sign(random io.Reader, signContentDigest []byte, signer crypto.SignerOpts) ([]byte, error) { 93 r, s, err := Sm2Sign(priv, signContentDigest, nil, random) 94 if err != nil { 95 return nil, err 96 } 97 var b cryptobyte.Builder 98 b.AddASN1(cbasn1.SEQUENCE, func(b *cryptobyte.Builder) { 99 b.AddASN1BigInt(r) 100 b.AddASN1BigInt(s) 101 }) 102 return b.Bytes() 103 } 104 105 // SignASN1 使用私钥priv对一个hash值进行签名。 106 // 返回的签名为DER字节数组,对(r,s)做了asn1编码。 107 // 108 //goland:noinspection GoUnusedExportedFunction 109 func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte) ([]byte, error) { 110 return priv.Sign(rand, hash, nil) 111 } 112 113 // Verify 使用pub公钥对签名sig做验签。 114 // - signContentDigest : 签名内容摘要(散列值) 115 // - sig : 签名DER字节数组(对(r,s)做了asn1编码,因此会先做asn1解码) 116 func (pub *PublicKey) Verify(signContentDigest []byte, sig []byte) bool { 117 var ( 118 r, s = &big.Int{}, &big.Int{} 119 inner cryptobyte.String 120 ) 121 input := cryptobyte.String(sig) 122 if !input.ReadASN1(&inner, cbasn1.SEQUENCE) || 123 !input.Empty() || 124 !inner.ReadASN1Integer(r) || 125 !inner.ReadASN1Integer(s) || 126 !inner.Empty() { 127 return false 128 } 129 return Sm2Verify(pub, signContentDigest, defaultUid, r, s) 130 } 131 132 // VerifyASN1 使用公钥pub对hash和sig进行验签。 133 // - pub 公钥 134 // - hash 签名内容摘要(散列值) 135 // - sig 签名DER字节数组(对(r,s)做了asn1编码,因此会先做asn1解码) 136 // 137 //goland:noinspection GoUnusedExportedFunction 138 func VerifyASN1(pub *PublicKey, hash, sig []byte) bool { 139 return pub.Verify(hash, sig) 140 } 141 142 // Sm3Digest 对签名内容进行SM3摘要计算,摘要计算前混入sm2椭圆曲线部分参数与公钥并预散列一次。 143 func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) { 144 if len(uid) == 0 { 145 uid = defaultUid 146 } 147 148 za, err := ZA(pub, uid) 149 if err != nil { 150 return nil, err 151 } 152 153 e, err := msgHash(za, msg) 154 if err != nil { 155 return nil, err 156 } 157 158 return e.Bytes(), nil 159 } 160 161 //****************************Encryption algorithm****************************// 162 163 // EncryptAsn1 sm2加密,C1C3C2,asn1编码 164 func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) { 165 return EncryptAsn1(pub, data, random) 166 } 167 168 // DecryptAsn1 sm2解密,C1C3C2,asn1解码 169 func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) { 170 return DecryptAsn1(priv, data) 171 } 172 173 // **************************Key agreement algorithm**************************// 174 175 // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k 176 func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) { 177 return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false) 178 } 179 180 // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k 181 func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) { 182 return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true) 183 } 184 185 //****************************************************************************// 186 187 // Sm2Sign SM2签名 188 // - priv : 签名私钥 *sm2.PrivateKey 189 // - signContentDigest : 签名内容摘要(散列值) 190 // - uid : 内部混合摘要计算用uid, 长度16的字节数组,可以传 nil 191 // - random : 随机数获取用 192 func Sm2Sign(priv *PrivateKey, signContentDigest, uid []byte, random io.Reader) (r, s *big.Int, err error) { 193 // 对签名内容进行摘要计算 194 digest, err := priv.PublicKey.Sm3Digest(signContentDigest, uid) 195 if err != nil { 196 return nil, nil, err 197 } 198 e := new(big.Int).SetBytes(digest) 199 c := priv.PublicKey.Curve 200 N := c.Params().N 201 if N.Sign() == 0 { 202 return nil, nil, errZeroParam 203 } 204 var k *big.Int 205 // SM2签名实现 206 for { 207 for { 208 // 生成随机数k 209 k, err = randFieldElement(c, random) 210 if err != nil { 211 r = nil 212 return 213 } 214 // 计算P = k*G,返回值的x赋予了r 215 r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) 216 // 计算 r = (e + P(x)) mod n 217 // e + P(x) 218 r.Add(r, e) 219 // (e + P(x)) mod n 220 r.Mod(r, N) 221 if r.Sign() != 0 { 222 if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 { 223 break 224 } 225 } 226 227 } 228 // 计算 s = (((1 + d)^-1) (k-rd)) mod n 229 // rd 230 rD := new(big.Int).Mul(priv.D, r) 231 // k - rd 232 s = new(big.Int).Sub(k, rD) 233 // 1 + d 234 d1 := new(big.Int).Add(priv.D, one) 235 // (1 + d)^-1 236 d1Inv := new(big.Int).ModInverse(d1, N) 237 // ((1 + d)^-1) × (k-rd) 238 s.Mul(s, d1Inv) 239 // (((1 + d)^-1) (k-rd)) mod n 240 s.Mod(s, N) 241 if s.Sign() != 0 { 242 break 243 } 244 } 245 return 246 } 247 248 // Sm2Verify SM2验签 249 // - pub : 验签公钥, *sm2.PublicKey 250 // - signContentDigest : 签名内容摘要(散列值) 251 // - uid : 内部混合摘要计算用uid, 长度16的字节数组,可以传 nil 252 // - r, s : 签名 253 func Sm2Verify(pub *PublicKey, signContentDigest, uid []byte, r, s *big.Int) bool { 254 c := pub.Curve 255 N := c.Params().N 256 one := new(big.Int).SetInt64(1) 257 if r.Cmp(one) < 0 || s.Cmp(one) < 0 { 258 return false 259 } 260 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 261 return false 262 } 263 if len(uid) == 0 { 264 uid = defaultUid 265 } 266 // 获取za: sm3(ENTLA || IDA || a || b || xG || yG || xA || yA) 267 za, err := ZA(pub, uid) 268 if err != nil { 269 return false 270 } 271 // 混合za与签名内容明文,并做sm3摘要 272 e, err := msgHash(za, signContentDigest) 273 if err != nil { 274 return false 275 } 276 // 计算 t = (r + s) mod n 277 t := new(big.Int).Add(r, s) 278 t.Mod(t, N) 279 if t.Sign() == 0 { 280 return false 281 } 282 var x *big.Int 283 // 计算 s*G 284 x1, y1 := c.ScalarBaseMult(s.Bytes()) 285 // 计算 t*pub 286 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 287 // 计算 s*G + t*pub 结果只要x轴座标 288 x, _ = c.Add(x1, y1, x2, y2) 289 // 计算 e + x 290 x.Add(x, e) 291 // 计算 R = (e + x) mod n 292 x.Mod(x, N) 293 // 判断 R == r 294 return x.Cmp(r) == 0 295 } 296 297 // Verify SM2验签 298 // - pub : 验签公钥, *sm2.PublicKey 299 // - hash : 签名内容摘要(散列值) 300 // - r, s : 签名 301 // 302 //goland:noinspection GoUnusedExportedFunction 303 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { 304 return Sm2Verify(pub, hash, nil, r, s) 305 } 306 307 /* 308 za, err := ZA(pub, uid) 309 if err != nil { 310 return 311 } 312 e, err := msgHash(za, msg) 313 hash=e.getBytes() 314 */ 315 // 并非sm2验签 316 // func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { 317 // c := pub.Curve 318 // N := c.Params().N 319 320 // if r.Sign() <= 0 || s.Sign() <= 0 { 321 // return false 322 // } 323 // if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 324 // return false 325 // } 326 327 // // 调整算法细节以实现SM2 328 // t := new(big.Int).Add(r, s) 329 // t.Mod(t, N) 330 // if t.Sign() == 0 { 331 // return false 332 // } 333 334 // var x *big.Int 335 // x1, y1 := c.ScalarBaseMult(s.Bytes()) 336 // x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 337 // x, _ = c.Add(x1, y1, x2, y2) 338 339 // e := new(big.Int).SetBytes(hash) 340 // x.Add(x, e) 341 // x.Mod(x, N) 342 // return x.Cmp(r) == 0 343 // } 344 345 // Encrypt sm2非对称加密,支持C1C3C2(mode = 0)与C1C2C3(mode = 1)两种模式,默认使用C1C3C2模式。 346 // 不同的模式表示不同的密文结构,其中C1C2C3的意义: 347 // C1 : sm2椭圆曲线上的某个点,每次加密得到的点不一样 348 // C2 : 密文 349 // C3 : 明文加盐后的摘要 350 func Encrypt(pub *PublicKey, data []byte, random io.Reader, mode int) ([]byte, error) { 351 length := len(data) 352 for { 353 c := []byte{} 354 curve := pub.Curve 355 // 获取随机数k 356 k, err := randFieldElement(curve, random) 357 if err != nil { 358 return nil, err 359 } 360 // 计算点C1 = k*G ,因为k是随机数,所以C1每次加密都是随机的 361 x1, y1 := curve.ScalarBaseMult(k.Bytes()) 362 // 计算点(x2,y2) = k*pub,利用公钥计算出一个随机的点P 363 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) 364 x1Buf := x1.Bytes() 365 y1Buf := y1.Bytes() 366 x2Buf := x2.Bytes() 367 y2Buf := y2.Bytes() 368 // 填充满32个字节长度 369 if n := len(x1Buf); n < 32 { 370 x1Buf = append(zeroByteSlice()[:32-n], x1Buf...) 371 } 372 if n := len(y1Buf); n < 32 { 373 y1Buf = append(zeroByteSlice()[:32-n], y1Buf...) 374 } 375 if n := len(x2Buf); n < 32 { 376 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...) 377 } 378 if n := len(y2Buf); n < 32 { 379 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...) 380 } 381 // 填入C1(x) 382 c = append(c, x1Buf...) 383 // 填入C1(y) 384 c = append(c, y1Buf...) 385 386 // 计算C3 : 按 x2 data y2 的顺序混合数据并做sm3摘要 387 tm := []byte{} 388 tm = append(tm, x2Buf...) 389 tm = append(tm, data...) 390 tm = append(tm, y2Buf...) 391 h := sm3.Sm3Sum(tm) 392 // 填入C3 393 c = append(c, h...) 394 395 // 使用密钥派生函数kdf,基于P计算长度等于data长度的派生密钥 ct 396 ct, ok := kdf(length, x2Buf, y2Buf) 397 if !ok { 398 continue 399 } 400 // 填入ct 401 c = append(c, ct...) 402 // 利用ct对data进行异或加密,并覆盖c中对应内容 403 for i := 0; i < length; i++ { 404 c[96+i] ^= data[i] 405 } 406 407 // 此时c的内容是 c1c3c2,需要根据传入的参数mode判断是否需要重新排列。 408 switch mode { 409 case C1C3C2: 410 return append([]byte{0x04}, c...), nil 411 case C1C2C3: 412 // 如果是 C1C2C3 模式,那么需要将c切分后重新组装 413 c1 := make([]byte, 64) 414 c2 := make([]byte, len(c)-96) 415 c3 := make([]byte, 32) 416 // C1,即 x1Buf+y1Buf 417 copy(c1, c[:64]) 418 // C3,即 x2+data+y2混合后的SM3摘要 419 copy(c3, c[64:96]) 420 // C2,即 使用kdf派生出的密钥对data进行加密后的密文 421 copy(c2, c[96:]) 422 // 按C1C2C3的顺序组装结果 423 ciphertext := []byte{} 424 ciphertext = append(ciphertext, c1...) 425 ciphertext = append(ciphertext, c2...) 426 ciphertext = append(ciphertext, c3...) 427 return append([]byte{0x04}, ciphertext...), nil 428 default: 429 return append([]byte{0x04}, c...), nil 430 } 431 } 432 } 433 434 // DecryptAsn1 sm2解密,解析asn.1编码格式的密文内容 435 func Decrypt(priv *PrivateKey, data []byte, mode int) ([]byte, error) { 436 switch mode { 437 case C1C3C2: 438 data = data[1:] 439 case C1C2C3: 440 // C1C2C3重新组装为 C1C3C2 441 data = data[1:] 442 c1 := make([]byte, 64) 443 c2 := make([]byte, len(data)-96) 444 c3 := make([]byte, 32) 445 copy(c1, data[:64]) //x1,y1 446 copy(c2, data[64:len(data)-32]) //密文 447 copy(c3, data[len(data)-32:]) //hash 448 c := []byte{} 449 c = append(c, c1...) 450 c = append(c, c3...) 451 c = append(c, c2...) 452 data = c 453 default: 454 data = data[1:] 455 } 456 length := len(data) - 96 457 curve := priv.Curve 458 // 取出C1的x和y 459 x := new(big.Int).SetBytes(data[:32]) 460 y := new(big.Int).SetBytes(data[32:64]) 461 // 根据C1计算 P = d*C1 462 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes()) 463 x2Buf := x2.Bytes() 464 y2Buf := y2.Bytes() 465 if n := len(x2Buf); n < 32 { 466 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...) 467 } 468 if n := len(y2Buf); n < 32 { 469 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...) 470 } 471 // 使用密钥派生函数kdf,基于P计算派生密钥 c 472 c, ok := kdf(length, x2Buf, y2Buf) 473 if !ok { 474 return nil, errors.New("decrypt: failed to decrypt") 475 } 476 // 使用派生密钥c对C2部分做异或计算解密 477 // 解密结果覆盖到c中,此时c即明文 478 for i := 0; i < length; i++ { 479 c[i] ^= data[i+96] 480 } 481 // 重新混合明文并计算摘要,与C3进行比较 482 tm := []byte{} 483 tm = append(tm, x2Buf...) 484 tm = append(tm, c...) 485 tm = append(tm, y2Buf...) 486 h := sm3.Sm3Sum(tm) 487 if !bytes.Equal(h, data[64:96]) { 488 return c, errors.New("decrypt: failed to decrypt") 489 } 490 return c, nil 491 } 492 493 // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串 494 // klen: 密钥长度 495 // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识 496 // pri: 函数调用者的密钥 497 // pub: 对方的公钥 498 // rpri: 函数调用者生成的临时SM2密钥 499 // rpub: 对方发来的临时SM2公钥 500 // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false 501 // 返回 k 为klen长度的字节串 502 func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) { 503 curve := P256Sm2() 504 N := curve.Params().N 505 x2hat := keXHat(rpri.PublicKey.X) 506 x2rb := new(big.Int).Mul(x2hat, rpri.D) 507 tbt := new(big.Int).Add(pri.D, x2rb) 508 tb := new(big.Int).Mod(tbt, N) 509 if !curve.IsOnCurve(rpub.X, rpub.Y) { 510 err = errors.New("ra not on curve") 511 return 512 } 513 x1hat := keXHat(rpub.X) 514 ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes()) 515 vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1) 516 517 vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes()) 518 pza := pub 519 if thisISA { 520 pza = &pri.PublicKey 521 } 522 za, err := ZA(pza, ida) 523 if err != nil { 524 return 525 } 526 zero := new(big.Int) 527 if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 { 528 err = errors.New("v is infinite") 529 return 530 } 531 pzb := pub 532 if !thisISA { 533 pzb = &pri.PublicKey 534 } 535 zb, _ := ZA(pzb, idb) 536 k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb) 537 if !ok { 538 err = errors.New("kdf: zero key") 539 return 540 } 541 h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes()) 542 if !thisISA { 543 h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes()) 544 } 545 hash := sm3.Sm3Sum(h1) 546 h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash) 547 S1 := sm3.Sm3Sum(h2) 548 h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash) 549 S2 := sm3.Sm3Sum(h3) 550 return k, S1, S2, nil 551 } 552 553 func msgHash(za, msg []byte) (*big.Int, error) { 554 e := sm3.New() 555 e.Write(za) 556 e.Write(msg) 557 return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil 558 } 559 560 // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) 561 func ZA(pub *PublicKey, uid []byte) ([]byte, error) { 562 za := sm3.New() 563 uidLen := len(uid) 564 if uidLen >= 8192 { 565 return []byte{}, errors.New("SM2: uid too large") 566 } 567 Entla := uint16(8 * uidLen) 568 za.Write([]byte{byte((Entla >> 8) & 0xFF)}) 569 za.Write([]byte{byte(Entla & 0xFF)}) 570 if uidLen > 0 { 571 za.Write(uid) 572 } 573 za.Write(sm2P256ToBig(&sm2P256.a).Bytes()) 574 za.Write(sm2P256.B.Bytes()) 575 za.Write(sm2P256.Gx.Bytes()) 576 za.Write(sm2P256.Gy.Bytes()) 577 578 xBuf := pub.X.Bytes() 579 yBuf := pub.Y.Bytes() 580 if n := len(xBuf); n < 32 { 581 xBuf = append(zeroByteSlice()[:32-n], xBuf...) 582 } 583 if n := len(yBuf); n < 32 { 584 yBuf = append(zeroByteSlice()[:32-n], yBuf...) 585 } 586 za.Write(xBuf) 587 za.Write(yBuf) 588 return za.Sum(nil)[:32], nil 589 } 590 591 // 32byte 592 func zeroByteSlice() []byte { 593 return []byte{ 594 0, 0, 0, 0, 595 0, 0, 0, 0, 596 0, 0, 0, 0, 597 0, 0, 0, 0, 598 0, 0, 0, 0, 599 0, 0, 0, 0, 600 0, 0, 0, 0, 601 0, 0, 0, 0, 602 } 603 } 604 605 // EncryptAsn1 sm2加密,返回asn.1编码格式的密文内容 606 func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) { 607 cipher, err := Encrypt(pub, data, rand, C1C3C2) 608 if err != nil { 609 return nil, err 610 } 611 return CipherMarshal(cipher) 612 } 613 614 // DecryptAsn1 sm2解密,解析asn.1编码格式的密文内容 615 func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) { 616 cipher, err := CipherUnmarshal(data) 617 if err != nil { 618 return nil, err 619 } 620 return Decrypt(pub, cipher, C1C3C2) 621 } 622 623 // CipherMarshal sm2密文转asn.1编码格式 624 // 625 // sm2密文结构如下: 626 // - x 627 // - y 628 // - hash 629 // - CipherText 630 func CipherMarshal(data []byte) ([]byte, error) { 631 data = data[1:] 632 x := new(big.Int).SetBytes(data[:32]) 633 y := new(big.Int).SetBytes(data[32:64]) 634 hash := data[64:96] 635 cipherText := data[96:] 636 return asn1.Marshal(sm2Cipher{x, y, hash, cipherText}) 637 } 638 639 // CipherUnmarshal sm2密文asn.1编码格式转C1|C3|C2拼接格式 640 func CipherUnmarshal(data []byte) ([]byte, error) { 641 var cipher sm2Cipher 642 _, err := asn1.Unmarshal(data, &cipher) 643 if err != nil { 644 return nil, err 645 } 646 x := cipher.XCoordinate.Bytes() 647 y := cipher.YCoordinate.Bytes() 648 hash := cipher.HASH 649 if err != nil { 650 return nil, err 651 } 652 cipherText := cipher.CipherText 653 if err != nil { 654 return nil, err 655 } 656 if n := len(x); n < 32 { 657 x = append(zeroByteSlice()[:32-n], x...) 658 } 659 if n := len(y); n < 32 { 660 y = append(zeroByteSlice()[:32-n], y...) 661 } 662 c := []byte{} 663 c = append(c, x...) // x分量 664 c = append(c, y...) // y分 665 c = append(c, hash...) // x分量 666 c = append(c, cipherText...) // y分 667 return append([]byte{0x04}, c...), nil 668 } 669 670 // keXHat 计算 x = 2^w + (x & (2^w-1)) 671 // 密钥协商算法辅助函数 672 func keXHat(x *big.Int) (xul *big.Int) { 673 buf := x.Bytes() 674 for i := 0; i < len(buf)-16; i++ { 675 buf[i] = 0 676 } 677 if len(buf) >= 16 { 678 c := buf[len(buf)-16] 679 buf[len(buf)-16] = c & 0x7f 680 } 681 682 r := new(big.Int).SetBytes(buf) 683 _2w := new(big.Int).SetBytes([]byte{ 684 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 685 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 686 return r.Add(r, _2w) 687 } 688 689 func BytesCombine(pBytes ...[]byte) []byte { 690 length := len(pBytes) 691 s := make([][]byte, length) 692 for index := 0; index < length; index++ { 693 s[index] = pBytes[index] 694 } 695 sep := []byte("") 696 return bytes.Join(s, sep) 697 } 698 699 func intToBytes(x int) []byte { 700 var buf = make([]byte, 4) 701 702 binary.BigEndian.PutUint32(buf, uint32(x)) 703 return buf 704 } 705 706 func kdf(length int, x ...[]byte) ([]byte, bool) { 707 var c []byte 708 709 ct := 1 710 h := sm3.New() 711 for i, j := 0, (length+31)/32; i < j; i++ { 712 h.Reset() 713 for _, xx := range x { 714 h.Write(xx) 715 } 716 h.Write(intToBytes(ct)) 717 hash := h.Sum(nil) 718 if i+1 == j && length%32 != 0 { 719 c = append(c, hash[:length%32]...) 720 } else { 721 c = append(c, hash...) 722 } 723 ct++ 724 } 725 for i := 0; i < length; i++ { 726 if c[i] != 0 { 727 return c, true 728 } 729 } 730 return c, false 731 } 732 733 // 选取一个位于[1~n-1]之间的随机数k,n是椭圆曲线的参数N 734 func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) { 735 if random == nil { 736 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. 737 } 738 params := c.Params() 739 b := make([]byte, params.BitSize/8+8) 740 _, err = io.ReadFull(random, b) 741 if err != nil { 742 return 743 } 744 k = new(big.Int).SetBytes(b) 745 n := new(big.Int).Sub(params.N, one) 746 k.Mod(k, n) 747 k.Add(k, one) 748 return 749 } 750 751 // GenerateKey 基于P256Sm2曲线生成sm2的公私钥 752 func GenerateKey(random io.Reader) (*PrivateKey, error) { 753 c := P256Sm2() 754 if random == nil { 755 random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it. 756 } 757 params := c.Params() 758 b := make([]byte, params.BitSize/8+8) 759 _, err := io.ReadFull(random, b) 760 if err != nil { 761 return nil, err 762 } 763 // 生成随机数k 764 k := new(big.Int).SetBytes(b) 765 // n = N - 2 766 n := new(big.Int).Sub(params.N, two) 767 // k = k mod n 768 k.Mod(k, n) 769 // k = k + 1 770 k.Add(k, one) 771 priv := new(PrivateKey) 772 // 设置曲线 773 priv.PublicKey.Curve = c 774 // 设置私钥 775 priv.D = k 776 // 公钥 = k * G 777 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 778 779 return priv, nil 780 } 781 782 // ecdsa的实现中,使用zeroReader辅助生成一个cipher.StreamReader用作随机数生成的 rand io.Reader。 783 // sm2目前没有采用这种方式,所以这里将相关代码注释掉了。 784 785 // type zr struct { 786 // io.Reader 787 // } 788 789 // func (z *zr) Read(dst []byte) (n int, err error) { 790 // for i := range dst { 791 // dst[i] = 0 792 // } 793 // return len(dst), nil 794 // } 795 796 // var zeroReader = &zr{} 797 798 func getLastBit(a *big.Int) uint { 799 return a.Bit(0) 800 } 801 802 // Decrypt crypto.Decrypter 803 func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) { 804 return Decrypt(priv, msg, C1C3C2) 805 }