gitee.com/lh-her-team/common@v1.5.1/crypto/asym/sm2/sk.go (about)

     1  package sm2
     2  
     3  import (
     4  	"bytes"
     5  	crypto2 "crypto"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/elliptic"
     9  	"crypto/rand"
    10  	"crypto/sha512"
    11  	"encoding/asn1"
    12  	"encoding/pem"
    13  	"fmt"
    14  	"io"
    15  	"math/big"
    16  
    17  	"gitee.com/lh-her-team/common/crypto"
    18  	"gitee.com/lh-her-team/common/crypto/hash"
    19  	tjsm2 "github.com/tjfoc/gmsm/sm2"
    20  )
    21  
    22  var defaultSM2Opts = &crypto.EncOpts{
    23  	EncodingType: "",
    24  	BlockMode:    "",
    25  	EnableMAC:    false,
    26  	Hash:         0,
    27  	Label:        nil,
    28  	EnableASN1:   true,
    29  }
    30  
    31  type PrivateKey struct {
    32  	K *tjsm2.PrivateKey
    33  }
    34  
    35  type Sig struct {
    36  	R *big.Int `json:"r"`
    37  	S *big.Int `json:"s"`
    38  }
    39  
    40  func (sk *PrivateKey) Bytes() ([]byte, error) {
    41  	if sk.K == nil {
    42  		return nil, fmt.Errorf("private key is nil")
    43  	}
    44  	return MarshalPKCS8PrivateKey(sk.K)
    45  }
    46  
    47  func (sk *PrivateKey) PublicKey() crypto.PublicKey {
    48  	return &PublicKey{K: &sk.K.PublicKey}
    49  }
    50  
    51  const (
    52  	aesIV = "IV for <SM2> CTR"
    53  )
    54  
    55  type zr struct {
    56  	io.Reader
    57  }
    58  
    59  func (z *zr) Read(dst []byte) (n int, err error) {
    60  	for i := range dst {
    61  		dst[i] = 0
    62  	}
    63  	return len(dst), nil
    64  }
    65  
    66  var zeroReader = &zr{}
    67  var one = new(big.Int).SetInt64(1)
    68  
    69  func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
    70  	params := c.Params()
    71  	b := make([]byte, params.BitSize/8+8)
    72  	_, err = io.ReadFull(rand, b)
    73  	if err != nil {
    74  		return
    75  	}
    76  	k = new(big.Int).SetBytes(b)
    77  	n := new(big.Int).Sub(params.N, one)
    78  	k.Mod(k, n)
    79  	k.Add(k, one)
    80  	return
    81  }
    82  
    83  func SM2Sign(priv *tjsm2.PrivateKey, hash []byte) (r, s *big.Int, err error) {
    84  	entropylen := (priv.Curve.Params().BitSize + 7) / 16
    85  	if entropylen > 32 {
    86  		entropylen = 32
    87  	}
    88  	entropy := make([]byte, entropylen)
    89  	_, err = io.ReadFull(rand.Reader, entropy)
    90  	if err != nil {
    91  		return
    92  	}
    93  	// Initialize an SHA-512 hash context; digest ...
    94  	md := sha512.New()
    95  	md.Write(priv.D.Bytes()) // the private key,
    96  	md.Write(entropy)        // the entropy,
    97  	md.Write(hash)           // and the input hash;
    98  	key := md.Sum(nil)[:32]  // and compute ChopMD-256(SHA-512),
    99  	// which is an indifferentiable MAC.
   100  	// Create an AES-CTR instance to use as a CSPRNG.
   101  	block, err := aes.NewCipher(key)
   102  	if err != nil {
   103  		return nil, nil, err
   104  	}
   105  	// Create a CSPRNG that xors a stream of zeros with
   106  	// the output of the AES-CTR instance.
   107  	csprng := cipher.StreamReader{
   108  		R: zeroReader,
   109  		S: cipher.NewCTR(block, []byte(aesIV)),
   110  	}
   111  	// See [NSA] 3.4.1
   112  	c := priv.PublicKey.Curve
   113  	n := c.Params().N
   114  	if n.Sign() == 0 {
   115  		return nil, nil, fmt.Errorf("zero parameter")
   116  	}
   117  	var k *big.Int
   118  	e := new(big.Int).SetBytes(hash)
   119  	for { // 调整算法细节以实现SM2
   120  		for {
   121  			k, err = randFieldElement(c, csprng)
   122  			if err != nil {
   123  				r = nil
   124  				return
   125  			}
   126  			r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
   127  			r.Add(r, e)
   128  			r.Mod(r, n)
   129  			if r.Sign() != 0 {
   130  				if t := new(big.Int).Add(r, k); t.Cmp(n) != 0 {
   131  					break
   132  				}
   133  			}
   134  		}
   135  		rD := new(big.Int).Mul(priv.D, r)
   136  		s = new(big.Int).Sub(k, rD)
   137  		d1 := new(big.Int).Add(priv.D, one)
   138  		d1Inv := new(big.Int).ModInverse(d1, n)
   139  		s.Mul(s, d1Inv)
   140  		s.Mod(s, n)
   141  		if s.Sign() != 0 {
   142  			break
   143  		}
   144  	}
   145  	return
   146  }
   147  
   148  func (sk *PrivateKey) Sign(digest []byte) ([]byte, error) {
   149  	var (
   150  		r, s *big.Int
   151  		err  error
   152  	)
   153  	r, s, err = SM2Sign(sk.K, digest[:])
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	return asn1.Marshal(Sig{R: r, S: s})
   158  }
   159  
   160  func (sk *PrivateKey) SignWithOpts(msg []byte, opts *crypto.SignOpts) ([]byte, error) {
   161  	if opts == nil {
   162  		return sk.Sign(msg)
   163  	}
   164  	if opts.Hash == crypto.HASH_TYPE_SM3 && sk.Type() == crypto.SM2 {
   165  		uid := opts.UID
   166  		if len(uid) == 0 {
   167  			uid = crypto.CRYPTO_DEFAULT_UID
   168  		}
   169  		r, s, err := tjsm2.Sm2Sign(sk.K, msg, []byte(uid), rand.Reader)
   170  		if err != nil {
   171  			return nil, fmt.Errorf("fail to sign with SM2-SM3: [%v]", err)
   172  		}
   173  		return asn1.Marshal(Sig{R: r, S: s})
   174  	}
   175  	dgst, err := hash.Get(opts.Hash, msg)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	return sk.Sign(dgst)
   180  }
   181  
   182  func (sk *PrivateKey) Type() crypto.KeyType {
   183  	return sk.PublicKey().Type()
   184  }
   185  
   186  func (sk *PrivateKey) String() (string, error) {
   187  	skDER, err := sk.Bytes()
   188  	if err != nil {
   189  		return "", err
   190  	}
   191  	block := &pem.Block{
   192  		Type:  "PRIVATE KEY",
   193  		Bytes: skDER,
   194  	}
   195  	buf := new(bytes.Buffer)
   196  	if err = pem.Encode(buf, block); err != nil {
   197  		return "", err
   198  	}
   199  	return buf.String(), nil
   200  }
   201  
   202  func (sk *PrivateKey) ToStandardKey() crypto2.PrivateKey {
   203  	return sk.K
   204  }
   205  
   206  func (sk *PrivateKey) Decrypt(ciphertext []byte) ([]byte, error) {
   207  	return sk.DecryptWithOpts(ciphertext, defaultSM2Opts)
   208  }
   209  
   210  func (sk *PrivateKey) DecryptWithOpts(ciphertext []byte, opts *crypto.EncOpts) ([]byte, error) {
   211  	if opts == nil || opts.EnableASN1 {
   212  		return tjsm2.DecryptAsn1(sk.K, ciphertext)
   213  	}
   214  	return tjsm2.Decrypt(sk.K, ciphertext, tjsm2.C1C3C2)
   215  }
   216  
   217  func (sk *PrivateKey) EncryptKey() crypto.EncryptKey {
   218  	return &PublicKey{&sk.K.PublicKey}
   219  }
   220  
   221  func New(keyType crypto.KeyType) (crypto.PrivateKey, error) {
   222  	pri, err := tjsm2.GenerateKey(rand.Reader)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	return &PrivateKey{K: pri}, nil
   227  }