github.com/polarismesh/polaris@v1.17.8/plugin/crypto/aes/aes.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package aes
    19  
    20  import (
    21  	"bytes"
    22  	"crypto/aes"
    23  	"crypto/cipher"
    24  	"crypto/rand"
    25  	"encoding/base64"
    26  	"errors"
    27  
    28  	"github.com/polarismesh/polaris/plugin"
    29  )
    30  
    31  const (
    32  	// PluginName plugin name
    33  	PluginName = "AES"
    34  )
    35  
    36  func init() {
    37  	plugin.RegisterPlugin(PluginName, &AESCrypto{})
    38  }
    39  
    40  // AESCrypto AES crypto
    41  type AESCrypto struct {
    42  }
    43  
    44  // Name 返回插件名字
    45  func (h *AESCrypto) Name() string {
    46  	return PluginName
    47  }
    48  
    49  // Destroy 销毁插件
    50  func (h *AESCrypto) Destroy() error {
    51  	return nil
    52  }
    53  
    54  // Initialize 插件初始化
    55  func (h *AESCrypto) Initialize(c *plugin.ConfigEntry) error {
    56  	return nil
    57  }
    58  
    59  // GenerateKey generate key
    60  func (c *AESCrypto) GenerateKey() ([]byte, error) {
    61  	key := make([]byte, 16)
    62  	_, err := rand.Read(key)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	return key, nil
    67  }
    68  
    69  // Encrypt AES encrypt plaintext and base64 encode ciphertext
    70  func (c *AESCrypto) Encrypt(plaintext string, key []byte) (string, error) {
    71  	if plaintext == "" {
    72  		return "", nil
    73  	}
    74  	ciphertext, err := c.doEncrypt([]byte(plaintext), key)
    75  	if err != nil {
    76  		return "", err
    77  	}
    78  	return base64.StdEncoding.EncodeToString(ciphertext), nil
    79  }
    80  
    81  // Decrypt base64 decode ciphertext and AES decrypt
    82  func (c *AESCrypto) Decrypt(ciphertext string, key []byte) (string, error) {
    83  	if ciphertext == "" {
    84  		return "", nil
    85  	}
    86  	ciphertextBytes, err := base64.StdEncoding.DecodeString(ciphertext)
    87  	if err != nil {
    88  		return "", err
    89  	}
    90  	plaintext, err := c.doDecrypt(ciphertextBytes, key)
    91  	if err != nil {
    92  		return "", err
    93  	}
    94  	return string(plaintext), nil
    95  }
    96  
    97  // Encrypt AES encryption
    98  func (c *AESCrypto) doEncrypt(plaintext []byte, key []byte) ([]byte, error) {
    99  	block, err := aes.NewCipher(key)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	blockSize := block.BlockSize()
   104  	paddingData := pkcs7Padding(plaintext, blockSize)
   105  	ciphertext := make([]byte, len(paddingData))
   106  	blockMode := cipher.NewCBCEncrypter(block, key[:blockSize])
   107  	blockMode.CryptBlocks(ciphertext, paddingData)
   108  	return ciphertext, nil
   109  }
   110  
   111  // Decrypt AES decryption
   112  func (c *AESCrypto) doDecrypt(ciphertext []byte, key []byte) ([]byte, error) {
   113  	block, err := aes.NewCipher(key)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	blockSize := block.BlockSize()
   118  	blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
   119  	paddingPlaintext := make([]byte, len(ciphertext))
   120  	blockMode.CryptBlocks(paddingPlaintext, ciphertext)
   121  	plaintext, err := pkcs7UnPadding(paddingPlaintext)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	return plaintext, nil
   126  }
   127  
   128  func pkcs7Padding(data []byte, blockSize int) []byte {
   129  	padding := blockSize - len(data)%blockSize
   130  	padText := bytes.Repeat([]byte{byte(padding)}, padding)
   131  	return append(data, padText...)
   132  }
   133  
   134  func pkcs7UnPadding(data []byte) ([]byte, error) {
   135  	length := len(data)
   136  	if length == 0 {
   137  		return nil, errors.New("invalid encryption data")
   138  	}
   139  	unPadding := int(data[length-1])
   140  	if unPadding > length {
   141  		return nil, errors.New("invalid encryption data")
   142  	}
   143  	return data[:(length - unPadding)], nil
   144  }