github.com/code-to-go/safepool.lib@v0.0.0-20221205180519-ee25e63c226e/security/aescrypt.go (about)

     1  package security
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/md5"
     7  	"crypto/rand"
     8  	"crypto/sha256"
     9  	"encoding/binary"
    10  	"errors"
    11  	"io"
    12  
    13  	"github.com/code-to-go/safepool.lib/core"
    14  
    15  	"github.com/zenazn/pkcs7pad"
    16  )
    17  
    18  func GenerateBytesKey(size int) []byte {
    19  	key := make([]byte, size)
    20  	_, err := rand.Read(key)
    21  	if err != nil {
    22  		panic(err)
    23  	}
    24  	return key
    25  }
    26  
    27  func EncryptBlock(key []byte, nonce []byte, data []byte) ([]byte, error) {
    28  	block, err := newBlock(key)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  
    33  	data = pkcs7pad.Pad(data, aes.BlockSize)
    34  	cipherdata := make([]byte, len(data))
    35  
    36  	mode := cipher.NewCBCEncrypter(block, nonce)
    37  	mode.CryptBlocks(cipherdata, data)
    38  	return cipherdata, nil
    39  }
    40  
    41  func DecryptBlock(key []byte, nonce []byte, cipherdata []byte) ([]byte, error) {
    42  	block, err := newBlock(key)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	data := make([]byte, len(cipherdata))
    48  	mode := cipher.NewCBCDecrypter(block, nonce)
    49  	mode.CryptBlocks(data, cipherdata)
    50  
    51  	data, err = pkcs7pad.Unpad(data)
    52  	if core.IsErr(err, "invalid padding in AES decrypted data: %v") {
    53  		return nil, err
    54  	}
    55  	return data, nil
    56  }
    57  
    58  type StreamReader struct {
    59  	loc    int
    60  	header []byte
    61  	r      cipher.StreamReader
    62  }
    63  
    64  func (sr *StreamReader) Read(p []byte) (n int, err error) {
    65  	if sr.loc < 8+aes.BlockSize {
    66  		m := copy(p[sr.loc:], sr.header)
    67  		sr.loc += m
    68  		n, err = sr.r.Read(p[m:])
    69  		return m + n, err
    70  	} else {
    71  		return sr.r.Read(p)
    72  	}
    73  }
    74  
    75  // EncryptedWriter wraps w with an OFB cipher stream.
    76  func EncryptingReader(keyId uint64, keyFunc func(uint64) []byte, r io.Reader) (*StreamReader, error) {
    77  
    78  	header := make([]byte, 8+aes.BlockSize)
    79  	binary.LittleEndian.PutUint64(header, keyId)
    80  
    81  	// generate random initial value
    82  	if _, err := io.ReadFull(rand.Reader, header[8:]); err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	value := keyFunc(keyId)
    87  	if value == nil {
    88  		return nil, errors.New("unknown encryption key")
    89  	}
    90  
    91  	block, err := newBlock(value)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	stream := cipher.NewOFB(block, header[8:])
    97  	return &StreamReader{
    98  		header: header,
    99  		r:      cipher.StreamReader{S: stream, R: r},
   100  	}, nil
   101  }
   102  
   103  type StreamWriter struct {
   104  	loc     int
   105  	header  []byte
   106  	keyFunc func(uint64) []byte
   107  	w       *cipher.StreamWriter
   108  }
   109  
   110  func (sr *StreamWriter) Write(p []byte) (n int, err error) {
   111  	if sr.w.S == nil {
   112  		m := copy(sr.header[sr.loc:], p)
   113  		sr.loc += m
   114  
   115  		if sr.loc == 8+aes.BlockSize {
   116  			keyId := binary.LittleEndian.Uint64(sr.header)
   117  			value := sr.keyFunc(keyId)
   118  			if value == nil {
   119  				return 0, errors.New("unknown encryption key")
   120  			}
   121  
   122  			block, err := newBlock(value)
   123  			if err != nil {
   124  				return 0, err
   125  			}
   126  
   127  			iv := sr.header[8:]
   128  			sr.w.S = cipher.NewOFB(block, iv)
   129  		}
   130  		return sr.w.Write(p[m:])
   131  	} else {
   132  		return sr.w.Write(p)
   133  	}
   134  }
   135  
   136  // EncryptedWriter wraps w with an OFB cipher stream.
   137  func DecryptingWriter(keyFunc func(uint64) []byte, w io.Writer) (*StreamWriter, error) {
   138  	return &StreamWriter{
   139  		keyFunc: keyFunc,
   140  		header:  make([]byte, 8+aes.BlockSize),
   141  		w:       &cipher.StreamWriter{S: nil, W: w},
   142  	}, nil
   143  }
   144  
   145  func newBlock(key []byte) (cipher.Block, error) {
   146  	sh := sha256.Sum256(key)
   147  	hash := md5.Sum(sh[:])
   148  	block, err := aes.NewCipher(hash[:])
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	return block, nil
   153  }