github.com/wfusion/gofusion@v1.1.14/common/utils/cipher/encrypt.go (about) 1 package cipher 2 3 import ( 4 "crypto/cipher" 5 "fmt" 6 "io" 7 8 "github.com/wfusion/gofusion/common/utils" 9 ) 10 11 func EncryptBytesFunc(algo Algorithm, mode Mode, key, iv []byte) ( 12 enc func(src []byte) (dst []byte, err error), err error) { 13 bm, err := getEncrypter(algo, mode, key, iv) 14 if err != nil { 15 return 16 } 17 mode = bm.CipherMode() 18 plainBlockSize := bm.PlainBlockSize() 19 return func(src []byte) (dst []byte, err error) { 20 if mode.ShouldPadding() { 21 src = PKCS7Pad(src, plainBlockSize) 22 } 23 return encrypt(bm, src) 24 }, err 25 } 26 27 func EncryptStreamFunc(algo Algorithm, mode Mode, key, iv []byte) ( 28 enc func(dst io.Writer, src io.Reader) (err error), err error) { 29 var fn func(w io.Writer) io.WriteCloser 30 31 _, err = utils.Catch(func() { fn = NewEncFunc(algo, mode, key, iv) }) 32 if err != nil { 33 return nil, err 34 } 35 36 enc = func(dst io.Writer, src io.Reader) (err error) { 37 buf, cb := utils.BytesPool.Get(defaultBlockSize * blockSizeTimes) 38 defer cb() 39 40 wrapper := fn(dst) 41 defer utils.CloseAnyway(wrapper) 42 43 _, err = io.CopyBuffer(wrapper, src, buf) 44 return 45 } 46 47 return 48 } 49 50 func encrypt(bm blockMode, src []byte) (dst []byte, err error) { 51 plainBlockSize := bm.PlainBlockSize() 52 cipherBlockSize := bm.CipherBlockSize() 53 defers := make([]func(), 0, 3) 54 defer func() { 55 for _, cb := range defers { 56 cb() 57 } 58 }() 59 60 w, cb := utils.BytesBufferPool.Get(nil) 61 defers = append(defers, cb) 62 63 if plainBlockSize != 0 && cipherBlockSize != 0 { 64 w.Grow((len(src) / plainBlockSize) * cipherBlockSize) 65 } else { 66 plainBlockSize, cipherBlockSize = len(src), len(src) 67 w.Grow(cipherBlockSize) 68 } 69 70 sealed, cb := utils.BytesPool.Get(cipherBlockSize) 71 defers = append(defers, cb) 72 73 buf, cb := utils.BytesPool.Get(plainBlockSize) 74 defers = append(defers, cb) 75 76 var n, blockSize int 77 for len(src) > 0 { 78 blockSize = utils.Min(plainBlockSize, len(src)) 79 n, err = bm.CryptBlocks(sealed[:cipherBlockSize], src[:blockSize], buf) 80 if err != nil { 81 return 82 } 83 if _, err = w.Write(sealed[:n]); err != nil { 84 return 85 } 86 src = src[blockSize:] 87 } 88 89 bs := w.Bytes() 90 dst = make([]byte, len(bs)) 91 copy(dst, bs) 92 return 93 } 94 95 func getEncrypter(algo Algorithm, mode Mode, key, iv []byte) (bm blockMode, err error) { 96 if blockMapping, ok := cipherBlockMapping[algo]; ok { 97 var cipherBlock cipher.Block 98 cipherBlock, err = blockMapping(key) 99 if err != nil { 100 return 101 } 102 modeMapping, ok := encryptModeMapping[mode] 103 if !ok { 104 return nil, fmt.Errorf("unknown cipher mode %+v", mode) 105 } 106 bm, err = modeMapping(cipherBlock, iv) 107 if err != nil { 108 return 109 } 110 } 111 112 // stream 113 if bm == nil { 114 blockMapping, ok := streamEncryptMapping[algo] 115 if !ok { 116 return nil, fmt.Errorf("unknown cipher algorithm %+v", algo) 117 } 118 if bm, err = blockMapping(key); err != nil { 119 return 120 } 121 } 122 123 if bm == nil { 124 return nil, fmt.Errorf("unknown cipher algorithm(%+v) or mode(%+v)", algo, mode) 125 } 126 127 return 128 } 129 130 type enc struct { 131 bm blockMode 132 w io.Writer 133 134 n int 135 buf []byte 136 cb func() 137 138 sealed, sealBuf []byte 139 } 140 141 func NewEncFunc(algo Algorithm, mode Mode, key, iv []byte) func(w io.Writer) io.WriteCloser { 142 bm, err := getEncrypter(algo, mode, key, iv) 143 if err != nil { 144 panic(err) 145 } 146 if !bm.CipherMode().SupportStream() { 147 panic(ErrNotSupportStream) 148 } 149 150 var ( 151 buf, sealed, sealBuf []byte 152 cb func() 153 ) 154 if bm.PlainBlockSize() > 0 { 155 var bcb, scb, sbcb func() 156 buf, bcb = utils.BytesPool.Get(bm.PlainBlockSize()) 157 sealed, scb = utils.BytesPool.Get(bm.CipherBlockSize()) 158 sealBuf, sbcb = utils.BytesPool.Get(bm.PlainBlockSize()) 159 cb = func() { bcb(); scb(); sbcb() } 160 } 161 162 return func(w io.Writer) io.WriteCloser { 163 return &enc{ 164 bm: bm, 165 w: w, 166 n: 0, 167 buf: buf, 168 cb: cb, 169 sealed: sealed, 170 sealBuf: sealBuf, 171 } 172 } 173 } 174 175 func (e *enc) Write(p []byte) (n int, err error) { 176 if e.buf == nil { 177 var dst []byte 178 dst, err = encrypt(e.bm, p) 179 if err != nil { 180 return 181 } 182 n = len(p) 183 _, err = e.w.Write(dst) 184 return 185 } 186 187 written := 0 188 for len(p) > 0 { 189 nCopy := utils.Min(len(p), len(e.buf)-e.n) 190 copy(e.buf[e.n:], p[:nCopy]) 191 e.n += nCopy 192 written += nCopy 193 p = p[nCopy:] 194 195 if e.n == len(e.buf) { 196 if err = e.flush(); err != nil { 197 return written, err 198 } 199 } 200 } 201 202 return written, nil 203 } 204 205 func (e *enc) flush() (err error) { 206 if e.buf == nil { 207 return 208 } 209 210 n, err := e.bm.CryptBlocks(e.sealed, e.buf[:e.n], e.sealBuf) 211 if err != nil { 212 return 213 } 214 215 _, err = e.w.Write(e.sealed[:n]) 216 if err != nil { 217 return err 218 } 219 220 e.n = 0 221 return 222 } 223 224 func (e *enc) Flush() (err error) { 225 defer utils.FlushAnyway(e.w) 226 if e.n > 0 { 227 err = e.flush() 228 } 229 return 230 } 231 232 func (e *enc) Close() (err error) { 233 defer func() { 234 utils.CloseAnyway(e.w) 235 if e.cb != nil { 236 e.cb() 237 } 238 }() 239 240 return e.Flush() 241 }