github.com/marinho/drone@v0.2.1-0.20140504195434-d3ba962e89a7/pkg/database/encrypt/encrypt.go (about)

     1  package encrypt
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"encoding/gob"
     8  	"fmt"
     9  	"io"
    10  )
    11  
    12  // EncryptedField handles encrypted and decryption of
    13  // values to and from database columns.
    14  type EncryptedField struct {
    15  	Cipher cipher.Block
    16  }
    17  
    18  // PreRead is called before a Scan operation. It is given a pointer to
    19  // the raw struct field, and returns the value that will be given to
    20  // the database driver.
    21  func (e *EncryptedField) PreRead(fieldAddr interface{}) (scanTarget interface{}, err error) {
    22  	// give a pointer to a byte buffer to grab the raw data
    23  	return new([]byte), nil
    24  }
    25  
    26  // PostRead is called after a Scan operation. It is given the value returned
    27  // by PreRead and a pointer to the raw struct field. It is expected to fill
    28  // in the struct field if the two are different.
    29  func (e *EncryptedField) PostRead(fieldAddr interface{}, scanTarget interface{}) error {
    30  	ptr := scanTarget.(*[]byte)
    31  	if ptr == nil {
    32  		return fmt.Errorf("encrypter.PostRead: nil pointer")
    33  	}
    34  	raw := *ptr
    35  
    36  	// ignore fields that aren't set at all
    37  	if len(raw) == 0 {
    38  		return nil
    39  	}
    40  
    41  	// decrypt value for gob decoding
    42  	var err error
    43  	raw, err = decrypt(e.Cipher, raw)
    44  	if err != nil {
    45  		return fmt.Errorf("Gob decryption error: %v", err)
    46  	}
    47  
    48  	// decode gob
    49  	gobDecoder := gob.NewDecoder(bytes.NewReader(raw))
    50  	if err := gobDecoder.Decode(fieldAddr); err != nil {
    51  		return fmt.Errorf("Gob decode error: %v", err)
    52  	}
    53  
    54  	return nil
    55  }
    56  
    57  // PreWrite is called before an Insert or Update operation. It is given
    58  // a pointer to the raw struct field, and returns the value that will be
    59  // given to the database driver.
    60  func (e *EncryptedField) PreWrite(field interface{}) (saveValue interface{}, err error) {
    61  	buffer := new(bytes.Buffer)
    62  
    63  	// gob encode
    64  	gobEncoder := gob.NewEncoder(buffer)
    65  	if err := gobEncoder.Encode(field); err != nil {
    66  		return nil, fmt.Errorf("Gob encoding error: %v", err)
    67  	}
    68  	// and then ecrypt
    69  	encrypted, err := encrypt(e.Cipher, buffer.Bytes())
    70  	if err != nil {
    71  		return nil, fmt.Errorf("Gob decryption error: %v", err)
    72  	}
    73  
    74  	return encrypted, nil
    75  }
    76  
    77  // encrypt is a helper function to encrypt a slice
    78  // of bytes using the specified block cipher.
    79  func encrypt(block cipher.Block, v []byte) ([]byte, error) {
    80  	// if no block cipher value exists we'll assume
    81  	// the database is running in non-ecrypted mode.
    82  	if block == nil {
    83  		return v, nil
    84  	}
    85  
    86  	value := make([]byte, len(v))
    87  	copy(value, v)
    88  
    89  	// Generate a random initialization vector
    90  	iv := generateRandomKey(block.BlockSize())
    91  	if len(iv) != block.BlockSize() {
    92  		return nil, fmt.Errorf("Could not generate a valid initialization vector for encryption")
    93  	}
    94  
    95  	// Encrypt it.
    96  	stream := cipher.NewCTR(block, iv)
    97  	stream.XORKeyStream(value, value)
    98  
    99  	// Return iv + ciphertext.
   100  	return append(iv, value...), nil
   101  }
   102  
   103  // decrypt is a helper function to decrypt a slice
   104  // using the specified block cipher.
   105  func decrypt(block cipher.Block, value []byte) ([]byte, error) {
   106  	// if no block cipher value exists we'll assume
   107  	// the database is running in non-ecrypted mode.
   108  	if block == nil {
   109  		return value, nil
   110  	}
   111  
   112  	size := block.BlockSize()
   113  	if len(value) > size {
   114  		// Extract iv.
   115  		iv := value[:size]
   116  		// Extract ciphertext.
   117  		value = value[size:]
   118  		// Decrypt it.
   119  		stream := cipher.NewCTR(block, iv)
   120  		stream.XORKeyStream(value, value)
   121  		return value, nil
   122  	}
   123  	return nil, fmt.Errorf("Could not decrypt the value")
   124  }
   125  
   126  // GenerateRandomKey creates a random key of size length bytes
   127  func generateRandomKey(strength int) []byte {
   128  	k := make([]byte, strength)
   129  	if _, err := io.ReadFull(rand.Reader, k); err != nil {
   130  		return nil
   131  	}
   132  	return k
   133  }