gitee.com/chunanyong/dm@v1.8.12/security/zzf.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  
     6  package security
     7  
     8  import (
     9  	"crypto/md5"
    10  	"errors"
    11  	"fmt"
    12  	"reflect"
    13  	"unsafe"
    14  )
    15  
    16  type ThirdPartCipher struct {
    17  	encryptType int    // 外部加密算法id
    18  	encryptName string // 外部加密算法名称
    19  	hashType    int
    20  	key         []byte
    21  	cipherCount int // 外部加密算法个数
    22  	//innerId		int // 外部加密算法内部id
    23  	blockSize int // 分组块大小
    24  	khSize    int // key/hash大小
    25  }
    26  
    27  func NewThirdPartCipher(encryptType int, key []byte, cipherPath string, hashType int) (ThirdPartCipher, error) {
    28  	var tpc = ThirdPartCipher{
    29  		encryptType: encryptType,
    30  		key:         key,
    31  		hashType:    hashType,
    32  		cipherCount: -1,
    33  	}
    34  	var err error
    35  	err = initThirdPartCipher(cipherPath)
    36  	if err != nil {
    37  		return tpc, err
    38  	}
    39  	tpc.getCount()
    40  	if err = tpc.getInfo(); err != nil {
    41  		return tpc, err
    42  	}
    43  	return tpc, nil
    44  }
    45  
    46  func (tpc *ThirdPartCipher) getCount() int {
    47  	if tpc.cipherCount == -1 {
    48  		tpc.cipherCount = cipherGetCount()
    49  	}
    50  	return tpc.cipherCount
    51  }
    52  
    53  func (tpc *ThirdPartCipher) getInfo() error {
    54  	var cipher_id, ty, blk_size, kh_size int
    55  	//var strptr, _ = syscall.UTF16PtrFromString(tpc.encryptName)
    56  	var strptr *uint16 = new(uint16)
    57  	for i := 1; i <= tpc.getCount(); i++ {
    58  		cipherGetInfo(uintptr(i), uintptr(unsafe.Pointer(&cipher_id)), uintptr(unsafe.Pointer(&strptr)),
    59  			uintptr(unsafe.Pointer(&ty)), uintptr(unsafe.Pointer(&blk_size)), uintptr(unsafe.Pointer(&kh_size)))
    60  		if tpc.encryptType == cipher_id {
    61  			tpc.blockSize = blk_size
    62  			tpc.khSize = kh_size
    63  			tpc.encryptName = string(uintptr2bytes(uintptr(unsafe.Pointer(strptr))))
    64  			return nil
    65  		}
    66  	}
    67  	return fmt.Errorf("ThirdPartyCipher: cipher id:%d not found", tpc.encryptType)
    68  }
    69  
    70  func (tpc ThirdPartCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
    71  	var tmp_para uintptr
    72  	cipherEncryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
    73  
    74  	ciphertextLen := cipherGetCipherTextSize(uintptr(tpc.encryptType), tmp_para, uintptr(len(plaintext)))
    75  
    76  	ciphertext := make([]byte, ciphertextLen)
    77  	ret := cipherEncrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)),
    78  		uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)))
    79  	ciphertext = ciphertext[:ret]
    80  
    81  	cipherClean(uintptr(tpc.encryptType), tmp_para)
    82  	// md5摘要
    83  	if genDigest {
    84  		digest := md5.Sum(plaintext)
    85  		encrypt := ciphertext
    86  		ciphertext = make([]byte, len(encrypt)+len(digest))
    87  		copy(ciphertext[:len(encrypt)], encrypt)
    88  		copy(ciphertext[len(encrypt):], digest[:])
    89  	}
    90  	return ciphertext
    91  }
    92  
    93  func (tpc ThirdPartCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
    94  	var ret []byte
    95  	if checkDigest {
    96  		var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
    97  		ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
    98  		ret = tpc.decrypt(ret)
    99  		var msgDigest = md5.Sum(ret)
   100  		if !reflect.DeepEqual(msgDigest[:], digest) {
   101  			return nil, errors.New("Decrypt failed/Digest not match\n")
   102  		}
   103  	} else {
   104  		ret = tpc.decrypt(ciphertext)
   105  	}
   106  	return ret, nil
   107  }
   108  
   109  func (tpc ThirdPartCipher) decrypt(ciphertext []byte) []byte {
   110  	var tmp_para uintptr
   111  
   112  	cipherDecryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
   113  
   114  	plaintext := make([]byte, len(ciphertext))
   115  	ret := cipherDecrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)),
   116  		uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)))
   117  	plaintext = plaintext[:ret]
   118  
   119  	cipherClean(uintptr(tpc.encryptType), tmp_para)
   120  	return plaintext
   121  }
   122  
   123  func addBufSize(buf []byte, newCap int) []byte {
   124  	newBuf := make([]byte, newCap)
   125  	copy(newBuf, buf)
   126  	return newBuf
   127  }
   128  
   129  func uintptr2bytes(p uintptr) []byte {
   130  	buf := make([]byte, 64)
   131  	i := 0
   132  	for b := (*byte)(unsafe.Pointer(p)); *b != 0; i++ {
   133  		if i > cap(buf) {
   134  			buf = addBufSize(buf, i*2)
   135  		}
   136  		buf[i] = *b
   137  		// byte占1字节
   138  		p++
   139  		b = (*byte)(unsafe.Pointer(p))
   140  	}
   141  	return buf[:i]
   142  }