github.com/hashicorp/vault/sdk@v0.11.0/helper/keysutil/policy_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/ecdsa"
    10  	"crypto/elliptic"
    11  	"crypto/rand"
    12  	"crypto/rsa"
    13  	"crypto/x509"
    14  	"errors"
    15  	"fmt"
    16  	mathrand "math/rand"
    17  	"reflect"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"golang.org/x/crypto/ed25519"
    25  
    26  	"github.com/hashicorp/vault/sdk/helper/errutil"
    27  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    28  	"github.com/hashicorp/vault/sdk/logical"
    29  	"github.com/mitchellh/copystructure"
    30  )
    31  
    32  func TestPolicy_KeyEntryMapUpgrade(t *testing.T) {
    33  	now := time.Now()
    34  	old := map[int]KeyEntry{
    35  		1: {
    36  			Key:                []byte("samplekey"),
    37  			HMACKey:            []byte("samplehmackey"),
    38  			CreationTime:       now,
    39  			FormattedPublicKey: "sampleformattedpublickey",
    40  		},
    41  		2: {
    42  			Key:                []byte("samplekey2"),
    43  			HMACKey:            []byte("samplehmackey2"),
    44  			CreationTime:       now.Add(10 * time.Second),
    45  			FormattedPublicKey: "sampleformattedpublickey2",
    46  		},
    47  	}
    48  
    49  	oldEncoded, err := jsonutil.EncodeJSON(old)
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  
    54  	var new keyEntryMap
    55  	err = jsonutil.DecodeJSON(oldEncoded, &new)
    56  	if err != nil {
    57  		t.Fatal(err)
    58  	}
    59  
    60  	newEncoded, err := jsonutil.EncodeJSON(&new)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  
    65  	if string(oldEncoded) != string(newEncoded) {
    66  		t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded))
    67  	}
    68  }
    69  
    70  func Test_KeyUpgrade(t *testing.T) {
    71  	lockManagerWithCache, _ := NewLockManager(true, 0)
    72  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
    73  	testKeyUpgradeCommon(t, lockManagerWithCache)
    74  	testKeyUpgradeCommon(t, lockManagerWithoutCache)
    75  }
    76  
    77  func testKeyUpgradeCommon(t *testing.T, lm *LockManager) {
    78  	ctx := context.Background()
    79  
    80  	storage := &logical.InmemStorage{}
    81  	p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{
    82  		Upsert:  true,
    83  		Storage: storage,
    84  		KeyType: KeyType_AES256_GCM96,
    85  		Name:    "test",
    86  	}, rand.Reader)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	if p == nil {
    91  		t.Fatal("nil policy")
    92  	}
    93  	if !upserted {
    94  		t.Fatal("expected an upsert")
    95  	}
    96  	if !lm.useCache {
    97  		p.Unlock()
    98  	}
    99  
   100  	testBytes := make([]byte, len(p.Keys["1"].Key))
   101  	copy(testBytes, p.Keys["1"].Key)
   102  
   103  	p.Key = p.Keys["1"].Key
   104  	p.Keys = nil
   105  	p.MigrateKeyToKeysMap()
   106  	if p.Key != nil {
   107  		t.Fatal("policy.Key is not nil")
   108  	}
   109  	if len(p.Keys) != 1 {
   110  		t.Fatal("policy.Keys is the wrong size")
   111  	}
   112  	if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) {
   113  		t.Fatal("key mismatch")
   114  	}
   115  }
   116  
   117  func Test_ArchivingUpgrade(t *testing.T) {
   118  	lockManagerWithCache, _ := NewLockManager(true, 0)
   119  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
   120  	testArchivingUpgradeCommon(t, lockManagerWithCache)
   121  	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
   122  }
   123  
   124  func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) {
   125  	ctx := context.Background()
   126  
   127  	// First, we generate a policy and rotate it a number of times. Each time
   128  	// we'll ensure that we have the expected number of keys in the archive and
   129  	// the main keys object, which without changing the min version should be
   130  	// zero and latest, respectively
   131  
   132  	storage := &logical.InmemStorage{}
   133  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   134  		Upsert:  true,
   135  		Storage: storage,
   136  		KeyType: KeyType_AES256_GCM96,
   137  		Name:    "test",
   138  	}, rand.Reader)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  	if p == nil {
   143  		t.Fatal("nil policy")
   144  	}
   145  	if !lm.useCache {
   146  		p.Unlock()
   147  	}
   148  
   149  	// Store the initial key in the archive
   150  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   151  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   152  
   153  	for i := 2; i <= 10; i++ {
   154  		err = p.Rotate(ctx, storage, rand.Reader)
   155  		if err != nil {
   156  			t.Fatal(err)
   157  		}
   158  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   159  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   160  	}
   161  
   162  	// Now, wipe the archive and set the archive version to zero
   163  	err = storage.Delete(ctx, "archive/test")
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  	p.ArchiveVersion = 0
   168  
   169  	// Store it, but without calling persist, so we don't trigger
   170  	// handleArchiving()
   171  	buf, err := p.Serialize()
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  
   176  	// Write the policy into storage
   177  	err = storage.Put(ctx, &logical.StorageEntry{
   178  		Key:   "policy/" + p.Name,
   179  		Value: buf,
   180  	})
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  
   185  	// If we're caching, expire from the cache since we modified it
   186  	// under-the-hood
   187  	if lm.useCache {
   188  		lm.cache.Delete("test")
   189  	}
   190  
   191  	// Now get the policy again; the upgrade should happen automatically
   192  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   193  		Storage: storage,
   194  		Name:    "test",
   195  	}, rand.Reader)
   196  	if err != nil {
   197  		t.Fatal(err)
   198  	}
   199  	if p == nil {
   200  		t.Fatal("nil policy")
   201  	}
   202  	if !lm.useCache {
   203  		p.Unlock()
   204  	}
   205  
   206  	checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10)
   207  
   208  	// Let's check some deletion logic while we're at it
   209  
   210  	// The policy should be in there
   211  	if lm.useCache {
   212  		_, ok := lm.cache.Load("test")
   213  		if !ok {
   214  			t.Fatal("nil policy in cache")
   215  		}
   216  	}
   217  
   218  	// First we'll do this wrong, by not setting the deletion flag
   219  	err = lm.DeletePolicy(ctx, storage, "test")
   220  	if err == nil {
   221  		t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy")
   222  	}
   223  
   224  	// The policy should still be in there
   225  	if lm.useCache {
   226  		_, ok := lm.cache.Load("test")
   227  		if !ok {
   228  			t.Fatal("nil policy in cache")
   229  		}
   230  	}
   231  
   232  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   233  		Storage: storage,
   234  		Name:    "test",
   235  	}, rand.Reader)
   236  	if err != nil {
   237  		t.Fatal(err)
   238  	}
   239  	if p == nil {
   240  		t.Fatal("policy nil after bad delete")
   241  	}
   242  	if !lm.useCache {
   243  		p.Unlock()
   244  	}
   245  
   246  	// Now do it properly
   247  	p.DeletionAllowed = true
   248  	err = p.Persist(ctx, storage)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	err = lm.DeletePolicy(ctx, storage, "test")
   253  	if err != nil {
   254  		t.Fatal(err)
   255  	}
   256  
   257  	// The policy should *not* be in there
   258  	if lm.useCache {
   259  		_, ok := lm.cache.Load("test")
   260  		if ok {
   261  			t.Fatal("non-nil policy in cache")
   262  		}
   263  	}
   264  
   265  	p, _, err = lm.GetPolicy(ctx, PolicyRequest{
   266  		Storage: storage,
   267  		Name:    "test",
   268  	}, rand.Reader)
   269  	if err != nil {
   270  		t.Fatal(err)
   271  	}
   272  	if p != nil {
   273  		t.Fatal("policy not nil after delete")
   274  	}
   275  }
   276  
   277  func Test_Archiving(t *testing.T) {
   278  	lockManagerWithCache, _ := NewLockManager(true, 0)
   279  	lockManagerWithoutCache, _ := NewLockManager(false, 0)
   280  	testArchivingUpgradeCommon(t, lockManagerWithCache)
   281  	testArchivingUpgradeCommon(t, lockManagerWithoutCache)
   282  }
   283  
   284  func testArchivingCommon(t *testing.T, lm *LockManager) {
   285  	ctx := context.Background()
   286  
   287  	// First, we generate a policy and rotate it a number of times. Each time
   288  	// we'll ensure that we have the expected number of keys in the archive and
   289  	// the main keys object, which without changing the min version should be
   290  	// zero and latest, respectively
   291  
   292  	storage := &logical.InmemStorage{}
   293  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   294  		Upsert:  true,
   295  		Storage: storage,
   296  		KeyType: KeyType_AES256_GCM96,
   297  		Name:    "test",
   298  	}, rand.Reader)
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  	if p == nil {
   303  		t.Fatal("nil policy")
   304  	}
   305  	if !lm.useCache {
   306  		p.Unlock()
   307  	}
   308  
   309  	// Store the initial key in the archive
   310  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   311  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   312  
   313  	for i := 2; i <= 10; i++ {
   314  		err = p.Rotate(ctx, storage, rand.Reader)
   315  		if err != nil {
   316  			t.Fatal(err)
   317  		}
   318  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   319  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   320  	}
   321  
   322  	// Move the min decryption version up
   323  	for i := 1; i <= 10; i++ {
   324  		p.MinDecryptionVersion = i
   325  
   326  		err = p.Persist(ctx, storage)
   327  		if err != nil {
   328  			t.Fatal(err)
   329  		}
   330  		// We expect to find:
   331  		// * The keys in archive are the same as the latest version
   332  		// * The latest version is constant
   333  		// * The number of keys in the policy itself is from the min
   334  		// decryption version up to the latest version, so for e.g. 7 and
   335  		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
   336  		// decryption version plus 1 (the min decryption version key
   337  		// itself)
   338  		checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
   339  	}
   340  
   341  	// Move the min decryption version down
   342  	for i := 10; i >= 1; i-- {
   343  		p.MinDecryptionVersion = i
   344  
   345  		err = p.Persist(ctx, storage)
   346  		if err != nil {
   347  			t.Fatal(err)
   348  		}
   349  		// We expect to find:
   350  		// * The keys in archive are never removed so same as the latest version
   351  		// * The latest version is constant
   352  		// * The number of keys in the policy itself is from the min
   353  		// decryption version up to the latest version, so for e.g. 7 and
   354  		// 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min
   355  		// decryption version plus 1 (the min decryption version key
   356  		// itself)
   357  		checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1)
   358  	}
   359  }
   360  
   361  func checkKeys(t *testing.T,
   362  	ctx context.Context,
   363  	p *Policy,
   364  	storage logical.Storage,
   365  	keysArchive []KeyEntry,
   366  	action string,
   367  	archiveVer, latestVer, keysSize int,
   368  ) {
   369  	// Sanity check
   370  	if len(keysArchive) != latestVer+1 {
   371  		t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+
   372  			"but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive))
   373  	}
   374  
   375  	archive, err := p.LoadArchive(ctx, storage)
   376  	if err != nil {
   377  		t.Fatal(err)
   378  	}
   379  
   380  	badArchiveVer := false
   381  	if archiveVer == 0 {
   382  		if len(archive.Keys) != 0 || p.ArchiveVersion != 0 {
   383  			badArchiveVer = true
   384  		}
   385  	} else {
   386  		// We need to subtract one because we have the indexes match key
   387  		// versions, which start at 1. So for an archive version of 1, we
   388  		// actually have two entries -- a blank 0 entry, and the key at spot 1
   389  		if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion {
   390  			badArchiveVer = true
   391  		}
   392  	}
   393  	if badArchiveVer {
   394  		t.Fatalf(
   395  			"expected archive version %d, found length of archive keys %d and policy archive version %d",
   396  			archiveVer, len(archive.Keys), p.ArchiveVersion,
   397  		)
   398  	}
   399  
   400  	if latestVer != p.LatestVersion {
   401  		t.Fatalf(
   402  			"expected latest version %d, found %d",
   403  			latestVer, p.LatestVersion,
   404  		)
   405  	}
   406  
   407  	if keysSize != len(p.Keys) {
   408  		t.Fatalf(
   409  			"expected keys size %d, found %d, action is %s, policy is \n%#v\n",
   410  			keysSize, len(p.Keys), action, p,
   411  		)
   412  	}
   413  
   414  	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   415  		if _, ok := p.Keys[strconv.Itoa(i)]; !ok {
   416  			t.Fatalf(
   417  				"expected key %d, did not find it in policy keys", i,
   418  			)
   419  		}
   420  	}
   421  
   422  	for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   423  		ver := strconv.Itoa(i)
   424  		if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) {
   425  			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
   426  		}
   427  		polKey := p.Keys[ver]
   428  		polKey.CreationTime = keysArchive[i].CreationTime
   429  		p.Keys[ver] = polKey
   430  		if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) {
   431  			t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i])
   432  		}
   433  	}
   434  
   435  	for i := 1; i < len(archive.Keys); i++ {
   436  		if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) {
   437  			t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key)
   438  		}
   439  	}
   440  }
   441  
   442  func Test_StorageErrorSafety(t *testing.T) {
   443  	ctx := context.Background()
   444  	lm, _ := NewLockManager(true, 0)
   445  
   446  	storage := &logical.InmemStorage{}
   447  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   448  		Upsert:  true,
   449  		Storage: storage,
   450  		KeyType: KeyType_AES256_GCM96,
   451  		Name:    "test",
   452  	}, rand.Reader)
   453  	if err != nil {
   454  		t.Fatal(err)
   455  	}
   456  	if p == nil {
   457  		t.Fatal("nil policy")
   458  	}
   459  
   460  	// Store the initial key in the archive
   461  	keysArchive := []KeyEntry{{}, p.Keys["1"]}
   462  	checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1)
   463  
   464  	// We use checkKeys here just for sanity; it doesn't really handle cases of
   465  	// errors below so we do more targeted testing later
   466  	for i := 2; i <= 5; i++ {
   467  		err = p.Rotate(ctx, storage, rand.Reader)
   468  		if err != nil {
   469  			t.Fatal(err)
   470  		}
   471  		keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)])
   472  		checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i)
   473  	}
   474  
   475  	underlying := storage.Underlying()
   476  	underlying.FailPut(true)
   477  
   478  	priorLen := len(p.Keys)
   479  
   480  	err = p.Rotate(ctx, storage, rand.Reader)
   481  	if err == nil {
   482  		t.Fatal("expected error")
   483  	}
   484  
   485  	if len(p.Keys) != priorLen {
   486  		t.Fatal("length of keys should not have changed")
   487  	}
   488  }
   489  
   490  func Test_BadUpgrade(t *testing.T) {
   491  	ctx := context.Background()
   492  	lm, _ := NewLockManager(true, 0)
   493  	storage := &logical.InmemStorage{}
   494  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   495  		Upsert:  true,
   496  		Storage: storage,
   497  		KeyType: KeyType_AES256_GCM96,
   498  		Name:    "test",
   499  	}, rand.Reader)
   500  	if err != nil {
   501  		t.Fatal(err)
   502  	}
   503  	if p == nil {
   504  		t.Fatal("nil policy")
   505  	}
   506  
   507  	orig, err := copystructure.Copy(p)
   508  	if err != nil {
   509  		t.Fatal(err)
   510  	}
   511  	orig.(*Policy).l = p.l
   512  
   513  	p.Key = p.Keys["1"].Key
   514  	p.Keys = nil
   515  	p.MinDecryptionVersion = 0
   516  
   517  	if err := p.Upgrade(ctx, storage, rand.Reader); err != nil {
   518  		t.Fatal(err)
   519  	}
   520  
   521  	k := p.Keys["1"]
   522  	o := orig.(*Policy).Keys["1"]
   523  	k.CreationTime = o.CreationTime
   524  	k.HMACKey = o.HMACKey
   525  	p.Keys["1"] = k
   526  	p.versionPrefixCache = sync.Map{}
   527  
   528  	if !reflect.DeepEqual(orig, p) {
   529  		t.Fatalf("not equal:\n%#v\n%#v", orig, p)
   530  	}
   531  
   532  	// Do it again with a failing storage call
   533  	underlying := storage.Underlying()
   534  	underlying.FailPut(true)
   535  
   536  	p.Key = p.Keys["1"].Key
   537  	p.Keys = nil
   538  	p.MinDecryptionVersion = 0
   539  
   540  	if err := p.Upgrade(ctx, storage, rand.Reader); err == nil {
   541  		t.Fatal("expected error")
   542  	}
   543  
   544  	if p.MinDecryptionVersion == 1 {
   545  		t.Fatal("min decryption version was changed")
   546  	}
   547  	if p.Keys != nil {
   548  		t.Fatal("found upgraded keys")
   549  	}
   550  	if p.Key == nil {
   551  		t.Fatal("non-upgraded key not found")
   552  	}
   553  }
   554  
   555  func Test_BadArchive(t *testing.T) {
   556  	ctx := context.Background()
   557  	lm, _ := NewLockManager(true, 0)
   558  	storage := &logical.InmemStorage{}
   559  	p, _, err := lm.GetPolicy(ctx, PolicyRequest{
   560  		Upsert:  true,
   561  		Storage: storage,
   562  		KeyType: KeyType_AES256_GCM96,
   563  		Name:    "test",
   564  	}, rand.Reader)
   565  	if err != nil {
   566  		t.Fatal(err)
   567  	}
   568  	if p == nil {
   569  		t.Fatal("nil policy")
   570  	}
   571  
   572  	for i := 2; i <= 10; i++ {
   573  		err = p.Rotate(ctx, storage, rand.Reader)
   574  		if err != nil {
   575  			t.Fatal(err)
   576  		}
   577  	}
   578  
   579  	p.MinDecryptionVersion = 5
   580  	if err := p.Persist(ctx, storage); err != nil {
   581  		t.Fatal(err)
   582  	}
   583  	if p.ArchiveVersion != 10 {
   584  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   585  	}
   586  	if len(p.Keys) != 6 {
   587  		t.Fatalf("unexpected key length %d", len(p.Keys))
   588  	}
   589  
   590  	// Set back
   591  	p.MinDecryptionVersion = 1
   592  	if err := p.Persist(ctx, storage); err != nil {
   593  		t.Fatal(err)
   594  	}
   595  	if p.ArchiveVersion != 10 {
   596  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   597  	}
   598  	if len(p.Keys) != 10 {
   599  		t.Fatalf("unexpected key length %d", len(p.Keys))
   600  	}
   601  
   602  	// Run it again but we'll turn off storage along the way
   603  	p.MinDecryptionVersion = 5
   604  	if err := p.Persist(ctx, storage); err != nil {
   605  		t.Fatal(err)
   606  	}
   607  	if p.ArchiveVersion != 10 {
   608  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   609  	}
   610  	if len(p.Keys) != 6 {
   611  		t.Fatalf("unexpected key length %d", len(p.Keys))
   612  	}
   613  
   614  	underlying := storage.Underlying()
   615  	underlying.FailPut(true)
   616  
   617  	// Set back, which should cause p.Keys to be changed if the persist works,
   618  	// but it doesn't
   619  	p.MinDecryptionVersion = 1
   620  	if err := p.Persist(ctx, storage); err == nil {
   621  		t.Fatal("expected error during put")
   622  	}
   623  	if p.ArchiveVersion != 10 {
   624  		t.Fatalf("unexpected archive version %d", p.ArchiveVersion)
   625  	}
   626  	// Here's the expected change
   627  	if len(p.Keys) != 6 {
   628  		t.Fatalf("unexpected key length %d", len(p.Keys))
   629  	}
   630  }
   631  
   632  func Test_Import(t *testing.T) {
   633  	ctx := context.Background()
   634  	storage := &logical.InmemStorage{}
   635  	testKeys, err := generateTestKeys()
   636  	if err != nil {
   637  		t.Fatalf("error generating test keys: %s", err)
   638  	}
   639  
   640  	tests := map[string]struct {
   641  		policy      Policy
   642  		key         []byte
   643  		shouldError bool
   644  	}{
   645  		"import AES key": {
   646  			policy: Policy{
   647  				Name: "test-aes-key",
   648  				Type: KeyType_AES256_GCM96,
   649  			},
   650  			key:         testKeys[KeyType_AES256_GCM96],
   651  			shouldError: false,
   652  		},
   653  		"import RSA key": {
   654  			policy: Policy{
   655  				Name: "test-rsa-key",
   656  				Type: KeyType_RSA2048,
   657  			},
   658  			key:         testKeys[KeyType_RSA2048],
   659  			shouldError: false,
   660  		},
   661  		"import ECDSA key": {
   662  			policy: Policy{
   663  				Name: "test-ecdsa-key",
   664  				Type: KeyType_ECDSA_P256,
   665  			},
   666  			key:         testKeys[KeyType_ECDSA_P256],
   667  			shouldError: false,
   668  		},
   669  		"import ED25519 key": {
   670  			policy: Policy{
   671  				Name: "test-ed25519-key",
   672  				Type: KeyType_ED25519,
   673  			},
   674  			key:         testKeys[KeyType_ED25519],
   675  			shouldError: false,
   676  		},
   677  		"import incorrect key type": {
   678  			policy: Policy{
   679  				Name: "test-ed25519-key",
   680  				Type: KeyType_ED25519,
   681  			},
   682  			key:         testKeys[KeyType_AES256_GCM96],
   683  			shouldError: true,
   684  		},
   685  	}
   686  
   687  	for name, test := range tests {
   688  		t.Run(name, func(t *testing.T) {
   689  			if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError {
   690  				t.Fatalf("error importing key: %s", err)
   691  			}
   692  		})
   693  	}
   694  }
   695  
   696  func generateTestKeys() (map[KeyType][]byte, error) {
   697  	keyMap := make(map[KeyType][]byte)
   698  
   699  	rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
   700  	if err != nil {
   701  		return nil, err
   702  	}
   703  	rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey)
   704  	if err != nil {
   705  		return nil, err
   706  	}
   707  	keyMap[KeyType_RSA2048] = rsaKeyBytes
   708  
   709  	rsaKey, err = rsa.GenerateKey(rand.Reader, 3072)
   710  	if err != nil {
   711  		return nil, err
   712  	}
   713  	rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
   714  	if err != nil {
   715  		return nil, err
   716  	}
   717  	keyMap[KeyType_RSA3072] = rsaKeyBytes
   718  
   719  	rsaKey, err = rsa.GenerateKey(rand.Reader, 4096)
   720  	if err != nil {
   721  		return nil, err
   722  	}
   723  	rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey)
   724  	if err != nil {
   725  		return nil, err
   726  	}
   727  	keyMap[KeyType_RSA4096] = rsaKeyBytes
   728  
   729  	ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   730  	if err != nil {
   731  		return nil, err
   732  	}
   733  	ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey)
   734  	if err != nil {
   735  		return nil, err
   736  	}
   737  	keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes
   738  
   739  	_, ed25519Key, err := ed25519.GenerateKey(rand.Reader)
   740  	if err != nil {
   741  		return nil, err
   742  	}
   743  	ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key)
   744  	if err != nil {
   745  		return nil, err
   746  	}
   747  	keyMap[KeyType_ED25519] = ed25519KeyBytes
   748  
   749  	aesKey := make([]byte, 32)
   750  	_, err = rand.Read(aesKey)
   751  	if err != nil {
   752  		return nil, err
   753  	}
   754  	keyMap[KeyType_AES256_GCM96] = aesKey
   755  
   756  	return keyMap, nil
   757  }
   758  
   759  func BenchmarkSymmetric(b *testing.B) {
   760  	ctx := context.Background()
   761  	lm, _ := NewLockManager(true, 0)
   762  	storage := &logical.InmemStorage{}
   763  	p, _, _ := lm.GetPolicy(ctx, PolicyRequest{
   764  		Upsert:  true,
   765  		Storage: storage,
   766  		KeyType: KeyType_AES256_GCM96,
   767  		Name:    "test",
   768  	}, rand.Reader)
   769  	key, _ := p.GetKey(nil, 1, 32)
   770  	pt := make([]byte, 10)
   771  	ad := make([]byte, 10)
   772  	for i := 0; i < b.N; i++ {
   773  		ct, _ := p.SymmetricEncryptRaw(1, key, pt,
   774  			SymmetricOpts{
   775  				AdditionalData: ad,
   776  			})
   777  		pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{
   778  			AdditionalData: ad,
   779  		})
   780  		if !bytes.Equal(pt, pt2) {
   781  			b.Fail()
   782  		}
   783  	}
   784  }
   785  
   786  func saltOptions(options SigningOptions, saltLength int) SigningOptions {
   787  	return SigningOptions{
   788  		HashAlgorithm: options.HashAlgorithm,
   789  		Marshaling:    options.Marshaling,
   790  		SaltLength:    saltLength,
   791  		SigAlgorithm:  options.SigAlgorithm,
   792  	}
   793  }
   794  
   795  func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
   796  	tabs := strings.Repeat("\t", depth)
   797  	t.Log(tabs, "Manually verifying signature with options:", options)
   798  
   799  	tabs = strings.Repeat("\t", depth+1)
   800  	verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options)
   801  	if err != nil {
   802  		t.Fatal(tabs, "❌ Failed to manually verify signature:", err)
   803  	}
   804  	if !verified {
   805  		t.Fatal(tabs, "❌ Failed to manually verify signature")
   806  	}
   807  }
   808  
   809  func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) {
   810  	tabs := strings.Repeat("\t", depth)
   811  	t.Log(tabs, "Automatically verifying signature with options:", options)
   812  
   813  	tabs = strings.Repeat("\t", depth+1)
   814  	verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature)
   815  	if err != nil {
   816  		t.Fatal(tabs, "❌ Failed to automatically verify signature:", err)
   817  	}
   818  	if !verified {
   819  		t.Fatal(tabs, "❌ Failed to automatically verify signature")
   820  	}
   821  }
   822  
   823  func Test_RSA_PSS(t *testing.T) {
   824  	t.Log("Testing RSA PSS")
   825  	mathrand.Seed(time.Now().UnixNano())
   826  
   827  	var userError errutil.UserError
   828  	ctx := context.Background()
   829  	storage := &logical.InmemStorage{}
   830  	// https://crypto.stackexchange.com/a/1222
   831  	input := []byte("the ancients say the longer the salt, the more provable the security")
   832  	sigAlgorithm := "pss"
   833  
   834  	tabs := make(map[int]string)
   835  	for i := 1; i <= 6; i++ {
   836  		tabs[i] = strings.Repeat("\t", i)
   837  	}
   838  
   839  	test_RSA_PSS := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType,
   840  		marshalingType MarshalingType,
   841  	) {
   842  		unsaltedOptions := SigningOptions{
   843  			HashAlgorithm: hashType,
   844  			Marshaling:    marshalingType,
   845  			SigAlgorithm:  sigAlgorithm,
   846  		}
   847  		cryptoHash := CryptoHashMap[hashType]
   848  		minSaltLength := p.minRSAPSSSaltLength()
   849  		maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash)
   850  		hash := cryptoHash.New()
   851  		hash.Write(input)
   852  		input = hash.Sum(nil)
   853  
   854  		// 1. Make an "automatic" signature with the given key size and hash algorithm,
   855  		// but an automatically chosen salt length.
   856  		t.Log(tabs[3], "Make an automatic signature")
   857  		sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
   858  		if err != nil {
   859  			// A bit of a hack but FIPS go does not support some hash types
   860  			if isUnsupportedGoHashType(hashType, err) {
   861  				t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type")
   862  				return
   863  			}
   864  			t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
   865  		}
   866  
   867  		// 1.1 Verify this automatic signature using the *inferred* salt length.
   868  		autoVerify(4, t, p, input, sig, unsaltedOptions)
   869  
   870  		// 1.2. Verify this automatic signature using the *correct, given* salt length.
   871  		manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength))
   872  
   873  		// 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths.
   874  		t.Log(tabs[4], "Test incorrect salt lengths")
   875  		incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1}
   876  		for _, saltLength := range incorrectSaltLengths {
   877  			t.Log(tabs[5], "Salt length:", saltLength)
   878  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
   879  
   880  			verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
   881  			if verified {
   882  				t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err)
   883  			}
   884  		}
   885  
   886  		// 2. Rule out boundary, invalid salt lengths.
   887  		t.Log(tabs[3], "Test invalid salt lengths")
   888  		invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1}
   889  		for _, saltLength := range invalidSaltLengths {
   890  			t.Log(tabs[4], "Salt length:", saltLength)
   891  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
   892  
   893  			// 2.1. Fail to sign.
   894  			t.Log(tabs[5], "Try to make a manual signature")
   895  			_, err := p.SignWithOptions(0, nil, input, &saltedOptions)
   896  			if !errors.As(err, &userError) {
   897  				t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
   898  			}
   899  
   900  			// 2.2. Fail to verify.
   901  			t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length")
   902  			_, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions)
   903  			if !errors.As(err, &userError) {
   904  				t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err)
   905  			}
   906  		}
   907  
   908  		// 3. For three possible valid salt lengths...
   909  		t.Log(tabs[3], "Test three possible valid salt lengths")
   910  		midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength)
   911  		validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength}
   912  		for _, saltLength := range validSaltLengths {
   913  			t.Log(tabs[4], "Salt length:", saltLength)
   914  			saltedOptions := saltOptions(unsaltedOptions, saltLength)
   915  
   916  			// 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length.
   917  			t.Log(tabs[5], "Make a manual signature")
   918  			sig, err := p.SignWithOptions(0, nil, input, &saltedOptions)
   919  			if err != nil {
   920  				t.Fatal(tabs[6], "❌ Failed to manually sign:", err)
   921  			}
   922  
   923  			// 3.2. Verify this manual signature using the *correct, given* salt length.
   924  			manualVerify(6, t, p, input, sig, saltedOptions)
   925  
   926  			// 3.3. Verify this manual signature using the *inferred* salt length.
   927  			autoVerify(6, t, p, input, sig, unsaltedOptions)
   928  		}
   929  	}
   930  
   931  	rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
   932  	testKeys, err := generateTestKeys()
   933  	if err != nil {
   934  		t.Fatalf("error generating test keys: %s", err)
   935  	}
   936  
   937  	// 1. For each standard RSA key size 2048, 3072, and 4096...
   938  	for _, rsaKeyType := range rsaKeyTypes {
   939  		t.Log("Key size: ", rsaKeyType)
   940  		p := &Policy{
   941  			Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
   942  			Type: rsaKeyType,
   943  		}
   944  
   945  		rsaKeyBytes := testKeys[rsaKeyType]
   946  		err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
   947  		if err != nil {
   948  			t.Fatal(tabs[1], "❌ Failed to import key:", err)
   949  		}
   950  		rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
   951  		if err != nil {
   952  			t.Fatalf("error parsing test keys: %s", err)
   953  		}
   954  		rsaKey := rsaKeyAny.(*rsa.PrivateKey)
   955  
   956  		// 2. For each hash algorithm...
   957  		for hashAlgorithm, hashType := range HashTypeMap {
   958  			t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
   959  			if hashAlgorithm == "none" {
   960  				continue
   961  			}
   962  
   963  			// 3. For each marshaling type...
   964  			for marshalingName, marshalingType := range MarshalingTypeMap {
   965  				t.Log(tabs[2], "Marshaling type:", marshalingName)
   966  				testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName)
   967  				t.Run(testName, func(t *testing.T) { test_RSA_PSS(t, p, rsaKey, hashType, marshalingType) })
   968  			}
   969  		}
   970  	}
   971  }
   972  
   973  func Test_RSA_PKCS1(t *testing.T) {
   974  	t.Log("Testing RSA PKCS#1v1.5")
   975  
   976  	ctx := context.Background()
   977  	storage := &logical.InmemStorage{}
   978  	// https://crypto.stackexchange.com/a/1222
   979  	input := []byte("Sphinx of black quartz, judge my vow")
   980  	sigAlgorithm := "pkcs1v15"
   981  
   982  	tabs := make(map[int]string)
   983  	for i := 1; i <= 6; i++ {
   984  		tabs[i] = strings.Repeat("\t", i)
   985  	}
   986  
   987  	test_RSA_PKCS1 := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType,
   988  		marshalingType MarshalingType,
   989  	) {
   990  		unsaltedOptions := SigningOptions{
   991  			HashAlgorithm: hashType,
   992  			Marshaling:    marshalingType,
   993  			SigAlgorithm:  sigAlgorithm,
   994  		}
   995  		cryptoHash := CryptoHashMap[hashType]
   996  
   997  		// PKCS#1v1.5 NoOID uses a direct input and assumes it is pre-hashed.
   998  		if hashType != 0 {
   999  			hash := cryptoHash.New()
  1000  			hash.Write(input)
  1001  			input = hash.Sum(nil)
  1002  		}
  1003  
  1004  		// 1. Make a signature with the given key size and hash algorithm.
  1005  		t.Log(tabs[3], "Make an automatic signature")
  1006  		sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType)
  1007  		if err != nil {
  1008  			// A bit of a hack but FIPS go does not support some hash types
  1009  			if isUnsupportedGoHashType(hashType, err) {
  1010  				t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type")
  1011  				return
  1012  			}
  1013  			t.Fatal(tabs[4], "❌ Failed to automatically sign:", err)
  1014  		}
  1015  
  1016  		// 1.1 Verify this signature using the *inferred* salt length.
  1017  		autoVerify(4, t, p, input, sig, unsaltedOptions)
  1018  	}
  1019  
  1020  	rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096}
  1021  	testKeys, err := generateTestKeys()
  1022  	if err != nil {
  1023  		t.Fatalf("error generating test keys: %s", err)
  1024  	}
  1025  
  1026  	// 1. For each standard RSA key size 2048, 3072, and 4096...
  1027  	for _, rsaKeyType := range rsaKeyTypes {
  1028  		t.Log("Key size: ", rsaKeyType)
  1029  		p := &Policy{
  1030  			Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size
  1031  			Type: rsaKeyType,
  1032  		}
  1033  
  1034  		rsaKeyBytes := testKeys[rsaKeyType]
  1035  		err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader)
  1036  		if err != nil {
  1037  			t.Fatal(tabs[1], "❌ Failed to import key:", err)
  1038  		}
  1039  		rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes)
  1040  		if err != nil {
  1041  			t.Fatalf("error parsing test keys: %s", err)
  1042  		}
  1043  		rsaKey := rsaKeyAny.(*rsa.PrivateKey)
  1044  
  1045  		// 2. For each hash algorithm...
  1046  		for hashAlgorithm, hashType := range HashTypeMap {
  1047  			t.Log(tabs[1], "Hash algorithm:", hashAlgorithm)
  1048  
  1049  			// 3. For each marshaling type...
  1050  			for marshalingName, marshalingType := range MarshalingTypeMap {
  1051  				t.Log(tabs[2], "Marshaling type:", marshalingName)
  1052  				testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName)
  1053  				t.Run(testName, func(t *testing.T) { test_RSA_PKCS1(t, p, rsaKey, hashType, marshalingType) })
  1054  			}
  1055  		}
  1056  	}
  1057  }
  1058  
  1059  // Normal Go builds support all the hash functions for RSA_PSS signatures but the
  1060  // FIPS Go build does not support at this time the SHA3 hashes as FIPS 140_2 does
  1061  // not accept them.
  1062  func isUnsupportedGoHashType(hashType HashType, err error) bool {
  1063  	switch hashType {
  1064  	case HashTypeSHA3224, HashTypeSHA3256, HashTypeSHA3384, HashTypeSHA3512:
  1065  		return strings.Contains(err.Error(), "unsupported hash function")
  1066  	}
  1067  
  1068  	return false
  1069  }