github.com/mangodowner/go-gm@v0.0.0-20180818020936-8baa2bd4408c/src/crypto/sm2/sm2.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package sm2
    17  
    18  // reference to ecdsa
    19  import (
    20  	"bytes"
    21  	"crypto"
    22  	"crypto/aes"
    23  	"crypto/cipher"
    24  	"crypto/elliptic"
    25  	"crypto/rand"
    26  	"crypto/sha512"
    27  	"encoding/asn1"
    28  	"encoding/binary"
    29  	"errors"
    30  	"io"
    31  	"math/big"
    32  
    33  	"crypto/sm3"
    34  )
    35  
    36  const (
    37  	aesIV = "IV for <SM2> CTR"
    38  )
    39  
    40  type PublicKey struct {
    41  	elliptic.Curve
    42  	X, Y *big.Int
    43  }
    44  
    45  type PrivateKey struct {
    46  	PublicKey
    47  	D *big.Int
    48  }
    49  
    50  type sm2Signature struct {
    51  	R, S *big.Int
    52  }
    53  
    54  // The SM2's private key contains the public key
    55  func (priv *PrivateKey) Public() crypto.PublicKey {
    56  	return &priv.PublicKey
    57  }
    58  
    59  // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
    60  func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
    61  	r, s, err := Sign(priv, msg)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return asn1.Marshal(sm2Signature{r, s})
    66  }
    67  
    68  func (priv *PrivateKey) Decrypt(data []byte) ([]byte, error) {
    69  	return Decrypt(priv, data)
    70  }
    71  
    72  func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
    73  	var sm2Sign sm2Signature
    74  
    75  	_, err := asn1.Unmarshal(sign, &sm2Sign)
    76  	if err != nil {
    77  		return false
    78  	}
    79  	return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
    80  }
    81  
    82  func (pub *PublicKey) Encrypt(data []byte) ([]byte, error) {
    83  	return Encrypt(pub, data)
    84  }
    85  
    86  var one = new(big.Int).SetInt64(1)
    87  
    88  func intToBytes(x int) []byte {
    89  	var buf = make([]byte, 4)
    90  
    91  	binary.BigEndian.PutUint32(buf, uint32(x))
    92  	return buf
    93  }
    94  
    95  func kdf(x, y []byte, length int) ([]byte, bool) {
    96  	var c []byte
    97  
    98  	ct := 1
    99  	h := sm3.New()
   100  	x = append(x, y...)
   101  	for i, j := 0, (length+31)/32; i < j; i++ {
   102  		h.Reset()
   103  		h.Write(x)
   104  		h.Write(intToBytes(ct))
   105  		hash := h.Sum(nil)
   106  		if i+1 == j && length%32 != 0 {
   107  			c = append(c, hash[:length%32]...)
   108  		} else {
   109  			c = append(c, hash...)
   110  		}
   111  		ct++
   112  	}
   113  	for i := 0; i < length; i++ {
   114  		if c[i] != 0 {
   115  			return c, true
   116  		}
   117  	}
   118  	return c, false
   119  }
   120  
   121  func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
   122  	params := c.Params()
   123  	b := make([]byte, params.BitSize/8+8)
   124  	_, err = io.ReadFull(rand, b)
   125  	if err != nil {
   126  		return
   127  	}
   128  	k = new(big.Int).SetBytes(b)
   129  	n := new(big.Int).Sub(params.N, one)
   130  	k.Mod(k, n)
   131  	k.Add(k, one)
   132  	return
   133  }
   134  
   135  // GenerateKey generates a public and private key pair.
   136  func GenerateKey(c elliptic.Curve, rand io.Reader) (*PrivateKey, error) {
   137  	k, err := randFieldElement(c, rand)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	priv := new(PrivateKey)
   143  	priv.PublicKey.Curve = c
   144  	priv.D = k
   145  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
   146  	return priv, nil
   147  }
   148  
   149  var errZeroParam = errors.New("zero parameter")
   150  
   151  func Sign(priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
   152  	entropylen := (priv.Curve.Params().BitSize + 7) / 16
   153  	if entropylen > 32 {
   154  		entropylen = 32
   155  	}
   156  	entropy := make([]byte, entropylen)
   157  	_, err = io.ReadFull(rand.Reader, entropy)
   158  	if err != nil {
   159  		return
   160  	}
   161  
   162  	// Initialize an SHA-512 hash context; digest ...
   163  	md := sha512.New()
   164  	md.Write(priv.D.Bytes()) // the private key,
   165  	md.Write(entropy)        // the entropy,
   166  	md.Write(hash)           // and the input hash;
   167  	key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
   168  	// which is an indifferentiable MAC.
   169  
   170  	// Create an AES-CTR instance to use as a CSPRNG.
   171  	block, err := aes.NewCipher(key)
   172  	if err != nil {
   173  		return nil, nil, err
   174  	}
   175  
   176  	// Create a CSPRNG that xors a stream of zeros with
   177  	// the output of the AES-CTR instance.
   178  	csprng := cipher.StreamReader{
   179  		R: zeroReader,
   180  		S: cipher.NewCTR(block, []byte(aesIV)),
   181  	}
   182  
   183  	// See [NSA] 3.4.1
   184  	c := priv.PublicKey.Curve
   185  	N := c.Params().N
   186  	if N.Sign() == 0 {
   187  		return nil, nil, errZeroParam
   188  	}
   189  	var k *big.Int
   190  	e := new(big.Int).SetBytes(hash)
   191  	for { // 调整算法细节以实现SM2
   192  		for {
   193  			k, err = randFieldElement(c, csprng)
   194  			if err != nil {
   195  				r = nil
   196  				return
   197  			}
   198  			r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
   199  			r.Add(r, e)
   200  			r.Mod(r, N)
   201  			if r.Sign() != 0 {
   202  				break
   203  			}
   204  			if t := new(big.Int).Add(r, k); t.Cmp(N) == 0 {
   205  				break
   206  			}
   207  		}
   208  		rD := new(big.Int).Mul(priv.D, r)
   209  		s = new(big.Int).Sub(k, rD)
   210  		d1 := new(big.Int).Add(priv.D, one)
   211  		d1Inv := new(big.Int).ModInverse(d1, N)
   212  		s.Mul(s, d1Inv)
   213  		s.Mod(s, N)
   214  		if s.Sign() != 0 {
   215  			break
   216  		}
   217  	}
   218  	return
   219  }
   220  
   221  func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
   222  	c := pub.Curve
   223  	N := c.Params().N
   224  
   225  	if r.Sign() <= 0 || s.Sign() <= 0 {
   226  		return false
   227  	}
   228  	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
   229  		return false
   230  	}
   231  
   232  	// 调整算法细节以实现SM2
   233  	t := new(big.Int).Add(r, s)
   234  	t.Mod(t, N)
   235  	if N.Sign() == 0 {
   236  		return false
   237  	}
   238  
   239  	var x *big.Int
   240  	x1, y1 := c.ScalarBaseMult(s.Bytes())
   241  	x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
   242  	x, _ = c.Add(x1, y1, x2, y2)
   243  
   244  	e := new(big.Int).SetBytes(hash)
   245  	x.Add(x, e)
   246  	x.Mod(x, N)
   247  	return x.Cmp(r) == 0
   248  }
   249  
   250  // 32byte
   251  var zeroByteSlice = []byte{
   252  	0, 0, 0, 0,
   253  	0, 0, 0, 0,
   254  	0, 0, 0, 0,
   255  	0, 0, 0, 0,
   256  	0, 0, 0, 0,
   257  	0, 0, 0, 0,
   258  	0, 0, 0, 0,
   259  	0, 0, 0, 0,
   260  }
   261  
   262  /*
   263   * sm2密文结构如下:
   264   *  x
   265   *  y
   266   *  hash
   267   *  CipherText
   268   */
   269  func Encrypt(pub *PublicKey, data []byte) ([]byte, error) {
   270  	lenx1 := 0
   271  	leny1 := 0
   272  	lenx2 := 0
   273  	leny2 := 0
   274  	length := len(data)
   275  	for {
   276  		c := []byte{}
   277  		curve := pub.Curve
   278  		k, err := randFieldElement(curve, rand.Reader)
   279  		if err != nil {
   280  			return nil, err
   281  		}
   282  		x1, y1 := curve.ScalarBaseMult(k.Bytes())
   283  		x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
   284  		lenx1 = len(x1.Bytes())
   285  		leny1 = len(y1.Bytes())
   286  		lenx2 = len(x2.Bytes())
   287  		leny2 = len(y2.Bytes())
   288  		if lenx1 < 32 {
   289  			c = append(c, zeroByteSlice[:(32-lenx1)]...)
   290  		}
   291  		c = append(c, x1.Bytes()...) // x分量
   292  		if leny1 < 32 {
   293  			c = append(c, zeroByteSlice[:(32-leny1)]...)
   294  		}
   295  		c = append(c, y1.Bytes()...) // y分量
   296  		tm := []byte{}
   297  		if lenx2 < 32 {
   298  			tm = append(tm, zeroByteSlice[:(32-lenx2)]...)
   299  		}
   300  		tm = append(tm, x2.Bytes()...)
   301  		tm = append(tm, data...)
   302  		if leny2 < 32 {
   303  			tm = append(tm, zeroByteSlice[:(32-leny2)]...)
   304  		}
   305  		tm = append(tm, y2.Bytes()...)
   306  		h := sm3.Sm3Sum(tm)
   307  		c = append(c, h...)
   308  		ct, ok := kdf(x2.Bytes(), y2.Bytes(), length) // 密文
   309  		if !ok {
   310  			continue
   311  		}
   312  		c = append(c, ct...)
   313  		for i := 0; i < length; i++ {
   314  			c[96+i] ^= data[i]
   315  		}
   316  		return c, nil
   317  	}
   318  }
   319  
   320  func Decrypt(priv *PrivateKey, data []byte) ([]byte, error) {
   321  	length := len(data) - 96
   322  	curve := priv.Curve
   323  	x := new(big.Int).SetBytes(data[:32])
   324  	y := new(big.Int).SetBytes(data[32:64])
   325  	x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
   326  	c, ok := kdf(x2.Bytes(), y2.Bytes(), length)
   327  	if !ok {
   328  		return nil, errors.New("Decrypt: failed to decrypt")
   329  	}
   330  	for i := 0; i < length; i++ {
   331  		c[i] ^= data[i+96]
   332  	}
   333  	tm := []byte{}
   334  	tm = append(tm, x2.Bytes()...)
   335  	tm = append(tm, c...)
   336  	tm = append(tm, y2.Bytes()...)
   337  	h := sm3.Sm3Sum(tm)
   338  	if bytes.Compare(h, data[64:96]) != 0 {
   339  		return c, errors.New("Decrypt: failed to decrypt")
   340  	}
   341  	return c, nil
   342  }
   343  
   344  type zr struct {
   345  	io.Reader
   346  }
   347  
   348  func (z *zr) Read(dst []byte) (n int, err error) {
   349  	for i := range dst {
   350  		dst[i] = 0
   351  	}
   352  	return len(dst), nil
   353  }
   354  
   355  var zeroReader = &zr{}