github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/internal/util/aes_crypto.go (about)

     1  package util
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"fmt"
     7  )
     8  
     9  // 把整数 n 格式化成 4 字节的网络字节序
    10  func encodeNetworkByteOrder(b []byte, n uint32) {
    11  	b[0] = byte(n >> 24)
    12  	b[1] = byte(n >> 16)
    13  	b[2] = byte(n >> 8)
    14  	b[3] = byte(n)
    15  }
    16  
    17  // 从 4 字节的网络字节序里解析出整数
    18  func decodeNetworkByteOrder(b []byte) (n uint32) {
    19  	return uint32(b[0])<<24 |
    20  		uint32(b[1])<<16 |
    21  		uint32(b[2])<<8 |
    22  		uint32(b[3])
    23  }
    24  
    25  // ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
    26  func AESEncryptMsg(random, rawXMLMsg []byte, appId string, aesKey []byte) (ciphertext []byte) {
    27  	const (
    28  		BLOCK_SIZE = 32             // PKCS#7
    29  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
    30  	)
    31  
    32  	appIdOffset := 20 + len(rawXMLMsg)
    33  	contentLen := appIdOffset + len(appId)
    34  	amountToPad := BLOCK_SIZE - contentLen&BLOCK_MASK
    35  	plaintextLen := contentLen + amountToPad
    36  
    37  	plaintext := make([]byte, plaintextLen)
    38  
    39  	// 拼接
    40  	copy(plaintext[:16], random)
    41  	encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg)))
    42  	copy(plaintext[20:], rawXMLMsg)
    43  	copy(plaintext[appIdOffset:], appId)
    44  
    45  	// PKCS#7 补位
    46  	for i := contentLen; i < plaintextLen; i++ {
    47  		plaintext[i] = byte(amountToPad)
    48  	}
    49  
    50  	// 加密
    51  	block, err := aes.NewCipher(aesKey)
    52  	if err != nil {
    53  		panic(err)
    54  	}
    55  	mode := cipher.NewCBCEncrypter(block, aesKey[:16])
    56  	mode.CryptBlocks(plaintext, plaintext)
    57  
    58  	ciphertext = plaintext
    59  	return
    60  }
    61  
    62  // ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
    63  func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appId []byte, err error) {
    64  	const (
    65  		BLOCK_SIZE = 32             // PKCS#7
    66  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
    67  	)
    68  
    69  	if len(ciphertext) < BLOCK_SIZE {
    70  		err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
    71  		return
    72  	}
    73  	if len(ciphertext)&BLOCK_MASK != 0 {
    74  		err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
    75  		return
    76  	}
    77  
    78  	plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
    79  
    80  	// 解密
    81  	block, err := aes.NewCipher(aesKey)
    82  	if err != nil {
    83  		panic(err)
    84  	}
    85  	mode := cipher.NewCBCDecrypter(block, aesKey[:16])
    86  	mode.CryptBlocks(plaintext, ciphertext)
    87  
    88  	// PKCS#7 去除补位
    89  	amountToPad := int(plaintext[len(plaintext)-1])
    90  	if amountToPad < 1 || amountToPad > BLOCK_SIZE {
    91  		err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
    92  		return
    93  	}
    94  	plaintext = plaintext[:len(plaintext)-amountToPad]
    95  
    96  	// 反拼接
    97  	// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
    98  	if len(plaintext) <= 20 {
    99  		err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
   100  		return
   101  	}
   102  	rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
   103  	if rawXMLMsgLen < 0 {
   104  		err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
   105  		return
   106  	}
   107  	appIdOffset := 20 + rawXMLMsgLen
   108  	if len(plaintext) <= appIdOffset {
   109  		err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
   110  		return
   111  	}
   112  
   113  	random = plaintext[:16:20]
   114  	rawXMLMsg = plaintext[20:appIdOffset:appIdOffset]
   115  	appId = plaintext[appIdOffset:]
   116  	return
   117  }
   118  
   119  func AESDecryptData(cipherText []byte, aesKey []byte, iv []byte) (rawData []byte, err error) {
   120  
   121  	const (
   122  		BLOCK_SIZE = 32             // PKCS#7
   123  		BLOCK_MASK = BLOCK_SIZE - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
   124  	)
   125  
   126  	if len(cipherText) < BLOCK_SIZE {
   127  		err = fmt.Errorf("the length of ciphertext too short: %d", len(cipherText))
   128  		return
   129  	}
   130  
   131  	plaintext := make([]byte, len(cipherText)) // len(plaintext) >= BLOCK_SIZE
   132  
   133  	// 解密
   134  	block, err := aes.NewCipher(aesKey)
   135  	if err != nil {
   136  		panic(err)
   137  	}
   138  	mode := cipher.NewCBCDecrypter(block, iv)
   139  	mode.CryptBlocks(plaintext, cipherText)
   140  
   141  	// PKCS#7 去除补位
   142  	amountToPad := int(plaintext[len(plaintext)-1])
   143  	if amountToPad < 1 || amountToPad > BLOCK_SIZE {
   144  		err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
   145  		return
   146  	}
   147  	plaintext = plaintext[:len(plaintext)-amountToPad]
   148  
   149  	// 反拼接
   150  	// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
   151  	if len(plaintext) <= 20 {
   152  		err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
   153  		return
   154  	}
   155  
   156  	rawData = plaintext
   157  
   158  	return
   159  
   160  }