github.com/wfusion/gofusion@v1.1.14/common/utils/cipher/decrypt.go (about) 1 package cipher 2 3 import ( 4 "crypto/cipher" 5 "errors" 6 "fmt" 7 "io" 8 9 "github.com/wfusion/gofusion/common/utils" 10 ) 11 12 func DecryptBytesFunc(algo Algorithm, mode Mode, key, iv []byte) ( 13 enc func(src []byte) (dst []byte, err error), err error) { 14 bm, err := getDecrypter(algo, mode, key, iv) 15 if err != nil { 16 return 17 } 18 mode = bm.CipherMode() 19 return func(src []byte) (dst []byte, err error) { 20 dst, err = decrypt(bm, src) 21 if err != nil { 22 return 23 } 24 if mode.ShouldPadding() { 25 dst, err = PKCS7Unpad(dst) 26 } 27 return 28 }, err 29 } 30 31 func DecryptStreamFunc(algo Algorithm, mode Mode, key, iv []byte) ( 32 dec func(dst io.Writer, src io.Reader) (err error), err error) { 33 var fn func(src io.Reader) io.ReadCloser 34 35 _, err = utils.Catch(func() { fn = NewDecFunc(algo, mode, key, iv) }) 36 if err != nil { 37 return 38 } 39 40 dec = func(dst io.Writer, src io.Reader) (err error) { 41 buf, cb := utils.BytesPool.Get(defaultBlockSize * blockSizeTimes) 42 defer cb() 43 44 wrapper := fn(src) 45 defer utils.CloseAnyway(wrapper) 46 _, err = io.CopyBuffer(dst, wrapper, buf) 47 return 48 } 49 50 return 51 } 52 53 func decrypt(bm blockMode, src []byte) (dst []byte, err error) { 54 plainBlockSize := bm.PlainBlockSize() 55 cipherBlockSize := bm.CipherBlockSize() 56 defers := make([]func(), 0, 3) 57 58 w, cb := utils.BytesBufferPool.Get(nil) 59 defers = append(defers, cb) 60 61 if plainBlockSize != 0 && cipherBlockSize != 0 { 62 w.Grow((len(src) / cipherBlockSize) * plainBlockSize) 63 } else { 64 plainBlockSize, cipherBlockSize = len(src), len(src) 65 w.Grow(plainBlockSize) 66 } 67 68 unsealed, cb := utils.BytesPool.Get(plainBlockSize) 69 defers = append(defers, cb) 70 71 buf, cb := utils.BytesPool.Get(plainBlockSize) 72 defers = append(defers, cb) 73 74 var n, blockSize int 75 for len(src) > 0 { 76 blockSize = utils.Min(cipherBlockSize, len(src)) 77 n, err = bm.CryptBlocks(unsealed[:plainBlockSize], src[:blockSize], buf) 78 if err != nil { 79 return 80 } 81 if _, err = w.Write(unsealed[:n]); err != nil { 82 return 83 } 84 src = src[blockSize:] 85 } 86 87 bs := w.Bytes() 88 dst = make([]byte, len(bs)) 89 copy(dst, bs) 90 return 91 } 92 93 func getDecrypter(algo Algorithm, mode Mode, key, iv []byte) (bm blockMode, err error) { 94 if blockMapping, ok := cipherBlockMapping[algo]; ok { 95 var cipherBlock cipher.Block 96 cipherBlock, err = blockMapping(key) 97 if err != nil { 98 return 99 } 100 modeMapping, ok := decryptModeMapping[mode] 101 if !ok { 102 return nil, fmt.Errorf("unknown cipher mode %+v", mode) 103 } 104 bm, err = modeMapping(cipherBlock, iv) 105 if err != nil { 106 return 107 } 108 } 109 110 // stream 111 if bm == nil { 112 blockMapping, ok := streamDecryptMapping[algo] 113 if !ok { 114 return nil, fmt.Errorf("unknown cipher algorithm %+v", algo) 115 } 116 if bm, err = blockMapping(key); err != nil { 117 return 118 } 119 } 120 121 if bm == nil { 122 return nil, fmt.Errorf("unknown cipher algorithm(%+v) or mode(%+v)", algo, mode) 123 } 124 125 return 126 } 127 128 type dec struct { 129 bm blockMode 130 r io.Reader 131 132 buf []byte 133 cb func() 134 n int // current position in buf 135 end int // end of data in buf 136 size int 137 138 eof bool 139 unsealed, unsealBuf []byte 140 } 141 142 func NewDecFunc(algo Algorithm, mode Mode, key, iv []byte) func(r io.Reader) io.ReadCloser { 143 bm, err := getDecrypter(algo, mode, key, iv) 144 if err != nil { 145 panic(err) 146 } 147 if !bm.CipherMode().SupportStream() { 148 panic(ErrNotSupportStream) 149 } 150 151 var ( 152 buf, unsealed, unsealBuf []byte 153 cb func() 154 size int 155 ) 156 157 if size = bm.CipherBlockSize(); size > 0 { 158 var bcb, ucb, ubcb func() 159 buf, bcb = utils.BytesPool.Get(size) 160 unsealed, ucb = utils.BytesPool.Get(bm.PlainBlockSize()) 161 unsealBuf, ubcb = utils.BytesPool.Get(bm.PlainBlockSize()) 162 cb = func() { bcb(); ucb(); ubcb() } 163 } 164 165 return func(r io.Reader) io.ReadCloser { 166 return &dec{ 167 bm: bm, 168 r: r, 169 buf: buf, 170 cb: cb, 171 n: 0, 172 end: 0, 173 size: size, 174 eof: false, 175 unsealed: unsealed, 176 unsealBuf: unsealBuf, 177 } 178 } 179 } 180 181 func (d *dec) Read(p []byte) (n int, err error) { 182 var ( 183 nr, nc int 184 dst []byte 185 ) 186 187 if d.buf == nil { 188 n, err = d.r.Read(p) 189 d.eof = errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) 190 if err != nil && !d.eof { 191 return 192 } 193 dst, err = decrypt(d.bm, p[:n]) 194 if err != nil { 195 return 196 } 197 n = copy(p[:len(dst)], dst) 198 if d.eof { 199 err = io.EOF 200 } 201 return 202 } 203 204 for len(p) > 0 { 205 // read from buffer 206 if length := d.end - d.n; length > 0 { 207 copied := utils.Min(length, len(p)) 208 n += copy(p[:copied], d.buf[d.n:d.n+copied]) 209 d.n += copied 210 p = p[copied:] 211 continue 212 } 213 214 // buffer is empty, write new buffer 215 if d.eof { 216 return n, io.EOF 217 } 218 d.n = 0 219 d.end = 0 220 for { 221 nr, err = d.r.Read(d.buf[d.n:d.size]) 222 d.eof = errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) 223 if err != nil && !d.eof { 224 return 225 } 226 227 d.n += nr 228 if d.n < d.size && !d.eof { 229 continue 230 } 231 232 nc, err = d.bm.CryptBlocks(d.unsealed, d.buf[:d.n], d.unsealBuf) 233 if err != nil { 234 return 235 } 236 237 d.end += copy(d.buf[:nc], d.unsealed[:nc]) 238 d.n = 0 239 break 240 } 241 } 242 243 return 244 } 245 246 func (d *dec) Close() (err error) { 247 utils.CloseAnyway(d.r) 248 if d.cb != nil { 249 d.cb() 250 } 251 return 252 }