github.com/aaabigfish/gopkg@v1.1.0/crypto/rsa.go (about)

     1  package crypto
     2  
     3  import (
     4      "bytes"
     5      "crypto/rand"
     6      "crypto/rsa"
     7      "crypto/x509"
     8      "encoding/pem"
     9      "errors"
    10      "fmt"
    11      "io"
    12      "io/ioutil"
    13      "math/big"
    14      "os"
    15      "runtime/debug"
    16  )
    17  
    18  var (
    19      ErrDataToLarge     = errors.New("message too long for RSA public key size")
    20      ErrDataLen         = errors.New("data length error")
    21      ErrDataBroken      = errors.New("data broken, first byte is not zero")
    22      ErrKeyPairDismatch = errors.New("data is not encrypted by the private key")
    23      ErrDecryption      = errors.New("decryption error")
    24      ErrPublicKey       = errors.New("public key error")
    25      ErrPrivateKey      = errors.New("private key error")
    26  )
    27  // rsa 加解密
    28  type rsaCrypto struct {
    29      // pem 格式公钥
    30      publicKey []byte
    31  
    32      // pem 格式私钥
    33      privateKey []byte
    34  
    35      // rsa 公钥
    36      rsaPriKey *rsa.PrivateKey
    37  
    38      // rsa 私钥
    39      rsaPubKey *rsa.PublicKey
    40  }
    41  
    42  // 创建 rsa 实例
    43  // 公钥 pubKey 和 私钥 priKey 必须传一个,没值的传 nil
    44  // 加解密时公私钥必须是一对
    45  func NewRsa(pubKey, priKey []byte) *rsaCrypto {
    46      defer func() {
    47          if err := recover(); err != nil {
    48              fmt.Println(err)
    49              debug.PrintStack()
    50              os.Exit(-2)
    51          }
    52      }()
    53  
    54      if len(pubKey) == 0 && len(priKey) == 0 {
    55          panic("public key or private key is needed")
    56      }
    57      rc := &rsaCrypto{
    58          publicKey:  pubKey,
    59          privateKey: priKey,
    60      }
    61  
    62      var err error
    63      if len(pubKey) > 0 {
    64          rc.rsaPubKey, err = getRsaPublicKey(pubKey)
    65          if err != nil {
    66              panic(err.Error())
    67          }
    68      }
    69      if len(priKey) > 0 {
    70          rc.rsaPriKey, err = getRsaPrivateKey(priKey)
    71          if err != nil {
    72              panic(err.Error())
    73          }
    74      }
    75      return rc
    76  }
    77  
    78  // rsa 公钥加密
    79  func (r *rsaCrypto) EncryptWithPublicKey(data []byte) ([]byte, error) {
    80      // 解密 pem 格式公钥
    81      block, _ := pem.Decode(r.publicKey)
    82      if block == nil {
    83          return nil, ErrPublicKey
    84      }
    85  
    86      // 解析公钥
    87      pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
    88      if err != nil {
    89          return nil, err
    90      }
    91  
    92      pub := pubInterface.(*rsa.PublicKey)
    93      return rsa.EncryptPKCS1v15(rand.Reader, pub, data)
    94  }
    95  
    96  // rsa 私钥解密
    97  func (r *rsaCrypto) DecryptWithPrivateKey(ciphertext []byte) ([]byte, error) {
    98      block, _ := pem.Decode(r.privateKey)
    99      if block == nil {
   100          return nil, ErrPrivateKey
   101      }
   102  
   103      pri, err := x509.ParsePKCS1PrivateKey(block.Bytes)
   104      if err != nil {
   105          p, err := x509.ParsePKCS8PrivateKey(block.Bytes)
   106          if err != nil {
   107              return nil, err
   108          }
   109          pri = p.(*rsa.PrivateKey)
   110      }
   111      return rsa.DecryptPKCS1v15(rand.Reader, pri, ciphertext)
   112  }
   113  
   114  // rsa 私钥加密
   115  func (r *rsaCrypto) EncryptWithPrivateKey(data []byte) ([]byte, error) {
   116      out := bytes.NewBuffer(nil)
   117      err := r.privKeyIO(bytes.NewReader(data), out)
   118      if err != nil {
   119          return nil, err
   120      }
   121      return ioutil.ReadAll(out)
   122  }
   123  
   124  // rsa 公钥解密
   125  func (r *rsaCrypto) DecryptWithPublicKey(ciphertext []byte) ([]byte, error) {
   126      out := bytes.NewBuffer(nil)
   127      err := r.pubKeyIO(bytes.NewReader(ciphertext), out)
   128      if err != nil {
   129          return nil, err
   130      }
   131      return ioutil.ReadAll(out)
   132  }
   133  
   134  // 公钥解密 reader
   135  func (r *rsaCrypto) pubKeyIO(in io.Reader, w io.Writer) (err error) {
   136      k := (r.rsaPubKey.N.BitLen() + 7) / 8
   137      buf := make([]byte, k)
   138      var b []byte
   139      size := 0
   140      for {
   141          size, err = in.Read(buf)
   142          if err != nil {
   143              if err == io.EOF {
   144                  return nil
   145              }
   146              return err
   147          }
   148          if size < k {
   149              b = buf[:size]
   150          } else {
   151              b = buf
   152          }
   153          b, err = r.pubKeyDecrypt(b)
   154          if err != nil {
   155              return err
   156          }
   157          if _, err = w.Write(b); err != nil {
   158              return err
   159          }
   160      }
   161      return nil
   162  }
   163  
   164  // 私钥加密 reader
   165  func (r *rsaCrypto) privKeyIO(re io.Reader, w io.Writer) (err error) {
   166      k := (r.rsaPriKey.N.BitLen()+7)/8 - 11
   167      buf := make([]byte, k)
   168      var b []byte
   169      size := 0
   170      for {
   171          size, err = re.Read(buf)
   172          if err != nil {
   173              if err == io.EOF {
   174                  return nil
   175              }
   176              return err
   177          }
   178          if size < k {
   179              b = buf[:size]
   180          } else {
   181              b = buf
   182          }
   183          b, err = r.priKeyEncrypt(rand.Reader, b)
   184          if err != nil {
   185              return err
   186          }
   187          if _, err = w.Write(b); err != nil {
   188              return err
   189          }
   190      }
   191      return nil
   192  }
   193  
   194  // 私钥加密
   195  func (r *rsaCrypto) priKeyEncrypt(rand io.Reader, hashed []byte) ([]byte, error) {
   196      hl := len(hashed)
   197      k := (r.rsaPriKey.N.BitLen() + 7) / 8
   198      if k < hl+11 {
   199          return nil, ErrDataLen
   200      }
   201      em := make([]byte, k)
   202      em[1] = 1
   203      for i := 2; i < k-hl-1; i++ {
   204          em[i] = 0xff
   205      }
   206      copy(em[k-hl:k], hashed)
   207      m := new(big.Int).SetBytes(em)
   208      c, err := decrypt(rand, r.rsaPriKey, m)
   209      if err != nil {
   210          return nil, err
   211      }
   212      copyWithLeftPad(em, c.Bytes())
   213      return em, nil
   214  }
   215  
   216  // 公钥解密
   217  func (r *rsaCrypto) pubKeyDecrypt(data []byte) ([]byte, error) {
   218      k := (r.rsaPubKey.N.BitLen() + 7) / 8
   219      if k != len(data) {
   220          return nil, ErrDataLen
   221      }
   222      m := new(big.Int).SetBytes(data)
   223      if m.Cmp(r.rsaPubKey.N) > 0 {
   224          return nil, ErrDataToLarge
   225      }
   226      m.Exp(m, big.NewInt(int64(r.rsaPubKey.E)), r.rsaPubKey.N)
   227      d := leftPad(m.Bytes(), k)
   228      if d[0] != 0 {
   229          return nil, ErrDataBroken
   230      }
   231      if d[1] != 0 && d[1] != 1 {
   232          return nil, ErrKeyPairDismatch
   233      }
   234      var i = 2
   235      for ; i < len(d); i++ {
   236          if d[i] == 0 {
   237              break
   238          }
   239      }
   240      i++
   241      if i == len(d) {
   242          return nil, nil
   243      }
   244      return d[i:], nil
   245  }
   246  
   247  // 获取 rsa 私钥
   248  func getRsaPrivateKey(privateKey []byte) (*rsa.PrivateKey, error) {
   249      block, _ := pem.Decode(privateKey)
   250      if block == nil {
   251          return nil, ErrPrivateKey
   252      }
   253      pri, err := x509.ParsePKCS1PrivateKey(block.Bytes)
   254      if err == nil {
   255          return pri, nil
   256      }
   257      p, err := x509.ParsePKCS8PrivateKey(block.Bytes)
   258      if err != nil {
   259          return nil, err
   260      }
   261      return p.(*rsa.PrivateKey), nil
   262  }
   263  
   264  // 设置 rsa 公钥
   265  func getRsaPublicKey(publicKey []byte) (*rsa.PublicKey, error) {
   266      block, _ := pem.Decode(publicKey)
   267      if block == nil {
   268          return nil, ErrPublicKey
   269      }
   270      // x509 parse public key
   271      pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   272      if err != nil {
   273          return nil, err
   274      }
   275      return pub.(*rsa.PublicKey), nil
   276  }
   277  
   278  // 从 crypto/rsa 复制
   279  var bigZero = big.NewInt(0)
   280  var bigOne = big.NewInt(1)
   281  
   282  // 从 crypto/rsa 复制
   283  func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) {
   284      if c.Cmp(priv.N) > 0 {
   285          err = ErrDecryption
   286          return
   287      }
   288      var ir *big.Int
   289      if random != nil {
   290          var r *big.Int
   291  
   292          for {
   293              r, err = rand.Int(random, priv.N)
   294              if err != nil {
   295                  return
   296              }
   297              if r.Cmp(bigZero) == 0 {
   298                  r = bigOne
   299              }
   300              var ok bool
   301              ir, ok = modInverse(r, priv.N)
   302              if ok {
   303                  break
   304              }
   305          }
   306          bigE := big.NewInt(int64(priv.E))
   307          rpowe := new(big.Int).Exp(r, bigE, priv.N)
   308          cCopy := new(big.Int).Set(c)
   309          cCopy.Mul(cCopy, rpowe)
   310          cCopy.Mod(cCopy, priv.N)
   311          c = cCopy
   312      }
   313      if priv.Precomputed.Dp == nil {
   314          m = new(big.Int).Exp(c, priv.D, priv.N)
   315      } else {
   316          m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
   317          m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
   318          m.Sub(m, m2)
   319          if m.Sign() < 0 {
   320              m.Add(m, priv.Primes[0])
   321          }
   322          m.Mul(m, priv.Precomputed.Qinv)
   323          m.Mod(m, priv.Primes[0])
   324          m.Mul(m, priv.Primes[1])
   325          m.Add(m, m2)
   326  
   327          for i, values := range priv.Precomputed.CRTValues {
   328              prime := priv.Primes[2+i]
   329              m2.Exp(c, values.Exp, prime)
   330              m2.Sub(m2, m)
   331              m2.Mul(m2, values.Coeff)
   332              m2.Mod(m2, prime)
   333              if m2.Sign() < 0 {
   334                  m2.Add(m2, prime)
   335              }
   336              m2.Mul(m2, values.R)
   337              m.Add(m, m2)
   338          }
   339      }
   340      if ir != nil {
   341          m.Mul(m, ir)
   342          m.Mod(m, priv.N)
   343      }
   344  
   345      return
   346  }
   347  
   348  // 从 crypto/rsa 复制
   349  func copyWithLeftPad(dest, src []byte) {
   350      numPaddingBytes := len(dest) - len(src)
   351      for i := 0; i < numPaddingBytes; i++ {
   352          dest[i] = 0
   353      }
   354      copy(dest[numPaddingBytes:], src)
   355  }
   356  
   357  // 从 crypto/rsa 复制
   358  func leftPad(input []byte, size int) (out []byte) {
   359      n := len(input)
   360      if n > size {
   361          n = size
   362      }
   363      out = make([]byte, size)
   364      copy(out[len(out)-n:], input)
   365      return
   366  }
   367  
   368  // 从 crypto/rsa 复制
   369  func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
   370      g := new(big.Int)
   371      x := new(big.Int)
   372      y := new(big.Int)
   373      g.GCD(x, y, a, n)
   374      if g.Cmp(bigOne) != 0 {
   375          return
   376      }
   377      if x.Cmp(bigOne) < 0 {
   378          x.Add(x, n)
   379      }
   380      return x, true
   381  }