github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/soliton/encrypt/ase_layer.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package encrypt
    15  
    16  import (
    17  	"crypto/aes"
    18  	"crypto/cipher"
    19  	"encoding/binary"
    20  	"errors"
    21  	"io"
    22  	"math/rand"
    23  )
    24  
    25  var errInvalidBlockSize = errors.New("invalid encrypt causet size")
    26  
    27  // defaultEncryptBlockSize indicates the default encrypt causet size in bytes
    28  const defaultEncryptBlockSize = 1024
    29  
    30  // CtrCipher encrypting data using AES in counter mode
    31  type CtrCipher struct {
    32  	nonce uint64
    33  	causet cipher.Block
    34  	// encryptBlockSize indicates the encrypt causet size in bytes.
    35  	encryptBlockSize int64
    36  	// aesBlockCount indicates the total aes blocks in one encrypt causet
    37  	aesBlockCount int64
    38  }
    39  
    40  // NewCtrCipher return a CtrCipher using the default encrypt causet size
    41  func NewCtrCipher() (ctr *CtrCipher, err error) {
    42  	return NewCtrCipherWithBlockSize(defaultEncryptBlockSize)
    43  }
    44  
    45  // NewCtrCipherWithBlockSize return a CtrCipher with the encrypt causet size
    46  func NewCtrCipherWithBlockSize(encryptBlockSize int64) (ctr *CtrCipher, err error) {
    47  	key := make([]byte, aes.BlockSize)
    48  	rand.Read(key)
    49  	causet, err := aes.NewCipher(key)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	if encryptBlockSize%aes.BlockSize != 0 {
    54  		return nil, errInvalidBlockSize
    55  	}
    56  	ctr = new(CtrCipher)
    57  	ctr.causet = causet
    58  	ctr.nonce = rand.Uint64()
    59  	ctr.encryptBlockSize = encryptBlockSize
    60  	ctr.aesBlockCount = encryptBlockSize / aes.BlockSize
    61  	return
    62  }
    63  
    64  // stream returns a cipher.Stream be use to encrypts/decrypts
    65  func (ctr *CtrCipher) stream(counter uint64) cipher.Stream {
    66  	counterBuf := make([]byte, aes.BlockSize)
    67  	binary.BigEndian.PutUint64(counterBuf, ctr.nonce)
    68  	binary.BigEndian.PutUint64(counterBuf[8:], counter)
    69  	return cipher.NewCTR(ctr.causet, counterBuf)
    70  }
    71  
    72  // Writer implements an io.WriteCloser, it encrypt data using AES before writing to the underlying object.
    73  type Writer struct {
    74  	err          error
    75  	w            io.WriteCloser
    76  	n            int
    77  	buf          []byte
    78  	cipherStream cipher.Stream
    79  }
    80  
    81  // NewWriter returns a new Writer which encrypt data using AES before writing to the underlying object.
    82  func NewWriter(w io.WriteCloser, ctrCipher *CtrCipher) *Writer {
    83  	writer := &Writer{w: w}
    84  	writer.buf = make([]byte, ctrCipher.encryptBlockSize)
    85  	writer.cipherStream = ctrCipher.stream(0)
    86  	return writer
    87  }
    88  
    89  // AvailableSize returns how many bytes are unused in the buffer.
    90  func (w *Writer) AvailableSize() int { return len(w.buf) - w.n }
    91  
    92  // Write implements the io.Writer interface.
    93  func (w *Writer) Write(p []byte) (n int, err error) {
    94  	if w.err != nil {
    95  		return n, w.err
    96  	}
    97  	for len(p) > w.AvailableSize() && w.err == nil {
    98  		copiedNum := copy(w.buf[w.n:], p)
    99  		w.n += copiedNum
   100  		err = w.Flush()
   101  		if err != nil {
   102  			return
   103  		}
   104  		n += copiedNum
   105  		p = p[copiedNum:]
   106  	}
   107  	copiedNum := copy(w.buf[w.n:], p)
   108  	w.n += copiedNum
   109  	n += copiedNum
   110  	return
   111  }
   112  
   113  // Buffered returns the number of bytes that have been written into the current buffer.
   114  func (w *Writer) Buffered() int { return w.n }
   115  
   116  // Flush writes all the buffered data to the underlying object.
   117  func (w *Writer) Flush() error {
   118  	if w.err != nil {
   119  		return w.err
   120  	}
   121  	if w.n == 0 {
   122  		return nil
   123  	}
   124  	w.cipherStream.XORKeyStream(w.buf[:w.n], w.buf[:w.n])
   125  	n, err := w.w.Write(w.buf[:w.n])
   126  	if n < w.n && err == nil {
   127  		err = io.ErrShortWrite
   128  	}
   129  	if err != nil {
   130  		w.err = err
   131  		return err
   132  	}
   133  	w.n = 0
   134  	return nil
   135  }
   136  
   137  // Close implements the io.Closer interface.
   138  func (w *Writer) Close() (err error) {
   139  	err = w.Flush()
   140  	if err != nil {
   141  		return
   142  	}
   143  	return w.w.Close()
   144  }
   145  
   146  // Reader implements an io.ReadAt, reading from the input source after decrypting.
   147  type Reader struct {
   148  	r      io.ReaderAt
   149  	cipher *CtrCipher
   150  }
   151  
   152  // NewReader returns a new Reader which can read from the input source after decrypting.
   153  func NewReader(r io.ReaderAt, ctrCipher *CtrCipher) *Reader {
   154  	reader := &Reader{r: r, cipher: ctrCipher}
   155  	return reader
   156  }
   157  
   158  // ReadAt implements the io.ReadAt interface.
   159  func (r *Reader) ReadAt(p []byte, off int64) (n int, err error) {
   160  	if len(p) == 0 {
   161  		return 0, nil
   162  	}
   163  	offset := off % r.cipher.encryptBlockSize
   164  	counter := (off / r.cipher.encryptBlockSize) * r.cipher.aesBlockCount
   165  	cursor := off - offset
   166  
   167  	buf := make([]byte, r.cipher.encryptBlockSize)
   168  	var readNum int
   169  	cipherStream := r.cipher.stream(uint64(counter))
   170  	for len(p) > 0 && err == nil {
   171  		readNum, err = r.r.ReadAt(buf, cursor)
   172  		if err != nil {
   173  			if readNum == 0 || err != io.EOF {
   174  				return n, err
   175  			}
   176  			err = nil
   177  			// continue if n > 0 and r.err is io.EOF
   178  		}
   179  		cursor += int64(readNum)
   180  		cipherStream.XORKeyStream(buf[:readNum], buf[:readNum])
   181  		copiedNum := copy(p, buf[offset:readNum])
   182  		n += copiedNum
   183  		p = p[copiedNum:]
   184  		offset = 0
   185  	}
   186  	return n, err
   187  }