k8s.io/apiserver@v0.31.1/pkg/storage/value/encrypt/aes/aes_test.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package aes
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/aes"
    23  	"crypto/cipher"
    24  	"crypto/rand"
    25  	"encoding/binary"
    26  	"encoding/hex"
    27  	"fmt"
    28  	"io"
    29  	"math"
    30  	"reflect"
    31  	"sync"
    32  	"sync/atomic"
    33  	"testing"
    34  
    35  	"k8s.io/apiserver/pkg/storage/value"
    36  )
    37  
    38  func TestGCMDataStable(t *testing.T) {
    39  	block, err := aes.NewCipher([]byte("0123456789abcdef"))
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	aead, err := cipher.NewGCM(block)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	// IMPORTANT: If you must fix this test, then all previously encrypted data from previously compiled versions is broken unless you hardcode the nonce size to 12
    48  	if aead.NonceSize() != 12 {
    49  		t.Errorf("The underlying Golang crypto size has changed, old version of AES on disk will not be readable unless the AES implementation is changed to hardcode nonce size.")
    50  	}
    51  
    52  	transformerCounterNonce, _, err := NewGCMTransformerWithUniqueKeyUnsafe()
    53  	if err != nil {
    54  		t.Fatal(err)
    55  	}
    56  	if nonceSize := transformerCounterNonce.(*gcm).aead.NonceSize(); nonceSize != 12 {
    57  		t.Errorf("counter nonce: backwards incompatible change to nonce size detected: %d", nonceSize)
    58  	}
    59  
    60  	transformerRandomNonce, err := NewGCMTransformer(block)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	if nonceSize := transformerRandomNonce.(*gcm).aead.NonceSize(); nonceSize != 12 {
    65  		t.Errorf("random nonce: backwards incompatible change to nonce size detected: %d", nonceSize)
    66  	}
    67  }
    68  
    69  func TestGCMUnsafeNonceOverflow(t *testing.T) {
    70  	var msgFatal string
    71  	var count int
    72  
    73  	nonceGen := &nonceGenerator{
    74  		fatal: func(msg string) {
    75  			msgFatal = msg
    76  			count++
    77  		},
    78  	}
    79  
    80  	block, err := aes.NewCipher([]byte("abcdefghijklmnop"))
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	transformer, err := newGCMTransformerWithUniqueKeyUnsafe(block, nonceGen)
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  
    89  	assertNonce(t, &nonceGen.nonce, 0)
    90  
    91  	runEncrypt(t, transformer)
    92  
    93  	assertNonce(t, &nonceGen.nonce, 1)
    94  
    95  	runEncrypt(t, transformer)
    96  
    97  	assertNonce(t, &nonceGen.nonce, 2)
    98  
    99  	nonceGen.nonce.Store(math.MaxUint64 - 1) // pretend lots of encryptions occurred
   100  
   101  	runEncrypt(t, transformer)
   102  
   103  	assertNonce(t, &nonceGen.nonce, math.MaxUint64)
   104  
   105  	if count != 0 {
   106  		t.Errorf("fatal should not have been called yet")
   107  	}
   108  
   109  	runEncrypt(t, transformer)
   110  
   111  	assertNonce(t, &nonceGen.nonce, 0)
   112  
   113  	if count != 1 {
   114  		t.Errorf("fatal should have been once, got %d", count)
   115  	}
   116  
   117  	if msgFatal != "aes-gcm detected nonce overflow - cryptographic wear out has occurred" {
   118  		t.Errorf("unexpected message: %s", msgFatal)
   119  	}
   120  }
   121  
   122  func assertNonce(t *testing.T, nonce *atomic.Uint64, want uint64) {
   123  	t.Helper()
   124  
   125  	if got := nonce.Load(); want != got {
   126  		t.Errorf("nonce should equal %d, got %d", want, got)
   127  	}
   128  }
   129  
   130  func runEncrypt(t *testing.T, transformer value.Transformer) {
   131  	t.Helper()
   132  
   133  	ctx := context.Background()
   134  	dataCtx := value.DefaultContext("authenticated_data")
   135  
   136  	_, err := transformer.TransformToStorage(ctx, []byte("firstvalue"), dataCtx)
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  }
   141  
   142  // TestGCMUnsafeCompatibility asserts that encryptions performed via
   143  // NewGCMTransformerWithUniqueKeyUnsafe can be decrypted via NewGCMTransformer.
   144  func TestGCMUnsafeCompatibility(t *testing.T) {
   145  	transformerEncrypt, key, err := NewGCMTransformerWithUniqueKeyUnsafe()
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  
   150  	block, err := aes.NewCipher(key)
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  
   155  	transformerDecrypt := newGCMTransformer(t, block, nil)
   156  
   157  	ctx := context.Background()
   158  	dataCtx := value.DefaultContext("authenticated_data")
   159  
   160  	plaintext := []byte("firstvalue")
   161  
   162  	ciphertext, err := transformerEncrypt.TransformToStorage(ctx, plaintext, dataCtx)
   163  	if err != nil {
   164  		t.Fatal(err)
   165  	}
   166  
   167  	if bytes.Equal(plaintext, ciphertext) {
   168  		t.Errorf("plaintext %q matches ciphertext %q", string(plaintext), string(ciphertext))
   169  	}
   170  
   171  	plaintextAgain, _, err := transformerDecrypt.TransformFromStorage(ctx, ciphertext, dataCtx)
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  
   176  	if !bytes.Equal(plaintext, plaintextAgain) {
   177  		t.Errorf("expected original plaintext %q, got %q", string(plaintext), string(plaintextAgain))
   178  	}
   179  }
   180  
   181  func TestGCMLegacyDataCompatibility(t *testing.T) {
   182  	block, err := aes.NewCipher([]byte("snorlax_awesomes"))
   183  	if err != nil {
   184  		t.Fatal(err)
   185  	}
   186  
   187  	transformerDecrypt := newGCMTransformer(t, block, nil)
   188  
   189  	// recorded output from NewGCMTransformer at commit 3b1fc60d8010dd8b53e97ba80e4710dbb430beee
   190  	const legacyCiphertext = "\x9f'\xc8\xfc\xea\x8aX\xc4g\xd8\xe47\xdb\xf2\xd8YU\xf9\xb4\xbd\x91/N\xf9g\u05c8\xa0\xcb\ay}\xac\n?\n\bE`\\\xa8Z\xc8V+J\xe1"
   191  
   192  	ctx := context.Background()
   193  	dataCtx := value.DefaultContext("bamboo")
   194  
   195  	plaintext := []byte("pandas are the best")
   196  
   197  	plaintextAgain, _, err := transformerDecrypt.TransformFromStorage(ctx, []byte(legacyCiphertext), dataCtx)
   198  	if err != nil {
   199  		t.Fatal(err)
   200  	}
   201  
   202  	if !bytes.Equal(plaintext, plaintextAgain) {
   203  		t.Errorf("expected original plaintext %q, got %q", string(plaintext), string(plaintextAgain))
   204  	}
   205  }
   206  
   207  func TestExtendedNonceGCMLegacyDataCompatibility(t *testing.T) {
   208  	// recorded output from NewKDFExtendedNonceGCMTransformerWithUniqueSeed from https://github.com/kubernetes/kubernetes/pull/118828
   209  	const (
   210  		legacyKey        = "]@2:\x82\x0f\xf9Uag^;\x95\xe8\x18g\xc5\xfd\xd5a\xd3Z\x88\xa2Ћ\b\xaa\x9dO\xcf\\"
   211  		legacyCiphertext = "$Bu\x9e3\x94_\xba\xd7\t\xdbWz\x0f\x03\x7fا\t\xfcv\x97\x9b\x89B \x9d\xeb\xce˝W\xef\xe3\xd6\xffj\x1e\xf6\xee\x9aP\x03\xb9\x83;0C\xce\xc1\xe4{5\x17[\x15\x11\a\xa8\xd2Ak\x0e)k\xbff\xb5\xd1\x02\xfc\xefߚx\xf2\x93\xd2q"
   212  	)
   213  
   214  	transformerDecrypt := newHKDFExtendedNonceGCMTransformerTest(t, nil, []byte(legacyKey))
   215  
   216  	ctx := context.Background()
   217  	dataCtx := value.DefaultContext("bamboo")
   218  
   219  	plaintext := []byte("pandas are the best")
   220  
   221  	plaintextAgain, _, err := transformerDecrypt.TransformFromStorage(ctx, []byte(legacyCiphertext), dataCtx)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  
   226  	if !bytes.Equal(plaintext, plaintextAgain) {
   227  		t.Errorf("expected original plaintext %q, got %q", string(plaintext), string(plaintextAgain))
   228  	}
   229  }
   230  
   231  func TestGCMUnsafeNonceGen(t *testing.T) {
   232  	block, err := aes.NewCipher([]byte("abcdefghijklmnop"))
   233  	if err != nil {
   234  		t.Fatal(err)
   235  	}
   236  	transformer := newGCMTransformerWithUniqueKeyUnsafeTest(t, block, nil)
   237  
   238  	ctx := context.Background()
   239  	dataCtx := value.DefaultContext("authenticated_data")
   240  
   241  	const count = 1_000
   242  
   243  	counters := make([]uint64, count)
   244  
   245  	// run a bunch of go routines to make sure we are go routine safe
   246  	// on both the nonce generation and the actual encryption/decryption
   247  	var wg sync.WaitGroup
   248  	for i := 0; i < count; i++ {
   249  		i := i
   250  		wg.Add(1)
   251  		go func() {
   252  			defer wg.Done()
   253  
   254  			plaintext := bytes.Repeat([]byte{byte(i % 8)}, count)
   255  
   256  			out, err := transformer.TransformToStorage(ctx, plaintext, dataCtx)
   257  			if err != nil {
   258  				t.Error(err)
   259  				return
   260  			}
   261  
   262  			nonce := out[:12]
   263  			randomN := nonce[:4]
   264  
   265  			if bytes.Equal(randomN, make([]byte, len(randomN))) {
   266  				t.Error("got all zeros for random four byte nonce")
   267  			}
   268  
   269  			counter := nonce[4:]
   270  			counters[binary.LittleEndian.Uint64(counter)-1]++ // subtract one because the counter starts at 1, not 0
   271  
   272  			plaintextAgain, _, err := transformer.TransformFromStorage(ctx, out, dataCtx)
   273  			if err != nil {
   274  				t.Error(err)
   275  				return
   276  			}
   277  
   278  			if !bytes.Equal(plaintext, plaintextAgain) {
   279  				t.Errorf("expected original plaintext %q, got %q", string(plaintext), string(plaintextAgain))
   280  			}
   281  		}()
   282  	}
   283  	wg.Wait()
   284  
   285  	want := make([]uint64, count)
   286  	for i := range want {
   287  		want[i] = 1
   288  	}
   289  
   290  	if !reflect.DeepEqual(want, counters) {
   291  		t.Error("unexpected counter state")
   292  	}
   293  }
   294  
   295  func TestGCMNonce(t *testing.T) {
   296  	t.Run("gcm", func(t *testing.T) {
   297  		testGCMNonce(t, newGCMTransformer, 0, func(_ int, nonce []byte) {
   298  			if bytes.Equal(nonce, make([]byte, len(nonce))) {
   299  				t.Error("got all zeros for nonce")
   300  			}
   301  		})
   302  	})
   303  
   304  	t.Run("gcm unsafe", func(t *testing.T) {
   305  		testGCMNonce(t, newGCMTransformerWithUniqueKeyUnsafeTest, 0, func(i int, nonce []byte) {
   306  			counter := binary.LittleEndian.Uint64(nonce)
   307  			if uint64(i+1) != counter { // add one because the counter starts at 1, not 0
   308  				t.Errorf("counter nonce is invalid: want %d, got %d", i+1, counter)
   309  			}
   310  		})
   311  	})
   312  
   313  	t.Run("gcm extended nonce", func(t *testing.T) {
   314  		testGCMNonce(t, newHKDFExtendedNonceGCMTransformerTest, infoSizeExtendedNonceGCM, func(_ int, nonce []byte) {
   315  			if bytes.Equal(nonce, make([]byte, len(nonce))) {
   316  				t.Error("got all zeros for nonce")
   317  			}
   318  		})
   319  	})
   320  }
   321  
   322  func testGCMNonce(t *testing.T, f transformerFunc, infoLen int, check func(int, []byte)) {
   323  	key := []byte("abcdefghijklmnopabcdefghijklmnop")
   324  	block, err := aes.NewCipher(key)
   325  	if err != nil {
   326  		t.Fatal(err)
   327  	}
   328  	transformer := f(t, block, key)
   329  
   330  	ctx := context.Background()
   331  	dataCtx := value.DefaultContext("authenticated_data")
   332  
   333  	const count = 1_000
   334  
   335  	for i := 0; i < count; i++ {
   336  		i := i
   337  
   338  		out, err := transformer.TransformToStorage(ctx, bytes.Repeat([]byte{byte(i % 8)}, count), dataCtx)
   339  		if err != nil {
   340  			t.Fatal(err)
   341  		}
   342  
   343  		info := out[:infoLen]
   344  		nonce := out[infoLen : 12+infoLen]
   345  		randomN := nonce[:4]
   346  
   347  		if bytes.Equal(randomN, make([]byte, len(randomN))) {
   348  			t.Error("got all zeros for first four bytes")
   349  		}
   350  
   351  		if infoLen != 0 {
   352  			if bytes.Equal(info, make([]byte, infoLen)) {
   353  				t.Error("got all zeros for info")
   354  			}
   355  		}
   356  
   357  		check(i, nonce[4:])
   358  	}
   359  }
   360  
   361  func TestGCMKeyRotation(t *testing.T) {
   362  	t.Run("gcm", func(t *testing.T) {
   363  		testGCMKeyRotation(t, newGCMTransformer)
   364  	})
   365  
   366  	t.Run("gcm unsafe", func(t *testing.T) {
   367  		testGCMKeyRotation(t, newGCMTransformerWithUniqueKeyUnsafeTest)
   368  	})
   369  
   370  	t.Run("gcm extended", func(t *testing.T) {
   371  		testGCMKeyRotation(t, newHKDFExtendedNonceGCMTransformerTest)
   372  	})
   373  }
   374  
   375  func testGCMKeyRotation(t *testing.T, f transformerFunc) {
   376  	key1 := []byte("abcdefghijklmnopabcdefghijklmnop")
   377  	key2 := []byte("0123456789abcdef0123456789abcdef")
   378  
   379  	testErr := fmt.Errorf("test error")
   380  	block1, err := aes.NewCipher(key1)
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  	block2, err := aes.NewCipher(key2)
   385  	if err != nil {
   386  		t.Fatal(err)
   387  	}
   388  
   389  	ctx := context.Background()
   390  	dataCtx := value.DefaultContext("authenticated_data")
   391  
   392  	p := value.NewPrefixTransformers(testErr,
   393  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: f(t, block1, key1)},
   394  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: f(t, block2, key2)},
   395  	)
   396  	out, err := p.TransformToStorage(ctx, []byte("firstvalue"), dataCtx)
   397  	if err != nil {
   398  		t.Fatal(err)
   399  	}
   400  	if !bytes.HasPrefix(out, []byte("first:")) {
   401  		t.Fatalf("unexpected prefix: %q", out)
   402  	}
   403  	from, stale, err := p.TransformFromStorage(ctx, out, dataCtx)
   404  	if err != nil {
   405  		t.Fatal(err)
   406  	}
   407  	if stale || !bytes.Equal([]byte("firstvalue"), from) {
   408  		t.Fatalf("unexpected data: %t %q", stale, from)
   409  	}
   410  
   411  	// verify changing the context fails storage
   412  	_, _, err = p.TransformFromStorage(ctx, out, value.DefaultContext("incorrect_context"))
   413  	if err == nil {
   414  		t.Fatalf("expected unauthenticated data")
   415  	}
   416  
   417  	// reverse the order, use the second key
   418  	p = value.NewPrefixTransformers(testErr,
   419  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: f(t, block2, key2)},
   420  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: f(t, block1, key1)},
   421  	)
   422  	from, stale, err = p.TransformFromStorage(ctx, out, dataCtx)
   423  	if err != nil {
   424  		t.Fatal(err)
   425  	}
   426  	if !stale || !bytes.Equal([]byte("firstvalue"), from) {
   427  		t.Fatalf("unexpected data: %t %q", stale, from)
   428  	}
   429  }
   430  
   431  func TestCBCKeyRotation(t *testing.T) {
   432  	testErr := fmt.Errorf("test error")
   433  	block1, err := aes.NewCipher([]byte("abcdefghijklmnop"))
   434  	if err != nil {
   435  		t.Fatal(err)
   436  	}
   437  	block2, err := aes.NewCipher([]byte("0123456789abcdef"))
   438  	if err != nil {
   439  		t.Fatal(err)
   440  	}
   441  
   442  	ctx := context.Background()
   443  	dataCtx := value.DefaultContext("authenticated_data")
   444  
   445  	p := value.NewPrefixTransformers(testErr,
   446  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
   447  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
   448  	)
   449  	out, err := p.TransformToStorage(ctx, []byte("firstvalue"), dataCtx)
   450  	if err != nil {
   451  		t.Fatal(err)
   452  	}
   453  	if !bytes.HasPrefix(out, []byte("first:")) {
   454  		t.Fatalf("unexpected prefix: %q", out)
   455  	}
   456  	from, stale, err := p.TransformFromStorage(ctx, out, dataCtx)
   457  	if err != nil {
   458  		t.Fatal(err)
   459  	}
   460  	if stale || !bytes.Equal([]byte("firstvalue"), from) {
   461  		t.Fatalf("unexpected data: %t %q", stale, from)
   462  	}
   463  
   464  	// verify changing the context fails storage
   465  	_, _, err = p.TransformFromStorage(ctx, out, value.DefaultContext("incorrect_context"))
   466  	if err != nil {
   467  		t.Fatalf("CBC mode does not support authentication: %v", err)
   468  	}
   469  
   470  	// reverse the order, use the second key
   471  	p = value.NewPrefixTransformers(testErr,
   472  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
   473  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
   474  	)
   475  	from, stale, err = p.TransformFromStorage(ctx, out, dataCtx)
   476  	if err != nil {
   477  		t.Fatal(err)
   478  	}
   479  	if !stale || !bytes.Equal([]byte("firstvalue"), from) {
   480  		t.Fatalf("unexpected data: %t %q", stale, from)
   481  	}
   482  }
   483  
   484  var gcmBenchmarks = []namedTransformerFunc{
   485  	{name: "gcm-random-nonce", f: newGCMTransformer},
   486  	{name: "gcm-counter-nonce", f: newGCMTransformerWithUniqueKeyUnsafeTest},
   487  	{name: "gcm-extended-nonce", f: newHKDFExtendedNonceGCMTransformerTest},
   488  }
   489  
   490  func BenchmarkGCMRead(b *testing.B) {
   491  	tests := []struct {
   492  		keyLength   int
   493  		valueLength int
   494  		expectStale bool
   495  	}{
   496  		{keyLength: 16, valueLength: 1024, expectStale: false},
   497  		{keyLength: 32, valueLength: 1024, expectStale: false},
   498  		{keyLength: 32, valueLength: 16384, expectStale: false},
   499  		{keyLength: 32, valueLength: 16384, expectStale: true},
   500  	}
   501  	for _, t := range tests {
   502  		name := fmt.Sprintf("%vKeyLength/%vValueLength/%vExpectStale", t.keyLength, t.valueLength, t.expectStale)
   503  		b.Run(name, func(b *testing.B) {
   504  			for _, n := range gcmBenchmarks {
   505  				n := n
   506  				if t.keyLength == 16 && n.name == "gcm-extended-nonce" {
   507  					continue // gcm-extended-nonce requires 32 byte keys
   508  				}
   509  				b.Run(n.name, func(b *testing.B) {
   510  					b.ReportAllocs()
   511  					benchmarkGCMRead(b, n.f, t.keyLength, t.valueLength, t.expectStale)
   512  				})
   513  			}
   514  		})
   515  	}
   516  }
   517  
   518  func BenchmarkGCMWrite(b *testing.B) {
   519  	tests := []struct {
   520  		keyLength   int
   521  		valueLength int
   522  	}{
   523  		{keyLength: 16, valueLength: 1024},
   524  		{keyLength: 32, valueLength: 1024},
   525  		{keyLength: 32, valueLength: 16384},
   526  	}
   527  	for _, t := range tests {
   528  		name := fmt.Sprintf("%vKeyLength/%vValueLength", t.keyLength, t.valueLength)
   529  		b.Run(name, func(b *testing.B) {
   530  			for _, n := range gcmBenchmarks {
   531  				n := n
   532  				if t.keyLength == 16 && n.name == "gcm-extended-nonce" {
   533  					continue // gcm-extended-nonce requires 32 byte keys
   534  				}
   535  				b.Run(n.name, func(b *testing.B) {
   536  					b.ReportAllocs()
   537  					benchmarkGCMWrite(b, n.f, t.keyLength, t.valueLength)
   538  				})
   539  			}
   540  		})
   541  	}
   542  }
   543  
   544  func benchmarkGCMRead(b *testing.B, f transformerFunc, keyLength int, valueLength int, expectStale bool) {
   545  	key1 := bytes.Repeat([]byte("a"), keyLength)
   546  	key2 := bytes.Repeat([]byte("b"), keyLength)
   547  
   548  	block1, err := aes.NewCipher(key1)
   549  	if err != nil {
   550  		b.Fatal(err)
   551  	}
   552  	block2, err := aes.NewCipher(key2)
   553  	if err != nil {
   554  		b.Fatal(err)
   555  	}
   556  	p := value.NewPrefixTransformers(nil,
   557  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: f(b, block1, key1)},
   558  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: f(b, block2, key2)},
   559  	)
   560  
   561  	ctx := context.Background()
   562  	dataCtx := value.DefaultContext("authenticated_data")
   563  	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
   564  
   565  	out, err := p.TransformToStorage(ctx, v, dataCtx)
   566  	if err != nil {
   567  		b.Fatal(err)
   568  	}
   569  	// reverse the key order if expecting stale
   570  	if expectStale {
   571  		p = value.NewPrefixTransformers(nil,
   572  			value.PrefixTransformer{Prefix: []byte("second:"), Transformer: f(b, block2, key2)},
   573  			value.PrefixTransformer{Prefix: []byte("first:"), Transformer: f(b, block1, key1)},
   574  		)
   575  	}
   576  
   577  	b.ResetTimer()
   578  	for i := 0; i < b.N; i++ {
   579  		from, stale, err := p.TransformFromStorage(ctx, out, dataCtx)
   580  		if err != nil {
   581  			b.Fatal(err)
   582  		}
   583  		if expectStale != stale {
   584  			b.Fatalf("unexpected data: %q, expect stale %t but got %t", from, expectStale, stale)
   585  		}
   586  	}
   587  	b.StopTimer()
   588  }
   589  
   590  func benchmarkGCMWrite(b *testing.B, f transformerFunc, keyLength int, valueLength int) {
   591  	key1 := bytes.Repeat([]byte("a"), keyLength)
   592  	key2 := bytes.Repeat([]byte("b"), keyLength)
   593  
   594  	block1, err := aes.NewCipher(key1)
   595  	if err != nil {
   596  		b.Fatal(err)
   597  	}
   598  	block2, err := aes.NewCipher(key2)
   599  	if err != nil {
   600  		b.Fatal(err)
   601  	}
   602  	p := value.NewPrefixTransformers(nil,
   603  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: f(b, block1, key1)},
   604  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: f(b, block2, key2)},
   605  	)
   606  
   607  	ctx := context.Background()
   608  	dataCtx := value.DefaultContext("authenticated_data")
   609  	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
   610  
   611  	b.ResetTimer()
   612  	for i := 0; i < b.N; i++ {
   613  		_, err := p.TransformToStorage(ctx, v, dataCtx)
   614  		if err != nil {
   615  			b.Fatal(err)
   616  		}
   617  	}
   618  	b.StopTimer()
   619  }
   620  
   621  func BenchmarkCBCRead(b *testing.B) {
   622  	tests := []struct {
   623  		keyLength   int
   624  		valueLength int
   625  		expectStale bool
   626  	}{
   627  		{keyLength: 32, valueLength: 1024, expectStale: false},
   628  		{keyLength: 32, valueLength: 16384, expectStale: false},
   629  		{keyLength: 32, valueLength: 16384, expectStale: true},
   630  	}
   631  	for _, t := range tests {
   632  		name := fmt.Sprintf("%vKeyLength/%vValueLength/%vExpectStale", t.keyLength, t.valueLength, t.expectStale)
   633  		b.Run(name, func(b *testing.B) {
   634  			benchmarkCBCRead(b, t.keyLength, t.valueLength, t.expectStale)
   635  		})
   636  	}
   637  }
   638  
   639  func BenchmarkCBCWrite(b *testing.B) {
   640  	tests := []struct {
   641  		keyLength   int
   642  		valueLength int
   643  	}{
   644  		{keyLength: 32, valueLength: 1024},
   645  		{keyLength: 32, valueLength: 16384},
   646  	}
   647  	for _, t := range tests {
   648  		name := fmt.Sprintf("%vKeyLength/%vValueLength", t.keyLength, t.valueLength)
   649  		b.Run(name, func(b *testing.B) {
   650  			benchmarkCBCWrite(b, t.keyLength, t.valueLength)
   651  		})
   652  	}
   653  }
   654  
   655  func benchmarkCBCRead(b *testing.B, keyLength int, valueLength int, expectStale bool) {
   656  	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
   657  	if err != nil {
   658  		b.Fatal(err)
   659  	}
   660  	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
   661  	if err != nil {
   662  		b.Fatal(err)
   663  	}
   664  	p := value.NewPrefixTransformers(nil,
   665  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
   666  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
   667  	)
   668  
   669  	ctx := context.Background()
   670  	dataCtx := value.DefaultContext("authenticated_data")
   671  	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
   672  
   673  	out, err := p.TransformToStorage(ctx, v, dataCtx)
   674  	if err != nil {
   675  		b.Fatal(err)
   676  	}
   677  	// reverse the key order if expecting stale
   678  	if expectStale {
   679  		p = value.NewPrefixTransformers(nil,
   680  			value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
   681  			value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
   682  		)
   683  	}
   684  
   685  	b.ResetTimer()
   686  	for i := 0; i < b.N; i++ {
   687  		from, stale, err := p.TransformFromStorage(ctx, out, dataCtx)
   688  		if err != nil {
   689  			b.Fatal(err)
   690  		}
   691  		if expectStale != stale {
   692  			b.Fatalf("unexpected data: %q, expect stale %t but got %t", from, expectStale, stale)
   693  		}
   694  	}
   695  	b.StopTimer()
   696  }
   697  
   698  func benchmarkCBCWrite(b *testing.B, keyLength int, valueLength int) {
   699  	block1, err := aes.NewCipher(bytes.Repeat([]byte("a"), keyLength))
   700  	if err != nil {
   701  		b.Fatal(err)
   702  	}
   703  	block2, err := aes.NewCipher(bytes.Repeat([]byte("b"), keyLength))
   704  	if err != nil {
   705  		b.Fatal(err)
   706  	}
   707  	p := value.NewPrefixTransformers(nil,
   708  		value.PrefixTransformer{Prefix: []byte("first:"), Transformer: NewCBCTransformer(block1)},
   709  		value.PrefixTransformer{Prefix: []byte("second:"), Transformer: NewCBCTransformer(block2)},
   710  	)
   711  
   712  	ctx := context.Background()
   713  	dataCtx := value.DefaultContext("authenticated_data")
   714  	v := bytes.Repeat([]byte("0123456789abcdef"), valueLength/16)
   715  
   716  	b.ResetTimer()
   717  	for i := 0; i < b.N; i++ {
   718  		_, err := p.TransformToStorage(ctx, v, dataCtx)
   719  		if err != nil {
   720  			b.Fatal(err)
   721  		}
   722  	}
   723  	b.StopTimer()
   724  }
   725  
   726  func TestRoundTrip(t *testing.T) {
   727  	lengths := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 128, 1024}
   728  
   729  	aes16block, err := aes.NewCipher(bytes.Repeat([]byte("a"), 16))
   730  	if err != nil {
   731  		t.Fatal(err)
   732  	}
   733  	aes24block, err := aes.NewCipher(bytes.Repeat([]byte("b"), 24))
   734  	if err != nil {
   735  		t.Fatal(err)
   736  	}
   737  	key32 := bytes.Repeat([]byte("c"), 32)
   738  	aes32block, err := aes.NewCipher(key32)
   739  	if err != nil {
   740  		t.Fatal(err)
   741  	}
   742  
   743  	ctx := context.Background()
   744  	tests := []struct {
   745  		name string
   746  		t    value.Transformer
   747  	}{
   748  		{name: "GCM 16 byte key", t: newGCMTransformer(t, aes16block, nil)},
   749  		{name: "GCM 24 byte key", t: newGCMTransformer(t, aes24block, nil)},
   750  		{name: "GCM 32 byte key", t: newGCMTransformer(t, aes32block, nil)},
   751  		{name: "GCM 16 byte unsafe key", t: newGCMTransformerWithUniqueKeyUnsafeTest(t, aes16block, nil)},
   752  		{name: "GCM 24 byte unsafe key", t: newGCMTransformerWithUniqueKeyUnsafeTest(t, aes24block, nil)},
   753  		{name: "GCM 32 byte unsafe key", t: newGCMTransformerWithUniqueKeyUnsafeTest(t, aes32block, nil)},
   754  		{name: "GCM 32 byte seed", t: newHKDFExtendedNonceGCMTransformerTest(t, nil, key32)},
   755  		{name: "CBC 32 byte key", t: NewCBCTransformer(aes32block)},
   756  	}
   757  	for _, tt := range tests {
   758  		t.Run(tt.name, func(t *testing.T) {
   759  			dataCtx := value.DefaultContext("/foo/bar")
   760  			for _, l := range lengths {
   761  				data := make([]byte, l)
   762  				if _, err := io.ReadFull(rand.Reader, data); err != nil {
   763  					t.Fatalf("unable to read sufficient random bytes: %v", err)
   764  				}
   765  				original := append([]byte{}, data...)
   766  
   767  				ciphertext, err := tt.t.TransformToStorage(ctx, data, dataCtx)
   768  				if err != nil {
   769  					t.Errorf("TransformToStorage error = %v", err)
   770  					continue
   771  				}
   772  
   773  				result, stale, err := tt.t.TransformFromStorage(ctx, ciphertext, dataCtx)
   774  				if err != nil {
   775  					t.Errorf("TransformFromStorage error = %v", err)
   776  					continue
   777  				}
   778  				if stale {
   779  					t.Errorf("unexpected stale output")
   780  					continue
   781  				}
   782  
   783  				switch {
   784  				case l == 0:
   785  					if len(result) != 0 {
   786  						t.Errorf("Round trip failed len=%d\noriginal:\n%s\nresult:\n%s", l, hex.Dump(original), hex.Dump(result))
   787  					}
   788  				case !reflect.DeepEqual(original, result):
   789  					t.Errorf("Round trip failed len=%d\noriginal:\n%s\nresult:\n%s", l, hex.Dump(original), hex.Dump(result))
   790  				}
   791  			}
   792  		})
   793  	}
   794  }
   795  
   796  type namedTransformerFunc struct {
   797  	name string
   798  	f    transformerFunc
   799  }
   800  
   801  type transformerFunc func(t testing.TB, block cipher.Block, key []byte) value.Transformer
   802  
   803  func newGCMTransformer(t testing.TB, block cipher.Block, _ []byte) value.Transformer {
   804  	t.Helper()
   805  
   806  	transformer, err := NewGCMTransformer(block)
   807  	if err != nil {
   808  		t.Fatal(err)
   809  	}
   810  
   811  	return transformer
   812  }
   813  
   814  func newGCMTransformerWithUniqueKeyUnsafeTest(t testing.TB, block cipher.Block, _ []byte) value.Transformer {
   815  	t.Helper()
   816  
   817  	nonceGen := &nonceGenerator{fatal: die}
   818  	transformer, err := newGCMTransformerWithUniqueKeyUnsafe(block, nonceGen)
   819  	if err != nil {
   820  		t.Fatal(err)
   821  	}
   822  
   823  	return transformer
   824  }
   825  
   826  func newHKDFExtendedNonceGCMTransformerTest(t testing.TB, _ cipher.Block, key []byte) value.Transformer {
   827  	t.Helper()
   828  
   829  	transformer, err := NewHKDFExtendedNonceGCMTransformer(key)
   830  	if err != nil {
   831  		t.Fatal(err)
   832  	}
   833  
   834  	return transformer
   835  }