github.com/lingyao2333/mo-zero@v1.4.1/core/codec/rsa.go (about)

     1  package codec
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"crypto/x509"
     7  	"encoding/base64"
     8  	"encoding/pem"
     9  	"errors"
    10  	"os"
    11  )
    12  
    13  var (
    14  	// ErrPrivateKey indicates the invalid private key.
    15  	ErrPrivateKey = errors.New("private key error")
    16  	// ErrPublicKey indicates the invalid public key.
    17  	ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
    18  	// ErrNotRsaKey indicates the invalid RSA key.
    19  	ErrNotRsaKey = errors.New("key type is not RSA")
    20  )
    21  
    22  type (
    23  	// RsaDecrypter represents a RSA decrypter.
    24  	RsaDecrypter interface {
    25  		Decrypt(input []byte) ([]byte, error)
    26  		DecryptBase64(input string) ([]byte, error)
    27  	}
    28  
    29  	// RsaEncrypter represents a RSA encrypter.
    30  	RsaEncrypter interface {
    31  		Encrypt(input []byte) ([]byte, error)
    32  	}
    33  
    34  	rsaBase struct {
    35  		bytesLimit int
    36  	}
    37  
    38  	rsaDecrypter struct {
    39  		rsaBase
    40  		privateKey *rsa.PrivateKey
    41  	}
    42  
    43  	rsaEncrypter struct {
    44  		rsaBase
    45  		publicKey *rsa.PublicKey
    46  	}
    47  )
    48  
    49  // NewRsaDecrypter returns a RsaDecrypter with the given file.
    50  func NewRsaDecrypter(file string) (RsaDecrypter, error) {
    51  	content, err := os.ReadFile(file)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	block, _ := pem.Decode(content)
    57  	if block == nil {
    58  		return nil, ErrPrivateKey
    59  	}
    60  
    61  	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	return &rsaDecrypter{
    67  		rsaBase: rsaBase{
    68  			bytesLimit: privateKey.N.BitLen() >> 3,
    69  		},
    70  		privateKey: privateKey,
    71  	}, nil
    72  }
    73  
    74  func (r *rsaDecrypter) Decrypt(input []byte) ([]byte, error) {
    75  	return r.crypt(input, func(block []byte) ([]byte, error) {
    76  		return rsaDecryptBlock(r.privateKey, block)
    77  	})
    78  }
    79  
    80  func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
    81  	if len(input) == 0 {
    82  		return nil, nil
    83  	}
    84  
    85  	base64Decoded, err := base64.StdEncoding.DecodeString(input)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	return r.Decrypt(base64Decoded)
    91  }
    92  
    93  // NewRsaEncrypter returns a RsaEncrypter with the given key.
    94  func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
    95  	block, _ := pem.Decode(key)
    96  	if block == nil {
    97  		return nil, ErrPublicKey
    98  	}
    99  
   100  	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	switch pubKey := pub.(type) {
   106  	case *rsa.PublicKey:
   107  		return &rsaEncrypter{
   108  			rsaBase: rsaBase{
   109  				// https://www.ietf.org/rfc/rfc2313.txt
   110  				// The length of the data D shall not be more than k-11 octets, which is
   111  				// positive since the length k of the modulus is at least 12 octets.
   112  				bytesLimit: (pubKey.N.BitLen() >> 3) - 11,
   113  			},
   114  			publicKey: pubKey,
   115  		}, nil
   116  	default:
   117  		return nil, ErrNotRsaKey
   118  	}
   119  }
   120  
   121  func (r *rsaEncrypter) Encrypt(input []byte) ([]byte, error) {
   122  	return r.crypt(input, func(block []byte) ([]byte, error) {
   123  		return rsaEncryptBlock(r.publicKey, block)
   124  	})
   125  }
   126  
   127  func (r *rsaBase) crypt(input []byte, cryptFn func([]byte) ([]byte, error)) ([]byte, error) {
   128  	var result []byte
   129  	inputLen := len(input)
   130  
   131  	for i := 0; i*r.bytesLimit < inputLen; i++ {
   132  		start := r.bytesLimit * i
   133  		var stop int
   134  		if r.bytesLimit*(i+1) > inputLen {
   135  			stop = inputLen
   136  		} else {
   137  			stop = r.bytesLimit * (i + 1)
   138  		}
   139  		bs, err := cryptFn(input[start:stop])
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  
   144  		result = append(result, bs...)
   145  	}
   146  
   147  	return result, nil
   148  }
   149  
   150  func rsaDecryptBlock(privateKey *rsa.PrivateKey, block []byte) ([]byte, error) {
   151  	return rsa.DecryptPKCS1v15(rand.Reader, privateKey, block)
   152  }
   153  
   154  func rsaEncryptBlock(publicKey *rsa.PublicKey, msg []byte) ([]byte, error) {
   155  	return rsa.EncryptPKCS1v15(rand.Reader, publicKey, msg)
   156  }