gitee.com/curryzheng/dm@v0.0.1/security/zzf.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  
     6  package security
     7  
     8  import (
     9  	"bytes"
    10  	"crypto/aes"
    11  	"crypto/cipher"
    12  	"crypto/des"
    13  	"crypto/md5"
    14  	"crypto/rc4"
    15  	"errors"
    16  	"reflect"
    17  )
    18  
    19  type SymmCipher struct {
    20  	encryptCipher interface{} //cipher.BlockMode | cipher.Stream
    21  	decryptCipher interface{} //cipher.BlockMode | cipher.Stream
    22  	key           []byte
    23  	block         cipher.Block // 分组加密算法
    24  	algorithmType int
    25  	workMode      int
    26  	needPadding   bool
    27  }
    28  
    29  func NewSymmCipher(algorithmID int, key []byte) (SymmCipher, error) {
    30  	var sc SymmCipher
    31  	var err error
    32  	sc.key = key
    33  	sc.algorithmType = algorithmID & ALGO_MASK
    34  	sc.workMode = algorithmID & WORK_MODE_MASK
    35  	switch sc.algorithmType {
    36  	case AES128:
    37  		if sc.block, err = aes.NewCipher(key[:16]); err != nil {
    38  			return sc, err
    39  		}
    40  	case AES192:
    41  		if sc.block, err = aes.NewCipher(key[:24]); err != nil {
    42  			return sc, err
    43  		}
    44  	case AES256:
    45  		if sc.block, err = aes.NewCipher(key[:32]); err != nil {
    46  			return sc, err
    47  		}
    48  	case DES:
    49  		if sc.block, err = des.NewCipher(key[:8]); err != nil {
    50  			return sc, err
    51  		}
    52  	case DES3:
    53  		var tripleDESKey []byte
    54  		tripleDESKey = append(tripleDESKey, key[:16]...)
    55  		tripleDESKey = append(tripleDESKey, key[:8]...)
    56  		if sc.block, err = des.NewTripleDESCipher(tripleDESKey); err != nil {
    57  			return sc, err
    58  		}
    59  	case RC4:
    60  		if sc.encryptCipher, err = rc4.NewCipher(key[:16]); err != nil {
    61  			return sc, err
    62  		}
    63  		if sc.decryptCipher, err = rc4.NewCipher(key[:16]); err != nil {
    64  			return sc, err
    65  		}
    66  		return sc, nil
    67  	default:
    68  		return sc, errors.New("invalidCipher")
    69  	}
    70  	blockSize := sc.block.BlockSize()
    71  	if sc.encryptCipher, err = sc.getEncrypter(sc.workMode, sc.block, defaultIV[:blockSize]); err != nil {
    72  		return sc, err
    73  	}
    74  	if sc.decryptCipher, err = sc.getDecrypter(sc.workMode, sc.block, defaultIV[:blockSize]); err != nil {
    75  		return sc, err
    76  	}
    77  	return sc, nil
    78  }
    79  
    80  func (sc SymmCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
    81  	// 执行过加密后,IV值变了,需要重新初始化encryptCipher对象(因为没有类似resetIV的方法)
    82  	if sc.algorithmType != RC4 {
    83  		sc.encryptCipher, _ = sc.getEncrypter(sc.workMode, sc.block, defaultIV[:sc.block.BlockSize()])
    84  	} else {
    85  		sc.encryptCipher, _ = rc4.NewCipher(sc.key[:16])
    86  	}
    87  	// 填充
    88  	var paddingtext = make([]byte, len(plaintext))
    89  	copy(paddingtext, plaintext)
    90  	if sc.needPadding {
    91  		paddingtext = pkcs5Padding(paddingtext)
    92  	}
    93  
    94  	ret := make([]byte, len(paddingtext))
    95  
    96  	if v, ok := sc.encryptCipher.(cipher.Stream); ok {
    97  		v.XORKeyStream(ret, paddingtext)
    98  	} else if v, ok := sc.encryptCipher.(cipher.BlockMode); ok {
    99  		v.CryptBlocks(ret, paddingtext)
   100  	}
   101  
   102  	// md5摘要
   103  	if genDigest {
   104  		digest := md5.Sum(plaintext)
   105  		encrypt := ret
   106  		ret = make([]byte, len(encrypt)+len(digest))
   107  		copy(ret[:len(encrypt)], encrypt)
   108  		copy(ret[len(encrypt):], digest[:])
   109  	}
   110  	return ret
   111  }
   112  
   113  func (sc SymmCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
   114  	// 执行过解密后,IV值变了,需要重新初始化decryptCipher对象(因为没有类似resetIV的方法)
   115  	if sc.algorithmType != RC4 {
   116  		sc.decryptCipher, _ = sc.getDecrypter(sc.workMode, sc.block, defaultIV[:sc.block.BlockSize()])
   117  	} else {
   118  		sc.decryptCipher, _ = rc4.NewCipher(sc.key[:16])
   119  	}
   120  	var ret []byte
   121  	if checkDigest {
   122  		var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
   123  		ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
   124  		ret = sc.decrypt(ret)
   125  		var msgDigest = md5.Sum(ret)
   126  		if !reflect.DeepEqual(msgDigest[:], digest) {
   127  			return nil, errors.New("Decrypt failed/Digest not match\n")
   128  		}
   129  	} else {
   130  		ret = sc.decrypt(ciphertext)
   131  	}
   132  	return ret, nil
   133  }
   134  
   135  func (sc SymmCipher) decrypt(ciphertext []byte) []byte {
   136  	ret := make([]byte, len(ciphertext))
   137  	if v, ok := sc.decryptCipher.(cipher.Stream); ok {
   138  		v.XORKeyStream(ret, ciphertext)
   139  	} else if v, ok := sc.decryptCipher.(cipher.BlockMode); ok {
   140  		v.CryptBlocks(ret, ciphertext)
   141  	}
   142  	// 去除填充
   143  	if sc.needPadding {
   144  		ret = pkcs5UnPadding(ret)
   145  	}
   146  	return ret
   147  }
   148  
   149  func (sc *SymmCipher) getEncrypter(workMode int, block cipher.Block, iv []byte) (ret interface{}, err error) {
   150  	switch workMode {
   151  	case ECB_MODE:
   152  		ret = NewECBEncrypter(block)
   153  		sc.needPadding = true
   154  	case CBC_MODE:
   155  		ret = cipher.NewCBCEncrypter(block, iv)
   156  		sc.needPadding = true
   157  	case CFB_MODE:
   158  		ret = cipher.NewCFBEncrypter(block, iv)
   159  		sc.needPadding = false
   160  	case OFB_MODE:
   161  		ret = cipher.NewOFB(block, iv)
   162  		sc.needPadding = false
   163  	default:
   164  		err = errors.New("invalidCipherMode")
   165  	}
   166  	return
   167  }
   168  
   169  func (sc *SymmCipher) getDecrypter(workMode int, block cipher.Block, iv []byte) (ret interface{}, err error) {
   170  	switch workMode {
   171  	case ECB_MODE:
   172  		ret = NewECBDecrypter(block)
   173  		sc.needPadding = true
   174  	case CBC_MODE:
   175  		ret = cipher.NewCBCDecrypter(block, iv)
   176  		sc.needPadding = true
   177  	case CFB_MODE:
   178  		ret = cipher.NewCFBDecrypter(block, iv)
   179  		sc.needPadding = false
   180  	case OFB_MODE:
   181  		ret = cipher.NewOFB(block, iv)
   182  		sc.needPadding = false
   183  	default:
   184  		err = errors.New("invalidCipherMode")
   185  	}
   186  	return
   187  }
   188  
   189  // 补码
   190  func pkcs77Padding(ciphertext []byte, blocksize int) []byte {
   191  	padding := blocksize - len(ciphertext)%blocksize
   192  	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
   193  	return append(ciphertext, padtext...)
   194  }
   195  
   196  // 去码
   197  func pkcs7UnPadding(origData []byte) []byte {
   198  	length := len(origData)
   199  	unpadding := int(origData[length-1])
   200  	return origData[:length-unpadding]
   201  }
   202  
   203  // 补码
   204  func pkcs5Padding(ciphertext []byte) []byte {
   205  	return pkcs77Padding(ciphertext, 8)
   206  }
   207  
   208  // 去码
   209  func pkcs5UnPadding(ciphertext []byte) []byte {
   210  	return pkcs7UnPadding(ciphertext)
   211  }