gitee.com/runner.mei/dm@v0.0.0-20220207044607-a9ba0dc20bf7/security/zzf.go (about)

     1  /*
     2   * Copyright (c) 2000-2018, 达梦数据库有限公司.
     3   * All rights reserved.
     4   */
     5  
     6  package security
     7  
     8  import (
     9  	"crypto/md5"
    10  	"errors"
    11  	"fmt"
    12  	"reflect"
    13  	"unsafe"
    14  )
    15  
    16  type ThirdPartCipher struct {
    17  	encryptType int    // 外部加密算法id
    18  	encryptName string // 外部加密算法名称
    19  	hashType    int
    20  	key         []byte
    21  	cipherCount int // 外部加密算法个数
    22  	//innerId		int // 外部加密算法内部id
    23  	blockSize   int // 分组块大小
    24  	khSize      int // key/hash大小
    25  }
    26  
    27  func NewThirdPartCipher(encryptType int, key []byte, cipherPath string, hashType int) (ThirdPartCipher, error) {
    28  	var tpc = ThirdPartCipher{
    29  		encryptType: encryptType,
    30  		key:         key,
    31  		hashType:    hashType,
    32  		cipherCount: -1,
    33  	}
    34  	var err error
    35  	err = initThirdPartCipher(cipherPath)
    36  	if err != nil {
    37  		return tpc, err
    38  	}
    39  	tpc.getCount()
    40  	tpc.getInfo()
    41  	return tpc, nil
    42  }
    43  
    44  func (tpc *ThirdPartCipher) getCount() int {
    45  	if tpc.cipherCount == -1 {
    46  		tpc.cipherCount = cipherGetCount()
    47  	}
    48  	return tpc.cipherCount
    49  }
    50  
    51  func (tpc *ThirdPartCipher) getInfo() {
    52  	var cipher_id, ty, blk_size, kh_size int
    53  	//var strptr, _ = syscall.UTF16PtrFromString(tpc.encryptName)
    54  	var strptr *uint16 = new(uint16)
    55  	for i := 1; i <= tpc.getCount(); i++ {
    56  		cipherGetInfo(uintptr(i), uintptr(unsafe.Pointer(&cipher_id)), uintptr(unsafe.Pointer(&strptr)),
    57  			uintptr(unsafe.Pointer(&ty)), uintptr(unsafe.Pointer(&blk_size)), uintptr(unsafe.Pointer(&kh_size)))
    58  		if tpc.encryptType == cipher_id {
    59  			tpc.blockSize = blk_size
    60  			tpc.khSize = kh_size
    61  			tpc.encryptName = string(uintptr2bytes(uintptr(unsafe.Pointer(strptr))))
    62  			return
    63  		}
    64  	}
    65  	panic(fmt.Sprintf("ThirdPartyCipher: cipher id:%d not found", tpc.encryptType))
    66  }
    67  
    68  func (tpc ThirdPartCipher) Encrypt(plaintext []byte, genDigest bool) []byte {
    69  	var tmp_para uintptr
    70  	cipherEncryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
    71  
    72  	ciphertextLen := cipherGetCipherTextSize(uintptr(tpc.encryptType), tmp_para, uintptr(len(plaintext)))
    73  
    74  	ciphertext := make([]byte, ciphertextLen)
    75  	ret := cipherEncrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)),
    76  		uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)))
    77  	ciphertext = ciphertext[:ret]
    78  
    79  	cipherClean(uintptr(tpc.encryptType), tmp_para)
    80  	// md5摘要
    81  	if genDigest {
    82  		digest := md5.Sum(plaintext)
    83  		encrypt := ciphertext
    84  		ciphertext = make([]byte, len(encrypt)+len(digest))
    85  		copy(ciphertext[:len(encrypt)], encrypt)
    86  		copy(ciphertext[len(encrypt):], digest[:])
    87  	}
    88  	return ciphertext
    89  }
    90  
    91  func (tpc ThirdPartCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) {
    92  	var ret []byte
    93  	if checkDigest {
    94  		var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:]
    95  		ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE]
    96  		ret = tpc.decrypt(ret)
    97  		var msgDigest = md5.Sum(ret)
    98  		if !reflect.DeepEqual(msgDigest[:], digest) {
    99  			return nil, errors.New("Decrypt failed/Digest not match\n")
   100  		}
   101  	} else {
   102  		ret = tpc.decrypt(ciphertext)
   103  	}
   104  	return ret, nil
   105  }
   106  
   107  func (tpc ThirdPartCipher) decrypt(ciphertext []byte) []byte {
   108  	var tmp_para uintptr
   109  
   110  	cipherDecryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para)
   111  
   112  	plaintext := make([]byte, len(ciphertext))
   113  	ret := cipherDecrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)),
   114  		uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)))
   115  	plaintext = plaintext[:ret]
   116  
   117  	cipherClean(uintptr(tpc.encryptType), tmp_para)
   118  	return plaintext
   119  }
   120  
   121  func addBufSize(buf []byte, newCap int) []byte {
   122  	newBuf := make([]byte, newCap)
   123  	copy(newBuf, buf)
   124  	return newBuf
   125  }
   126  
   127  func uintptr2bytes(p uintptr) []byte {
   128  	buf := make([]byte, 64)
   129  	i := 0
   130  	for b := (*byte)(unsafe.Pointer(p)); *b != 0; i++ {
   131  		if i > cap(buf) {
   132  			buf = addBufSize(buf, i * 2)
   133  		}
   134  		buf[i] = *b
   135  		// byte占1字节
   136  		p ++
   137  		b = (*byte)(unsafe.Pointer(p))
   138  	}
   139  	return buf[:i]
   140  }