github.com/sandwich-go/boost@v1.3.29/xencoding/encrypt/encrypt.go (about)

     1  package encrypt
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"github.com/sandwich-go/boost/xcrypto/algorithm/aes"
     7  	"github.com/sandwich-go/boost/xencoding"
     8  	"github.com/sandwich-go/boost/xpanic"
     9  )
    10  
    11  var (
    12  	errCodecMarshalParam   = errors.New("encrypt codec marshal must be []byte parameter")
    13  	errCodecUnmarshalParam = errors.New("encrypt codec unmarshal must be *[]byte parameter")
    14  	errCodecNoFound        = errors.New("encrypt codec not found")
    15  )
    16  
    17  const (
    18  	// NoneCodecName 无加解密效果名称,可以通过 encoding2.GetCodec(NoneCodecName) 获取对应的 Codec
    19  	NoneCodecName = "none_encrypt"
    20  	// AESCodecName aes 加解密名称,可以通过 encoding2.GetCodec(AESCodecName) 获取对应的 Codec
    21  	AESCodecName = "aes_encrypt"
    22  )
    23  
    24  // KeySetter key 的设置
    25  // 如果自定义的 Codec 实现了 KeySetter 接口,那么在 NewCodec 的时候,会将 key 设置进去
    26  type KeySetter interface {
    27  	SetKey(key []byte)
    28  }
    29  
    30  var (
    31  	// NoneCodec 无加解密效果
    32  	NoneCodec = noneCodec{}
    33  	// AESCodec aes 加解密
    34  	AESCodec = aesCodec{}
    35  )
    36  
    37  // SetKey 设置加解密 key
    38  func SetKey(key []byte) { AESCodec.key = key }
    39  
    40  var codecs = map[Type]xencoding.Codec{
    41  	NoneType: NoneCodec,
    42  	AESType:  AESCodec,
    43  }
    44  
    45  func init() {
    46  	for _, v := range codecs {
    47  		xencoding.RegisterCodec(v)
    48  	}
    49  }
    50  
    51  // Register 注册自定义的加解密 Codec ,该方法非协程安全
    52  func Register(t Type, codec xencoding.Codec) {
    53  	_, exists := codecs[t]
    54  	xpanic.WhenTrue(exists, "register called twice for codec, %d", t)
    55  	codecs[t] = codec
    56  	xencoding.RegisterCodec(codec)
    57  }
    58  
    59  type noneCodec struct{}
    60  
    61  func (c noneCodec) Name() string { return NoneCodecName }
    62  func (c noneCodec) Marshal(_ context.Context, v interface{}) ([]byte, error) {
    63  	if data, ok := v.([]byte); !ok {
    64  		return nil, errCodecMarshalParam
    65  	} else {
    66  		return data, nil
    67  	}
    68  }
    69  func (c noneCodec) Unmarshal(_ context.Context, bytes []byte, v interface{}) error {
    70  	v1, ok := v.(*[]byte)
    71  	if !ok {
    72  		return errCodecUnmarshalParam
    73  	}
    74  	*v1 = bytes
    75  	return nil
    76  }
    77  
    78  type aesCodec struct {
    79  	key []byte
    80  }
    81  
    82  func (c aesCodec) Name() string { return AESCodecName }
    83  func (c aesCodec) Marshal(_ context.Context, v interface{}) ([]byte, error) {
    84  	if data, ok := v.([]byte); !ok {
    85  		return nil, errCodecMarshalParam
    86  	} else {
    87  		return aes.Encrypt(data, c.key)
    88  	}
    89  }
    90  
    91  func (c aesCodec) Unmarshal(_ context.Context, bytes []byte, v interface{}) error {
    92  	v1, ok := v.(*[]byte)
    93  	if !ok {
    94  		return errCodecUnmarshalParam
    95  	}
    96  	data, err := aes.Decrypt(bytes, c.key)
    97  	if err != nil {
    98  		return err
    99  	}
   100  	*v1 = data
   101  	return nil
   102  }
   103  
   104  type codec struct {
   105  	xencoding.Codec
   106  	encryptType Type
   107  }
   108  
   109  // NewCodec 通过类型创建解压缩 Codec
   110  func NewCodec(encryptType Type, key []byte) xencoding.Codec {
   111  	c := codec{encryptType: encryptType}
   112  	switch encryptType {
   113  	case AESType:
   114  		c.Codec = aesCodec{key: key}
   115  	case NoneType:
   116  		c.Codec = noneCodec{}
   117  	default:
   118  		c.Codec = codecs[encryptType]
   119  	}
   120  	if c.Codec == nil {
   121  		xpanic.WhenError(errCodecNoFound)
   122  	}
   123  	if cc, ok := c.Codec.(KeySetter); ok {
   124  		cc.SetKey(key)
   125  	}
   126  	return c
   127  }
   128  
   129  // Name 返回 Codec 名
   130  func (c codec) Name() string { return "encrypt" }
   131  
   132  // Marshal 编码
   133  func (c codec) Marshal(ctx context.Context, v interface{}) ([]byte, error) {
   134  	if c.Codec == nil {
   135  		return nil, errCodecNoFound
   136  	}
   137  	return c.Codec.Marshal(ctx, v)
   138  }
   139  
   140  // Unmarshal 解码
   141  func (c codec) Unmarshal(ctx context.Context, bytes []byte, v interface{}) error {
   142  	if c.Codec == nil {
   143  		return errCodecNoFound
   144  	}
   145  	return c.Codec.Unmarshal(ctx, bytes, v)
   146  }