github.com/mangodowner/go-gm@v0.0.0-20180818020936-8baa2bd4408c/src/crypto/sm2/sm2.go (about) 1 /* 2 Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 package sm2 17 18 // reference to ecdsa 19 import ( 20 "bytes" 21 "crypto" 22 "crypto/aes" 23 "crypto/cipher" 24 "crypto/elliptic" 25 "crypto/rand" 26 "crypto/sha512" 27 "encoding/asn1" 28 "encoding/binary" 29 "errors" 30 "io" 31 "math/big" 32 33 "crypto/sm3" 34 ) 35 36 const ( 37 aesIV = "IV for <SM2> CTR" 38 ) 39 40 type PublicKey struct { 41 elliptic.Curve 42 X, Y *big.Int 43 } 44 45 type PrivateKey struct { 46 PublicKey 47 D *big.Int 48 } 49 50 type sm2Signature struct { 51 R, S *big.Int 52 } 53 54 // The SM2's private key contains the public key 55 func (priv *PrivateKey) Public() crypto.PublicKey { 56 return &priv.PublicKey 57 } 58 59 // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s 60 func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) { 61 r, s, err := Sign(priv, msg) 62 if err != nil { 63 return nil, err 64 } 65 return asn1.Marshal(sm2Signature{r, s}) 66 } 67 68 func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) { 69 return Decrypt(priv, data) 70 } 71 72 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool { 73 var sm2Sign sm2Signature 74 75 _, err := asn1.Unmarshal(sign, &sm2Sign) 76 if err != nil { 77 return false 78 } 79 return Verify(pub, msg, sm2Sign.R, sm2Sign.S) 80 } 81 82 func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) { 83 return Encrypt(pub, data) 84 } 85 86 var one = new(big.Int).SetInt64(1) 87 88 func intToBytes(x int) []byte { 89 var buf = make([]byte, 4) 90 91 binary.BigEndian.PutUint32(buf, uint32(x)) 92 return buf 93 } 94 95 func kdf(x, y []byte, length int) ([]byte, bool) { 96 var c []byte 97 98 ct := 1 99 h := sm3.New() 100 x = append(x, y...) 101 for i, j := 0, (length+31)/32; i < j; i++ { 102 h.Reset() 103 h.Write(x) 104 h.Write(intToBytes(ct)) 105 hash := h.Sum(nil) 106 if i+1 == j && length%32 != 0 { 107 c = append(c, hash[:length%32]...) 108 } else { 109 c = append(c, hash...) 110 } 111 ct++ 112 } 113 for i := 0; i < length; i++ { 114 if c[i] != 0 { 115 return c, true 116 } 117 } 118 return c, false 119 } 120 121 func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) { 122 params := c.Params() 123 b := make([]byte, params.BitSize/8+8) 124 _, err = io.ReadFull(rand, b) 125 if err != nil { 126 return 127 } 128 k = new(big.Int).SetBytes(b) 129 n := new(big.Int).Sub(params.N, one) 130 k.Mod(k, n) 131 k.Add(k, one) 132 return 133 } 134 135 // GenerateKey generates a public and private key pair. 136 func GenerateKey(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) { 137 k, err := randFieldElement(c, rand) 138 if err != nil { 139 return nil, err 140 } 141 142 priv := new(PrivateKey) 143 priv.PublicKey.Curve = c 144 priv.D = k 145 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes()) 146 return priv, nil 147 } 148 149 var errZeroParam = errors.New("zero parameter") 150 151 func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) { 152 entropylen := (priv.Curve.Params().BitSize + 7) / 16 153 if entropylen > 32 { 154 entropylen = 32 155 } 156 entropy := make([]byte, entropylen) 157 _, err = io.ReadFull(rand.Reader, entropy) 158 if err != nil { 159 return 160 } 161 162 // Initialize an SHA-512 hash context; digest ... 163 md := sha512.New() 164 md.Write(priv.D.Bytes()) // the private key, 165 md.Write(entropy) // the entropy, 166 md.Write(hash) // and the input hash; 167 key := md.Sum(nil)[:32] // and compute ChopMD-256(SHA-512), 168 // which is an indifferentiable MAC. 169 170 // Create an AES-CTR instance to use as a CSPRNG. 171 block, err := aes.NewCipher(key) 172 if err != nil { 173 return nil, nil, err 174 } 175 176 // Create a CSPRNG that xors a stream of zeros with 177 // the output of the AES-CTR instance. 178 csprng := cipher.StreamReader{ 179 R: zeroReader, 180 S: cipher.NewCTR(block, []byte(aesIV)), 181 } 182 183 // See [NSA] 3.4.1 184 c := priv.PublicKey.Curve 185 N := c.Params().N 186 if N.Sign() == 0 { 187 return nil, nil, errZeroParam 188 } 189 var k *big.Int 190 e := new(big.Int).SetBytes(hash) 191 for { // 调整算法细节以实现SM2 192 for { 193 k, err = randFieldElement(c, csprng) 194 if err != nil { 195 r = nil 196 return 197 } 198 r, _ = priv.Curve.ScalarBaseMult(k.Bytes()) 199 r.Add(r, e) 200 r.Mod(r, N) 201 if r.Sign() != 0 { 202 break 203 } 204 if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 { 205 break 206 } 207 } 208 rD := new(big.Int).Mul(priv.D, r) 209 s = new(big.Int).Sub(k, rD) 210 d1 := new(big.Int).Add(priv.D, one) 211 d1Inv := new(big.Int).ModInverse(d1, N) 212 s.Mul(s, d1Inv) 213 s.Mod(s, N) 214 if s.Sign() != 0 { 215 break 216 } 217 } 218 return 219 } 220 221 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { 222 c := pub.Curve 223 N := c.Params().N 224 225 if r.Sign() <= 0 || s.Sign() <= 0 { 226 return false 227 } 228 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 { 229 return false 230 } 231 232 // 调整算法细节以实现SM2 233 t := new(big.Int).Add(r, s) 234 t.Mod(t, N) 235 if N.Sign() == 0 { 236 return false 237 } 238 239 var x *big.Int 240 x1, y1 := c.ScalarBaseMult(s.Bytes()) 241 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes()) 242 x, _ = c.Add(x1, y1, x2, y2) 243 244 e := new(big.Int).SetBytes(hash) 245 x.Add(x, e) 246 x.Mod(x, N) 247 return x.Cmp(r) == 0 248 } 249 250 // 32byte 251 var zeroByteSlice = []byte{ 252 0, 0, 0, 0, 253 0, 0, 0, 0, 254 0, 0, 0, 0, 255 0, 0, 0, 0, 256 0, 0, 0, 0, 257 0, 0, 0, 0, 258 0, 0, 0, 0, 259 0, 0, 0, 0, 260 } 261 262 /* 263 * sm2密文结构如下: 264 * x 265 * y 266 * hash 267 * CipherText 268 */ 269 func Encrypt(pub *PublicKey, data []byte) ([]byte, error) { 270 lenx1 := 0 271 leny1 := 0 272 lenx2 := 0 273 leny2 := 0 274 length := len(data) 275 for { 276 c := []byte{} 277 curve := pub.Curve 278 k, err := randFieldElement(curve, rand.Reader) 279 if err != nil { 280 return nil, err 281 } 282 x1, y1 := curve.ScalarBaseMult(k.Bytes()) 283 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) 284 lenx1 = len(x1.Bytes()) 285 leny1 = len(y1.Bytes()) 286 lenx2 = len(x2.Bytes()) 287 leny2 = len(y2.Bytes()) 288 if lenx1 < 32 { 289 c = append(c, zeroByteSlice[:(32-lenx1)]...) 290 } 291 c = append(c, x1.Bytes()...) // x分量 292 if leny1 < 32 { 293 c = append(c, zeroByteSlice[:(32-leny1)]...) 294 } 295 c = append(c, y1.Bytes()...) // y分量 296 tm := []byte{} 297 if lenx2 < 32 { 298 tm = append(tm, zeroByteSlice[:(32-lenx2)]...) 299 } 300 tm = append(tm, x2.Bytes()...) 301 tm = append(tm, data...) 302 if leny2 < 32 { 303 tm = append(tm, zeroByteSlice[:(32-leny2)]...) 304 } 305 tm = append(tm, y2.Bytes()...) 306 h := sm3.Sm3Sum(tm) 307 c = append(c, h...) 308 ct, ok := kdf(x2.Bytes(), y2.Bytes(), length) // 密文 309 if !ok { 310 continue 311 } 312 c = append(c, ct...) 313 for i := 0; i < length; i++ { 314 c[96+i] ^= data[i] 315 } 316 return c, nil 317 } 318 } 319 320 func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) { 321 length := len(data) - 96 322 curve := priv.Curve 323 x := new(big.Int).SetBytes(data[:32]) 324 y := new(big.Int).SetBytes(data[32:64]) 325 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes()) 326 c, ok := kdf(x2.Bytes(), y2.Bytes(), length) 327 if !ok { 328 return nil, errors.New("Decrypt: failed to decrypt") 329 } 330 for i := 0; i < length; i++ { 331 c[i] ^= data[i+96] 332 } 333 tm := []byte{} 334 tm = append(tm, x2.Bytes()...) 335 tm = append(tm, c...) 336 tm = append(tm, y2.Bytes()...) 337 h := sm3.Sm3Sum(tm) 338 if bytes.Compare(h, data[64:96]) != 0 { 339 return c, errors.New("Decrypt: failed to decrypt") 340 } 341 return c, nil 342 } 343 344 type zr struct { 345 io.Reader 346 } 347 348 func (z *zr) Read(dst []byte) (n int, err error) { 349 for i := range dst { 350 dst[i] = 0 351 } 352 return len(dst), nil 353 } 354 355 var zeroReader = &zr{}