github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/soliton/encrypt/ase_layer.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package encrypt 15 16 import ( 17 "crypto/aes" 18 "crypto/cipher" 19 "encoding/binary" 20 "errors" 21 "io" 22 "math/rand" 23 ) 24 25 var errInvalidBlockSize = errors.New("invalid encrypt causet size") 26 27 // defaultEncryptBlockSize indicates the default encrypt causet size in bytes 28 const defaultEncryptBlockSize = 1024 29 30 // CtrCipher encrypting data using AES in counter mode 31 type CtrCipher struct { 32 nonce uint64 33 causet cipher.Block 34 // encryptBlockSize indicates the encrypt causet size in bytes. 35 encryptBlockSize int64 36 // aesBlockCount indicates the total aes blocks in one encrypt causet 37 aesBlockCount int64 38 } 39 40 // NewCtrCipher return a CtrCipher using the default encrypt causet size 41 func NewCtrCipher() (ctr *CtrCipher, err error) { 42 return NewCtrCipherWithBlockSize(defaultEncryptBlockSize) 43 } 44 45 // NewCtrCipherWithBlockSize return a CtrCipher with the encrypt causet size 46 func NewCtrCipherWithBlockSize(encryptBlockSize int64) (ctr *CtrCipher, err error) { 47 key := make([]byte, aes.BlockSize) 48 rand.Read(key) 49 causet, err := aes.NewCipher(key) 50 if err != nil { 51 return nil, err 52 } 53 if encryptBlockSize%aes.BlockSize != 0 { 54 return nil, errInvalidBlockSize 55 } 56 ctr = new(CtrCipher) 57 ctr.causet = causet 58 ctr.nonce = rand.Uint64() 59 ctr.encryptBlockSize = encryptBlockSize 60 ctr.aesBlockCount = encryptBlockSize / aes.BlockSize 61 return 62 } 63 64 // stream returns a cipher.Stream be use to encrypts/decrypts 65 func (ctr *CtrCipher) stream(counter uint64) cipher.Stream { 66 counterBuf := make([]byte, aes.BlockSize) 67 binary.BigEndian.PutUint64(counterBuf, ctr.nonce) 68 binary.BigEndian.PutUint64(counterBuf[8:], counter) 69 return cipher.NewCTR(ctr.causet, counterBuf) 70 } 71 72 // Writer implements an io.WriteCloser, it encrypt data using AES before writing to the underlying object. 73 type Writer struct { 74 err error 75 w io.WriteCloser 76 n int 77 buf []byte 78 cipherStream cipher.Stream 79 } 80 81 // NewWriter returns a new Writer which encrypt data using AES before writing to the underlying object. 82 func NewWriter(w io.WriteCloser, ctrCipher *CtrCipher) *Writer { 83 writer := &Writer{w: w} 84 writer.buf = make([]byte, ctrCipher.encryptBlockSize) 85 writer.cipherStream = ctrCipher.stream(0) 86 return writer 87 } 88 89 // AvailableSize returns how many bytes are unused in the buffer. 90 func (w *Writer) AvailableSize() int { return len(w.buf) - w.n } 91 92 // Write implements the io.Writer interface. 93 func (w *Writer) Write(p []byte) (n int, err error) { 94 if w.err != nil { 95 return n, w.err 96 } 97 for len(p) > w.AvailableSize() && w.err == nil { 98 copiedNum := copy(w.buf[w.n:], p) 99 w.n += copiedNum 100 err = w.Flush() 101 if err != nil { 102 return 103 } 104 n += copiedNum 105 p = p[copiedNum:] 106 } 107 copiedNum := copy(w.buf[w.n:], p) 108 w.n += copiedNum 109 n += copiedNum 110 return 111 } 112 113 // Buffered returns the number of bytes that have been written into the current buffer. 114 func (w *Writer) Buffered() int { return w.n } 115 116 // Flush writes all the buffered data to the underlying object. 117 func (w *Writer) Flush() error { 118 if w.err != nil { 119 return w.err 120 } 121 if w.n == 0 { 122 return nil 123 } 124 w.cipherStream.XORKeyStream(w.buf[:w.n], w.buf[:w.n]) 125 n, err := w.w.Write(w.buf[:w.n]) 126 if n < w.n && err == nil { 127 err = io.ErrShortWrite 128 } 129 if err != nil { 130 w.err = err 131 return err 132 } 133 w.n = 0 134 return nil 135 } 136 137 // Close implements the io.Closer interface. 138 func (w *Writer) Close() (err error) { 139 err = w.Flush() 140 if err != nil { 141 return 142 } 143 return w.w.Close() 144 } 145 146 // Reader implements an io.ReadAt, reading from the input source after decrypting. 147 type Reader struct { 148 r io.ReaderAt 149 cipher *CtrCipher 150 } 151 152 // NewReader returns a new Reader which can read from the input source after decrypting. 153 func NewReader(r io.ReaderAt, ctrCipher *CtrCipher) *Reader { 154 reader := &Reader{r: r, cipher: ctrCipher} 155 return reader 156 } 157 158 // ReadAt implements the io.ReadAt interface. 159 func (r *Reader) ReadAt(p []byte, off int64) (n int, err error) { 160 if len(p) == 0 { 161 return 0, nil 162 } 163 offset := off % r.cipher.encryptBlockSize 164 counter := (off / r.cipher.encryptBlockSize) * r.cipher.aesBlockCount 165 cursor := off - offset 166 167 buf := make([]byte, r.cipher.encryptBlockSize) 168 var readNum int 169 cipherStream := r.cipher.stream(uint64(counter)) 170 for len(p) > 0 && err == nil { 171 readNum, err = r.r.ReadAt(buf, cursor) 172 if err != nil { 173 if readNum == 0 || err != io.EOF { 174 return n, err 175 } 176 err = nil 177 // continue if n > 0 and r.err is io.EOF 178 } 179 cursor += int64(readNum) 180 cipherStream.XORKeyStream(buf[:readNum], buf[:readNum]) 181 copiedNum := copy(p, buf[offset:readNum]) 182 n += copiedNum 183 p = p[copiedNum:] 184 offset = 0 185 } 186 return n, err 187 }