gitee.com/chunanyong/dm@v1.8.12/security/zzf.go (about) 1 /* 2 * Copyright (c) 2000-2018, 达梦数据库有限公司. 3 * All rights reserved. 4 */ 5 6 package security 7 8 import ( 9 "crypto/md5" 10 "errors" 11 "fmt" 12 "reflect" 13 "unsafe" 14 ) 15 16 type ThirdPartCipher struct { 17 encryptType int // 外部加密算法id 18 encryptName string // 外部加密算法名称 19 hashType int 20 key []byte 21 cipherCount int // 外部加密算法个数 22 //innerId int // 外部加密算法内部id 23 blockSize int // 分组块大小 24 khSize int // key/hash大小 25 } 26 27 func NewThirdPartCipher(encryptType int, key []byte, cipherPath string, hashType int) (ThirdPartCipher, error) { 28 var tpc = ThirdPartCipher{ 29 encryptType: encryptType, 30 key: key, 31 hashType: hashType, 32 cipherCount: -1, 33 } 34 var err error 35 err = initThirdPartCipher(cipherPath) 36 if err != nil { 37 return tpc, err 38 } 39 tpc.getCount() 40 if err = tpc.getInfo(); err != nil { 41 return tpc, err 42 } 43 return tpc, nil 44 } 45 46 func (tpc *ThirdPartCipher) getCount() int { 47 if tpc.cipherCount == -1 { 48 tpc.cipherCount = cipherGetCount() 49 } 50 return tpc.cipherCount 51 } 52 53 func (tpc *ThirdPartCipher) getInfo() error { 54 var cipher_id, ty, blk_size, kh_size int 55 //var strptr, _ = syscall.UTF16PtrFromString(tpc.encryptName) 56 var strptr *uint16 = new(uint16) 57 for i := 1; i <= tpc.getCount(); i++ { 58 cipherGetInfo(uintptr(i), uintptr(unsafe.Pointer(&cipher_id)), uintptr(unsafe.Pointer(&strptr)), 59 uintptr(unsafe.Pointer(&ty)), uintptr(unsafe.Pointer(&blk_size)), uintptr(unsafe.Pointer(&kh_size))) 60 if tpc.encryptType == cipher_id { 61 tpc.blockSize = blk_size 62 tpc.khSize = kh_size 63 tpc.encryptName = string(uintptr2bytes(uintptr(unsafe.Pointer(strptr)))) 64 return nil 65 } 66 } 67 return fmt.Errorf("ThirdPartyCipher: cipher id:%d not found", tpc.encryptType) 68 } 69 70 func (tpc ThirdPartCipher) Encrypt(plaintext []byte, genDigest bool) []byte { 71 var tmp_para uintptr 72 cipherEncryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para) 73 74 ciphertextLen := cipherGetCipherTextSize(uintptr(tpc.encryptType), tmp_para, uintptr(len(plaintext))) 75 76 ciphertext := make([]byte, ciphertextLen) 77 ret := cipherEncrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext)), 78 uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext))) 79 ciphertext = ciphertext[:ret] 80 81 cipherClean(uintptr(tpc.encryptType), tmp_para) 82 // md5摘要 83 if genDigest { 84 digest := md5.Sum(plaintext) 85 encrypt := ciphertext 86 ciphertext = make([]byte, len(encrypt)+len(digest)) 87 copy(ciphertext[:len(encrypt)], encrypt) 88 copy(ciphertext[len(encrypt):], digest[:]) 89 } 90 return ciphertext 91 } 92 93 func (tpc ThirdPartCipher) Decrypt(ciphertext []byte, checkDigest bool) ([]byte, error) { 94 var ret []byte 95 if checkDigest { 96 var digest = ciphertext[len(ciphertext)-MD5_DIGEST_SIZE:] 97 ret = ciphertext[:len(ciphertext)-MD5_DIGEST_SIZE] 98 ret = tpc.decrypt(ret) 99 var msgDigest = md5.Sum(ret) 100 if !reflect.DeepEqual(msgDigest[:], digest) { 101 return nil, errors.New("Decrypt failed/Digest not match\n") 102 } 103 } else { 104 ret = tpc.decrypt(ciphertext) 105 } 106 return ret, nil 107 } 108 109 func (tpc ThirdPartCipher) decrypt(ciphertext []byte) []byte { 110 var tmp_para uintptr 111 112 cipherDecryptInit(uintptr(tpc.encryptType), uintptr(unsafe.Pointer(&tpc.key[0])), uintptr(len(tpc.key)), tmp_para) 113 114 plaintext := make([]byte, len(ciphertext)) 115 ret := cipherDecrypt(uintptr(tpc.encryptType), tmp_para, uintptr(unsafe.Pointer(&ciphertext[0])), uintptr(len(ciphertext)), 116 uintptr(unsafe.Pointer(&plaintext[0])), uintptr(len(plaintext))) 117 plaintext = plaintext[:ret] 118 119 cipherClean(uintptr(tpc.encryptType), tmp_para) 120 return plaintext 121 } 122 123 func addBufSize(buf []byte, newCap int) []byte { 124 newBuf := make([]byte, newCap) 125 copy(newBuf, buf) 126 return newBuf 127 } 128 129 func uintptr2bytes(p uintptr) []byte { 130 buf := make([]byte, 64) 131 i := 0 132 for b := (*byte)(unsafe.Pointer(p)); *b != 0; i++ { 133 if i > cap(buf) { 134 buf = addBufSize(buf, i*2) 135 } 136 buf[i] = *b 137 // byte占1字节 138 p++ 139 b = (*byte)(unsafe.Pointer(p)) 140 } 141 return buf[:i] 142 }