github.com/wfusion/gofusion@v1.1.14/common/utils/cipher/encrypt.go (about)

     1  package cipher
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"fmt"
     6  	"io"
     7  
     8  	"github.com/wfusion/gofusion/common/utils"
     9  )
    10  
    11  func EncryptBytesFunc(algo Algorithm, mode Mode, key, iv []byte) (
    12  	enc func(src []byte) (dst []byte, err error), err error) {
    13  	bm, err := getEncrypter(algo, mode, key, iv)
    14  	if err != nil {
    15  		return
    16  	}
    17  	mode = bm.CipherMode()
    18  	plainBlockSize := bm.PlainBlockSize()
    19  	return func(src []byte) (dst []byte, err error) {
    20  		if mode.ShouldPadding() {
    21  			src = PKCS7Pad(src, plainBlockSize)
    22  		}
    23  		return encrypt(bm, src)
    24  	}, err
    25  }
    26  
    27  func EncryptStreamFunc(algo Algorithm, mode Mode, key, iv []byte) (
    28  	enc func(dst io.Writer, src io.Reader) (err error), err error) {
    29  	var fn func(w io.Writer) io.WriteCloser
    30  
    31  	_, err = utils.Catch(func() { fn = NewEncFunc(algo, mode, key, iv) })
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	enc = func(dst io.Writer, src io.Reader) (err error) {
    37  		buf, cb := utils.BytesPool.Get(defaultBlockSize * blockSizeTimes)
    38  		defer cb()
    39  
    40  		wrapper := fn(dst)
    41  		defer utils.CloseAnyway(wrapper)
    42  
    43  		_, err = io.CopyBuffer(wrapper, src, buf)
    44  		return
    45  	}
    46  
    47  	return
    48  }
    49  
    50  func encrypt(bm blockMode, src []byte) (dst []byte, err error) {
    51  	plainBlockSize := bm.PlainBlockSize()
    52  	cipherBlockSize := bm.CipherBlockSize()
    53  	defers := make([]func(), 0, 3)
    54  	defer func() {
    55  		for _, cb := range defers {
    56  			cb()
    57  		}
    58  	}()
    59  
    60  	w, cb := utils.BytesBufferPool.Get(nil)
    61  	defers = append(defers, cb)
    62  
    63  	if plainBlockSize != 0 && cipherBlockSize != 0 {
    64  		w.Grow((len(src) / plainBlockSize) * cipherBlockSize)
    65  	} else {
    66  		plainBlockSize, cipherBlockSize = len(src), len(src)
    67  		w.Grow(cipherBlockSize)
    68  	}
    69  
    70  	sealed, cb := utils.BytesPool.Get(cipherBlockSize)
    71  	defers = append(defers, cb)
    72  
    73  	buf, cb := utils.BytesPool.Get(plainBlockSize)
    74  	defers = append(defers, cb)
    75  
    76  	var n, blockSize int
    77  	for len(src) > 0 {
    78  		blockSize = utils.Min(plainBlockSize, len(src))
    79  		n, err = bm.CryptBlocks(sealed[:cipherBlockSize], src[:blockSize], buf)
    80  		if err != nil {
    81  			return
    82  		}
    83  		if _, err = w.Write(sealed[:n]); err != nil {
    84  			return
    85  		}
    86  		src = src[blockSize:]
    87  	}
    88  
    89  	bs := w.Bytes()
    90  	dst = make([]byte, len(bs))
    91  	copy(dst, bs)
    92  	return
    93  }
    94  
    95  func getEncrypter(algo Algorithm, mode Mode, key, iv []byte) (bm blockMode, err error) {
    96  	if blockMapping, ok := cipherBlockMapping[algo]; ok {
    97  		var cipherBlock cipher.Block
    98  		cipherBlock, err = blockMapping(key)
    99  		if err != nil {
   100  			return
   101  		}
   102  		modeMapping, ok := encryptModeMapping[mode]
   103  		if !ok {
   104  			return nil, fmt.Errorf("unknown cipher mode %+v", mode)
   105  		}
   106  		bm, err = modeMapping(cipherBlock, iv)
   107  		if err != nil {
   108  			return
   109  		}
   110  	}
   111  
   112  	// stream
   113  	if bm == nil {
   114  		blockMapping, ok := streamEncryptMapping[algo]
   115  		if !ok {
   116  			return nil, fmt.Errorf("unknown cipher algorithm %+v", algo)
   117  		}
   118  		if bm, err = blockMapping(key); err != nil {
   119  			return
   120  		}
   121  	}
   122  
   123  	if bm == nil {
   124  		return nil, fmt.Errorf("unknown cipher algorithm(%+v) or mode(%+v)", algo, mode)
   125  	}
   126  
   127  	return
   128  }
   129  
   130  type enc struct {
   131  	bm blockMode
   132  	w  io.Writer
   133  
   134  	n   int
   135  	buf []byte
   136  	cb  func()
   137  
   138  	sealed, sealBuf []byte
   139  }
   140  
   141  func NewEncFunc(algo Algorithm, mode Mode, key, iv []byte) func(w io.Writer) io.WriteCloser {
   142  	bm, err := getEncrypter(algo, mode, key, iv)
   143  	if err != nil {
   144  		panic(err)
   145  	}
   146  	if !bm.CipherMode().SupportStream() {
   147  		panic(ErrNotSupportStream)
   148  	}
   149  
   150  	var (
   151  		buf, sealed, sealBuf []byte
   152  		cb                   func()
   153  	)
   154  	if bm.PlainBlockSize() > 0 {
   155  		var bcb, scb, sbcb func()
   156  		buf, bcb = utils.BytesPool.Get(bm.PlainBlockSize())
   157  		sealed, scb = utils.BytesPool.Get(bm.CipherBlockSize())
   158  		sealBuf, sbcb = utils.BytesPool.Get(bm.PlainBlockSize())
   159  		cb = func() { bcb(); scb(); sbcb() }
   160  	}
   161  
   162  	return func(w io.Writer) io.WriteCloser {
   163  		return &enc{
   164  			bm:      bm,
   165  			w:       w,
   166  			n:       0,
   167  			buf:     buf,
   168  			cb:      cb,
   169  			sealed:  sealed,
   170  			sealBuf: sealBuf,
   171  		}
   172  	}
   173  }
   174  
   175  func (e *enc) Write(p []byte) (n int, err error) {
   176  	if e.buf == nil {
   177  		var dst []byte
   178  		dst, err = encrypt(e.bm, p)
   179  		if err != nil {
   180  			return
   181  		}
   182  		n = len(p)
   183  		_, err = e.w.Write(dst)
   184  		return
   185  	}
   186  
   187  	written := 0
   188  	for len(p) > 0 {
   189  		nCopy := utils.Min(len(p), len(e.buf)-e.n)
   190  		copy(e.buf[e.n:], p[:nCopy])
   191  		e.n += nCopy
   192  		written += nCopy
   193  		p = p[nCopy:]
   194  
   195  		if e.n == len(e.buf) {
   196  			if err = e.flush(); err != nil {
   197  				return written, err
   198  			}
   199  		}
   200  	}
   201  
   202  	return written, nil
   203  }
   204  
   205  func (e *enc) flush() (err error) {
   206  	if e.buf == nil {
   207  		return
   208  	}
   209  
   210  	n, err := e.bm.CryptBlocks(e.sealed, e.buf[:e.n], e.sealBuf)
   211  	if err != nil {
   212  		return
   213  	}
   214  
   215  	_, err = e.w.Write(e.sealed[:n])
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	e.n = 0
   221  	return
   222  }
   223  
   224  func (e *enc) Flush() (err error) {
   225  	defer utils.FlushAnyway(e.w)
   226  	if e.n > 0 {
   227  		err = e.flush()
   228  	}
   229  	return
   230  }
   231  
   232  func (e *enc) Close() (err error) {
   233  	defer func() {
   234  		utils.CloseAnyway(e.w)
   235  		if e.cb != nil {
   236  			e.cb()
   237  		}
   238  	}()
   239  
   240  	return e.Flush()
   241  }