github.com/wfusion/gofusion@v1.1.14/common/utils/cipher/decrypt.go (about)

     1  package cipher
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/wfusion/gofusion/common/utils"
    10  )
    11  
    12  func DecryptBytesFunc(algo Algorithm, mode Mode, key, iv []byte) (
    13  	enc func(src []byte) (dst []byte, err error), err error) {
    14  	bm, err := getDecrypter(algo, mode, key, iv)
    15  	if err != nil {
    16  		return
    17  	}
    18  	mode = bm.CipherMode()
    19  	return func(src []byte) (dst []byte, err error) {
    20  		dst, err = decrypt(bm, src)
    21  		if err != nil {
    22  			return
    23  		}
    24  		if mode.ShouldPadding() {
    25  			dst, err = PKCS7Unpad(dst)
    26  		}
    27  		return
    28  	}, err
    29  }
    30  
    31  func DecryptStreamFunc(algo Algorithm, mode Mode, key, iv []byte) (
    32  	dec func(dst io.Writer, src io.Reader) (err error), err error) {
    33  	var fn func(src io.Reader) io.ReadCloser
    34  
    35  	_, err = utils.Catch(func() { fn = NewDecFunc(algo, mode, key, iv) })
    36  	if err != nil {
    37  		return
    38  	}
    39  
    40  	dec = func(dst io.Writer, src io.Reader) (err error) {
    41  		buf, cb := utils.BytesPool.Get(defaultBlockSize * blockSizeTimes)
    42  		defer cb()
    43  
    44  		wrapper := fn(src)
    45  		defer utils.CloseAnyway(wrapper)
    46  		_, err = io.CopyBuffer(dst, wrapper, buf)
    47  		return
    48  	}
    49  
    50  	return
    51  }
    52  
    53  func decrypt(bm blockMode, src []byte) (dst []byte, err error) {
    54  	plainBlockSize := bm.PlainBlockSize()
    55  	cipherBlockSize := bm.CipherBlockSize()
    56  	defers := make([]func(), 0, 3)
    57  
    58  	w, cb := utils.BytesBufferPool.Get(nil)
    59  	defers = append(defers, cb)
    60  
    61  	if plainBlockSize != 0 && cipherBlockSize != 0 {
    62  		w.Grow((len(src) / cipherBlockSize) * plainBlockSize)
    63  	} else {
    64  		plainBlockSize, cipherBlockSize = len(src), len(src)
    65  		w.Grow(plainBlockSize)
    66  	}
    67  
    68  	unsealed, cb := utils.BytesPool.Get(plainBlockSize)
    69  	defers = append(defers, cb)
    70  
    71  	buf, cb := utils.BytesPool.Get(plainBlockSize)
    72  	defers = append(defers, cb)
    73  
    74  	var n, blockSize int
    75  	for len(src) > 0 {
    76  		blockSize = utils.Min(cipherBlockSize, len(src))
    77  		n, err = bm.CryptBlocks(unsealed[:plainBlockSize], src[:blockSize], buf)
    78  		if err != nil {
    79  			return
    80  		}
    81  		if _, err = w.Write(unsealed[:n]); err != nil {
    82  			return
    83  		}
    84  		src = src[blockSize:]
    85  	}
    86  
    87  	bs := w.Bytes()
    88  	dst = make([]byte, len(bs))
    89  	copy(dst, bs)
    90  	return
    91  }
    92  
    93  func getDecrypter(algo Algorithm, mode Mode, key, iv []byte) (bm blockMode, err error) {
    94  	if blockMapping, ok := cipherBlockMapping[algo]; ok {
    95  		var cipherBlock cipher.Block
    96  		cipherBlock, err = blockMapping(key)
    97  		if err != nil {
    98  			return
    99  		}
   100  		modeMapping, ok := decryptModeMapping[mode]
   101  		if !ok {
   102  			return nil, fmt.Errorf("unknown cipher mode %+v", mode)
   103  		}
   104  		bm, err = modeMapping(cipherBlock, iv)
   105  		if err != nil {
   106  			return
   107  		}
   108  	}
   109  
   110  	// stream
   111  	if bm == nil {
   112  		blockMapping, ok := streamDecryptMapping[algo]
   113  		if !ok {
   114  			return nil, fmt.Errorf("unknown cipher algorithm %+v", algo)
   115  		}
   116  		if bm, err = blockMapping(key); err != nil {
   117  			return
   118  		}
   119  	}
   120  
   121  	if bm == nil {
   122  		return nil, fmt.Errorf("unknown cipher algorithm(%+v) or mode(%+v)", algo, mode)
   123  	}
   124  
   125  	return
   126  }
   127  
   128  type dec struct {
   129  	bm blockMode
   130  	r  io.Reader
   131  
   132  	buf  []byte
   133  	cb   func()
   134  	n    int // current position in buf
   135  	end  int // end of data in buf
   136  	size int
   137  
   138  	eof                 bool
   139  	unsealed, unsealBuf []byte
   140  }
   141  
   142  func NewDecFunc(algo Algorithm, mode Mode, key, iv []byte) func(r io.Reader) io.ReadCloser {
   143  	bm, err := getDecrypter(algo, mode, key, iv)
   144  	if err != nil {
   145  		panic(err)
   146  	}
   147  	if !bm.CipherMode().SupportStream() {
   148  		panic(ErrNotSupportStream)
   149  	}
   150  
   151  	var (
   152  		buf, unsealed, unsealBuf []byte
   153  		cb                       func()
   154  		size                     int
   155  	)
   156  
   157  	if size = bm.CipherBlockSize(); size > 0 {
   158  		var bcb, ucb, ubcb func()
   159  		buf, bcb = utils.BytesPool.Get(size)
   160  		unsealed, ucb = utils.BytesPool.Get(bm.PlainBlockSize())
   161  		unsealBuf, ubcb = utils.BytesPool.Get(bm.PlainBlockSize())
   162  		cb = func() { bcb(); ucb(); ubcb() }
   163  	}
   164  
   165  	return func(r io.Reader) io.ReadCloser {
   166  		return &dec{
   167  			bm:        bm,
   168  			r:         r,
   169  			buf:       buf,
   170  			cb:        cb,
   171  			n:         0,
   172  			end:       0,
   173  			size:      size,
   174  			eof:       false,
   175  			unsealed:  unsealed,
   176  			unsealBuf: unsealBuf,
   177  		}
   178  	}
   179  }
   180  
   181  func (d *dec) Read(p []byte) (n int, err error) {
   182  	var (
   183  		nr, nc int
   184  		dst    []byte
   185  	)
   186  
   187  	if d.buf == nil {
   188  		n, err = d.r.Read(p)
   189  		d.eof = errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
   190  		if err != nil && !d.eof {
   191  			return
   192  		}
   193  		dst, err = decrypt(d.bm, p[:n])
   194  		if err != nil {
   195  			return
   196  		}
   197  		n = copy(p[:len(dst)], dst)
   198  		if d.eof {
   199  			err = io.EOF
   200  		}
   201  		return
   202  	}
   203  
   204  	for len(p) > 0 {
   205  		// read from buffer
   206  		if length := d.end - d.n; length > 0 {
   207  			copied := utils.Min(length, len(p))
   208  			n += copy(p[:copied], d.buf[d.n:d.n+copied])
   209  			d.n += copied
   210  			p = p[copied:]
   211  			continue
   212  		}
   213  
   214  		// buffer is empty, write new buffer
   215  		if d.eof {
   216  			return n, io.EOF
   217  		}
   218  		d.n = 0
   219  		d.end = 0
   220  		for {
   221  			nr, err = d.r.Read(d.buf[d.n:d.size])
   222  			d.eof = errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
   223  			if err != nil && !d.eof {
   224  				return
   225  			}
   226  
   227  			d.n += nr
   228  			if d.n < d.size && !d.eof {
   229  				continue
   230  			}
   231  
   232  			nc, err = d.bm.CryptBlocks(d.unsealed, d.buf[:d.n], d.unsealBuf)
   233  			if err != nil {
   234  				return
   235  			}
   236  
   237  			d.end += copy(d.buf[:nc], d.unsealed[:nc])
   238  			d.n = 0
   239  			break
   240  		}
   241  	}
   242  
   243  	return
   244  }
   245  
   246  func (d *dec) Close() (err error) {
   247  	utils.CloseAnyway(d.r)
   248  	if d.cb != nil {
   249  		d.cb()
   250  	}
   251  	return
   252  }