github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zstring/aes.go (about)

     1  package zstring
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"strings"
    12  )
    13  
    14  func assKeyPadding(key string) []byte {
    15  	k := String2Bytes(key)
    16  	l := len(k)
    17  	switch l {
    18  	case 16, 24, 32:
    19  		return k
    20  	default:
    21  		if l < 16 {
    22  			return append(k, String2Bytes(strings.Repeat(" ", 16-l))...)
    23  		} else if l < 24 {
    24  			return append(k, String2Bytes(strings.Repeat(" ", 24-l))...)
    25  		} else if l < 32 {
    26  			return append(k, String2Bytes(strings.Repeat(" ", 32-l))...)
    27  		}
    28  		return k[0:32]
    29  	}
    30  }
    31  
    32  // PKCS7Padding PKCS7 fill mode
    33  func PKCS7Padding(ciphertext []byte, blockSize int) []byte {
    34  	padding := blockSize - len(ciphertext)%blockSize
    35  	pad := bytes.Repeat([]byte{byte(padding)}, padding)
    36  	return append(ciphertext, pad...)
    37  }
    38  
    39  // PKCS7UnPadding Reverse operation of padding to delete the padding string
    40  func PKCS7UnPadding(origData []byte) ([]byte, error) {
    41  	length := len(origData)
    42  	if length == 0 {
    43  		return nil, errors.New("encryption string error")
    44  	} else {
    45  		u := int(origData[length-1])
    46  		return origData[:(length - u)], nil
    47  	}
    48  }
    49  
    50  // AesEncrypt aes encryption
    51  func AesEncrypt(plainText []byte, key string, iv ...string) (ciphertext []byte,
    52  	err error) {
    53  	var k []byte
    54  	var block cipher.Block
    55  	if len(iv) > 0 {
    56  		k = String2Bytes(iv[0])
    57  		block, err = aes.NewCipher(String2Bytes(key))
    58  	} else {
    59  		k = assKeyPadding(key)
    60  		block, err = aes.NewCipher(k)
    61  	}
    62  	if err == nil {
    63  		blockSize := block.BlockSize()
    64  		plainText = PKCS7Padding(plainText, blockSize)
    65  		blocMode := cipher.NewCBCEncrypter(block, k[:blockSize])
    66  		ciphertext = make([]byte, len(plainText))
    67  		blocMode.CryptBlocks(ciphertext, plainText)
    68  	}
    69  	return
    70  }
    71  
    72  // AesDecrypt aes decryption
    73  func AesDecrypt(cipherText []byte, key string, iv ...string) (plainText []byte, err error) {
    74  	var (
    75  		block cipher.Block
    76  		k     []byte
    77  	)
    78  	if len(iv) > 0 {
    79  		k = String2Bytes(iv[0])
    80  		block, err = aes.NewCipher(String2Bytes(key))
    81  	} else {
    82  		k = assKeyPadding(key)
    83  		block, err = aes.NewCipher(k)
    84  	}
    85  
    86  	if err == nil {
    87  		blockSize := block.BlockSize()
    88  		blockMode := cipher.NewCBCDecrypter(block, k[:blockSize])
    89  		plainText = make([]byte, len(cipherText))
    90  		defer func() {
    91  			if e := recover(); e != nil {
    92  				var ok bool
    93  				err, ok = e.(error)
    94  				if !ok {
    95  					err = fmt.Errorf("%s", e)
    96  				}
    97  			}
    98  		}()
    99  		blockMode.CryptBlocks(plainText, cipherText)
   100  		if err == nil {
   101  			plainText, err = PKCS7UnPadding(plainText)
   102  		}
   103  	}
   104  	return
   105  }
   106  
   107  // AesEncryptString Aes Encrypt to String
   108  func AesEncryptString(plainText string, key string, iv ...string) (string, error) {
   109  	str := ""
   110  	c, err := AesEncrypt(String2Bytes(plainText), key, iv...)
   111  	if err == nil {
   112  		str = Bytes2String(Base64Encode(c))
   113  	}
   114  	return str, nil
   115  }
   116  
   117  // AesDecryptString Aes Decrypt to String
   118  func AesDecryptString(cipherText string, key string, iv ...string) (string,
   119  	error) {
   120  	base64Byte, _ := Base64Decode(String2Bytes(cipherText))
   121  	origData, err := AesDecrypt(base64Byte, key, iv...)
   122  	if err != nil {
   123  		return "", err
   124  	}
   125  	return Bytes2String(origData), nil
   126  }
   127  
   128  func AesGCMEncrypt(plaintext []byte, key string) (ciphertext []byte, err error) {
   129  	var (
   130  		block  cipher.Block
   131  		aesGCM cipher.AEAD
   132  	)
   133  
   134  	block, err = aes.NewCipher(String2Bytes(key))
   135  	if err != nil {
   136  		return
   137  	}
   138  
   139  	aesGCM, err = cipher.NewGCM(block)
   140  	if err != nil {
   141  		return
   142  	}
   143  
   144  	nonce := make([]byte, aesGCM.NonceSize())
   145  	if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
   146  		return
   147  	}
   148  
   149  	ciphertext = aesGCM.Seal(nonce, nonce, plaintext, nil)
   150  	return
   151  }
   152  
   153  func AesGCMDecrypt(ciphertext []byte, key string) (plaintext []byte, err error) {
   154  	if len(ciphertext) == 0 {
   155  		return nil, errors.New("ciphertext is empty")
   156  	}
   157  
   158  	var (
   159  		block  cipher.Block
   160  		aesGCM cipher.AEAD
   161  	)
   162  	block, err = aes.NewCipher(String2Bytes(key))
   163  	if err != nil {
   164  		return
   165  	}
   166  
   167  	aesGCM, err = cipher.NewGCM(block)
   168  	if err != nil {
   169  		return
   170  	}
   171  
   172  	nonceSize := aesGCM.NonceSize()
   173  	if len(ciphertext) < nonceSize {
   174  		return nil, errors.New("ciphertext is too short")
   175  	}
   176  	nonce, text := ciphertext[:nonceSize], ciphertext[nonceSize:]
   177  
   178  	return aesGCM.Open(nil, nonce, text, nil)
   179  }
   180  
   181  func AesGCMEncryptString(plainText string, key string) (string, error) {
   182  	str := ""
   183  	c, err := AesGCMEncrypt(String2Bytes(plainText), key)
   184  	if err == nil {
   185  		str = Bytes2String(Base64Encode(c))
   186  	}
   187  	return str, err
   188  }
   189  
   190  func AesGCMDecryptString(cipherText string, key string) (string,
   191  	error) {
   192  	base64Byte, _ := Base64Decode(String2Bytes(cipherText))
   193  	origData, err := AesGCMDecrypt(base64Byte, key)
   194  	if err != nil {
   195  		return "", err
   196  	}
   197  	return Bytes2String(origData), nil
   198  }