github.com/grailbio/base@v0.0.11/crypto/encryption/encryption_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package encryption_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/hmac"
    10  	"crypto/rand"
    11  	"crypto/sha512"
    12  	"encoding/json"
    13  	"fmt"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/grailbio/base/crypto/encryption"
    19  	"github.com/grailbio/testutil/assert"
    20  	"github.com/grailbio/testutil/encryptiontest"
    21  	"github.com/grailbio/testutil/expect"
    22  )
    23  
    24  type randError struct{}
    25  
    26  func (r *randError) Read(p []byte) (int, error) {
    27  	return 0, fmt.Errorf("rand failures")
    28  }
    29  
    30  type shortRandError struct{}
    31  
    32  func (r *shortRandError) Read(p []byte) (int, error) {
    33  	return 10, nil
    34  }
    35  
    36  func TestJSON(t *testing.T) {
    37  	out, _ := json.Marshal(&encryption.KeyDescriptor{})
    38  	if got, want := string(out), `{"registry":"","keyid":""}`; got != want {
    39  		t.Errorf("got %v, want %v", got, want)
    40  	}
    41  	out, _ = json.Marshal(&encryption.KeyDescriptor{Registry: "x", ID: encryptiontest.TestID})
    42  	if got, want := string(out), `{"registry":"x","keyid":"30313233343536373839616263646566"}`; got != want {
    43  		t.Errorf("got %v, want %v", got, want)
    44  	}
    45  
    46  	kd := encryption.KeyDescriptor{}
    47  	json.Unmarshal([]byte(`{"keyid":""}`), &kd)
    48  	ekd := encryption.KeyDescriptor{ID: []byte{}}
    49  	if got, want := kd, ekd; !reflect.DeepEqual(got, want) {
    50  		t.Errorf("got %v, want %v", got, want)
    51  	}
    52  	kd = encryption.KeyDescriptor{}
    53  	json.Unmarshal([]byte(`{"keyid":"ffee"}`), &kd)
    54  	ekd = encryption.KeyDescriptor{ID: []byte{0xff, 0xee}}
    55  	if got, want := kd, ekd; !reflect.DeepEqual(got, want) {
    56  		t.Errorf("got %v, want %v", got, want)
    57  	}
    58  	err := json.Unmarshal([]byte(`{"keyid": {} }`), &kd)
    59  	if err == nil || !strings.Contains(err.Error(), "not quoted") {
    60  		t.Errorf("missing or wrong error: %v", err)
    61  	}
    62  }
    63  
    64  func TestErrors(t *testing.T) {
    65  	reg := encryptiontest.NewFakeAESRegistry()
    66  	if err := encryption.Register("aesTE", reg); err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	// Multiple registrations of the same registry.
    70  	reg = encryptiontest.NewFakeAESRegistry()
    71  	if err := encryption.Register("aesTE", reg); err == nil {
    72  		t.Errorf("expected an error")
    73  	}
    74  	// Missing registry.
    75  	kd := encryption.KeyDescriptor{
    76  		"aesxxx",
    77  		[]byte("any-old-id"),
    78  		nil}
    79  	_, err := encryption.NewEncrypter(kd)
    80  	expect.HasSubstr(t, err, "no such registry")
    81  	_, err = encryption.NewDecrypter(kd)
    82  	expect.HasSubstr(t, err, "no such registry")
    83  
    84  	// Buffer too small.
    85  	kd.Registry = "aesTE"
    86  	enc, _ := encryption.NewEncrypter(kd)
    87  	dec, _ := encryption.NewDecrypter(kd)
    88  
    89  	err = enc.Encrypt([]byte("something"), []byte{0x00})
    90  	expect.HasSubstr(t, err, "too small")
    91  	err = enc.EncryptSlices([]byte{0x00}, []byte("anything"))
    92  	expect.HasSubstr(t, err, "too small")
    93  	_, _, err = dec.Decrypt([]byte("anything"), []byte{0x00})
    94  	expect.HasSubstr(t, err, "too small")
    95  
    96  	// Generate key error.
    97  	reg = encryptiontest.NewFakeAESRegistry()
    98  	reg.Key = encryptiontest.FailGenKey
    99  	_, err = reg.GenerateKey()
   100  	expect.HasSubstr(t, err, "generate-key-failed")
   101  
   102  	// NewBlock failure.
   103  	orig := []byte("some-errors")
   104  	src := make([]byte, enc.CiphertextSize(orig))
   105  	enc, _ = encryption.NewEncrypter(encryption.KeyDescriptor{
   106  		Registry: "aesTE", ID: encryptiontest.BadID,
   107  	})
   108  	err = enc.Encrypt(orig, src[:])
   109  	expect.HasSubstr(t, err, "new-block-failed")
   110  
   111  	err = enc.EncryptSlices(src[:], orig)
   112  	expect.HasSubstr(t, err, "new-block-failed")
   113  
   114  	// Failure to generate IV.
   115  	encryption.SetRandSource(&randError{})
   116  	err = enc.Encrypt(orig, src[:])
   117  	expect.HasSubstr(t, err, "failed to read 16 bytes of random data")
   118  
   119  	encryption.SetRandSource(&shortRandError{})
   120  	err = enc.Encrypt(orig, src[:])
   121  	expect.HasSubstr(t, err, "failed to generate complete iv")
   122  	encryption.SetRandSource(rand.Reader)
   123  
   124  	enc, _ = encryption.NewEncrypter(encryption.KeyDescriptor{
   125  		Registry: "aesTE", ID: encryptiontest.TestID,
   126  	})
   127  	if err = enc.Encrypt(orig, src[:]); err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	dec, _ = encryption.NewDecrypter(encryption.KeyDescriptor{
   131  		Registry: "aesTE", ID: encryptiontest.BadID,
   132  	})
   133  	dst := make([]byte, dec.PlaintextSize(src))
   134  	_, _, err = dec.Decrypt(src[:], dst[:])
   135  	expect.HasSubstr(t, err, "new-block-failed")
   136  
   137  	dec, _ = encryption.NewDecrypter(encryption.KeyDescriptor{
   138  		Registry: "aesTE", ID: encryptiontest.TestID,
   139  	})
   140  
   141  	// short IV
   142  	_, _, err = dec.Decrypt(src[:10], dst[:])
   143  	expect.HasSubstr(t, err, "failed to read IV")
   144  
   145  	// short Buffer
   146  	_, _, err = dec.Decrypt(src[:20], dst[:])
   147  	expect.HasSubstr(t, err, "mismatched checksums")
   148  
   149  	// corrupt the checksum
   150  	src[20] = src[20] + 1
   151  	_, _, err = dec.Decrypt(src[:], dst[:])
   152  	expect.HasSubstr(t, err, "mismatched checksums")
   153  }
   154  
   155  var keyDesc encryption.KeyDescriptor
   156  var aesKey []byte
   157  
   158  func init() {
   159  	aesReg := encryptiontest.NewFakeAESRegistry()
   160  	if err := encryption.Register("aes", aesReg); err != nil {
   161  		panic(err)
   162  	}
   163  	reg, err := encryption.Lookup("aes")
   164  	if err != nil {
   165  		panic(err)
   166  	}
   167  	id, err := reg.GenerateKey()
   168  	if err != nil {
   169  		panic(err)
   170  	}
   171  
   172  	keyDesc = encryption.KeyDescriptor{"aes", id, nil}
   173  	aesKey = aesReg.Key
   174  }
   175  
   176  func TestEncryption(t *testing.T) {
   177  	enc, err := encryption.NewEncrypter(keyDesc)
   178  	assert.NoError(t, err)
   179  	dec, err := encryption.NewDecrypter(keyDesc)
   180  	assert.NoError(t, err)
   181  
   182  	for _, tc := range []string{
   183  		"",
   184  		"me",
   185  		"oh hello world",
   186  		"oh hello world and something a little longer, really we should test with more data",
   187  	} {
   188  		orig := []byte(tc)
   189  		ctext := make([]byte, enc.CiphertextSize(orig))
   190  		err = enc.Encrypt(orig, ctext)
   191  		assert.NoError(t, err)
   192  
   193  		dst := make([]byte, dec.PlaintextSize(ctext))
   194  		sum, ptext, err := dec.Decrypt(ctext, dst)
   195  		assert.NoError(t, err)
   196  
   197  		if got, want := ptext, orig; !bytes.Equal(got, want) {
   198  			t.Fatalf("%v: got %v, want %v", orig, got, want)
   199  		}
   200  		hm := hmac.New(sha512.New, aesKey)
   201  		hm.Write(orig)
   202  		if got, want := hm.Sum(nil), sum; !hmac.Equal(got[:], want) {
   203  			t.Fatalf("%v: got %v, want %v", orig, got, want)
   204  		}
   205  	}
   206  
   207  	data := [][]byte{
   208  		[]byte(""),
   209  		[]byte("me"),
   210  		[]byte("oh hello world"),
   211  		[]byte("oh hello world and something a little longer, really we should test with more data"),
   212  	}
   213  	orig := bytes.Join(data, nil)
   214  	ctext := make([]byte, enc.CiphertextSizeSlices(data...))
   215  	err = enc.EncryptSlices(ctext, data...)
   216  	assert.NoError(t, err)
   217  
   218  	dst := make([]byte, dec.PlaintextSize(ctext))
   219  	sum, ptext, err := dec.Decrypt(ctext, dst)
   220  	assert.NoError(t, err)
   221  
   222  	if got, want := ptext, orig; !bytes.Equal(got, want) {
   223  		t.Fatalf("%v: got %v, want %v", orig, got, want)
   224  	}
   225  
   226  	hm := hmac.New(sha512.New, aesKey)
   227  	hm.Write(orig)
   228  	if got, want := hm.Sum(nil), sum; !hmac.Equal(got[:], want) {
   229  		t.Fatalf("%v: got %v, want %v", orig, got, want)
   230  	}
   231  
   232  }
   233  
   234  func TestRandomness(t *testing.T) {
   235  	encryptiontest.RunAtSignificanceLevel(encryptiontest.OnePercent,
   236  		func(s encryptiontest.Significance) bool {
   237  			ptext := make([]byte, 10000)
   238  			enc, err := encryption.NewEncrypter(keyDesc)
   239  			if err != nil {
   240  				return false
   241  			}
   242  			ctext := make([]byte, enc.CiphertextSize(ptext))
   243  			if err := enc.Encrypt(ptext, ctext); err != nil {
   244  				return false
   245  			}
   246  			return encryptiontest.IsRandom(ctext, s)
   247  		})
   248  }