github.com/hashicorp/vault/sdk@v0.13.0/helper/keysutil/policy_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/ecdsa"
    10  	"crypto/elliptic"
    11  	"crypto/rand"
    12  	"crypto/rsa"
    13  	"crypto/x509"
    14  	"errors"
    15  	"fmt"
    16  	mathrand "math/rand"
    17  	"reflect"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/hashicorp/vault/sdk/helper/errutil"
    25  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    26  	"github.com/hashicorp/vault/sdk/logical"
    27  	"github.com/mitchellh/copystructure"
    28  	"golang.org/x/crypto/ed25519"
    29  )
    30  
    31  // Ordering of these items needs to match the iota order defined in policy.go. Ordering changes
    32  // should never occur, as it would lead to a key type change within existing stored policies.
    33  var allTestKeyTypes = []KeyType{
    34  	KeyType_AES256_GCM96, KeyType_ECDSA_P256, KeyType_ED25519, KeyType_RSA2048,
    35  	KeyType_RSA4096, KeyType_ChaCha20_Poly1305, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_AES128_GCM96,
    36  	KeyType_RSA3072, KeyType_MANAGED_KEY, KeyType_HMAC, KeyType_AES128_CMAC, KeyType_AES256_CMAC,
    37  }
    38  
    39  func TestPolicy_KeyTypes(t *testing.T) {
    40  	// Make sure the iota value never change for key types, as existing storage would be affected
    41  	for i, keyType := range allTestKeyTypes {
    42  		if int(keyType) != i {
    43  			t.Fatalf("iota of keytype %s changed, expected %d got %d", keyType.String(), i, keyType)
    44  		}
    45  	}
    46  
    47  	// Make sure we have a string presentation for all types
    48  	for _, keyType := range allTestKeyTypes {
    49  		if strings.Contains(keyType.String(), "unknown") {
    50  			t.Fatalf("keytype with iota of %d should not contain 'unknown', missing in String() switch statement", keyType)
    51  		}
    52  	}
    53  }
    54  
    55  func TestPolicy_HmacCmacSuported(t *testing.T) {
    56  	// Test HMAC supported feature
    57  	for _, keyType := range allTestKeyTypes {
    58  		switch keyType {
    59  		case KeyType_MANAGED_KEY:
    60  			if keyType.HMACSupported() {
    61  				t.Fatalf("hmac should not have been not be supported for keytype %s", keyType.String())
    62  			}
    63  			if keyType.CMACSupported() {
    64  				t.Fatalf("cmac should not have been be supported for keytype %s", keyType.String())
    65  			}
    66  		case KeyType_AES128_CMAC, KeyType_AES256_CMAC:
    67  			if keyType.HMACSupported() {
    68  				t.Fatalf("hmac should have been not be supported for keytype %s", keyType.String())
    69  			}
    70  			if !keyType.CMACSupported() {
    71  				t.Fatalf("cmac should have been be supported for keytype %s", keyType.String())
    72  			}
    73  		default:
    74  			if !keyType.HMACSupported() {
    75  				t.Fatalf("hmac should have been supported for keytype %s", keyType.String())
    76  			}
    77  			if keyType.CMACSupported() {
    78  				t.Fatalf("cmac should not have been supported for keytype %s", keyType.String())
    79  			}
    80  		}
    81  	}
    82  }
    83  
    84  func TestPolicy_CMACKeyUpgrade(t *testing.T) {
    85  	ctx := context.Background()
    86  	lm, _ := NewLockManager(false, 0)
    87  	storage := &logical.InmemStorage{}
    88  	p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{
    89  		Upsert:  true,
    90  		Storage: storage,
    91  		KeyType: KeyType_AES256_CMAC,
    92  		Name:    "test",
    93  	}, rand.Reader)
    94  	if err != nil {
    95  		t.Fatalf("failed loading policy: %v", err)
    96  	}
    97  	if p == nil {
    98  		t.Fatal("nil policy")
    99  	}
   100  	if !upserted {
   101  		t.Fatal("expected an upsert")
   102  	}
   103  
   104  	// This verifies we don't have a hmac key
   105  	_, err = p.HMACKey(1)
   106  	if err == nil {
   107  		t.Fatal("cmac key should not return an hmac key but did on initial creation")
   108  	}
   109  
   110  	if p.NeedsUpgrade() {
   111  		t.Fatal("cmac key should not require an upgrade after initial key creation")
   112  	}
   113  
   114  	err = p.Upgrade(ctx, storage, rand.Reader)
   115  	if err != nil {
   116  		t.Fatalf("an error was returned from upgrade method: %v", err)
   117  	}
   118  	p.Unlock()
   119  
   120  	// Now reload our policy from disk and make sure we still don't have a hmac key
   121  	p, upserted, err = lm.GetPolicy(ctx, PolicyRequest{
   122  		Upsert:  true,
   123  		Storage: storage,
   124  		KeyType: KeyType_AES256_CMAC,
   125  		Name:    "test",
   126  	}, rand.Reader)
   127  	if err != nil {
   128  		t.Fatalf("failed loading policy: %v", err)
   129  	}
   130  	if p == nil {
   131  		t.Fatal("nil policy")
   132  	}
   133  	if upserted {
   134  		t.Fatal("expected the key to exist but upserted was true")
   135  	}
   136  
   137  	p.Unlock()
   138  
   139  	_, err = p.HMACKey(1)
   140  	if err == nil {
   141  		t.Fatal("cmac key should not return an hmac key post upgrade")
   142  	}
   143  }
   144  
   145  func TestPolicy_KeyEntryMapUpgrade(t *testing.T) {
   146  	now := time.Now()
   147  	old := map[int]KeyEntry{
   148  		1: {
   149  			Key:                []byte("samplekey"),
   150  			HMACKey:            []byte("samplehmackey"),
   151  			CreationTime:       now,
   152  			FormattedPublicKey: "sampleformattedpublickey",
   153  		},
   154  		2: {
   155  			Key:                []byte("samplekey2"),
   156  			HMACKey:            []byte("samplehmackey2"),
   157  			CreationTime:       now.Add(10 * time.Second),
   158  			FormattedPublicKey: "sampleformattedpublickey2",
   159  		},
   160  	}
   161  
   162  	oldEncoded, err := jsonutil.EncodeJSON(old)
   163  	if err != nil {
   164  		t.Fatal(err)
   165  	}
   166  
   167  	var new keyEntryMap
   168  	err = jsonutil.DecodeJSON(oldEncoded, &new)
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  
   173  	newEncoded, err := jsonutil.EncodeJSON(&new)
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  
   178  	if string(oldEncoded) != string(newEncoded) {
   179  		t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded))
   180  	}
   181  }
   182  
   183  func Test_KeyUpgrade(t *testing.T) {
   184  	lockManagerWithCache, _ := NewLockManager(true, 0)
   185  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
   186  	testKeyUpgradeCommon(t, lockManagerWithCache)
   187  	testKeyUpgradeCommon(t, lockManagerWithoutCache)
   188  }
   189  
   190  func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
   191  	ctx := context.Background()
   192  
   193  	storage := &logical.InmemStorage{}
   194  	p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{
   195  		Upsert:  true,
   196  		Storage: storage,
   197  		KeyType: KeyType_AES256_GCM96,
   198  		Name:    "test",
   199  	}, rand.Reader)
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	if p == nil {
   204  		t.Fatal("nil policy")
   205  	}
   206  	if !upserted {
   207  		t.Fatal("expected an upsert")
   208  	}
   209  	if !lm.useCache {
   210  		p.Unlock()
   211  	}
   212  
   213  	testBytes := make([]byte, len(p.Keys["1"].Key))
   214  	copy(testBytes, p.Keys["1"].Key)
   215  
   216  	p.Key = p.Keys["1"].Key
   217  	p.Keys = nil
   218  	p.MigrateKeyToKeysMap()
   219  	if p.Key != nil {
   220  		t.Fatal("policy.Key is not nil")
   221  	}
   222  	if len(p.Keys) != 1 {
   223  		t.Fatal("policy.Keys is the wrong size")
   224  	}
   225  	if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) {
   226  		t.Fatal("key mismatch")
   227  	}
   228  }
   229  
   230  func Test_ArchivingUpgrade(t *testing.T) {
   231  	lockManagerWithCache, _ := NewLockManager(true, 0)
   232  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
   233  	testArchivingUpgradeCommon(t, lockManagerWithCache)
   234  	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
   235  }
   236  
   237  func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
   238  	ctx := context.Background()
   239  
   240  	// First, we generate a policy and rotate it a number of times. Each time
   241  	// we'll ensure that we have the expected number of keys in the archive and
   242  	// the main keys object, which without changing the min version should be
   243  	// zero and latest, respectively
   244  
   245  	storage := &logical.InmemStorage{}
   246  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   247  		Upsert:  true,
   248  		Storage: storage,
   249  		KeyType: KeyType_AES256_GCM96,
   250  		Name:    "test",
   251  	}, rand.Reader)
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	if p == nil {
   256  		t.Fatal("nil policy")
   257  	}
   258  	if !lm.useCache {
   259  		p.Unlock()
   260  	}
   261  
   262  	// Store the initial key in the archive
   263  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   264  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   265  
   266  	for i := 2; i <= 10; i++ {
   267  		err = p.Rotate(ctx, storage, rand.Reader)
   268  		if err != nil {
   269  			t.Fatal(err)
   270  		}
   271  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   272  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   273  	}
   274  
   275  	// Now, wipe the archive and set the archive version to zero
   276  	err = storage.Delete(ctx, "archive/test")
   277  	if err != nil {
   278  		t.Fatal(err)
   279  	}
   280  	p.ArchiveVersion = 0
   281  
   282  	// Store it, but without calling persist, so we don't trigger
   283  	// handleArchiving()
   284  	buf, err := p.Serialize()
   285  	if err != nil {
   286  		t.Fatal(err)
   287  	}
   288  
   289  	// Write the policy into storage
   290  	err = storage.Put(ctx, &logical.StorageEntry{
   291  		Key:   "policy/" + p.Name,
   292  		Value: buf,
   293  	})
   294  	if err != nil {
   295  		t.Fatal(err)
   296  	}
   297  
   298  	// If we're caching, expire from the cache since we modified it
   299  	// under-the-hood
   300  	if lm.useCache {
   301  		lm.cache.Delete("test")
   302  	}
   303  
   304  	// Now get the policy again; the upgrade should happen automatically
   305  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   306  		Storage: storage,
   307  		Name:    "test",
   308  	}, rand.Reader)
   309  	if err != nil {
   310  		t.Fatal(err)
   311  	}
   312  	if p == nil {
   313  		t.Fatal("nil policy")
   314  	}
   315  	if !lm.useCache {
   316  		p.Unlock()
   317  	}
   318  
   319  	checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10)
   320  
   321  	// Let's check some deletion logic while we're at it
   322  
   323  	// The policy should be in there
   324  	if lm.useCache {
   325  		_, ok := lm.cache.Load("test")
   326  		if !ok {
   327  			t.Fatal("nil policy in cache")
   328  		}
   329  	}
   330  
   331  	// First we'll do this wrong, by not setting the deletion flag
   332  	err = lm.DeletePolicy(ctx, storage, "test")
   333  	if err == nil {
   334  		t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
   335  	}
   336  
   337  	// The policy should still be in there
   338  	if lm.useCache {
   339  		_, ok := lm.cache.Load("test")
   340  		if !ok {
   341  			t.Fatal("nil policy in cache")
   342  		}
   343  	}
   344  
   345  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   346  		Storage: storage,
   347  		Name:    "test",
   348  	}, rand.Reader)
   349  	if err != nil {
   350  		t.Fatal(err)
   351  	}
   352  	if p == nil {
   353  		t.Fatal("policy nil after bad delete")
   354  	}
   355  	if !lm.useCache {
   356  		p.Unlock()
   357  	}
   358  
   359  	// Now do it properly
   360  	p.DeletionAllowed = true
   361  	err = p.Persist(ctx, storage)
   362  	if err != nil {
   363  		t.Fatal(err)
   364  	}
   365  	err = lm.DeletePolicy(ctx, storage, "test")
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  
   370  	// The policy should *not* be in there
   371  	if lm.useCache {
   372  		_, ok := lm.cache.Load("test")
   373  		if ok {
   374  			t.Fatal("non-nil policy in cache")
   375  		}
   376  	}
   377  
   378  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   379  		Storage: storage,
   380  		Name:    "test",
   381  	}, rand.Reader)
   382  	if err != nil {
   383  		t.Fatal(err)
   384  	}
   385  	if p != nil {
   386  		t.Fatal("policy not nil after delete")
   387  	}
   388  }
   389  
   390  func Test_Archiving(t *testing.T) {
   391  	lockManagerWithCache, _ := NewLockManager(true, 0)
   392  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
   393  	testArchivingUpgradeCommon(t, lockManagerWithCache)
   394  	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
   395  }
   396  
   397  func testArchivingCommon(t *testing.T, lm *LockManager) {
   398  	ctx := context.Background()
   399  
   400  	// First, we generate a policy and rotate it a number of times. Each time
   401  	// we'll ensure that we have the expected number of keys in the archive and
   402  	// the main keys object, which without changing the min version should be
   403  	// zero and latest, respectively
   404  
   405  	storage := &logical.InmemStorage{}
   406  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   407  		Upsert:  true,
   408  		Storage: storage,
   409  		KeyType: KeyType_AES256_GCM96,
   410  		Name:    "test",
   411  	}, rand.Reader)
   412  	if err != nil {
   413  		t.Fatal(err)
   414  	}
   415  	if p == nil {
   416  		t.Fatal("nil policy")
   417  	}
   418  	if !lm.useCache {
   419  		p.Unlock()
   420  	}
   421  
   422  	// Store the initial key in the archive
   423  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   424  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   425  
   426  	for i := 2; i <= 10; i++ {
   427  		err = p.Rotate(ctx, storage, rand.Reader)
   428  		if err != nil {
   429  			t.Fatal(err)
   430  		}
   431  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   432  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   433  	}
   434  
   435  	// Move the min decryption version up
   436  	for i := 1; i <= 10; i++ {
   437  		p.MinDecryptionVersion = i
   438  
   439  		err = p.Persist(ctx, storage)
   440  		if err != nil {
   441  			t.Fatal(err)
   442  		}
   443  		// We expect to find:
   444  		// * The keys in archive are the same as the latest version
   445  		// * The latest version is constant
   446  		// * The number of keys in the policy itself is from the min
   447  		// decryption version up to the latest version, so for e.g. 7 and
   448  		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
   449  		// decryption version plus 1 (the min decryption version key
   450  		// itself)
   451  		checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
   452  	}
   453  
   454  	// Move the min decryption version down
   455  	for i := 10; i >= 1; i-- {
   456  		p.MinDecryptionVersion = i
   457  
   458  		err = p.Persist(ctx, storage)
   459  		if err != nil {
   460  			t.Fatal(err)
   461  		}
   462  		// We expect to find:
   463  		// * The keys in archive are never removed so same as the latest version
   464  		// * The latest version is constant
   465  		// * The number of keys in the policy itself is from the min
   466  		// decryption version up to the latest version, so for e.g. 7 and
   467  		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
   468  		// decryption version plus 1 (the min decryption version key
   469  		// itself)
   470  		checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
   471  	}
   472  }
   473  
   474  func checkKeys(t *testing.T,
   475  	ctx context.Context,
   476  	p *Policy,
   477  	storage logical.Storage,
   478  	keysArchive []KeyEntry,
   479  	action string,
   480  	archiveVer, latestVer, keysSize int,
   481  ) {
   482  	// Sanity check
   483  	if len(keysArchive) != latestVer+1 {
   484  		t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
   485  			"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
   486  	}
   487  
   488  	archive, err := p.LoadArchive(ctx, storage)
   489  	if err != nil {
   490  		t.Fatal(err)
   491  	}
   492  
   493  	badArchiveVer := false
   494  	if archiveVer == 0 {
   495  		if len(archive.Keys) != 0 || p.ArchiveVersion != 0 {
   496  			badArchiveVer = true
   497  		}
   498  	} else {
   499  		// We need to subtract one because we have the indexes match key
   500  		// versions, which start at 1. So for an archive version of 1, we
   501  		// actually have two entries -- a blank 0 entry, and the key at spot 1
   502  		if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion {
   503  			badArchiveVer = true
   504  		}
   505  	}
   506  	if badArchiveVer {
   507  		t.Fatalf(
   508  			"expected archive version %d, found length of archive keys %d and policy archive version %d",
   509  			archiveVer, len(archive.Keys), p.ArchiveVersion,
   510  		)
   511  	}
   512  
   513  	if latestVer != p.LatestVersion {
   514  		t.Fatalf(
   515  			"expected latest version %d, found %d",
   516  			latestVer, p.LatestVersion,
   517  		)
   518  	}
   519  
   520  	if keysSize != len(p.Keys) {
   521  		t.Fatalf(
   522  			"expected keys size %d, found %d, action is %s, policy is \n%#v\n",
   523  			keysSize, len(p.Keys), action, p,
   524  		)
   525  	}
   526  
   527  	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   528  		if _, ok := p.Keys[strconv.Itoa(i)]; !ok {
   529  			t.Fatalf(
   530  				"expected key %d, did not find it in policy keys", i,
   531  			)
   532  		}
   533  	}
   534  
   535  	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   536  		ver := strconv.Itoa(i)
   537  		if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) {
   538  			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
   539  		}
   540  		polKey := p.Keys[ver]
   541  		polKey.CreationTime = keysArchive[i].CreationTime
   542  		p.Keys[ver] = polKey
   543  		if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) {
   544  			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
   545  		}
   546  	}
   547  
   548  	for i := 1; i < len(archive.Keys); i++ {
   549  		if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
   550  			t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key)
   551  		}
   552  	}
   553  }
   554  
   555  func Test_StorageErrorSafety(t *testing.T) {
   556  	ctx := context.Background()
   557  	lm, _ := NewLockManager(true, 0)
   558  
   559  	storage := &logical.InmemStorage{}
   560  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   561  		Upsert:  true,
   562  		Storage: storage,
   563  		KeyType: KeyType_AES256_GCM96,
   564  		Name:    "test",
   565  	}, rand.Reader)
   566  	if err != nil {
   567  		t.Fatal(err)
   568  	}
   569  	if p == nil {
   570  		t.Fatal("nil policy")
   571  	}
   572  
   573  	// Store the initial key in the archive
   574  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   575  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   576  
   577  	// We use checkKeys here just for sanity; it doesn't really handle cases of
   578  	// errors below so we do more targeted testing later
   579  	for i := 2; i <= 5; i++ {
   580  		err = p.Rotate(ctx, storage, rand.Reader)
   581  		if err != nil {
   582  			t.Fatal(err)
   583  		}
   584  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   585  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   586  	}
   587  
   588  	underlying := storage.Underlying()
   589  	underlying.FailPut(true)
   590  
   591  	priorLen := len(p.Keys)
   592  
   593  	err = p.Rotate(ctx, storage, rand.Reader)
   594  	if err == nil {
   595  		t.Fatal("expected error")
   596  	}
   597  
   598  	if len(p.Keys) != priorLen {
   599  		t.Fatal("length of keys should not have changed")
   600  	}
   601  }
   602  
   603  func Test_BadUpgrade(t *testing.T) {
   604  	ctx := context.Background()
   605  	lm, _ := NewLockManager(true, 0)
   606  	storage := &logical.InmemStorage{}
   607  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   608  		Upsert:  true,
   609  		Storage: storage,
   610  		KeyType: KeyType_AES256_GCM96,
   611  		Name:    "test",
   612  	}, rand.Reader)
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  	if p == nil {
   617  		t.Fatal("nil policy")
   618  	}
   619  
   620  	orig, err := copystructure.Copy(p)
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	orig.(*Policy).l = p.l
   625  
   626  	p.Key = p.Keys["1"].Key
   627  	p.Keys = nil
   628  	p.MinDecryptionVersion = 0
   629  
   630  	if err := p.Upgrade(ctx, storage, rand.Reader); err != nil {
   631  		t.Fatal(err)
   632  	}
   633  
   634  	k := p.Keys["1"]
   635  	o := orig.(*Policy).Keys["1"]
   636  	k.CreationTime = o.CreationTime
   637  	k.HMACKey = o.HMACKey
   638  	p.Keys["1"] = k
   639  	p.versionPrefixCache = sync.Map{}
   640  
   641  	if !reflect.DeepEqual(orig, p) {
   642  		t.Fatalf("not equal:\n%#v\n%#v", orig, p)
   643  	}
   644  
   645  	// Do it again with a failing storage call
   646  	underlying := storage.Underlying()
   647  	underlying.FailPut(true)
   648  
   649  	p.Key = p.Keys["1"].Key
   650  	p.Keys = nil
   651  	p.MinDecryptionVersion = 0
   652  
   653  	if err := p.Upgrade(ctx, storage, rand.Reader); err == nil {
   654  		t.Fatal("expected error")
   655  	}
   656  
   657  	if p.MinDecryptionVersion == 1 {
   658  		t.Fatal("min decryption version was changed")
   659  	}
   660  	if p.Keys != nil {
   661  		t.Fatal("found upgraded keys")
   662  	}
   663  	if p.Key == nil {
   664  		t.Fatal("non-upgraded key not found")
   665  	}
   666  }
   667  
   668  func Test_BadArchive(t *testing.T) {
   669  	ctx := context.Background()
   670  	lm, _ := NewLockManager(true, 0)
   671  	storage := &logical.InmemStorage{}
   672  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   673  		Upsert:  true,
   674  		Storage: storage,
   675  		KeyType: KeyType_AES256_GCM96,
   676  		Name:    "test",
   677  	}, rand.Reader)
   678  	if err != nil {
   679  		t.Fatal(err)
   680  	}
   681  	if p == nil {
   682  		t.Fatal("nil policy")
   683  	}
   684  
   685  	for i := 2; i <= 10; i++ {
   686  		err = p.Rotate(ctx, storage, rand.Reader)
   687  		if err != nil {
   688  			t.Fatal(err)
   689  		}
   690  	}
   691  
   692  	p.MinDecryptionVersion = 5
   693  	if err := p.Persist(ctx, storage); err != nil {
   694  		t.Fatal(err)
   695  	}
   696  	if p.ArchiveVersion != 10 {
   697  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   698  	}
   699  	if len(p.Keys) != 6 {
   700  		t.Fatalf("unexpected key length %d", len(p.Keys))
   701  	}
   702  
   703  	// Set back
   704  	p.MinDecryptionVersion = 1
   705  	if err := p.Persist(ctx, storage); err != nil {
   706  		t.Fatal(err)
   707  	}
   708  	if p.ArchiveVersion != 10 {
   709  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   710  	}
   711  	if len(p.Keys) != 10 {
   712  		t.Fatalf("unexpected key length %d", len(p.Keys))
   713  	}
   714  
   715  	// Run it again but we'll turn off storage along the way
   716  	p.MinDecryptionVersion = 5
   717  	if err := p.Persist(ctx, storage); err != nil {
   718  		t.Fatal(err)
   719  	}
   720  	if p.ArchiveVersion != 10 {
   721  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   722  	}
   723  	if len(p.Keys) != 6 {
   724  		t.Fatalf("unexpected key length %d", len(p.Keys))
   725  	}
   726  
   727  	underlying := storage.Underlying()
   728  	underlying.FailPut(true)
   729  
   730  	// Set back, which should cause p.Keys to be changed if the persist works,
   731  	// but it doesn't
   732  	p.MinDecryptionVersion = 1
   733  	if err := p.Persist(ctx, storage); err == nil {
   734  		t.Fatal("expected error during put")
   735  	}
   736  	if p.ArchiveVersion != 10 {
   737  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   738  	}
   739  	// Here's the expected change
   740  	if len(p.Keys) != 6 {
   741  		t.Fatalf("unexpected key length %d", len(p.Keys))
   742  	}
   743  }
   744  
   745  func Test_Import(t *testing.T) {
   746  	ctx := context.Background()
   747  	storage := &logical.InmemStorage{}
   748  	testKeys, err := generateTestKeys()
   749  	if err != nil {
   750  		t.Fatalf("error generating test keys: %s", err)
   751  	}
   752  
   753  	tests := map[string]struct {
   754  		policy      Policy
   755  		key         []byte
   756  		shouldError bool
   757  	}{
   758  		"import AES key": {
   759  			policy: Policy{
   760  				Name: "test-aes-key",
   761  				Type: KeyType_AES256_GCM96,
   762  			},
   763  			key:         testKeys[KeyType_AES256_GCM96],
   764  			shouldError: false,
   765  		},
   766  		"import RSA key": {
   767  			policy: Policy{
   768  				Name: "test-rsa-key",
   769  				Type: KeyType_RSA2048,
   770  			},
   771  			key:         testKeys[KeyType_RSA2048],
   772  			shouldError: false,
   773  		},
   774  		"import ECDSA key": {
   775  			policy: Policy{
   776  				Name: "test-ecdsa-key",
   777  				Type: KeyType_ECDSA_P256,
   778  			},
   779  			key:         testKeys[KeyType_ECDSA_P256],
   780  			shouldError: false,
   781  		},
   782  		"import ED25519 key": {
   783  			policy: Policy{
   784  				Name: "test-ed25519-key",
   785  				Type: KeyType_ED25519,
   786  			},
   787  			key:         testKeys[KeyType_ED25519],
   788  			shouldError: false,
   789  		},
   790  		"import incorrect key type": {
   791  			policy: Policy{
   792  				Name: "test-ed25519-key",
   793  				Type: KeyType_ED25519,
   794  			},
   795  			key:         testKeys[KeyType_AES256_GCM96],
   796  			shouldError: true,
   797  		},
   798  	}
   799  
   800  	for name, test := range tests {
   801  		t.Run(name, func(t *testing.T) {
   802  			if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError {
   803  				t.Fatalf("error importing key: %s", err)
   804  			}
   805  		})
   806  	}
   807  }
   808  
   809  func generateTestKeys() (map[KeyType][]byte, error) {
   810  	keyMap := make(map[KeyType][]byte)
   811  
   812  	rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
   813  	if err != nil {
   814  		return nil, err
   815  	}
   816  	rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey)
   817  	if err != nil {
   818  		return nil, err
   819  	}
   820  	keyMap[KeyType_RSA2048] = rsaKeyBytes
   821  
   822  	rsaKey, err = rsa.GenerateKey(rand.Reader, 3072)
   823  	if err != nil {
   824  		return nil, err
   825  	}
   826  	rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
   827  	if err != nil {
   828  		return nil, err
   829  	}
   830  	keyMap[KeyType_RSA3072] = rsaKeyBytes
   831  
   832  	rsaKey, err = rsa.GenerateKey(rand.Reader, 4096)
   833  	if err != nil {
   834  		return nil, err
   835  	}
   836  	rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
   837  	if err != nil {
   838  		return nil, err
   839  	}
   840  	keyMap[KeyType_RSA4096] = rsaKeyBytes
   841  
   842  	ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   843  	if err != nil {
   844  		return nil, err
   845  	}
   846  	ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey)
   847  	if err != nil {
   848  		return nil, err
   849  	}
   850  	keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes
   851  
   852  	_, ed25519Key, err := ed25519.GenerateKey(rand.Reader)
   853  	if err != nil {
   854  		return nil, err
   855  	}
   856  	ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key)
   857  	if err != nil {
   858  		return nil, err
   859  	}
   860  	keyMap[KeyType_ED25519] = ed25519KeyBytes
   861  
   862  	aesKey := make([]byte, 32)
   863  	_, err = rand.Read(aesKey)
   864  	if err != nil {
   865  		return nil, err
   866  	}
   867  	keyMap[KeyType_AES256_GCM96] = aesKey
   868  
   869  	return keyMap, nil
   870  }
   871  
   872  func BenchmarkSymmetric(b *testing.B) {
   873  	ctx := context.Background()
   874  	lm, _ := NewLockManager(true, 0)
   875  	storage := &logical.InmemStorage{}
   876  	p, _, _ := lm.GetPolicy(ctx, PolicyRequest{
   877  		Upsert:  true,
   878  		Storage: storage,
   879  		KeyType: KeyType_AES256_GCM96,
   880  		Name:    "test",
   881  	}, rand.Reader)
   882  	key, _ := p.GetKey(nil, 1, 32)
   883  	pt := make([]byte, 10)
   884  	ad := make([]byte, 10)
   885  	for i := 0; i < b.N; i++ {
   886  		ct, _ := p.SymmetricEncryptRaw(1, key, pt,
   887  			SymmetricOpts{
   888  				AdditionalData: ad,
   889  			})
   890  		pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{
   891  			AdditionalData: ad,
   892  		})
   893  		if !bytes.Equal(pt, pt2) {
   894  			b.Fail()
   895  		}
   896  	}
   897  }
   898  
   899  func saltOptions(options SigningOptions, saltLength int) SigningOptions {
   900  	return SigningOptions{
   901  		HashAlgorithm: options.HashAlgorithm,
   902  		Marshaling:    options.Marshaling,
   903  		SaltLength:    saltLength,
   904  		SigAlgorithm:  options.SigAlgorithm,
   905  	}
   906  }
   907  
   908  func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
   909  	tabs := strings.Repeat("\t", depth)
   910  	t.Log(tabs, "Manually verifying signature with options:", options)
   911  
   912  	tabs = strings.Repeat("\t", depth+1)
   913  	verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options)
   914  	if err != nil {
   915  		t.Fatal(tabs, "❌ Failed to manually verify signature:", err)
   916  	}
   917  	if !verified {
   918  		t.Fatal(tabs, "❌ Failed to manually verify signature")
   919  	}
   920  }
   921  
   922  func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
   923  	tabs := strings.Repeat("\t", depth)
   924  	t.Log(tabs, "Automatically verifying signature with options:", options)
   925  
   926  	tabs = strings.Repeat("\t", depth+1)
   927  	verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature)
   928  	if err != nil {
   929  		t.Fatal(tabs, "❌ Failed to automatically verify signature:", err)
   930  	}
   931  	if !verified {
   932  		t.Fatal(tabs, "❌ Failed to automatically verify signature")
   933  	}
   934  }
   935  
   936  func Test_RSA_PSS(t *testing.T) {
   937  	t.Log("Testing RSA PSS")
   938  	mathrand.Seed(time.Now().UnixNano())
   939  
   940  	var userError errutil.UserError
   941  	ctx := context.Background()
   942  	storage := &logical.InmemStorage{}
   943  	// https://crypto.stackexchange.com/a/1222
   944  	input := []byte("the ancients say the longer the salt, the more provable the security")
   945  	sigAlgorithm := "pss"
   946  
   947  	tabs := make(map[int]string)
   948  	for i := 1; i <= 6; i++ {
   949  		tabs[i] = strings.Repeat("\t", i)
   950  	}
   951  
   952  	test_RSA_PSS := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType,
   953  		marshalingType MarshalingType,
   954  	) {
   955  		unsaltedOptions := SigningOptions{
   956  			HashAlgorithm: hashType,
   957  			Marshaling:    marshalingType,
   958  			SigAlgorithm:  sigAlgorithm,
   959  		}
   960  		cryptoHash := CryptoHashMap[hashType]
   961  		minSaltLength := p.minRSAPSSSaltLength()
   962  		maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash)
   963  		hash := cryptoHash.New()
   964  		hash.Write(input)
   965  		input = hash.Sum(nil)
   966  
   967  		// 1. Make an "automatic" signature with the given key size and hash algorithm,
   968  		// but an automatically chosen salt length.
   969  		t.Log(tabs[3], "Make an automatic signature")
   970  		sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
   971  		if err != nil {
   972  			// A bit of a hack but FIPS go does not support some hash types
   973  			if isUnsupportedGoHashType(hashType, err) {
   974  				t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type")
   975  				return
   976  			}
   977  			t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
   978  		}
   979  
   980  		// 1.1 Verify this automatic signature using the *inferred* salt length.
   981  		autoVerify(4, t, p, input, sig, unsaltedOptions)
   982  
   983  		// 1.2. Verify this automatic signature using the *correct, given* salt length.
   984  		manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength))
   985  
   986  		// 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths.
   987  		t.Log(tabs[4], "Test incorrect salt lengths")
   988  		incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1}
   989  		for _, saltLength := range incorrectSaltLengths {
   990  			t.Log(tabs[5], "Salt length:", saltLength)
   991  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
   992  
   993  			verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
   994  			if verified {
   995  				t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err)
   996  			}
   997  		}
   998  
   999  		// 2. Rule out boundary, invalid salt lengths.
  1000  		t.Log(tabs[3], "Test invalid salt lengths")
  1001  		invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1}
  1002  		for _, saltLength := range invalidSaltLengths {
  1003  			t.Log(tabs[4], "Salt length:", saltLength)
  1004  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
  1005  
  1006  			// 2.1. Fail to sign.
  1007  			t.Log(tabs[5], "Try to make a manual signature")
  1008  			_, err := p.SignWithOptions(0, nil, input, &saltedOptions)
  1009  			if !errors.As(err, &userError) {
  1010  				t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
  1011  			}
  1012  
  1013  			// 2.2. Fail to verify.
  1014  			t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length")
  1015  			_, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
  1016  			if !errors.As(err, &userError) {
  1017  				t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
  1018  			}
  1019  		}
  1020  
  1021  		// 3. For three possible valid salt lengths...
  1022  		t.Log(tabs[3], "Test three possible valid salt lengths")
  1023  		midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength)
  1024  		validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength}
  1025  		for _, saltLength := range validSaltLengths {
  1026  			t.Log(tabs[4], "Salt length:", saltLength)
  1027  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
  1028  
  1029  			// 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length.
  1030  			t.Log(tabs[5], "Make a manual signature")
  1031  			sig, err := p.SignWithOptions(0, nil, input, &saltedOptions)
  1032  			if err != nil {
  1033  				t.Fatal(tabs[6], "❌ Failed to manually sign:", err)
  1034  			}
  1035  
  1036  			// 3.2. Verify this manual signature using the *correct, given* salt length.
  1037  			manualVerify(6, t, p, input, sig, saltedOptions)
  1038  
  1039  			// 3.3. Verify this manual signature using the *inferred* salt length.
  1040  			autoVerify(6, t, p, input, sig, unsaltedOptions)
  1041  		}
  1042  	}
  1043  
  1044  	rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
  1045  	testKeys, err := generateTestKeys()
  1046  	if err != nil {
  1047  		t.Fatalf("error generating test keys: %s", err)
  1048  	}
  1049  
  1050  	// 1. For each standard RSA key size 2048, 3072, and 4096...
  1051  	for _, rsaKeyType := range rsaKeyTypes {
  1052  		t.Log("Key size: ", rsaKeyType)
  1053  		p := &Policy{
  1054  			Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
  1055  			Type: rsaKeyType,
  1056  		}
  1057  
  1058  		rsaKeyBytes := testKeys[rsaKeyType]
  1059  		err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
  1060  		if err != nil {
  1061  			t.Fatal(tabs[1], "❌ Failed to import key:", err)
  1062  		}
  1063  		rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
  1064  		if err != nil {
  1065  			t.Fatalf("error parsing test keys: %s", err)
  1066  		}
  1067  		rsaKey := rsaKeyAny.(*rsa.PrivateKey)
  1068  
  1069  		// 2. For each hash algorithm...
  1070  		for hashAlgorithm, hashType := range HashTypeMap {
  1071  			t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
  1072  			if hashAlgorithm == "none" {
  1073  				continue
  1074  			}
  1075  
  1076  			// 3. For each marshaling type...
  1077  			for marshalingName, marshalingType := range MarshalingTypeMap {
  1078  				t.Log(tabs[2], "Marshaling type:", marshalingName)
  1079  				testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName)
  1080  				t.Run(testName, func(t *testing.T) { test_RSA_PSS(t, p, rsaKey, hashType, marshalingType) })
  1081  			}
  1082  		}
  1083  	}
  1084  }
  1085  
  1086  func Test_RSA_PKCS1(t *testing.T) {
  1087  	t.Log("Testing RSA PKCS#1v1.5")
  1088  
  1089  	ctx := context.Background()
  1090  	storage := &logical.InmemStorage{}
  1091  	// https://crypto.stackexchange.com/a/1222
  1092  	input := []byte("Sphinx of black quartz, judge my vow")
  1093  	sigAlgorithm := "pkcs1v15"
  1094  
  1095  	tabs := make(map[int]string)
  1096  	for i := 1; i <= 6; i++ {
  1097  		tabs[i] = strings.Repeat("\t", i)
  1098  	}
  1099  
  1100  	test_RSA_PKCS1 := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType,
  1101  		marshalingType MarshalingType,
  1102  	) {
  1103  		unsaltedOptions := SigningOptions{
  1104  			HashAlgorithm: hashType,
  1105  			Marshaling:    marshalingType,
  1106  			SigAlgorithm:  sigAlgorithm,
  1107  		}
  1108  		cryptoHash := CryptoHashMap[hashType]
  1109  
  1110  		// PKCS#1v1.5 NoOID uses a direct input and assumes it is pre-hashed.
  1111  		if hashType != 0 {
  1112  			hash := cryptoHash.New()
  1113  			hash.Write(input)
  1114  			input = hash.Sum(nil)
  1115  		}
  1116  
  1117  		// 1. Make a signature with the given key size and hash algorithm.
  1118  		t.Log(tabs[3], "Make an automatic signature")
  1119  		sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
  1120  		if err != nil {
  1121  			// A bit of a hack but FIPS go does not support some hash types
  1122  			if isUnsupportedGoHashType(hashType, err) {
  1123  				t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type")
  1124  				return
  1125  			}
  1126  			t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
  1127  		}
  1128  
  1129  		// 1.1 Verify this signature using the *inferred* salt length.
  1130  		autoVerify(4, t, p, input, sig, unsaltedOptions)
  1131  	}
  1132  
  1133  	rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
  1134  	testKeys, err := generateTestKeys()
  1135  	if err != nil {
  1136  		t.Fatalf("error generating test keys: %s", err)
  1137  	}
  1138  
  1139  	// 1. For each standard RSA key size 2048, 3072, and 4096...
  1140  	for _, rsaKeyType := range rsaKeyTypes {
  1141  		t.Log("Key size: ", rsaKeyType)
  1142  		p := &Policy{
  1143  			Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
  1144  			Type: rsaKeyType,
  1145  		}
  1146  
  1147  		rsaKeyBytes := testKeys[rsaKeyType]
  1148  		err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
  1149  		if err != nil {
  1150  			t.Fatal(tabs[1], "❌ Failed to import key:", err)
  1151  		}
  1152  		rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
  1153  		if err != nil {
  1154  			t.Fatalf("error parsing test keys: %s", err)
  1155  		}
  1156  		rsaKey := rsaKeyAny.(*rsa.PrivateKey)
  1157  
  1158  		// 2. For each hash algorithm...
  1159  		for hashAlgorithm, hashType := range HashTypeMap {
  1160  			t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
  1161  
  1162  			// 3. For each marshaling type...
  1163  			for marshalingName, marshalingType := range MarshalingTypeMap {
  1164  				t.Log(tabs[2], "Marshaling type:", marshalingName)
  1165  				testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName)
  1166  				t.Run(testName, func(t *testing.T) { test_RSA_PKCS1(t, p, rsaKey, hashType, marshalingType) })
  1167  			}
  1168  		}
  1169  	}
  1170  }
  1171  
  1172  // Normal Go builds support all the hash functions for RSA_PSS signatures but the
  1173  // FIPS Go build does not support at this time the SHA3 hashes as FIPS 140_2 does
  1174  // not accept them.
  1175  func isUnsupportedGoHashType(hashType HashType, err error) bool {
  1176  	switch hashType {
  1177  	case HashTypeSHA3224, HashTypeSHA3256, HashTypeSHA3384, HashTypeSHA3512:
  1178  		return strings.Contains(err.Error(), "unsupported hash function")
  1179  	}
  1180  
  1181  	return false
  1182  }