github.com/flyinox/gosm@v0.0.0-20171117061539-16768cb62077/src/crypto/sm/sm2/sm2.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sm2 implements china crypto standards.
     6  package sm2
     7  
     8  import (
     9  	"crypto"
    10  	"crypto/elliptic"
    11  	"encoding/asn1"
    12  	"errors"
    13  	"io"
    14  	"math/big"
    15  )
    16  
    17  type PublicKey struct {
    18  	elliptic.Curve
    19  	X, Y *big.Int
    20  }
    21  
    22  type PrivateKey struct {
    23  	PublicKey
    24  	D *big.Int
    25  }
    26  
    27  type sm2Signature struct {
    28  	R, S *big.Int
    29  }
    30  
    31  // The SM2's private key contains the public key
    32  func (priv *PrivateKey) Public() crypto.PublicKey {
    33  	return &priv.PublicKey
    34  }
    35  
    36  func (priv *PrivateKey) Sign(rand io.Reader, msg []byte, opts crypto.SignerOpts) ([]byte, error) {
    37  	r, s, err := Sign(rand, priv, msg)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	return asn1.Marshal(sm2Signature{r, s})
    42  }
    43  
    44  func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
    45  	var sm2Sign sm2Signature
    46  	_, err := asn1.Unmarshal(sign, &sm2Sign)
    47  	if err != nil {
    48  		return false
    49  	}
    50  	return Verify(pub, msg, sm2Sign.R, sm2Sign.S)
    51  }
    52  
    53  var one = new(big.Int).SetInt64(1)
    54  
    55  func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
    56  	params := c.Params()
    57  	b := make([]byte, params.BitSize/8+8)
    58  	_, err = io.ReadFull(rand, b)
    59  	if err != nil {
    60  		return
    61  	}
    62  	k = new(big.Int).SetBytes(b)
    63  	n := new(big.Int).Sub(params.N, one)
    64  	k.Mod(k, n)
    65  	k.Add(k, one)
    66  	return
    67  }
    68  
    69  func GenerateKey(rand io.Reader) (*PrivateKey, error) {
    70  	c := elliptic.P256Sm2()
    71  	k, err := randFieldElement(c, rand)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	priv := new(PrivateKey)
    76  	priv.PublicKey.Curve = c
    77  	priv.D = k
    78  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
    79  	return priv, nil
    80  }
    81  
    82  var errZeroParam = errors.New("zero parameter")
    83  
    84  func generateRandK(rand io.Reader, c elliptic.Curve) (k *big.Int) {
    85  	var one = new(big.Int).SetInt64(1)
    86  	params := c.Params()
    87  	b := make([]byte, params.BitSize/8+8)
    88  	_, err := io.ReadFull(rand, b)
    89  	if err != nil {
    90  		return
    91  	}
    92  	k = new(big.Int).SetBytes(b)
    93  	n := new(big.Int).Sub(params.N, one)
    94  	k.Mod(k, n)
    95  	k.Add(k, one)
    96  	return
    97  }
    98  
    99  func Sign(rand io.Reader, priv *PrivateKey, hash []byte) (r, s *big.Int, err error) {
   100  	var one = new(big.Int).SetInt64(1)
   101  	if len(hash) < 32 {
   102  		err = errors.New("The length of hash has short than what SM2 need.")
   103  		return
   104  	}
   105  	var tmp []byte = hash[0:32]
   106  	e := new(big.Int).SetBytes(tmp)
   107  	k := generateRandK(rand, priv.PublicKey.Curve)
   108  
   109  	x1, _ := priv.PublicKey.Curve.ScalarBaseMult(k.Bytes())
   110  
   111  	n := priv.PublicKey.Curve.Params().N
   112  
   113  	r = new(big.Int).Add(e, x1)
   114  
   115  	r.Mod(r, n)
   116  
   117  	s1 := new(big.Int).Mul(r, priv.D)
   118  	s1.Mod(s1, n)
   119  	s1.Sub(k, s1)
   120  	s1.Mod(s1, n)
   121  
   122  	s2 := new(big.Int).Add(one, priv.D)
   123  	s2.Mod(s2, n)
   124  	s2.ModInverse(s2, n)
   125  	s = new(big.Int).Mul(s1, s2)
   126  	s.Mod(s, n)
   127  
   128  	return
   129  }
   130  
   131  func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
   132  	c := pub.Curve
   133  	N := c.Params().N
   134  
   135  	if r.Sign() <= 0 || s.Sign() <= 0 {
   136  		return false
   137  	}
   138  	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
   139  		return false
   140  	}
   141  
   142  	n := pub.Curve.Params().N
   143  	e := new(big.Int).SetBytes(hash)
   144  	t := new(big.Int).Add(r, s)
   145  	x11, y11 := pub.Curve.ScalarMult(pub.X, pub.Y, t.Bytes())
   146  	x12, y12 := pub.Curve.ScalarBaseMult(s.Bytes())
   147  	x1, _ := pub.Curve.Add(x11, y11, x12, y12)
   148  	x := new(big.Int).Add(e, x1)
   149  	x = x.Mod(x, n)
   150  
   151  	return x.Cmp(r) == 0
   152  }
   153  
   154  type zr struct {
   155  	io.Reader
   156  }
   157  
   158  func (z *zr) Read(dst []byte) (n int, err error) {
   159  	for i := range dst {
   160  		dst[i] = 0
   161  	}
   162  	return len(dst), nil
   163  }
   164  
   165  var zeroReader = &zr{}