github.com/spi-ca/misc@v1.0.1/crypto/stream.go (about)

     1  package crypto
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"errors"
     7  	"io"
     8  )
     9  
    10  // ErrShortCiphertext is an error that has text too short.
    11  var ErrShortCiphertext = errors.New("input too short to be valid ciphertext")
    12  
    13  /*
    14  NewCryptoWriter returns block cipher writer.
    15  from https://github.com/acasajus/dkeyczar/blob/master/streams.go
    16  */
    17  func NewCryptoWriter(bm cipher.BlockMode, sink io.WriteCloser) io.WriteCloser {
    18  	return &cryptoWriter{
    19  		bm:     bm,
    20  		sink:   sink,
    21  		buffer: bytes.NewBuffer(nil),
    22  	}
    23  }
    24  
    25  type cryptoWriter struct {
    26  	bm     cipher.BlockMode
    27  	buffer *bytes.Buffer
    28  	sink   io.WriteCloser
    29  	count  int
    30  }
    31  
    32  func (c *cryptoWriter) Write(data []byte) (int, error) {
    33  	if _, err := c.buffer.Write(data); err != nil {
    34  		return 0, err
    35  	}
    36  	bL := c.buffer.Len() - c.buffer.Len()%c.bm.BlockSize()
    37  	tmp := c.buffer.Next(bL)
    38  	c.bm.CryptBlocks(tmp, tmp)
    39  	wL := 0
    40  	for wL < len(tmp) {
    41  		n, err := c.sink.Write(tmp[wL:])
    42  		if err != nil {
    43  			return 0, err
    44  		}
    45  		wL += n
    46  	}
    47  	c.count += wL
    48  	return len(data), nil
    49  }
    50  
    51  func (c *cryptoWriter) Close() error {
    52  	tmp := Pkcs5pad(c.buffer.Next(c.buffer.Len()), c.bm.BlockSize())
    53  	c.bm.CryptBlocks(tmp, tmp)
    54  	wL := 0
    55  	for wL < len(tmp) {
    56  		n, err := c.sink.Write(tmp[wL:])
    57  		if err != nil {
    58  			return err
    59  		}
    60  		wL += n
    61  	}
    62  	c.count += wL
    63  	return c.sink.Close()
    64  }
    65  
    66  func NewCryptoReader(bm cipher.BlockMode, source io.ReadCloser) io.ReadCloser {
    67  	return &cryptoReader{
    68  		bm:     bm,
    69  		source: source,
    70  		outBuf: bytes.NewBuffer(nil),
    71  		inBuf:  bytes.NewBuffer(nil),
    72  		eof:    false,
    73  	}
    74  }
    75  
    76  type cryptoReader struct {
    77  	bm     cipher.BlockMode
    78  	outBuf *bytes.Buffer
    79  	inBuf  *bytes.Buffer
    80  	source io.ReadCloser
    81  	eof    bool
    82  }
    83  
    84  func (cr *cryptoReader) Read(data []byte) (int, error) {
    85  	missing := len(data) - cr.outBuf.Len()
    86  	for !cr.eof && missing > 0 {
    87  		toRead := missing + cr.bm.BlockSize() + 1 //Always go beyond the required data to be able to unpad when eof'ed
    88  		if off := toRead % cr.bm.BlockSize(); off > 0 {
    89  			toRead += cr.bm.BlockSize() - off //Make sure we read in multiples of blocksize
    90  		}
    91  		cr.inBuf.Grow(toRead)
    92  		n, err := io.CopyN(cr.inBuf, cr.source, int64(toRead))
    93  		if err == io.EOF {
    94  			cr.eof = true
    95  		} else if err != nil {
    96  			return 0, err
    97  		}
    98  		readBytes := int(n)
    99  		if readBytes%cr.bm.BlockSize() > 0 && cr.eof {
   100  			return 0, ErrShortCiphertext
   101  		}
   102  		bytesToDec := readBytes - readBytes%cr.bm.BlockSize()
   103  		tmpdata := cr.inBuf.Next(bytesToDec)
   104  		cr.bm.CryptBlocks(tmpdata, tmpdata)
   105  		if _, err := cr.outBuf.Write(tmpdata); err != nil {
   106  			return 0, err
   107  		}
   108  		if cr.eof {
   109  			pad := cr.outBuf.Bytes()[cr.outBuf.Len()-1]
   110  			cr.outBuf.Truncate(cr.outBuf.Len() - int(pad))
   111  		}
   112  		missing = len(data) - cr.outBuf.Len()
   113  	}
   114  	return cr.outBuf.Read(data)
   115  }
   116  
   117  func (cr *cryptoReader) Close() error {
   118  	cr.eof = true
   119  	return cr.source.Close()
   120  }