github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/coder/cipher.go (about)

     1  package coder
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"encoding/base64"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  )
    11  
    12  type Crypto interface {
    13  	Encrypt(plainText []byte) (string, error)
    14  	Decrypt(cipherText string) (string, error)
    15  }
    16  
    17  type Cipher struct {
    18  	GroupMode  int
    19  	FillMode   FillMode
    20  	DecodeType int
    21  	Key        []byte
    22  	Iv         []byte
    23  	Output     CipherText
    24  }
    25  
    26  func (c *Cipher) Encrypt(block cipher.Block, plainData []byte) (err error) {
    27  	c.Output = make([]byte, len(plainData))
    28  	if c.GroupMode == CBCMode {
    29  		cipher.NewCBCEncrypter(block, c.Iv).CryptBlocks(c.Output, plainData)
    30  		return
    31  	}
    32  	if c.GroupMode == ECBMode {
    33  		c.NewECBEncrypter(block, plainData)
    34  		return
    35  	}
    36  	return
    37  }
    38  
    39  func (c *Cipher) Decrypt(block cipher.Block, cipherData []byte) (err error) {
    40  	c.Output = make([]byte, len(cipherData))
    41  	if c.GroupMode == CBCMode {
    42  		cipher.NewCBCDecrypter(block, c.Iv).CryptBlocks(c.Output, cipherData)
    43  		return
    44  	}
    45  	if c.GroupMode == ECBMode {
    46  		c.NewECBDecrypter(block, cipherData)
    47  		return
    48  	}
    49  	return
    50  }
    51  
    52  // Encode default print format is base64
    53  func (c *Cipher) Encode() string {
    54  	if c.DecodeType == PrintHex {
    55  		return c.Output.hexEncode()
    56  	} else {
    57  		return c.Output.base64Encode()
    58  	}
    59  }
    60  
    61  func (c *Cipher) Decode(cipherText string) ([]byte, error) {
    62  	if c.DecodeType == PrintBase64 {
    63  		return base64Decode(cipherText)
    64  	} else if c.DecodeType == PrintHex {
    65  		return hexDecode(cipherText)
    66  	} else {
    67  		return nil, errors.New("unsupported print type")
    68  	}
    69  }
    70  
    71  func (c *Cipher) Fill(plainText []byte, blockSize int) []byte {
    72  	if c.FillMode == PkcsZero {
    73  		return c.FillMode.zeroPadding(plainText, blockSize)
    74  	} else {
    75  		return c.FillMode.pkcs7Padding(plainText, blockSize)
    76  	}
    77  }
    78  
    79  func (c *Cipher) UnFill(plainText []byte) (data []byte, err error) {
    80  	defer func() {
    81  		if r := recover(); r != nil {
    82  			var ok bool
    83  			err, ok = r.(error)
    84  			if !ok {
    85  				err = fmt.Errorf("%v", r)
    86  			}
    87  		}
    88  	}()
    89  	if c.FillMode == Pkcs7 {
    90  		return c.FillMode.pkcsUnPadding(plainText), nil
    91  	} else if c.FillMode == PkcsZero {
    92  		return c.FillMode.unZeroPadding(plainText), nil
    93  	} else {
    94  		return nil, errors.New("unsupported fill mode")
    95  	}
    96  }
    97  
    98  func (c *Cipher) NewECBEncrypter(block cipher.Block, plainData []byte) {
    99  	tempText := c.Output
   100  	for len(plainData) > 0 {
   101  		block.Encrypt(tempText, plainData[:block.BlockSize()])
   102  		plainData = plainData[block.BlockSize():]
   103  		tempText = tempText[block.BlockSize():]
   104  	}
   105  }
   106  
   107  func (c *Cipher) NewECBDecrypter(block cipher.Block, cipherData []byte) {
   108  	tempText := c.Output
   109  	for len(cipherData) > 0 {
   110  		block.Decrypt(tempText, cipherData[:block.BlockSize()])
   111  		cipherData = cipherData[block.BlockSize():]
   112  		tempText = tempText[block.BlockSize():]
   113  	}
   114  }
   115  
   116  const (
   117  	CBCMode = iota
   118  	CFBMode
   119  	CTRMode
   120  	ECBMode
   121  	OFBMode
   122  )
   123  
   124  type FillMode int
   125  
   126  const (
   127  	PkcsZero FillMode = iota
   128  	Pkcs7
   129  )
   130  
   131  func (fm FillMode) pkcs7Padding(plainText []byte, blockSize int) []byte {
   132  	paddingSize := blockSize - len(plainText)%blockSize
   133  	paddingText := bytes.Repeat([]byte{byte(paddingSize)}, paddingSize)
   134  	return append(plainText, paddingText...)
   135  }
   136  
   137  func (fm FillMode) pkcsUnPadding(plainText []byte) []byte {
   138  	length := len(plainText)
   139  	number := int(plainText[length-1])
   140  	return plainText[:length-number]
   141  }
   142  
   143  func (fm FillMode) zeroPadding(plainText []byte, blockSize int) []byte {
   144  	if plainText[len(plainText)-1] == 0 {
   145  		return nil
   146  	}
   147  	paddingSize := blockSize - len(plainText)%blockSize
   148  	paddingText := bytes.Repeat([]byte{byte(0)}, paddingSize)
   149  	return append(plainText, paddingText...)
   150  }
   151  
   152  func (fm FillMode) unZeroPadding(plainText []byte) []byte {
   153  	length := len(plainText)
   154  	count := 1
   155  	for i := length - 1; i > 0; i-- {
   156  		if plainText[i] == 0 && plainText[i-1] == plainText[i] {
   157  			count++
   158  		}
   159  	}
   160  	return plainText[:length-count]
   161  }
   162  
   163  type CipherText []byte
   164  
   165  const (
   166  	PrintHex = iota
   167  	PrintBase64
   168  )
   169  
   170  func (ct CipherText) hexEncode() string {
   171  	return hex.EncodeToString(ct)
   172  }
   173  
   174  func (ct CipherText) base64Encode() string {
   175  	return base64.StdEncoding.EncodeToString(ct)
   176  }
   177  
   178  func hexDecode(cipherText string) ([]byte, error) {
   179  	return hex.DecodeString(cipherText)
   180  }
   181  
   182  func base64Decode(cipherText string) ([]byte, error) {
   183  	return base64.StdEncoding.DecodeString(cipherText)
   184  }