github.com/fastwego/offiaccount@v1.0.1/util/aes_crypto.go (about)

     1  // Copyright 2014 chanxuehong(chanxuehong@gmail.com)
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"crypto/aes"
    19  	"crypto/cipher"
    20  	"encoding/base64"
    21  	"fmt"
    22  )
    23  
    24  // 把整数 n 格式化成 4 字节的网络字节序
    25  func encodeNetworkByteOrder(b []byte, n uint32) {
    26  	b[0] = byte(n >> 24)
    27  	b[1] = byte(n >> 16)
    28  	b[2] = byte(n >> 8)
    29  	b[3] = byte(n)
    30  }
    31  
    32  // 从 4 字节的网络字节序里解析出整数
    33  func decodeNetworkByteOrder(b []byte) (n uint32) {
    34  	return uint32(b[0])<<24 |
    35  		uint32(b[1])<<16 |
    36  		uint32(b[2])<<8 |
    37  		uint32(b[3])
    38  }
    39  
    40  // AESEncryptMsg 消息加密
    41  // ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
    42  func AESEncryptMsg(random, rawXMLMsg []byte, appId string, encodingAESKey string) (ciphertext string) {
    43  	aesKey, _ := base64.StdEncoding.DecodeString(encodingAESKey + "=")
    44  	const (
    45  		BLOCK_SIZE = 32             // PKCS#7
    46  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
    47  	)
    48  
    49  	appIdOffset := 20 + len(rawXMLMsg)
    50  	contentLen := appIdOffset + len(appId)
    51  	amountToPad := BLOCK_SIZE - contentLen&BLOCK_MASK
    52  	plaintextLen := contentLen + amountToPad
    53  
    54  	plaintext := make([]byte, plaintextLen)
    55  
    56  	// 拼接
    57  	copy(plaintext[:16], random)
    58  	encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg)))
    59  	copy(plaintext[20:], rawXMLMsg)
    60  	copy(plaintext[appIdOffset:], appId)
    61  
    62  	// PKCS#7 补位
    63  	for i := contentLen; i < plaintextLen; i++ {
    64  		plaintext[i] = byte(amountToPad)
    65  	}
    66  
    67  	// 加密
    68  	block, err := aes.NewCipher(aesKey)
    69  	if err != nil {
    70  		panic(err)
    71  	}
    72  	mode := cipher.NewCBCEncrypter(block, aesKey[:16])
    73  	mode.CryptBlocks(plaintext, plaintext)
    74  
    75  	return base64.StdEncoding.EncodeToString(plaintext)
    76  }
    77  
    78  // AESDecryptMsg 消息解密
    79  // ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
    80  func AESDecryptMsg(base64CipherText string, encodingAESKey string) (random, rawXMLMsg, appId []byte, err error) {
    81  	ciphertext, err := base64.StdEncoding.DecodeString(base64CipherText)
    82  	if err != nil {
    83  		return
    84  	}
    85  
    86  	aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
    87  	if err != nil {
    88  		return
    89  	}
    90  
    91  	const (
    92  		BLOCK_SIZE = 32             // PKCS#7
    93  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
    94  	)
    95  
    96  	if len(ciphertext) < BLOCK_SIZE {
    97  		err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
    98  		return
    99  	}
   100  	if len(ciphertext)&BLOCK_MASK != 0 {
   101  		err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
   102  		return
   103  	}
   104  
   105  	plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
   106  
   107  	// 解密
   108  	block, err := aes.NewCipher(aesKey)
   109  	if err != nil {
   110  		return
   111  	}
   112  	mode := cipher.NewCBCDecrypter(block, aesKey[:16])
   113  	mode.CryptBlocks(plaintext, ciphertext)
   114  
   115  	// PKCS#7 去除补位
   116  	amountToPad := int(plaintext[len(plaintext)-1])
   117  	if amountToPad < 1 || amountToPad > BLOCK_SIZE {
   118  		err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
   119  		return
   120  	}
   121  	plaintext = plaintext[:len(plaintext)-amountToPad]
   122  
   123  	// 反拼接
   124  	// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
   125  	if len(plaintext) <= 20 {
   126  		err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
   127  		return
   128  	}
   129  	rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
   130  	if rawXMLMsgLen < 0 {
   131  		err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
   132  		return
   133  	}
   134  	appIdOffset := 20 + rawXMLMsgLen
   135  	if len(plaintext) <= appIdOffset {
   136  		err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
   137  		return
   138  	}
   139  
   140  	random = plaintext[:16:20]
   141  	rawXMLMsg = plaintext[20:appIdOffset:appIdOffset]
   142  	appId = plaintext[appIdOffset:]
   143  	return
   144  }
   145  
   146  // AESDecryptData 数据解密
   147  func AESDecryptData(cipherText []byte, aesKey []byte, iv []byte) (rawData []byte, err error) {
   148  
   149  	const (
   150  		BLOCK_SIZE = 32             // PKCS#7
   151  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
   152  	)
   153  
   154  	if len(cipherText) < BLOCK_SIZE {
   155  		err = fmt.Errorf("the length of ciphertext too short: %d", len(cipherText))
   156  		return
   157  	}
   158  
   159  	plaintext := make([]byte, len(cipherText)) // len(plaintext) >= BLOCK_SIZE
   160  
   161  	// 解密
   162  	block, err := aes.NewCipher(aesKey)
   163  	if err != nil {
   164  		panic(err)
   165  	}
   166  	mode := cipher.NewCBCDecrypter(block, iv)
   167  	mode.CryptBlocks(plaintext, cipherText)
   168  
   169  	// PKCS#7 去除补位
   170  	amountToPad := int(plaintext[len(plaintext)-1])
   171  	if amountToPad < 1 || amountToPad > BLOCK_SIZE {
   172  		err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
   173  		return
   174  	}
   175  	plaintext = plaintext[:len(plaintext)-amountToPad]
   176  
   177  	// 反拼接
   178  	// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
   179  	if len(plaintext) <= 20 {
   180  		err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
   181  		return
   182  	}
   183  
   184  	rawData = plaintext
   185  
   186  	return
   187  
   188  }