github.com/songzhibin97/gkit@v1.2.13/encrypt/aes/aes.go (about)

     1  package aes
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"encoding/base64"
     8  	"unsafe"
     9  )
    10  
    11  const defaultKey = "gkit"
    12  
    13  func PadKey(s string) string {
    14  	if s == "" {
    15  		s = defaultKey
    16  	}
    17  	ps := []byte(s)
    18  	ls := len(ps)
    19  
    20  	if ls > 32 {
    21  		return string(ps[:32])
    22  	}
    23  	idx := 0
    24  	for i := ls; !(i == 16 || i == 24 || i == 32); i++ {
    25  		ps = append(ps, s[idx])
    26  		idx = (idx + 1) % ls
    27  	}
    28  
    29  	return string(ps)
    30  }
    31  
    32  func Encrypt(orig string, key string) string {
    33  	defer func() {
    34  		if err := recover(); err != nil {
    35  			//fmt.Println("Encrypt Err", orig, err)
    36  			return
    37  		}
    38  	}()
    39  	// 转成字节数组
    40  	origData := []byte(orig)
    41  	k := []byte(key)
    42  	// 分组秘钥
    43  	// NewCipher该函数限制了输入k的长度必须为16, 24或者32
    44  	block, _ := aes.NewCipher(k)
    45  	// 获取秘钥块的长度
    46  	blockSize := block.BlockSize()
    47  	// 补全码
    48  	origData = PKCS7Padding(origData, blockSize)
    49  	// 加密模式
    50  	blockMode := cipher.NewCBCEncrypter(block, k[:blockSize])
    51  	// 创建数组
    52  	cryted := make([]byte, len(origData))
    53  	// 加密
    54  	blockMode.CryptBlocks(cryted, origData)
    55  	return base64.StdEncoding.EncodeToString(cryted)
    56  }
    57  
    58  func Decrypt(cryted string, key string) string {
    59  	defer func() {
    60  		if err := recover(); err != nil {
    61  			//fmt.Println("Decrypt Err", cryted, err)
    62  			return
    63  		}
    64  	}()
    65  	// 转成字节数组
    66  	crytedByte, _ := base64.StdEncoding.DecodeString(cryted)
    67  	k := []byte(key)
    68  	// 分组秘钥
    69  	block, _ := aes.NewCipher(k)
    70  	// 获取秘钥块的长度
    71  	blockSize := block.BlockSize()
    72  	// 加密模式
    73  	blockMode := cipher.NewCBCDecrypter(block, k[:blockSize])
    74  	// 创建数组
    75  	orig := make([]byte, len(crytedByte))
    76  	// 解密
    77  	if len(crytedByte)%block.BlockSize() != 0 {
    78  		return ""
    79  	}
    80  	if len(orig) < len(crytedByte) {
    81  		return ""
    82  	}
    83  	if InexactOverlap(orig[:len(crytedByte)], crytedByte) {
    84  		return ""
    85  	}
    86  
    87  	blockMode.CryptBlocks(orig, crytedByte)
    88  	// 去补全码
    89  	orig = PKCS7UnPadding(orig)
    90  	return string(orig)
    91  }
    92  
    93  // PKCS7Padding 补码
    94  // AES加密数据块分组长度必须为128bit(byte[16]),密钥长度可以是128bit(byte[16])、192bit(byte[24])、256bit(byte[32])中的任意一个。
    95  func PKCS7Padding(ciphertext []byte, blocksize int) []byte {
    96  	padding := blocksize - len(ciphertext)%blocksize
    97  	padText := bytes.Repeat([]byte{byte(padding)}, padding)
    98  	return append(ciphertext, padText...)
    99  }
   100  
   101  // PKCS7UnPadding 去码
   102  func PKCS7UnPadding(origData []byte) []byte {
   103  	length := len(origData)
   104  	unPadding := int(origData[length-1])
   105  	return origData[:(length - unPadding)]
   106  }
   107  
   108  func InexactOverlap(x, y []byte) bool {
   109  	if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
   110  		return false
   111  	}
   112  	return AnyOverlap(x, y)
   113  }
   114  
   115  func AnyOverlap(x, y []byte) bool {
   116  	return len(x) > 0 && len(y) > 0 &&
   117  		uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) &&
   118  		uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1]))
   119  }