github.com/hashicorp/vault/sdk@v0.11.0/helper/keysutil/policy_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package keysutil 5 6 import ( 7 "bytes" 8 "context" 9 "crypto/ecdsa" 10 "crypto/elliptic" 11 "crypto/rand" 12 "crypto/rsa" 13 "crypto/x509" 14 "errors" 15 "fmt" 16 mathrand "math/rand" 17 "reflect" 18 "strconv" 19 "strings" 20 "sync" 21 "testing" 22 "time" 23 24 "golang.org/x/crypto/ed25519" 25 26 "github.com/hashicorp/vault/sdk/helper/errutil" 27 "github.com/hashicorp/vault/sdk/helper/jsonutil" 28 "github.com/hashicorp/vault/sdk/logical" 29 "github.com/mitchellh/copystructure" 30 ) 31 32 func TestPolicy_KeyEntryMapUpgrade(t *testing.T) { 33 now := time.Now() 34 old := map[int]KeyEntry{ 35 1: { 36 Key: []byte("samplekey"), 37 HMACKey: []byte("samplehmackey"), 38 CreationTime: now, 39 FormattedPublicKey: "sampleformattedpublickey", 40 }, 41 2: { 42 Key: []byte("samplekey2"), 43 HMACKey: []byte("samplehmackey2"), 44 CreationTime: now.Add(10 * time.Second), 45 FormattedPublicKey: "sampleformattedpublickey2", 46 }, 47 } 48 49 oldEncoded, err := jsonutil.EncodeJSON(old) 50 if err != nil { 51 t.Fatal(err) 52 } 53 54 var new keyEntryMap 55 err = jsonutil.DecodeJSON(oldEncoded, &new) 56 if err != nil { 57 t.Fatal(err) 58 } 59 60 newEncoded, err := jsonutil.EncodeJSON(&new) 61 if err != nil { 62 t.Fatal(err) 63 } 64 65 if string(oldEncoded) != string(newEncoded) { 66 t.Fatalf("failed to upgrade key entry map;\nold: %q\nnew: %q", string(oldEncoded), string(newEncoded)) 67 } 68 } 69 70 func Test_KeyUpgrade(t *testing.T) { 71 lockManagerWithCache, _ := NewLockManager(true, 0) 72 lockManagerWithoutCache, _ := NewLockManager(false, 0) 73 testKeyUpgradeCommon(t, lockManagerWithCache) 74 testKeyUpgradeCommon(t, lockManagerWithoutCache) 75 } 76 77 func testKeyUpgradeCommon(t *testing.T, lm *LockManager) { 78 ctx := context.Background() 79 80 storage := &logical.InmemStorage{} 81 p, upserted, err := lm.GetPolicy(ctx, PolicyRequest{ 82 Upsert: true, 83 Storage: storage, 84 KeyType: KeyType_AES256_GCM96, 85 Name: "test", 86 }, rand.Reader) 87 if err != nil { 88 t.Fatal(err) 89 } 90 if p == nil { 91 t.Fatal("nil policy") 92 } 93 if !upserted { 94 t.Fatal("expected an upsert") 95 } 96 if !lm.useCache { 97 p.Unlock() 98 } 99 100 testBytes := make([]byte, len(p.Keys["1"].Key)) 101 copy(testBytes, p.Keys["1"].Key) 102 103 p.Key = p.Keys["1"].Key 104 p.Keys = nil 105 p.MigrateKeyToKeysMap() 106 if p.Key != nil { 107 t.Fatal("policy.Key is not nil") 108 } 109 if len(p.Keys) != 1 { 110 t.Fatal("policy.Keys is the wrong size") 111 } 112 if !reflect.DeepEqual(testBytes, p.Keys["1"].Key) { 113 t.Fatal("key mismatch") 114 } 115 } 116 117 func Test_ArchivingUpgrade(t *testing.T) { 118 lockManagerWithCache, _ := NewLockManager(true, 0) 119 lockManagerWithoutCache, _ := NewLockManager(false, 0) 120 testArchivingUpgradeCommon(t, lockManagerWithCache) 121 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 122 } 123 124 func testArchivingUpgradeCommon(t *testing.T, lm *LockManager) { 125 ctx := context.Background() 126 127 // First, we generate a policy and rotate it a number of times. Each time 128 // we'll ensure that we have the expected number of keys in the archive and 129 // the main keys object, which without changing the min version should be 130 // zero and latest, respectively 131 132 storage := &logical.InmemStorage{} 133 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 134 Upsert: true, 135 Storage: storage, 136 KeyType: KeyType_AES256_GCM96, 137 Name: "test", 138 }, rand.Reader) 139 if err != nil { 140 t.Fatal(err) 141 } 142 if p == nil { 143 t.Fatal("nil policy") 144 } 145 if !lm.useCache { 146 p.Unlock() 147 } 148 149 // Store the initial key in the archive 150 keysArchive := []KeyEntry{{}, p.Keys["1"]} 151 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 152 153 for i := 2; i <= 10; i++ { 154 err = p.Rotate(ctx, storage, rand.Reader) 155 if err != nil { 156 t.Fatal(err) 157 } 158 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 159 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 160 } 161 162 // Now, wipe the archive and set the archive version to zero 163 err = storage.Delete(ctx, "archive/test") 164 if err != nil { 165 t.Fatal(err) 166 } 167 p.ArchiveVersion = 0 168 169 // Store it, but without calling persist, so we don't trigger 170 // handleArchiving() 171 buf, err := p.Serialize() 172 if err != nil { 173 t.Fatal(err) 174 } 175 176 // Write the policy into storage 177 err = storage.Put(ctx, &logical.StorageEntry{ 178 Key: "policy/" + p.Name, 179 Value: buf, 180 }) 181 if err != nil { 182 t.Fatal(err) 183 } 184 185 // If we're caching, expire from the cache since we modified it 186 // under-the-hood 187 if lm.useCache { 188 lm.cache.Delete("test") 189 } 190 191 // Now get the policy again; the upgrade should happen automatically 192 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 193 Storage: storage, 194 Name: "test", 195 }, rand.Reader) 196 if err != nil { 197 t.Fatal(err) 198 } 199 if p == nil { 200 t.Fatal("nil policy") 201 } 202 if !lm.useCache { 203 p.Unlock() 204 } 205 206 checkKeys(t, ctx, p, storage, keysArchive, "upgrade", 10, 10, 10) 207 208 // Let's check some deletion logic while we're at it 209 210 // The policy should be in there 211 if lm.useCache { 212 _, ok := lm.cache.Load("test") 213 if !ok { 214 t.Fatal("nil policy in cache") 215 } 216 } 217 218 // First we'll do this wrong, by not setting the deletion flag 219 err = lm.DeletePolicy(ctx, storage, "test") 220 if err == nil { 221 t.Fatal("got nil error, but should not have been able to delete since we didn't set the deletion flag on the policy") 222 } 223 224 // The policy should still be in there 225 if lm.useCache { 226 _, ok := lm.cache.Load("test") 227 if !ok { 228 t.Fatal("nil policy in cache") 229 } 230 } 231 232 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 233 Storage: storage, 234 Name: "test", 235 }, rand.Reader) 236 if err != nil { 237 t.Fatal(err) 238 } 239 if p == nil { 240 t.Fatal("policy nil after bad delete") 241 } 242 if !lm.useCache { 243 p.Unlock() 244 } 245 246 // Now do it properly 247 p.DeletionAllowed = true 248 err = p.Persist(ctx, storage) 249 if err != nil { 250 t.Fatal(err) 251 } 252 err = lm.DeletePolicy(ctx, storage, "test") 253 if err != nil { 254 t.Fatal(err) 255 } 256 257 // The policy should *not* be in there 258 if lm.useCache { 259 _, ok := lm.cache.Load("test") 260 if ok { 261 t.Fatal("non-nil policy in cache") 262 } 263 } 264 265 p, _, err = lm.GetPolicy(ctx, PolicyRequest{ 266 Storage: storage, 267 Name: "test", 268 }, rand.Reader) 269 if err != nil { 270 t.Fatal(err) 271 } 272 if p != nil { 273 t.Fatal("policy not nil after delete") 274 } 275 } 276 277 func Test_Archiving(t *testing.T) { 278 lockManagerWithCache, _ := NewLockManager(true, 0) 279 lockManagerWithoutCache, _ := NewLockManager(false, 0) 280 testArchivingUpgradeCommon(t, lockManagerWithCache) 281 testArchivingUpgradeCommon(t, lockManagerWithoutCache) 282 } 283 284 func testArchivingCommon(t *testing.T, lm *LockManager) { 285 ctx := context.Background() 286 287 // First, we generate a policy and rotate it a number of times. Each time 288 // we'll ensure that we have the expected number of keys in the archive and 289 // the main keys object, which without changing the min version should be 290 // zero and latest, respectively 291 292 storage := &logical.InmemStorage{} 293 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 294 Upsert: true, 295 Storage: storage, 296 KeyType: KeyType_AES256_GCM96, 297 Name: "test", 298 }, rand.Reader) 299 if err != nil { 300 t.Fatal(err) 301 } 302 if p == nil { 303 t.Fatal("nil policy") 304 } 305 if !lm.useCache { 306 p.Unlock() 307 } 308 309 // Store the initial key in the archive 310 keysArchive := []KeyEntry{{}, p.Keys["1"]} 311 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 312 313 for i := 2; i <= 10; i++ { 314 err = p.Rotate(ctx, storage, rand.Reader) 315 if err != nil { 316 t.Fatal(err) 317 } 318 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 319 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 320 } 321 322 // Move the min decryption version up 323 for i := 1; i <= 10; i++ { 324 p.MinDecryptionVersion = i 325 326 err = p.Persist(ctx, storage) 327 if err != nil { 328 t.Fatal(err) 329 } 330 // We expect to find: 331 // * The keys in archive are the same as the latest version 332 // * The latest version is constant 333 // * The number of keys in the policy itself is from the min 334 // decryption version up to the latest version, so for e.g. 7 and 335 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 336 // decryption version plus 1 (the min decryption version key 337 // itself) 338 checkKeys(t, ctx, p, storage, keysArchive, "minadd", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 339 } 340 341 // Move the min decryption version down 342 for i := 10; i >= 1; i-- { 343 p.MinDecryptionVersion = i 344 345 err = p.Persist(ctx, storage) 346 if err != nil { 347 t.Fatal(err) 348 } 349 // We expect to find: 350 // * The keys in archive are never removed so same as the latest version 351 // * The latest version is constant 352 // * The number of keys in the policy itself is from the min 353 // decryption version up to the latest version, so for e.g. 7 and 354 // 10, you'd need 7, 8, 9, and 10 -- IOW, latest version - min 355 // decryption version plus 1 (the min decryption version key 356 // itself) 357 checkKeys(t, ctx, p, storage, keysArchive, "minsub", 10, 10, p.LatestVersion-p.MinDecryptionVersion+1) 358 } 359 } 360 361 func checkKeys(t *testing.T, 362 ctx context.Context, 363 p *Policy, 364 storage logical.Storage, 365 keysArchive []KeyEntry, 366 action string, 367 archiveVer, latestVer, keysSize int, 368 ) { 369 // Sanity check 370 if len(keysArchive) != latestVer+1 { 371 t.Fatalf("latest expected key version is %d, expected test keys archive size is %d, "+ 372 "but keys archive is of size %d", latestVer, latestVer+1, len(keysArchive)) 373 } 374 375 archive, err := p.LoadArchive(ctx, storage) 376 if err != nil { 377 t.Fatal(err) 378 } 379 380 badArchiveVer := false 381 if archiveVer == 0 { 382 if len(archive.Keys) != 0 || p.ArchiveVersion != 0 { 383 badArchiveVer = true 384 } 385 } else { 386 // We need to subtract one because we have the indexes match key 387 // versions, which start at 1. So for an archive version of 1, we 388 // actually have two entries -- a blank 0 entry, and the key at spot 1 389 if archiveVer != len(archive.Keys)-1 || archiveVer != p.ArchiveVersion { 390 badArchiveVer = true 391 } 392 } 393 if badArchiveVer { 394 t.Fatalf( 395 "expected archive version %d, found length of archive keys %d and policy archive version %d", 396 archiveVer, len(archive.Keys), p.ArchiveVersion, 397 ) 398 } 399 400 if latestVer != p.LatestVersion { 401 t.Fatalf( 402 "expected latest version %d, found %d", 403 latestVer, p.LatestVersion, 404 ) 405 } 406 407 if keysSize != len(p.Keys) { 408 t.Fatalf( 409 "expected keys size %d, found %d, action is %s, policy is \n%#v\n", 410 keysSize, len(p.Keys), action, p, 411 ) 412 } 413 414 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 415 if _, ok := p.Keys[strconv.Itoa(i)]; !ok { 416 t.Fatalf( 417 "expected key %d, did not find it in policy keys", i, 418 ) 419 } 420 } 421 422 for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ { 423 ver := strconv.Itoa(i) 424 if !p.Keys[ver].CreationTime.Equal(keysArchive[i].CreationTime) { 425 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 426 } 427 polKey := p.Keys[ver] 428 polKey.CreationTime = keysArchive[i].CreationTime 429 p.Keys[ver] = polKey 430 if !reflect.DeepEqual(p.Keys[ver], keysArchive[i]) { 431 t.Fatalf("key %d not equivalent between policy keys and test keys archive; policy keys:\n%#v\ntest keys archive:\n%#v\n", i, p.Keys[ver], keysArchive[i]) 432 } 433 } 434 435 for i := 1; i < len(archive.Keys); i++ { 436 if !reflect.DeepEqual(archive.Keys[i].Key, keysArchive[i].Key) { 437 t.Fatalf("key %d not equivalent between policy archive and test keys archive; policy archive:\n%#v\ntest keys archive:\n%#v\n", i, archive.Keys[i].Key, keysArchive[i].Key) 438 } 439 } 440 } 441 442 func Test_StorageErrorSafety(t *testing.T) { 443 ctx := context.Background() 444 lm, _ := NewLockManager(true, 0) 445 446 storage := &logical.InmemStorage{} 447 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 448 Upsert: true, 449 Storage: storage, 450 KeyType: KeyType_AES256_GCM96, 451 Name: "test", 452 }, rand.Reader) 453 if err != nil { 454 t.Fatal(err) 455 } 456 if p == nil { 457 t.Fatal("nil policy") 458 } 459 460 // Store the initial key in the archive 461 keysArchive := []KeyEntry{{}, p.Keys["1"]} 462 checkKeys(t, ctx, p, storage, keysArchive, "initial", 1, 1, 1) 463 464 // We use checkKeys here just for sanity; it doesn't really handle cases of 465 // errors below so we do more targeted testing later 466 for i := 2; i <= 5; i++ { 467 err = p.Rotate(ctx, storage, rand.Reader) 468 if err != nil { 469 t.Fatal(err) 470 } 471 keysArchive = append(keysArchive, p.Keys[strconv.Itoa(i)]) 472 checkKeys(t, ctx, p, storage, keysArchive, "rotate", i, i, i) 473 } 474 475 underlying := storage.Underlying() 476 underlying.FailPut(true) 477 478 priorLen := len(p.Keys) 479 480 err = p.Rotate(ctx, storage, rand.Reader) 481 if err == nil { 482 t.Fatal("expected error") 483 } 484 485 if len(p.Keys) != priorLen { 486 t.Fatal("length of keys should not have changed") 487 } 488 } 489 490 func Test_BadUpgrade(t *testing.T) { 491 ctx := context.Background() 492 lm, _ := NewLockManager(true, 0) 493 storage := &logical.InmemStorage{} 494 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 495 Upsert: true, 496 Storage: storage, 497 KeyType: KeyType_AES256_GCM96, 498 Name: "test", 499 }, rand.Reader) 500 if err != nil { 501 t.Fatal(err) 502 } 503 if p == nil { 504 t.Fatal("nil policy") 505 } 506 507 orig, err := copystructure.Copy(p) 508 if err != nil { 509 t.Fatal(err) 510 } 511 orig.(*Policy).l = p.l 512 513 p.Key = p.Keys["1"].Key 514 p.Keys = nil 515 p.MinDecryptionVersion = 0 516 517 if err := p.Upgrade(ctx, storage, rand.Reader); err != nil { 518 t.Fatal(err) 519 } 520 521 k := p.Keys["1"] 522 o := orig.(*Policy).Keys["1"] 523 k.CreationTime = o.CreationTime 524 k.HMACKey = o.HMACKey 525 p.Keys["1"] = k 526 p.versionPrefixCache = sync.Map{} 527 528 if !reflect.DeepEqual(orig, p) { 529 t.Fatalf("not equal:\n%#v\n%#v", orig, p) 530 } 531 532 // Do it again with a failing storage call 533 underlying := storage.Underlying() 534 underlying.FailPut(true) 535 536 p.Key = p.Keys["1"].Key 537 p.Keys = nil 538 p.MinDecryptionVersion = 0 539 540 if err := p.Upgrade(ctx, storage, rand.Reader); err == nil { 541 t.Fatal("expected error") 542 } 543 544 if p.MinDecryptionVersion == 1 { 545 t.Fatal("min decryption version was changed") 546 } 547 if p.Keys != nil { 548 t.Fatal("found upgraded keys") 549 } 550 if p.Key == nil { 551 t.Fatal("non-upgraded key not found") 552 } 553 } 554 555 func Test_BadArchive(t *testing.T) { 556 ctx := context.Background() 557 lm, _ := NewLockManager(true, 0) 558 storage := &logical.InmemStorage{} 559 p, _, err := lm.GetPolicy(ctx, PolicyRequest{ 560 Upsert: true, 561 Storage: storage, 562 KeyType: KeyType_AES256_GCM96, 563 Name: "test", 564 }, rand.Reader) 565 if err != nil { 566 t.Fatal(err) 567 } 568 if p == nil { 569 t.Fatal("nil policy") 570 } 571 572 for i := 2; i <= 10; i++ { 573 err = p.Rotate(ctx, storage, rand.Reader) 574 if err != nil { 575 t.Fatal(err) 576 } 577 } 578 579 p.MinDecryptionVersion = 5 580 if err := p.Persist(ctx, storage); err != nil { 581 t.Fatal(err) 582 } 583 if p.ArchiveVersion != 10 { 584 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 585 } 586 if len(p.Keys) != 6 { 587 t.Fatalf("unexpected key length %d", len(p.Keys)) 588 } 589 590 // Set back 591 p.MinDecryptionVersion = 1 592 if err := p.Persist(ctx, storage); err != nil { 593 t.Fatal(err) 594 } 595 if p.ArchiveVersion != 10 { 596 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 597 } 598 if len(p.Keys) != 10 { 599 t.Fatalf("unexpected key length %d", len(p.Keys)) 600 } 601 602 // Run it again but we'll turn off storage along the way 603 p.MinDecryptionVersion = 5 604 if err := p.Persist(ctx, storage); err != nil { 605 t.Fatal(err) 606 } 607 if p.ArchiveVersion != 10 { 608 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 609 } 610 if len(p.Keys) != 6 { 611 t.Fatalf("unexpected key length %d", len(p.Keys)) 612 } 613 614 underlying := storage.Underlying() 615 underlying.FailPut(true) 616 617 // Set back, which should cause p.Keys to be changed if the persist works, 618 // but it doesn't 619 p.MinDecryptionVersion = 1 620 if err := p.Persist(ctx, storage); err == nil { 621 t.Fatal("expected error during put") 622 } 623 if p.ArchiveVersion != 10 { 624 t.Fatalf("unexpected archive version %d", p.ArchiveVersion) 625 } 626 // Here's the expected change 627 if len(p.Keys) != 6 { 628 t.Fatalf("unexpected key length %d", len(p.Keys)) 629 } 630 } 631 632 func Test_Import(t *testing.T) { 633 ctx := context.Background() 634 storage := &logical.InmemStorage{} 635 testKeys, err := generateTestKeys() 636 if err != nil { 637 t.Fatalf("error generating test keys: %s", err) 638 } 639 640 tests := map[string]struct { 641 policy Policy 642 key []byte 643 shouldError bool 644 }{ 645 "import AES key": { 646 policy: Policy{ 647 Name: "test-aes-key", 648 Type: KeyType_AES256_GCM96, 649 }, 650 key: testKeys[KeyType_AES256_GCM96], 651 shouldError: false, 652 }, 653 "import RSA key": { 654 policy: Policy{ 655 Name: "test-rsa-key", 656 Type: KeyType_RSA2048, 657 }, 658 key: testKeys[KeyType_RSA2048], 659 shouldError: false, 660 }, 661 "import ECDSA key": { 662 policy: Policy{ 663 Name: "test-ecdsa-key", 664 Type: KeyType_ECDSA_P256, 665 }, 666 key: testKeys[KeyType_ECDSA_P256], 667 shouldError: false, 668 }, 669 "import ED25519 key": { 670 policy: Policy{ 671 Name: "test-ed25519-key", 672 Type: KeyType_ED25519, 673 }, 674 key: testKeys[KeyType_ED25519], 675 shouldError: false, 676 }, 677 "import incorrect key type": { 678 policy: Policy{ 679 Name: "test-ed25519-key", 680 Type: KeyType_ED25519, 681 }, 682 key: testKeys[KeyType_AES256_GCM96], 683 shouldError: true, 684 }, 685 } 686 687 for name, test := range tests { 688 t.Run(name, func(t *testing.T) { 689 if err := test.policy.Import(ctx, storage, test.key, rand.Reader); (err != nil) != test.shouldError { 690 t.Fatalf("error importing key: %s", err) 691 } 692 }) 693 } 694 } 695 696 func generateTestKeys() (map[KeyType][]byte, error) { 697 keyMap := make(map[KeyType][]byte) 698 699 rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) 700 if err != nil { 701 return nil, err 702 } 703 rsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) 704 if err != nil { 705 return nil, err 706 } 707 keyMap[KeyType_RSA2048] = rsaKeyBytes 708 709 rsaKey, err = rsa.GenerateKey(rand.Reader, 3072) 710 if err != nil { 711 return nil, err 712 } 713 rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) 714 if err != nil { 715 return nil, err 716 } 717 keyMap[KeyType_RSA3072] = rsaKeyBytes 718 719 rsaKey, err = rsa.GenerateKey(rand.Reader, 4096) 720 if err != nil { 721 return nil, err 722 } 723 rsaKeyBytes, err = x509.MarshalPKCS8PrivateKey(rsaKey) 724 if err != nil { 725 return nil, err 726 } 727 keyMap[KeyType_RSA4096] = rsaKeyBytes 728 729 ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 730 if err != nil { 731 return nil, err 732 } 733 ecdsaKeyBytes, err := x509.MarshalPKCS8PrivateKey(ecdsaKey) 734 if err != nil { 735 return nil, err 736 } 737 keyMap[KeyType_ECDSA_P256] = ecdsaKeyBytes 738 739 _, ed25519Key, err := ed25519.GenerateKey(rand.Reader) 740 if err != nil { 741 return nil, err 742 } 743 ed25519KeyBytes, err := x509.MarshalPKCS8PrivateKey(ed25519Key) 744 if err != nil { 745 return nil, err 746 } 747 keyMap[KeyType_ED25519] = ed25519KeyBytes 748 749 aesKey := make([]byte, 32) 750 _, err = rand.Read(aesKey) 751 if err != nil { 752 return nil, err 753 } 754 keyMap[KeyType_AES256_GCM96] = aesKey 755 756 return keyMap, nil 757 } 758 759 func BenchmarkSymmetric(b *testing.B) { 760 ctx := context.Background() 761 lm, _ := NewLockManager(true, 0) 762 storage := &logical.InmemStorage{} 763 p, _, _ := lm.GetPolicy(ctx, PolicyRequest{ 764 Upsert: true, 765 Storage: storage, 766 KeyType: KeyType_AES256_GCM96, 767 Name: "test", 768 }, rand.Reader) 769 key, _ := p.GetKey(nil, 1, 32) 770 pt := make([]byte, 10) 771 ad := make([]byte, 10) 772 for i := 0; i < b.N; i++ { 773 ct, _ := p.SymmetricEncryptRaw(1, key, pt, 774 SymmetricOpts{ 775 AdditionalData: ad, 776 }) 777 pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{ 778 AdditionalData: ad, 779 }) 780 if !bytes.Equal(pt, pt2) { 781 b.Fail() 782 } 783 } 784 } 785 786 func saltOptions(options SigningOptions, saltLength int) SigningOptions { 787 return SigningOptions{ 788 HashAlgorithm: options.HashAlgorithm, 789 Marshaling: options.Marshaling, 790 SaltLength: saltLength, 791 SigAlgorithm: options.SigAlgorithm, 792 } 793 } 794 795 func manualVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { 796 tabs := strings.Repeat("\t", depth) 797 t.Log(tabs, "Manually verifying signature with options:", options) 798 799 tabs = strings.Repeat("\t", depth+1) 800 verified, err := p.VerifySignatureWithOptions(nil, input, sig.Signature, &options) 801 if err != nil { 802 t.Fatal(tabs, "❌ Failed to manually verify signature:", err) 803 } 804 if !verified { 805 t.Fatal(tabs, "❌ Failed to manually verify signature") 806 } 807 } 808 809 func autoVerify(depth int, t *testing.T, p *Policy, input []byte, sig *SigningResult, options SigningOptions) { 810 tabs := strings.Repeat("\t", depth) 811 t.Log(tabs, "Automatically verifying signature with options:", options) 812 813 tabs = strings.Repeat("\t", depth+1) 814 verified, err := p.VerifySignature(nil, input, options.HashAlgorithm, options.SigAlgorithm, options.Marshaling, sig.Signature) 815 if err != nil { 816 t.Fatal(tabs, "❌ Failed to automatically verify signature:", err) 817 } 818 if !verified { 819 t.Fatal(tabs, "❌ Failed to automatically verify signature") 820 } 821 } 822 823 func Test_RSA_PSS(t *testing.T) { 824 t.Log("Testing RSA PSS") 825 mathrand.Seed(time.Now().UnixNano()) 826 827 var userError errutil.UserError 828 ctx := context.Background() 829 storage := &logical.InmemStorage{} 830 // https://crypto.stackexchange.com/a/1222 831 input := []byte("the ancients say the longer the salt, the more provable the security") 832 sigAlgorithm := "pss" 833 834 tabs := make(map[int]string) 835 for i := 1; i <= 6; i++ { 836 tabs[i] = strings.Repeat("\t", i) 837 } 838 839 test_RSA_PSS := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, 840 marshalingType MarshalingType, 841 ) { 842 unsaltedOptions := SigningOptions{ 843 HashAlgorithm: hashType, 844 Marshaling: marshalingType, 845 SigAlgorithm: sigAlgorithm, 846 } 847 cryptoHash := CryptoHashMap[hashType] 848 minSaltLength := p.minRSAPSSSaltLength() 849 maxSaltLength := p.maxRSAPSSSaltLength(rsaKey.N.BitLen(), cryptoHash) 850 hash := cryptoHash.New() 851 hash.Write(input) 852 input = hash.Sum(nil) 853 854 // 1. Make an "automatic" signature with the given key size and hash algorithm, 855 // but an automatically chosen salt length. 856 t.Log(tabs[3], "Make an automatic signature") 857 sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) 858 if err != nil { 859 // A bit of a hack but FIPS go does not support some hash types 860 if isUnsupportedGoHashType(hashType, err) { 861 t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") 862 return 863 } 864 t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) 865 } 866 867 // 1.1 Verify this automatic signature using the *inferred* salt length. 868 autoVerify(4, t, p, input, sig, unsaltedOptions) 869 870 // 1.2. Verify this automatic signature using the *correct, given* salt length. 871 manualVerify(4, t, p, input, sig, saltOptions(unsaltedOptions, maxSaltLength)) 872 873 // 1.3. Try to verify this automatic signature using *incorrect, given* salt lengths. 874 t.Log(tabs[4], "Test incorrect salt lengths") 875 incorrectSaltLengths := []int{minSaltLength, maxSaltLength - 1} 876 for _, saltLength := range incorrectSaltLengths { 877 t.Log(tabs[5], "Salt length:", saltLength) 878 saltedOptions := saltOptions(unsaltedOptions, saltLength) 879 880 verified, _ := p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) 881 if verified { 882 t.Fatal(tabs[6], "❌ Failed to invalidate", verified, "signature using incorrect salt length:", err) 883 } 884 } 885 886 // 2. Rule out boundary, invalid salt lengths. 887 t.Log(tabs[3], "Test invalid salt lengths") 888 invalidSaltLengths := []int{minSaltLength - 1, maxSaltLength + 1} 889 for _, saltLength := range invalidSaltLengths { 890 t.Log(tabs[4], "Salt length:", saltLength) 891 saltedOptions := saltOptions(unsaltedOptions, saltLength) 892 893 // 2.1. Fail to sign. 894 t.Log(tabs[5], "Try to make a manual signature") 895 _, err := p.SignWithOptions(0, nil, input, &saltedOptions) 896 if !errors.As(err, &userError) { 897 t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) 898 } 899 900 // 2.2. Fail to verify. 901 t.Log(tabs[5], "Try to verify an automatic signature using an invalid salt length") 902 _, err = p.VerifySignatureWithOptions(nil, input, sig.Signature, &saltedOptions) 903 if !errors.As(err, &userError) { 904 t.Fatal(tabs[6], "❌ Failed to reject invalid salt length:", err) 905 } 906 } 907 908 // 3. For three possible valid salt lengths... 909 t.Log(tabs[3], "Test three possible valid salt lengths") 910 midSaltLength := mathrand.Intn(maxSaltLength-1) + 1 // [1, maxSaltLength) 911 validSaltLengths := []int{minSaltLength, midSaltLength, maxSaltLength} 912 for _, saltLength := range validSaltLengths { 913 t.Log(tabs[4], "Salt length:", saltLength) 914 saltedOptions := saltOptions(unsaltedOptions, saltLength) 915 916 // 3.1. Make a "manual" signature with the given key size, hash algorithm, and salt length. 917 t.Log(tabs[5], "Make a manual signature") 918 sig, err := p.SignWithOptions(0, nil, input, &saltedOptions) 919 if err != nil { 920 t.Fatal(tabs[6], "❌ Failed to manually sign:", err) 921 } 922 923 // 3.2. Verify this manual signature using the *correct, given* salt length. 924 manualVerify(6, t, p, input, sig, saltedOptions) 925 926 // 3.3. Verify this manual signature using the *inferred* salt length. 927 autoVerify(6, t, p, input, sig, unsaltedOptions) 928 } 929 } 930 931 rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} 932 testKeys, err := generateTestKeys() 933 if err != nil { 934 t.Fatalf("error generating test keys: %s", err) 935 } 936 937 // 1. For each standard RSA key size 2048, 3072, and 4096... 938 for _, rsaKeyType := range rsaKeyTypes { 939 t.Log("Key size: ", rsaKeyType) 940 p := &Policy{ 941 Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size 942 Type: rsaKeyType, 943 } 944 945 rsaKeyBytes := testKeys[rsaKeyType] 946 err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) 947 if err != nil { 948 t.Fatal(tabs[1], "❌ Failed to import key:", err) 949 } 950 rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) 951 if err != nil { 952 t.Fatalf("error parsing test keys: %s", err) 953 } 954 rsaKey := rsaKeyAny.(*rsa.PrivateKey) 955 956 // 2. For each hash algorithm... 957 for hashAlgorithm, hashType := range HashTypeMap { 958 t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) 959 if hashAlgorithm == "none" { 960 continue 961 } 962 963 // 3. For each marshaling type... 964 for marshalingName, marshalingType := range MarshalingTypeMap { 965 t.Log(tabs[2], "Marshaling type:", marshalingName) 966 testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) 967 t.Run(testName, func(t *testing.T) { test_RSA_PSS(t, p, rsaKey, hashType, marshalingType) }) 968 } 969 } 970 } 971 } 972 973 func Test_RSA_PKCS1(t *testing.T) { 974 t.Log("Testing RSA PKCS#1v1.5") 975 976 ctx := context.Background() 977 storage := &logical.InmemStorage{} 978 // https://crypto.stackexchange.com/a/1222 979 input := []byte("Sphinx of black quartz, judge my vow") 980 sigAlgorithm := "pkcs1v15" 981 982 tabs := make(map[int]string) 983 for i := 1; i <= 6; i++ { 984 tabs[i] = strings.Repeat("\t", i) 985 } 986 987 test_RSA_PKCS1 := func(t *testing.T, p *Policy, rsaKey *rsa.PrivateKey, hashType HashType, 988 marshalingType MarshalingType, 989 ) { 990 unsaltedOptions := SigningOptions{ 991 HashAlgorithm: hashType, 992 Marshaling: marshalingType, 993 SigAlgorithm: sigAlgorithm, 994 } 995 cryptoHash := CryptoHashMap[hashType] 996 997 // PKCS#1v1.5 NoOID uses a direct input and assumes it is pre-hashed. 998 if hashType != 0 { 999 hash := cryptoHash.New() 1000 hash.Write(input) 1001 input = hash.Sum(nil) 1002 } 1003 1004 // 1. Make a signature with the given key size and hash algorithm. 1005 t.Log(tabs[3], "Make an automatic signature") 1006 sig, err := p.Sign(0, nil, input, hashType, sigAlgorithm, marshalingType) 1007 if err != nil { 1008 // A bit of a hack but FIPS go does not support some hash types 1009 if isUnsupportedGoHashType(hashType, err) { 1010 t.Skip(tabs[4], "skipping test as FIPS Go does not support hash type") 1011 return 1012 } 1013 t.Fatal(tabs[4], "❌ Failed to automatically sign:", err) 1014 } 1015 1016 // 1.1 Verify this signature using the *inferred* salt length. 1017 autoVerify(4, t, p, input, sig, unsaltedOptions) 1018 } 1019 1020 rsaKeyTypes := []KeyType{KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096} 1021 testKeys, err := generateTestKeys() 1022 if err != nil { 1023 t.Fatalf("error generating test keys: %s", err) 1024 } 1025 1026 // 1. For each standard RSA key size 2048, 3072, and 4096... 1027 for _, rsaKeyType := range rsaKeyTypes { 1028 t.Log("Key size: ", rsaKeyType) 1029 p := &Policy{ 1030 Name: fmt.Sprint(rsaKeyType), // NOTE: crucial to create a new key per key size 1031 Type: rsaKeyType, 1032 } 1033 1034 rsaKeyBytes := testKeys[rsaKeyType] 1035 err := p.Import(ctx, storage, rsaKeyBytes, rand.Reader) 1036 if err != nil { 1037 t.Fatal(tabs[1], "❌ Failed to import key:", err) 1038 } 1039 rsaKeyAny, err := x509.ParsePKCS8PrivateKey(rsaKeyBytes) 1040 if err != nil { 1041 t.Fatalf("error parsing test keys: %s", err) 1042 } 1043 rsaKey := rsaKeyAny.(*rsa.PrivateKey) 1044 1045 // 2. For each hash algorithm... 1046 for hashAlgorithm, hashType := range HashTypeMap { 1047 t.Log(tabs[1], "Hash algorithm:", hashAlgorithm) 1048 1049 // 3. For each marshaling type... 1050 for marshalingName, marshalingType := range MarshalingTypeMap { 1051 t.Log(tabs[2], "Marshaling type:", marshalingName) 1052 testName := fmt.Sprintf("%s-%s-%s", rsaKeyType, hashAlgorithm, marshalingName) 1053 t.Run(testName, func(t *testing.T) { test_RSA_PKCS1(t, p, rsaKey, hashType, marshalingType) }) 1054 } 1055 } 1056 } 1057 } 1058 1059 // Normal Go builds support all the hash functions for RSA_PSS signatures but the 1060 // FIPS Go build does not support at this time the SHA3 hashes as FIPS 140_2 does 1061 // not accept them. 1062 func isUnsupportedGoHashType(hashType HashType, err error) bool { 1063 switch hashType { 1064 case HashTypeSHA3224, HashTypeSHA3256, HashTypeSHA3384, HashTypeSHA3512: 1065 return strings.Contains(err.Error(), "unsupported hash function") 1066 } 1067 1068 return false 1069 }