github.com/hashicorp/vault/sdk@v0.13.0/helper/keysutil/policy_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package keysutil 5 6 import ( 7 "bytes" 8 "context" 9 "crypto/ecdsa" 10 "crypto/elliptic" 11 "crypto/rand" 12 "crypto/rsa" 13 "crypto/x509" 14 "errors" 15 "fmt" 16 mathrand "math/rand" 17 "reflect" 18 "strconv" 19 "strings" 20 "sync" 21 "testing" 22 "time" 23 24 "github.com/hashicorp/vault/sdk/helper/errutil" 25 "github.com/hashicorp/vault/sdk/helper/jsonutil" 26 "github.com/hashicorp/vault/sdk/logical" 27 "github.com/mitchellh/copystructure" 28 "golang.org/x/crypto/ed25519" 29 ) 30 31 // Ordering of these items needs to match the iota order defined in policy.go. Ordering changes 32 // should never occur, as it would lead to a key type change within existing stored policies. 33 var allTestKeyTypes = []KeyType{ 34 KeyType_AES256_GCM96, KeyType_ECDSA_P256, KeyType_ED25519, KeyType_RSA2048, 35 KeyType_RSA4096, KeyType_ChaCha20_Poly1305, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_AES128_GCM96, 36 KeyType_RSA3072, KeyType_MANAGED_KEY, KeyType_HMAC, KeyType_AES128_CMAC, KeyType_AES256_CMAC, 37 } 38 39 func TestPolicy_KeyTypes(t *testing.T) { 40 // Make sure the iota value never change for key types, as existing storage would be affected 41 for i, keyType := range allTestKeyTypes { 42 if int(keyType) != i { 43 t.Fatalf("iota of keytype %s changed, expected %d got %d", keyType.String(), i, keyType) 44 } 45 } 46 47 // Make sure we have a string presentation for all types 48 for _, keyType := range allTestKeyTypes { 49 if strings.Contains(keyType.String(), "unknown") { 50 t.Fatalf("keytype with iota of %d should not contain 'unknown', missing in String() switch statement", keyType) 51 } 52 } 53 } 54 55 func TestPolicy_HmacCmacSuported(t *testing.T) { 56 // Test HMAC supported feature 57 for _, keyType := range allTestKeyTypes { 58 switch keyType { 59 case KeyType_MANAGED_KEY: 60 if keyType.HMACSupported() { 61 t.Fatalf("hmac should not have been not be supported for keytype %s", keyType.String()) 62 } 63 if keyType.CMACSupported() { 64 t.Fatalf("cmac should not have been be supported for keytype %s", keyType.String()) 65 } 66 case KeyType_AES128_CMAC, KeyType_AES256_CMAC: 67 if keyType.HMACSupported() { 68 t.Fatalf("hmac should have been not be supported for keytype %s", keyType.String()) 69 } 70 if !keyType.CMACSupported() { 71 t.Fatalf("cmac should have been be supported for keytype %s", keyType.String()) 72 } 73 default: 74 if !keyType.HMACSupported() { 75 t.Fatalf("hmac should have been supported for keytype %s", keyType.String()) 76 } 77 if keyType.CMACSupported() { 78 t.Fatalf("cmac should not have been supported for keytype %s", keyType.String()) 79 } 80 } 81 } 82 } 83 84 func TestPolicy_CMACKeyUpgrade(t *testing.T) { 85 ctx := context.Background() 86 lm, _ := NewLockManager(false, 0) 87 storage := &logical.InmemStorage{} 88 p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ 89 Upsert: true, 90 Storage: storage, 91 KeyType: KeyType_AES256_CMAC, 92 Name: "test", 93 }, rand.Reader) 94 if err != nil { 95 t.Fatalf("failed loading policy: %v", err) 96 } 97 if p == nil { 98 t.Fatal("nil policy") 99 } 100 if !upserted { 101 t.Fatal("expected an upsert") 102 } 103 104 // This verifies we don't have a hmac key 105 _, err = p.HMACKey(1) 106 if err == nil { 107 t.Fatal("cmac key should not return an hmac key but did on initial creation") 108 } 109 110 if p.NeedsUpgrade() { 111 t.Fatal("cmac key should not require an upgrade after initial key creation") 112 } 113 114 err = p.Upgrade(ctx, storage, rand.Reader) 115 if err != nil { 116 t.Fatalf("an error was returned from upgrade method: %v", err) 117 } 118 p.Unlock() 119 120 // Now reload our policy from disk and make sure we still don't have a hmac key 121 p, upserted, err = lm.GetPolicy(ctx, PolicyRequest{ 122 Upsert: true, 123 Storage: storage, 124 KeyType: KeyType_AES256_CMAC, 125 Name: "test", 126 }, rand.Reader) 127 if err != nil { 128 t.Fatalf("failed loading policy: %v", err) 129 } 130 if p == nil { 131 t.Fatal("nil policy") 132 } 133 if upserted { 134 t.Fatal("expected the key to exist but upserted was true") 135 } 136 137 p.Unlock() 138 139 _, err = p.HMACKey(1) 140 if err == nil { 141 t.Fatal("cmac key should not return an hmac key post upgrade") 142 } 143 } 144 145 func TestPolicy_KeyEntryMapUpgrade(t *testing.T) { 146 now := time.Now() 147 old := map[int]KeyEntry{ 148 1: { 149 Key: []byte("samplekey"), 150 HMACKey: []byte("samplehmackey"), 151 CreationTime: now, 152 FormattedPublicKey: "sampleformattedpublickey", 153 }, 154 2: { 155 Key: []byte("samplekey2"), 156 HMACKey: []byte("samplehmackey2"), 157 CreationTime: now.Add(10 * time.Second), 158 FormattedPublicKey: "sampleformattedpublickey2", 159 }, 160 } 161 162 oldEncoded, err := jsonutil.EncodeJSON(old) 163 if err != nil { 164 t.Fatal(err) 165 } 166 167 var new keyEntryMap 168 err = jsonutil.DecodeJSON(oldEncoded, &new) 169 if err != nil { 170 t.Fatal(err) 171 } 172 173 newEncoded, err := jsonutil.EncodeJSON(&new) 174 if err != nil { 175 t.Fatal(err) 176 } 177 178 if string(oldEncoded) != string(newEncoded) { 179 t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded)) 180 } 181 } 182 183 func Test_KeyUpgrade(t *testing.T) { 184 lockManagerWithCache, _ := NewLockManager(true, 0) 185 lockManagerWithoutCache, _ := NewLockManager(false, 0) 186 testKeyUpgradeCommon(t, lockManagerWithCache) 187 testKeyUpgradeCommon(t, lockManagerWithoutCache) 188 } 189 190 func testKeyUpgradeCommon(t *testing.T, lm *LockManager) { 191 ctx := context.Background() 192 193 storage := &logical.InmemStorage{} 194 p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ 195 Upsert: true, 196 Storage: storage, 197 KeyType: KeyType_AES256_GCM96, 198 Name: "test", 199 }, rand.Reader) 200 if err != nil { 201 t.Fatal(err) 202 } 203 if p == nil { 204 t.Fatal("nil policy") 205 } 206 if !upserted { 207 t.Fatal("expected an upsert") 208 } 209 if !lm.useCache { 210 p.Unlock() 211 } 212 213 testBytes := make([]byte, len(p.Keys["1"].Key)) 214 copy(testBytes, p.Keys["1"].Key) 215 216 p.Key = p.Keys["1"].Key 217 p.Keys = nil 218 p.MigrateKeyToKeysMap() 219 if p.Key != nil { 220 t.Fatal("policy.Key is not nil") 221 } 222 if len(p.Keys) != 1 { 223 t.Fatal("policy.Keys is the wrong size") 224 } 225 if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) { 226 t.Fatal("key mismatch") 227 } 228 } 229 230 func Test_ArchivingUpgrade(t *testing.T) { 231 lockManagerWithCache, _ := NewLockManager(true, 0) 232 lockManagerWithoutCache, _ := NewLockManager(false, 0) 233 testArchivingUpgradeCommon(t, lockManagerWithCache) 234 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 235 } 236 237 func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) { 238 ctx := context.Background() 239 240 // First, we generate a policy and rotate it a number of times. Each time 241 // we'll ensure that we have the expected number of keys in the archive and 242 // the main keys object, which without changing the min version should be 243 // zero and latest, respectively 244 245 storage := &logical.InmemStorage{} 246 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 247 Upsert: true, 248 Storage: storage, 249 KeyType: KeyType_AES256_GCM96, 250 Name: "test", 251 }, rand.Reader) 252 if err != nil { 253 t.Fatal(err) 254 } 255 if p == nil { 256 t.Fatal("nil policy") 257 } 258 if !lm.useCache { 259 p.Unlock() 260 } 261 262 // Store the initial key in the archive 263 keysArchive := []KeyEntry{{}, p.Keys["1"]} 264 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 265 266 for i := 2; i <= 10; i++ { 267 err = p.Rotate(ctx, storage, rand.Reader) 268 if err != nil { 269 t.Fatal(err) 270 } 271 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 272 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 273 } 274 275 // Now, wipe the archive and set the archive version to zero 276 err = storage.Delete(ctx, "archive/test") 277 if err != nil { 278 t.Fatal(err) 279 } 280 p.ArchiveVersion = 0 281 282 // Store it, but without calling persist, so we don't trigger 283 // handleArchiving() 284 buf, err := p.Serialize() 285 if err != nil { 286 t.Fatal(err) 287 } 288 289 // Write the policy into storage 290 err = storage.Put(ctx, &logical.StorageEntry{ 291 Key: "policy/" + p.Name, 292 Value: buf, 293 }) 294 if err != nil { 295 t.Fatal(err) 296 } 297 298 // If we're caching, expire from the cache since we modified it 299 // under-the-hood 300 if lm.useCache { 301 lm.cache.Delete("test") 302 } 303 304 // Now get the policy again; the upgrade should happen automatically 305 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 306 Storage: storage, 307 Name: "test", 308 }, rand.Reader) 309 if err != nil { 310 t.Fatal(err) 311 } 312 if p == nil { 313 t.Fatal("nil policy") 314 } 315 if !lm.useCache { 316 p.Unlock() 317 } 318 319 checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10) 320 321 // Let's check some deletion logic while we're at it 322 323 // The policy should be in there 324 if lm.useCache { 325 _, ok := lm.cache.Load("test") 326 if !ok { 327 t.Fatal("nil policy in cache") 328 } 329 } 330 331 // First we'll do this wrong, by not setting the deletion flag 332 err = lm.DeletePolicy(ctx, storage, "test") 333 if err == nil { 334 t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") 335 } 336 337 // The policy should still be in there 338 if lm.useCache { 339 _, ok := lm.cache.Load("test") 340 if !ok { 341 t.Fatal("nil policy in cache") 342 } 343 } 344 345 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 346 Storage: storage, 347 Name: "test", 348 }, rand.Reader) 349 if err != nil { 350 t.Fatal(err) 351 } 352 if p == nil { 353 t.Fatal("policy nil after bad delete") 354 } 355 if !lm.useCache { 356 p.Unlock() 357 } 358 359 // Now do it properly 360 p.DeletionAllowed = true 361 err = p.Persist(ctx, storage) 362 if err != nil { 363 t.Fatal(err) 364 } 365 err = lm.DeletePolicy(ctx, storage, "test") 366 if err != nil { 367 t.Fatal(err) 368 } 369 370 // The policy should *not* be in there 371 if lm.useCache { 372 _, ok := lm.cache.Load("test") 373 if ok { 374 t.Fatal("non-nil policy in cache") 375 } 376 } 377 378 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 379 Storage: storage, 380 Name: "test", 381 }, rand.Reader) 382 if err != nil { 383 t.Fatal(err) 384 } 385 if p != nil { 386 t.Fatal("policy not nil after delete") 387 } 388 } 389 390 func Test_Archiving(t *testing.T) { 391 lockManagerWithCache, _ := NewLockManager(true, 0) 392 lockManagerWithoutCache, _ := NewLockManager(false, 0) 393 testArchivingUpgradeCommon(t, lockManagerWithCache) 394 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 395 } 396 397 func testArchivingCommon(t *testing.T, lm *LockManager) { 398 ctx := context.Background() 399 400 // First, we generate a policy and rotate it a number of times. Each time 401 // we'll ensure that we have the expected number of keys in the archive and 402 // the main keys object, which without changing the min version should be 403 // zero and latest, respectively 404 405 storage := &logical.InmemStorage{} 406 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 407 Upsert: true, 408 Storage: storage, 409 KeyType: KeyType_AES256_GCM96, 410 Name: "test", 411 }, rand.Reader) 412 if err != nil { 413 t.Fatal(err) 414 } 415 if p == nil { 416 t.Fatal("nil policy") 417 } 418 if !lm.useCache { 419 p.Unlock() 420 } 421 422 // Store the initial key in the archive 423 keysArchive := []KeyEntry{{}, p.Keys["1"]} 424 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 425 426 for i := 2; i <= 10; i++ { 427 err = p.Rotate(ctx, storage, rand.Reader) 428 if err != nil { 429 t.Fatal(err) 430 } 431 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 432 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 433 } 434 435 // Move the min decryption version up 436 for i := 1; i <= 10; i++ { 437 p.MinDecryptionVersion = i 438 439 err = p.Persist(ctx, storage) 440 if err != nil { 441 t.Fatal(err) 442 } 443 // We expect to find: 444 // * The keys in archive are the same as the latest version 445 // * The latest version is constant 446 // * The number of keys in the policy itself is from the min 447 // decryption version up to the latest version, so for e.g. 7 and 448 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 449 // decryption version plus 1 (the min decryption version key 450 // itself) 451 checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 452 } 453 454 // Move the min decryption version down 455 for i := 10; i >= 1; i-- { 456 p.MinDecryptionVersion = i 457 458 err = p.Persist(ctx, storage) 459 if err != nil { 460 t.Fatal(err) 461 } 462 // We expect to find: 463 // * The keys in archive are never removed so same as the latest version 464 // * The latest version is constant 465 // * The number of keys in the policy itself is from the min 466 // decryption version up to the latest version, so for e.g. 7 and 467 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 468 // decryption version plus 1 (the min decryption version key 469 // itself) 470 checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 471 } 472 } 473 474 func checkKeys(t *testing.T, 475 ctx context.Context, 476 p *Policy, 477 storage logical.Storage, 478 keysArchive []KeyEntry, 479 action string, 480 archiveVer, latestVer, keysSize int, 481 ) { 482 // Sanity check 483 if len(keysArchive) != latestVer+1 { 484 t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ 485 "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) 486 } 487 488 archive, err := p.LoadArchive(ctx, storage) 489 if err != nil { 490 t.Fatal(err) 491 } 492 493 badArchiveVer := false 494 if archiveVer == 0 { 495 if len(archive.Keys) != 0 || p.ArchiveVersion != 0 { 496 badArchiveVer = true 497 } 498 } else { 499 // We need to subtract one because we have the indexes match key 500 // versions, which start at 1. So for an archive version of 1, we 501 // actually have two entries -- a blank 0 entry, and the key at spot 1 502 if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion { 503 badArchiveVer = true 504 } 505 } 506 if badArchiveVer { 507 t.Fatalf( 508 "expected archive version %d, found length of archive keys %d and policy archive version %d", 509 archiveVer, len(archive.Keys), p.ArchiveVersion, 510 ) 511 } 512 513 if latestVer != p.LatestVersion { 514 t.Fatalf( 515 "expected latest version %d, found %d", 516 latestVer, p.LatestVersion, 517 ) 518 } 519 520 if keysSize != len(p.Keys) { 521 t.Fatalf( 522 "expected keys size %d, found %d, action is %s, policy is \n%#v\n", 523 keysSize, len(p.Keys), action, p, 524 ) 525 } 526 527 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 528 if _, ok := p.Keys[strconv.Itoa(i)]; !ok { 529 t.Fatalf( 530 "expected key %d, did not find it in policy keys", i, 531 ) 532 } 533 } 534 535 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 536 ver := strconv.Itoa(i) 537 if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) { 538 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 539 } 540 polKey := p.Keys[ver] 541 polKey.CreationTime = keysArchive[i].CreationTime 542 p.Keys[ver] = polKey 543 if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) { 544 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 545 } 546 } 547 548 for i := 1; i < len(archive.Keys); i++ { 549 if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { 550 t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key) 551 } 552 } 553 } 554 555 func Test_StorageErrorSafety(t *testing.T) { 556 ctx := context.Background() 557 lm, _ := NewLockManager(true, 0) 558 559 storage := &logical.InmemStorage{} 560 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 561 Upsert: true, 562 Storage: storage, 563 KeyType: KeyType_AES256_GCM96, 564 Name: "test", 565 }, rand.Reader) 566 if err != nil { 567 t.Fatal(err) 568 } 569 if p == nil { 570 t.Fatal("nil policy") 571 } 572 573 // Store the initial key in the archive 574 keysArchive := []KeyEntry{{}, p.Keys["1"]} 575 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 576 577 // We use checkKeys here just for sanity; it doesn't really handle cases of 578 // errors below so we do more targeted testing later 579 for i := 2; i <= 5; i++ { 580 err = p.Rotate(ctx, storage, rand.Reader) 581 if err != nil { 582 t.Fatal(err) 583 } 584 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 585 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 586 } 587 588 underlying := storage.Underlying() 589 underlying.FailPut(true) 590 591 priorLen := len(p.Keys) 592 593 err = p.Rotate(ctx, storage, rand.Reader) 594 if err == nil { 595 t.Fatal("expected error") 596 } 597 598 if len(p.Keys) != priorLen { 599 t.Fatal("length of keys should not have changed") 600 } 601 } 602 603 func Test_BadUpgrade(t *testing.T) { 604 ctx := context.Background() 605 lm, _ := NewLockManager(true, 0) 606 storage := &logical.InmemStorage{} 607 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 608 Upsert: true, 609 Storage: storage, 610 KeyType: KeyType_AES256_GCM96, 611 Name: "test", 612 }, rand.Reader) 613 if err != nil { 614 t.Fatal(err) 615 } 616 if p == nil { 617 t.Fatal("nil policy") 618 } 619 620 orig, err := copystructure.Copy(p) 621 if err != nil { 622 t.Fatal(err) 623 } 624 orig.(*Policy).l = p.l 625 626 p.Key = p.Keys["1"].Key 627 p.Keys = nil 628 p.MinDecryptionVersion = 0 629 630 if err := p.Upgrade(ctx, storage, rand.Reader); err != nil { 631 t.Fatal(err) 632 } 633 634 k := p.Keys["1"] 635 o := orig.(*Policy).Keys["1"] 636 k.CreationTime = o.CreationTime 637 k.HMACKey = o.HMACKey 638 p.Keys["1"] = k 639 p.versionPrefixCache = sync.Map{} 640 641 if !reflect.DeepEqual(orig, p) { 642 t.Fatalf("not equal:\n%#v\n%#v", orig, p) 643 } 644 645 // Do it again with a failing storage call 646 underlying := storage.Underlying() 647 underlying.FailPut(true) 648 649 p.Key = p.Keys["1"].Key 650 p.Keys = nil 651 p.MinDecryptionVersion = 0 652 653 if err := p.Upgrade(ctx, storage, rand.Reader); err == nil { 654 t.Fatal("expected error") 655 } 656 657 if p.MinDecryptionVersion == 1 { 658 t.Fatal("min decryption version was changed") 659 } 660 if p.Keys != nil { 661 t.Fatal("found upgraded keys") 662 } 663 if p.Key == nil { 664 t.Fatal("non-upgraded key not found") 665 } 666 } 667 668 func Test_BadArchive(t *testing.T) { 669 ctx := context.Background() 670 lm, _ := NewLockManager(true, 0) 671 storage := &logical.InmemStorage{} 672 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 673 Upsert: true, 674 Storage: storage, 675 KeyType: KeyType_AES256_GCM96, 676 Name: "test", 677 }, rand.Reader) 678 if err != nil { 679 t.Fatal(err) 680 } 681 if p == nil { 682 t.Fatal("nil policy") 683 } 684 685 for i := 2; i <= 10; i++ { 686 err = p.Rotate(ctx, storage, rand.Reader) 687 if err != nil { 688 t.Fatal(err) 689 } 690 } 691 692 p.MinDecryptionVersion = 5 693 if err := p.Persist(ctx, storage); err != nil { 694 t.Fatal(err) 695 } 696 if p.ArchiveVersion != 10 { 697 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 698 } 699 if len(p.Keys) != 6 { 700 t.Fatalf("unexpected key length %d", len(p.Keys)) 701 } 702 703 // Set back 704 p.MinDecryptionVersion = 1 705 if err := p.Persist(ctx, storage); err != nil { 706 t.Fatal(err) 707 } 708 if p.ArchiveVersion != 10 { 709 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 710 } 711 if len(p.Keys) != 10 { 712 t.Fatalf("unexpected key length %d", len(p.Keys)) 713 } 714 715 // Run it again but we'll turn off storage along the way 716 p.MinDecryptionVersion = 5 717 if err := p.Persist(ctx, storage); err != nil { 718 t.Fatal(err) 719 } 720 if p.ArchiveVersion != 10 { 721 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 722 } 723 if len(p.Keys) != 6 { 724 t.Fatalf("unexpected key length %d", len(p.Keys)) 725 } 726 727 underlying := storage.Underlying() 728 underlying.FailPut(true) 729 730 // Set back, which should cause p.Keys to be changed if the persist works, 731 // but it doesn't 732 p.MinDecryptionVersion = 1 733 if err := p.Persist(ctx, storage); err == nil { 734 t.Fatal("expected error during put") 735 } 736 if p.ArchiveVersion != 10 { 737 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 738 } 739 // Here's the expected change 740 if len(p.Keys) != 6 { 741 t.Fatalf("unexpected key length %d", len(p.Keys)) 742 } 743 } 744 745 func Test_Import(t *testing.T) { 746 ctx := context.Background() 747 storage := &logical.InmemStorage{} 748 testKeys, err := generateTestKeys() 749 if err != nil { 750 t.Fatalf("error generating test keys: %s", err) 751 } 752 753 tests := map[string]struct { 754 policy Policy 755 key []byte 756 shouldError bool 757 }{ 758 "import AES key": { 759 policy: Policy{ 760 Name: "test-aes-key", 761 Type: KeyType_AES256_GCM96, 762 }, 763 key: testKeys[KeyType_AES256_GCM96], 764 shouldError: false, 765 }, 766 "import RSA key": { 767 policy: Policy{ 768 Name: "test-rsa-key", 769 Type: KeyType_RSA2048, 770 }, 771 key: testKeys[KeyType_RSA2048], 772 shouldError: false, 773 }, 774 "import ECDSA key": { 775 policy: Policy{ 776 Name: "test-ecdsa-key", 777 Type: KeyType_ECDSA_P256, 778 }, 779 key: testKeys[KeyType_ECDSA_P256], 780 shouldError: false, 781 }, 782 "import ED25519 key": { 783 policy: Policy{ 784 Name: "test-ed25519-key", 785 Type: KeyType_ED25519, 786 }, 787 key: testKeys[KeyType_ED25519], 788 shouldError: false, 789 }, 790 "import incorrect key type": { 791 policy: Policy{ 792 Name: "test-ed25519-key", 793 Type: KeyType_ED25519, 794 }, 795 key: testKeys[KeyType_AES256_GCM96], 796 shouldError: true, 797 }, 798 } 799 800 for name, test := range tests { 801 t.Run(name, func(t *testing.T) { 802 if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError { 803 t.Fatalf("error importing key: %s", err) 804 } 805 }) 806 } 807 } 808 809 func generateTestKeys() (map[KeyType][]byte, error) { 810 keyMap := make(map[KeyType][]byte) 811 812 rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) 813 if err != nil { 814 return nil, err 815 } 816 rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) 817 if err != nil { 818 return nil, err 819 } 820 keyMap[KeyType_RSA2048] = rsaKeyBytes 821 822 rsaKey, err = rsa.GenerateKey(rand.Reader, 3072) 823 if err != nil { 824 return nil, err 825 } 826 rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) 827 if err != nil { 828 return nil, err 829 } 830 keyMap[KeyType_RSA3072] = rsaKeyBytes 831 832 rsaKey, err = rsa.GenerateKey(rand.Reader, 4096) 833 if err != nil { 834 return nil, err 835 } 836 rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) 837 if err != nil { 838 return nil, err 839 } 840 keyMap[KeyType_RSA4096] = rsaKeyBytes 841 842 ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 843 if err != nil { 844 return nil, err 845 } 846 ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey) 847 if err != nil { 848 return nil, err 849 } 850 keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes 851 852 _, ed25519Key, err := ed25519.GenerateKey(rand.Reader) 853 if err != nil { 854 return nil, err 855 } 856 ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key) 857 if err != nil { 858 return nil, err 859 } 860 keyMap[KeyType_ED25519] = ed25519KeyBytes 861 862 aesKey := make([]byte, 32) 863 _, err = rand.Read(aesKey) 864 if err != nil { 865 return nil, err 866 } 867 keyMap[KeyType_AES256_GCM96] = aesKey 868 869 return keyMap, nil 870 } 871 872 func BenchmarkSymmetric(b *testing.B) { 873 ctx := context.Background() 874 lm, _ := NewLockManager(true, 0) 875 storage := &logical.InmemStorage{} 876 p, _, _ := lm.GetPolicy(ctx, PolicyRequest{ 877 Upsert: true, 878 Storage: storage, 879 KeyType: KeyType_AES256_GCM96, 880 Name: "test", 881 }, rand.Reader) 882 key, _ := p.GetKey(nil, 1, 32) 883 pt := make([]byte, 10) 884 ad := make([]byte, 10) 885 for i := 0; i < b.N; i++ { 886 ct, _ := p.SymmetricEncryptRaw(1, key, pt, 887 SymmetricOpts{ 888 AdditionalData: ad, 889 }) 890 pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{ 891 AdditionalData: ad, 892 }) 893 if !bytes.Equal(pt, pt2) { 894 b.Fail() 895 } 896 } 897 } 898 899 func saltOptions(options SigningOptions, saltLength int) SigningOptions { 900 return SigningOptions{ 901 HashAlgorithm: options.HashAlgorithm, 902 Marshaling: options.Marshaling, 903 SaltLength: saltLength, 904 SigAlgorithm: options.SigAlgorithm, 905 } 906 } 907 908 func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { 909 tabs := strings.Repeat("\t", depth) 910 t.Log(tabs, "Manually verifying signature with options:", options) 911 912 tabs = strings.Repeat("\t", depth+1) 913 verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options) 914 if err != nil { 915 t.Fatal(tabs, "❌ Failed to manually verify signature:", err) 916 } 917 if !verified { 918 t.Fatal(tabs, "❌ Failed to manually verify signature") 919 } 920 } 921 922 func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { 923 tabs := strings.Repeat("\t", depth) 924 t.Log(tabs, "Automatically verifying signature with options:", options) 925 926 tabs = strings.Repeat("\t", depth+1) 927 verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature) 928 if err != nil { 929 t.Fatal(tabs, "❌ Failed to automatically verify signature:", err) 930 } 931 if !verified { 932 t.Fatal(tabs, "❌ Failed to automatically verify signature") 933 } 934 } 935 936 func Test_RSA_PSS(t *testing.T) { 937 t.Log("Testing RSA PSS") 938 mathrand.Seed(time.Now().UnixNano()) 939 940 var userError errutil.UserError 941 ctx := context.Background() 942 storage := &logical.InmemStorage{} 943 // https://crypto.stackexchange.com/a/1222 944 input := []byte("the ancients say the longer the salt, the more provable the security") 945 sigAlgorithm := "pss" 946 947 tabs := make(map[int]string) 948 for i := 1; i <= 6; i++ { 949 tabs[i] = strings.Repeat("\t", i) 950 } 951 952 test_RSA_PSS := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, 953 marshalingType MarshalingType, 954 ) { 955 unsaltedOptions := SigningOptions{ 956 HashAlgorithm: hashType, 957 Marshaling: marshalingType, 958 SigAlgorithm: sigAlgorithm, 959 } 960 cryptoHash := CryptoHashMap[hashType] 961 minSaltLength := p.minRSAPSSSaltLength() 962 maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash) 963 hash := cryptoHash.New() 964 hash.Write(input) 965 input = hash.Sum(nil) 966 967 // 1. Make an "automatic" signature with the given key size and hash algorithm, 968 // but an automatically chosen salt length. 969 t.Log(tabs[3], "Make an automatic signature") 970 sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) 971 if err != nil { 972 // A bit of a hack but FIPS go does not support some hash types 973 if isUnsupportedGoHashType(hashType, err) { 974 t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") 975 return 976 } 977 t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) 978 } 979 980 // 1.1 Verify this automatic signature using the *inferred* salt length. 981 autoVerify(4, t, p, input, sig, unsaltedOptions) 982 983 // 1.2. Verify this automatic signature using the *correct, given* salt length. 984 manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength)) 985 986 // 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths. 987 t.Log(tabs[4], "Test incorrect salt lengths") 988 incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1} 989 for _, saltLength := range incorrectSaltLengths { 990 t.Log(tabs[5], "Salt length:", saltLength) 991 saltedOptions := saltOptions(unsaltedOptions, saltLength) 992 993 verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) 994 if verified { 995 t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err) 996 } 997 } 998 999 // 2. Rule out boundary, invalid salt lengths. 1000 t.Log(tabs[3], "Test invalid salt lengths") 1001 invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1} 1002 for _, saltLength := range invalidSaltLengths { 1003 t.Log(tabs[4], "Salt length:", saltLength) 1004 saltedOptions := saltOptions(unsaltedOptions, saltLength) 1005 1006 // 2.1. Fail to sign. 1007 t.Log(tabs[5], "Try to make a manual signature") 1008 _, err := p.SignWithOptions(0, nil, input, &saltedOptions) 1009 if !errors.As(err, &userError) { 1010 t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) 1011 } 1012 1013 // 2.2. Fail to verify. 1014 t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length") 1015 _, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) 1016 if !errors.As(err, &userError) { 1017 t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) 1018 } 1019 } 1020 1021 // 3. For three possible valid salt lengths... 1022 t.Log(tabs[3], "Test three possible valid salt lengths") 1023 midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength) 1024 validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength} 1025 for _, saltLength := range validSaltLengths { 1026 t.Log(tabs[4], "Salt length:", saltLength) 1027 saltedOptions := saltOptions(unsaltedOptions, saltLength) 1028 1029 // 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length. 1030 t.Log(tabs[5], "Make a manual signature") 1031 sig, err := p.SignWithOptions(0, nil, input, &saltedOptions) 1032 if err != nil { 1033 t.Fatal(tabs[6], "❌ Failed to manually sign:", err) 1034 } 1035 1036 // 3.2. Verify this manual signature using the *correct, given* salt length. 1037 manualVerify(6, t, p, input, sig, saltedOptions) 1038 1039 // 3.3. Verify this manual signature using the *inferred* salt length. 1040 autoVerify(6, t, p, input, sig, unsaltedOptions) 1041 } 1042 } 1043 1044 rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} 1045 testKeys, err := generateTestKeys() 1046 if err != nil { 1047 t.Fatalf("error generating test keys: %s", err) 1048 } 1049 1050 // 1. For each standard RSA key size 2048, 3072, and 4096... 1051 for _, rsaKeyType := range rsaKeyTypes { 1052 t.Log("Key size: ", rsaKeyType) 1053 p := &Policy{ 1054 Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size 1055 Type: rsaKeyType, 1056 } 1057 1058 rsaKeyBytes := testKeys[rsaKeyType] 1059 err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) 1060 if err != nil { 1061 t.Fatal(tabs[1], "❌ Failed to import key:", err) 1062 } 1063 rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) 1064 if err != nil { 1065 t.Fatalf("error parsing test keys: %s", err) 1066 } 1067 rsaKey := rsaKeyAny.(*rsa.PrivateKey) 1068 1069 // 2. For each hash algorithm... 1070 for hashAlgorithm, hashType := range HashTypeMap { 1071 t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) 1072 if hashAlgorithm == "none" { 1073 continue 1074 } 1075 1076 // 3. For each marshaling type... 1077 for marshalingName, marshalingType := range MarshalingTypeMap { 1078 t.Log(tabs[2], "Marshaling type:", marshalingName) 1079 testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) 1080 t.Run(testName, func(t *testing.T) { test_RSA_PSS(t, p, rsaKey, hashType, marshalingType) }) 1081 } 1082 } 1083 } 1084 } 1085 1086 func Test_RSA_PKCS1(t *testing.T) { 1087 t.Log("Testing RSA PKCS#1v1.5") 1088 1089 ctx := context.Background() 1090 storage := &logical.InmemStorage{} 1091 // https://crypto.stackexchange.com/a/1222 1092 input := []byte("Sphinx of black quartz, judge my vow") 1093 sigAlgorithm := "pkcs1v15" 1094 1095 tabs := make(map[int]string) 1096 for i := 1; i <= 6; i++ { 1097 tabs[i] = strings.Repeat("\t", i) 1098 } 1099 1100 test_RSA_PKCS1 := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, 1101 marshalingType MarshalingType, 1102 ) { 1103 unsaltedOptions := SigningOptions{ 1104 HashAlgorithm: hashType, 1105 Marshaling: marshalingType, 1106 SigAlgorithm: sigAlgorithm, 1107 } 1108 cryptoHash := CryptoHashMap[hashType] 1109 1110 // PKCS#1v1.5 NoOID uses a direct input and assumes it is pre-hashed. 1111 if hashType != 0 { 1112 hash := cryptoHash.New() 1113 hash.Write(input) 1114 input = hash.Sum(nil) 1115 } 1116 1117 // 1. Make a signature with the given key size and hash algorithm. 1118 t.Log(tabs[3], "Make an automatic signature") 1119 sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) 1120 if err != nil { 1121 // A bit of a hack but FIPS go does not support some hash types 1122 if isUnsupportedGoHashType(hashType, err) { 1123 t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") 1124 return 1125 } 1126 t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) 1127 } 1128 1129 // 1.1 Verify this signature using the *inferred* salt length. 1130 autoVerify(4, t, p, input, sig, unsaltedOptions) 1131 } 1132 1133 rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} 1134 testKeys, err := generateTestKeys() 1135 if err != nil { 1136 t.Fatalf("error generating test keys: %s", err) 1137 } 1138 1139 // 1. For each standard RSA key size 2048, 3072, and 4096... 1140 for _, rsaKeyType := range rsaKeyTypes { 1141 t.Log("Key size: ", rsaKeyType) 1142 p := &Policy{ 1143 Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size 1144 Type: rsaKeyType, 1145 } 1146 1147 rsaKeyBytes := testKeys[rsaKeyType] 1148 err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) 1149 if err != nil { 1150 t.Fatal(tabs[1], "❌ Failed to import key:", err) 1151 } 1152 rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) 1153 if err != nil { 1154 t.Fatalf("error parsing test keys: %s", err) 1155 } 1156 rsaKey := rsaKeyAny.(*rsa.PrivateKey) 1157 1158 // 2. For each hash algorithm... 1159 for hashAlgorithm, hashType := range HashTypeMap { 1160 t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) 1161 1162 // 3. For each marshaling type... 1163 for marshalingName, marshalingType := range MarshalingTypeMap { 1164 t.Log(tabs[2], "Marshaling type:", marshalingName) 1165 testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) 1166 t.Run(testName, func(t *testing.T) { test_RSA_PKCS1(t, p, rsaKey, hashType, marshalingType) }) 1167 } 1168 } 1169 } 1170 } 1171 1172 // Normal Go builds support all the hash functions for RSA_PSS signatures but the 1173 // FIPS Go build does not support at this time the SHA3 hashes as FIPS 140_2 does 1174 // not accept them. 1175 func isUnsupportedGoHashType(hashType HashType, err error) bool { 1176 switch hashType { 1177 case HashTypeSHA3224, HashTypeSHA3256, HashTypeSHA3384, HashTypeSHA3512: 1178 return strings.Contains(err.Error(), "unsupported hash function") 1179 } 1180 1181 return false 1182 }