github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/crypto/encryption/encryption.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package encryption
     6  
     7  import (
     8  	"crypto/cipher"
     9  	"crypto/hmac"
    10  	"crypto/rand"
    11  	"encoding/hex"
    12  	"fmt"
    13  	"hash"
    14  	"io"
    15  )
    16  
    17  // KeyID represents the ID used to identify a particular key.
    18  type KeyID []byte
    19  
    20  // MarshalJSON marshals a KeyID as a hex encoded string.
    21  func (id KeyID) MarshalJSON() ([]byte, error) {
    22  	if len(id) == 0 {
    23  		return []byte(`""`), nil
    24  	}
    25  	dst := make([]byte, hex.EncodedLen(len(id))+2)
    26  	hex.Encode(dst[1:], id)
    27  	// need to supply leading/trailing double quotes.
    28  	dst[0], dst[len(dst)-1] = '"', '"'
    29  	return dst, nil
    30  }
    31  
    32  // UnmarshalJSON unmarshals a hex encoded string into a KeyID.
    33  func (id *KeyID) UnmarshalJSON(data []byte) error {
    34  	// need to strip leading and trailing double quotes
    35  	if data[0] != '"' || data[len(data)-1] != '"' {
    36  		return fmt.Errorf("KeyID is not quoted")
    37  	}
    38  	data = data[1 : len(data)-1]
    39  	*id = make([]byte, hex.DecodedLen(len(data)))
    40  	_, err := hex.Decode(*id, data)
    41  	return err
    42  }
    43  
    44  // KeyDescriptor represents a given key and any associated options.
    45  type KeyDescriptor struct {
    46  	Registry string                 `json:"registry"`
    47  	ID       KeyID                  `json:"keyid"`
    48  	Options  map[string]interface{} `json:"options,omitempty"`
    49  }
    50  
    51  // Encrypter defines encryption methods.
    52  type Encrypter interface {
    53  	// CipherrextSize returns the size of the ciphertext that
    54  	// that will result from supplied plaintext. It should be used
    55  	// to size the slice supplied to Encrypt.
    56  	CiphertextSize(plaintext []byte) int
    57  
    58  	// CiphertextSizeSlices returns the size of the ciphertext that
    59  	// will result from the supplied plaintext slices. It should be used
    60  	// to size the slice supplied to EncryptSlices.
    61  	CiphertextSizeSlices(plaintexts ...[]byte) int
    62  
    63  	// Encrypt encrypts the plaintext in src into ciphertext in
    64  	// dst. dst must be at least CiphertextSize() bytes large.
    65  	Encrypt(src, dst []byte) error
    66  
    67  	// EncryptSlices encrypts the plaintext slices as a single
    68  	// block. It is intended to avoid the need for an external copy
    69  	// to obtain a single buffer for use with Encrypt. The slices
    70  	// will be decrypted as a single block.
    71  	EncryptSlices(dst []byte, src ...[]byte) error
    72  
    73  	// TODO: AEAD method.
    74  	//	Seal(header []byte, p []byte) (int, err)
    75  }
    76  
    77  // Decrypter defines decryption methods.
    78  type Decrypter interface {
    79  	// PlaintextSize returns the size of the decrypted plaintext
    80  	// that will result from decrypting the supplied ciphertext.
    81  	// Note that this size will include any checksums enrypted
    82  	// with the original plaintext.
    83  	PlaintextSize(ciphertext []byte) int
    84  	// Decrypt decrypts the ciphertext in src into plaintext stored in dst and
    85  	// returns slices that contain the checksum of the original plaintext
    86  	// and the plaintext. dst should be at least PlainTextSize() bytes big.
    87  	Decrypt(src, dst []byte) (sum, plaintext []byte, err error)
    88  }
    89  
    90  type engine struct {
    91  	reg KeyRegistry
    92  	kd  KeyDescriptor
    93  }
    94  
    95  var randomSource = rand.Reader
    96  
    97  func (e *engine) initIV(b []byte) (iv, buf []byte, err error) {
    98  	bs := e.reg.BlockSize()
    99  	iv, buf = b[:bs], b[bs:]
   100  	n, err := io.ReadFull(randomSource, iv)
   101  	if err != nil {
   102  		err = fmt.Errorf("failed to read %d bytes of random data: %v", len(iv), err)
   103  		return
   104  	}
   105  	if n != len(iv) {
   106  		err = fmt.Errorf("failed to generate complete iv: %d < %d", n, len(b))
   107  	}
   108  	return iv, buf, err
   109  }
   110  
   111  func (e *engine) readIV(b []byte) (iv, buf []byte, err error) {
   112  	bs := e.reg.BlockSize()
   113  	if len(b) < bs {
   114  		return nil, nil, fmt.Errorf("failed to read IV")
   115  	}
   116  	return b[:bs], b[bs:], nil
   117  }
   118  
   119  // CiphertextSize implements Encrypter.
   120  func (e *engine) CiphertextSize(plaintext []byte) int {
   121  	return e.reg.BlockSize() + e.reg.HMACSize() + len(plaintext)
   122  }
   123  
   124  // CiphertextSizeSlices implements Encrypter.
   125  func (e *engine) CiphertextSizeSlices(plaintext ...[]byte) int {
   126  	total := e.reg.BlockSize() + e.reg.HMACSize()
   127  	for _, p := range plaintext {
   128  		total += len(p)
   129  	}
   130  	return total
   131  }
   132  
   133  // PlaintextSize impementes Decrypter.
   134  func (e *engine) PlaintextSize(ciphertext []byte) int {
   135  	return e.reg.HMACSize() + len(ciphertext)
   136  }
   137  
   138  func (e *engine) setup(dst []byte) ([]byte, hash.Hash, cipher.Stream, error) {
   139  	iv, buf, err := e.initIV(dst)
   140  	if err != nil {
   141  		return nil, nil, nil, err
   142  	}
   143  	// Obtain an hmac hash and cipher.Block from the registry using
   144  	// the specified key ID.
   145  	hmacSum, block, err := e.reg.NewBlock(e.kd.ID)
   146  	if err != nil {
   147  		return nil, nil, nil, err
   148  	}
   149  	stream := cipher.NewCFBEncrypter(block, iv)
   150  	// Generate and encrypt the sum.
   151  	hmacSum.Reset()
   152  	return buf, hmacSum, stream, nil
   153  }
   154  
   155  // Encrypt implements Encrypter.
   156  // Encrypt can be used concurrently.
   157  func (e *engine) Encrypt(src, dst []byte) error {
   158  	if len(dst) < e.CiphertextSize(src) {
   159  		return fmt.Errorf("dst is too small, size it using CiphertextSize()")
   160  	}
   161  	buf, hmacSum, stream, err := e.setup(dst)
   162  	if err != nil {
   163  		return err
   164  	}
   165  	hmacSum.Write(src)
   166  	stream.XORKeyStream(buf, hmacSum.Sum(nil))
   167  	// Encrypt plaintext
   168  	stream.XORKeyStream(buf[e.reg.HMACSize():], src)
   169  	return nil
   170  }
   171  
   172  // EncryptSlices implements Encrypter.
   173  // EncryptSlices can be used concurrently.
   174  func (e *engine) EncryptSlices(dst []byte, src ...[]byte) error {
   175  	if len(dst) < e.CiphertextSizeSlices(src...) {
   176  		return fmt.Errorf("dst is too small, size it using CiphertextSizeSlices()")
   177  	}
   178  	buf, hmacSum, stream, err := e.setup(dst)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	for _, s := range src {
   183  		hmacSum.Write(s)
   184  	}
   185  	stream.XORKeyStream(buf, hmacSum.Sum(nil))
   186  	buf = buf[e.reg.HMACSize():]
   187  	for _, s := range src {
   188  		stream.XORKeyStream(buf, s)
   189  		buf = buf[len(s):]
   190  	}
   191  	return nil
   192  }
   193  
   194  func newEngine(kd KeyDescriptor) (*engine, error) {
   195  	reg, err := Lookup(kd.Registry)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	return &engine{
   200  		reg: reg,
   201  		kd:  kd,
   202  	}, nil
   203  }
   204  
   205  // NewEncrypter returns a new encrypter.
   206  // The implementation it returns uses an encrypted HMAC/SHA512 checksum of
   207  // the plaintext to ensure integrity. The format of a block is:
   208  // Initialization Vector (IV)
   209  // encrypted(HMAC(plaintext) + plaintext)
   210  func NewEncrypter(kd KeyDescriptor) (Encrypter, error) {
   211  	return newEngine(kd)
   212  }
   213  
   214  // NewDecrypter returns a new decrypter.
   215  func NewDecrypter(kd KeyDescriptor) (Decrypter, error) {
   216  	return newEngine(kd)
   217  }
   218  
   219  // Decrypt implements Decrypter.
   220  // Decrypt can be used concurrently.
   221  func (e *engine) Decrypt(src, dst []byte) (sum, plaintext []byte, err error) {
   222  	if len(dst) < e.PlaintextSize(src) {
   223  		return nil, nil, fmt.Errorf("dst is too small, size it using PlaintextSize()")
   224  	}
   225  	// Obtain a cipher.Block from the registry using the specified key ID.
   226  	hmacSum, block, err := e.reg.NewBlock(e.kd.ID)
   227  	if err != nil {
   228  		return nil, nil, err
   229  	}
   230  
   231  	iv, buf, err := e.readIV(src)
   232  	if err != nil {
   233  		return nil, nil, err
   234  	}
   235  	stream := cipher.NewCFBDecrypter(block, iv)
   236  
   237  	// Decrypt bytes from ciphertext, size buf to the length of the ciphertext
   238  	// including the checksum. Always make sure there is enough room for the
   239  	// checksum, if the buffer is short then we'll get a checksum error.
   240  	sumSize := e.reg.HMACSize()
   241  	if len(buf) >= sumSize {
   242  		dst = dst[:len(buf)]
   243  	}
   244  	stream.XORKeyStream(dst, buf)
   245  	got := dst[:sumSize]
   246  	hmacSum.Reset()
   247  	hmacSum.Write(dst[sumSize:])
   248  	want := hmacSum.Sum(nil)
   249  	if !hmac.Equal(got[:], want) {
   250  		return nil, nil, fmt.Errorf("mismatched checksums: %v != %v ", got, want)
   251  	}
   252  	sum = dst[:sumSize]
   253  	plaintext = dst[sumSize:]
   254  	return
   255  }