github.com/aaabigfish/gopkg@v1.1.0/crypto/rsa.go (about) 1 package crypto 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "crypto/rsa" 7 "crypto/x509" 8 "encoding/pem" 9 "errors" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "math/big" 14 "os" 15 "runtime/debug" 16 ) 17 18 var ( 19 ErrDataToLarge = errors.New("message too long for RSA public key size") 20 ErrDataLen = errors.New("data length error") 21 ErrDataBroken = errors.New("data broken, first byte is not zero") 22 ErrKeyPairDismatch = errors.New("data is not encrypted by the private key") 23 ErrDecryption = errors.New("decryption error") 24 ErrPublicKey = errors.New("public key error") 25 ErrPrivateKey = errors.New("private key error") 26 ) 27 // rsa 加解密 28 type rsaCrypto struct { 29 // pem 格式公钥 30 publicKey []byte 31 32 // pem 格式私钥 33 privateKey []byte 34 35 // rsa 公钥 36 rsaPriKey *rsa.PrivateKey 37 38 // rsa 私钥 39 rsaPubKey *rsa.PublicKey 40 } 41 42 // 创建 rsa 实例 43 // 公钥 pubKey 和 私钥 priKey 必须传一个,没值的传 nil 44 // 加解密时公私钥必须是一对 45 func NewRsa(pubKey, priKey []byte) *rsaCrypto { 46 defer func() { 47 if err := recover(); err != nil { 48 fmt.Println(err) 49 debug.PrintStack() 50 os.Exit(-2) 51 } 52 }() 53 54 if len(pubKey) == 0 && len(priKey) == 0 { 55 panic("public key or private key is needed") 56 } 57 rc := &rsaCrypto{ 58 publicKey: pubKey, 59 privateKey: priKey, 60 } 61 62 var err error 63 if len(pubKey) > 0 { 64 rc.rsaPubKey, err = getRsaPublicKey(pubKey) 65 if err != nil { 66 panic(err.Error()) 67 } 68 } 69 if len(priKey) > 0 { 70 rc.rsaPriKey, err = getRsaPrivateKey(priKey) 71 if err != nil { 72 panic(err.Error()) 73 } 74 } 75 return rc 76 } 77 78 // rsa 公钥加密 79 func (r *rsaCrypto) EncryptWithPublicKey(data []byte) ([]byte, error) { 80 // 解密 pem 格式公钥 81 block, _ := pem.Decode(r.publicKey) 82 if block == nil { 83 return nil, ErrPublicKey 84 } 85 86 // 解析公钥 87 pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) 88 if err != nil { 89 return nil, err 90 } 91 92 pub := pubInterface.(*rsa.PublicKey) 93 return rsa.EncryptPKCS1v15(rand.Reader, pub, data) 94 } 95 96 // rsa 私钥解密 97 func (r *rsaCrypto) DecryptWithPrivateKey(ciphertext []byte) ([]byte, error) { 98 block, _ := pem.Decode(r.privateKey) 99 if block == nil { 100 return nil, ErrPrivateKey 101 } 102 103 pri, err := x509.ParsePKCS1PrivateKey(block.Bytes) 104 if err != nil { 105 p, err := x509.ParsePKCS8PrivateKey(block.Bytes) 106 if err != nil { 107 return nil, err 108 } 109 pri = p.(*rsa.PrivateKey) 110 } 111 return rsa.DecryptPKCS1v15(rand.Reader, pri, ciphertext) 112 } 113 114 // rsa 私钥加密 115 func (r *rsaCrypto) EncryptWithPrivateKey(data []byte) ([]byte, error) { 116 out := bytes.NewBuffer(nil) 117 err := r.privKeyIO(bytes.NewReader(data), out) 118 if err != nil { 119 return nil, err 120 } 121 return ioutil.ReadAll(out) 122 } 123 124 // rsa 公钥解密 125 func (r *rsaCrypto) DecryptWithPublicKey(ciphertext []byte) ([]byte, error) { 126 out := bytes.NewBuffer(nil) 127 err := r.pubKeyIO(bytes.NewReader(ciphertext), out) 128 if err != nil { 129 return nil, err 130 } 131 return ioutil.ReadAll(out) 132 } 133 134 // 公钥解密 reader 135 func (r *rsaCrypto) pubKeyIO(in io.Reader, w io.Writer) (err error) { 136 k := (r.rsaPubKey.N.BitLen() + 7) / 8 137 buf := make([]byte, k) 138 var b []byte 139 size := 0 140 for { 141 size, err = in.Read(buf) 142 if err != nil { 143 if err == io.EOF { 144 return nil 145 } 146 return err 147 } 148 if size < k { 149 b = buf[:size] 150 } else { 151 b = buf 152 } 153 b, err = r.pubKeyDecrypt(b) 154 if err != nil { 155 return err 156 } 157 if _, err = w.Write(b); err != nil { 158 return err 159 } 160 } 161 return nil 162 } 163 164 // 私钥加密 reader 165 func (r *rsaCrypto) privKeyIO(re io.Reader, w io.Writer) (err error) { 166 k := (r.rsaPriKey.N.BitLen()+7)/8 - 11 167 buf := make([]byte, k) 168 var b []byte 169 size := 0 170 for { 171 size, err = re.Read(buf) 172 if err != nil { 173 if err == io.EOF { 174 return nil 175 } 176 return err 177 } 178 if size < k { 179 b = buf[:size] 180 } else { 181 b = buf 182 } 183 b, err = r.priKeyEncrypt(rand.Reader, b) 184 if err != nil { 185 return err 186 } 187 if _, err = w.Write(b); err != nil { 188 return err 189 } 190 } 191 return nil 192 } 193 194 // 私钥加密 195 func (r *rsaCrypto) priKeyEncrypt(rand io.Reader, hashed []byte) ([]byte, error) { 196 hl := len(hashed) 197 k := (r.rsaPriKey.N.BitLen() + 7) / 8 198 if k < hl+11 { 199 return nil, ErrDataLen 200 } 201 em := make([]byte, k) 202 em[1] = 1 203 for i := 2; i < k-hl-1; i++ { 204 em[i] = 0xff 205 } 206 copy(em[k-hl:k], hashed) 207 m := new(big.Int).SetBytes(em) 208 c, err := decrypt(rand, r.rsaPriKey, m) 209 if err != nil { 210 return nil, err 211 } 212 copyWithLeftPad(em, c.Bytes()) 213 return em, nil 214 } 215 216 // 公钥解密 217 func (r *rsaCrypto) pubKeyDecrypt(data []byte) ([]byte, error) { 218 k := (r.rsaPubKey.N.BitLen() + 7) / 8 219 if k != len(data) { 220 return nil, ErrDataLen 221 } 222 m := new(big.Int).SetBytes(data) 223 if m.Cmp(r.rsaPubKey.N) > 0 { 224 return nil, ErrDataToLarge 225 } 226 m.Exp(m, big.NewInt(int64(r.rsaPubKey.E)), r.rsaPubKey.N) 227 d := leftPad(m.Bytes(), k) 228 if d[0] != 0 { 229 return nil, ErrDataBroken 230 } 231 if d[1] != 0 && d[1] != 1 { 232 return nil, ErrKeyPairDismatch 233 } 234 var i = 2 235 for ; i < len(d); i++ { 236 if d[i] == 0 { 237 break 238 } 239 } 240 i++ 241 if i == len(d) { 242 return nil, nil 243 } 244 return d[i:], nil 245 } 246 247 // 获取 rsa 私钥 248 func getRsaPrivateKey(privateKey []byte) (*rsa.PrivateKey, error) { 249 block, _ := pem.Decode(privateKey) 250 if block == nil { 251 return nil, ErrPrivateKey 252 } 253 pri, err := x509.ParsePKCS1PrivateKey(block.Bytes) 254 if err == nil { 255 return pri, nil 256 } 257 p, err := x509.ParsePKCS8PrivateKey(block.Bytes) 258 if err != nil { 259 return nil, err 260 } 261 return p.(*rsa.PrivateKey), nil 262 } 263 264 // 设置 rsa 公钥 265 func getRsaPublicKey(publicKey []byte) (*rsa.PublicKey, error) { 266 block, _ := pem.Decode(publicKey) 267 if block == nil { 268 return nil, ErrPublicKey 269 } 270 // x509 parse public key 271 pub, err := x509.ParsePKIXPublicKey(block.Bytes) 272 if err != nil { 273 return nil, err 274 } 275 return pub.(*rsa.PublicKey), nil 276 } 277 278 // 从 crypto/rsa 复制 279 var bigZero = big.NewInt(0) 280 var bigOne = big.NewInt(1) 281 282 // 从 crypto/rsa 复制 283 func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) { 284 if c.Cmp(priv.N) > 0 { 285 err = ErrDecryption 286 return 287 } 288 var ir *big.Int 289 if random != nil { 290 var r *big.Int 291 292 for { 293 r, err = rand.Int(random, priv.N) 294 if err != nil { 295 return 296 } 297 if r.Cmp(bigZero) == 0 { 298 r = bigOne 299 } 300 var ok bool 301 ir, ok = modInverse(r, priv.N) 302 if ok { 303 break 304 } 305 } 306 bigE := big.NewInt(int64(priv.E)) 307 rpowe := new(big.Int).Exp(r, bigE, priv.N) 308 cCopy := new(big.Int).Set(c) 309 cCopy.Mul(cCopy, rpowe) 310 cCopy.Mod(cCopy, priv.N) 311 c = cCopy 312 } 313 if priv.Precomputed.Dp == nil { 314 m = new(big.Int).Exp(c, priv.D, priv.N) 315 } else { 316 m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) 317 m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) 318 m.Sub(m, m2) 319 if m.Sign() < 0 { 320 m.Add(m, priv.Primes[0]) 321 } 322 m.Mul(m, priv.Precomputed.Qinv) 323 m.Mod(m, priv.Primes[0]) 324 m.Mul(m, priv.Primes[1]) 325 m.Add(m, m2) 326 327 for i, values := range priv.Precomputed.CRTValues { 328 prime := priv.Primes[2+i] 329 m2.Exp(c, values.Exp, prime) 330 m2.Sub(m2, m) 331 m2.Mul(m2, values.Coeff) 332 m2.Mod(m2, prime) 333 if m2.Sign() < 0 { 334 m2.Add(m2, prime) 335 } 336 m2.Mul(m2, values.R) 337 m.Add(m, m2) 338 } 339 } 340 if ir != nil { 341 m.Mul(m, ir) 342 m.Mod(m, priv.N) 343 } 344 345 return 346 } 347 348 // 从 crypto/rsa 复制 349 func copyWithLeftPad(dest, src []byte) { 350 numPaddingBytes := len(dest) - len(src) 351 for i := 0; i < numPaddingBytes; i++ { 352 dest[i] = 0 353 } 354 copy(dest[numPaddingBytes:], src) 355 } 356 357 // 从 crypto/rsa 复制 358 func leftPad(input []byte, size int) (out []byte) { 359 n := len(input) 360 if n > size { 361 n = size 362 } 363 out = make([]byte, size) 364 copy(out[len(out)-n:], input) 365 return 366 } 367 368 // 从 crypto/rsa 复制 369 func modInverse(a, n *big.Int) (ia *big.Int, ok bool) { 370 g := new(big.Int) 371 x := new(big.Int) 372 y := new(big.Int) 373 g.GCD(x, y, a, n) 374 if g.Cmp(bigOne) != 0 { 375 return 376 } 377 if x.Cmp(bigOne) < 0 { 378 x.Add(x, n) 379 } 380 return x, true 381 }