github.com/hashicorp/vault/sdk@v0.13.0/helper/keysutil/policy.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto"
    10  	"crypto/aes"
    11  	"crypto/cipher"
    12  	"crypto/ecdsa"
    13  	"crypto/elliptic"
    14  	"crypto/hmac"
    15  	"crypto/rand"
    16  	"crypto/rsa"
    17  	"crypto/sha256"
    18  	"crypto/x509"
    19  	"encoding/asn1"
    20  	"encoding/base64"
    21  	"encoding/json"
    22  	"encoding/pem"
    23  	"errors"
    24  	"fmt"
    25  	"hash"
    26  	"io"
    27  	"math/big"
    28  	"path"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"github.com/google/tink/go/kwp/subtle"
    36  	"github.com/hashicorp/errwrap"
    37  	"github.com/hashicorp/go-uuid"
    38  	"github.com/hashicorp/vault/sdk/helper/errutil"
    39  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    40  	"github.com/hashicorp/vault/sdk/helper/kdf"
    41  	"github.com/hashicorp/vault/sdk/logical"
    42  	"golang.org/x/crypto/chacha20poly1305"
    43  	"golang.org/x/crypto/ed25519"
    44  	"golang.org/x/crypto/hkdf"
    45  )
    46  
    47  // Careful with iota; don't put anything before it in this const block because
    48  // we need the default of zero to be the old-style KDF
    49  const (
    50  	Kdf_hmac_sha256_counter = iota // built-in helper
    51  	Kdf_hkdf_sha256                // golang.org/x/crypto/hkdf
    52  
    53  	HmacMinKeySize = 256 / 8
    54  	HmacMaxKeySize = 4096 / 8
    55  )
    56  
    57  // Or this one...we need the default of zero to be the original AES256-GCM96
    58  const (
    59  	KeyType_AES256_GCM96 = iota
    60  	KeyType_ECDSA_P256
    61  	KeyType_ED25519
    62  	KeyType_RSA2048
    63  	KeyType_RSA4096
    64  	KeyType_ChaCha20_Poly1305
    65  	KeyType_ECDSA_P384
    66  	KeyType_ECDSA_P521
    67  	KeyType_AES128_GCM96
    68  	KeyType_RSA3072
    69  	KeyType_MANAGED_KEY
    70  	KeyType_HMAC
    71  	KeyType_AES128_CMAC
    72  	KeyType_AES256_CMAC
    73  	// If adding to this list please update allTestKeyTypes in policy_test.go
    74  )
    75  
    76  const (
    77  	// ErrTooOld is returned whtn the ciphertext or signatures's key version is
    78  	// too old.
    79  	ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
    80  
    81  	// DefaultVersionTemplate is used when no version template is provided.
    82  	DefaultVersionTemplate = "vault:v{{version}}:"
    83  )
    84  
    85  type AEADFactory interface {
    86  	GetAEAD(iv []byte) (cipher.AEAD, error)
    87  }
    88  
    89  type AssociatedDataFactory interface {
    90  	GetAssociatedData() ([]byte, error)
    91  }
    92  
    93  type ManagedKeyFactory interface {
    94  	GetManagedKeyParameters() ManagedKeyParameters
    95  }
    96  
    97  type RestoreInfo struct {
    98  	Time    time.Time `json:"time"`
    99  	Version int       `json:"version"`
   100  }
   101  
   102  type BackupInfo struct {
   103  	Time    time.Time `json:"time"`
   104  	Version int       `json:"version"`
   105  }
   106  
   107  type SigningOptions struct {
   108  	HashAlgorithm    HashType
   109  	Marshaling       MarshalingType
   110  	SaltLength       int
   111  	SigAlgorithm     string
   112  	ManagedKeyParams ManagedKeyParameters
   113  }
   114  
   115  type SigningResult struct {
   116  	Signature string
   117  	PublicKey []byte
   118  }
   119  
   120  type ecdsaSignature struct {
   121  	R, S *big.Int
   122  }
   123  
   124  type KeyType int
   125  
   126  func (kt KeyType) EncryptionSupported() bool {
   127  	switch kt {
   128  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   129  		return true
   130  	}
   131  	return false
   132  }
   133  
   134  func (kt KeyType) DecryptionSupported() bool {
   135  	switch kt {
   136  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   137  		return true
   138  	}
   139  	return false
   140  }
   141  
   142  func (kt KeyType) SigningSupported() bool {
   143  	switch kt {
   144  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   145  		return true
   146  	}
   147  	return false
   148  }
   149  
   150  func (kt KeyType) HashSignatureInput() bool {
   151  	switch kt {
   152  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   153  		return true
   154  	}
   155  	return false
   156  }
   157  
   158  func (kt KeyType) DerivationSupported() bool {
   159  	switch kt {
   160  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_ED25519:
   161  		return true
   162  	}
   163  	return false
   164  }
   165  
   166  func (kt KeyType) AssociatedDataSupported() bool {
   167  	switch kt {
   168  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_MANAGED_KEY:
   169  		return true
   170  	}
   171  	return false
   172  }
   173  
   174  func (kt KeyType) CMACSupported() bool {
   175  	switch kt {
   176  	case KeyType_AES128_CMAC, KeyType_AES256_CMAC:
   177  		return true
   178  	default:
   179  		return false
   180  	}
   181  }
   182  
   183  func (kt KeyType) HMACSupported() bool {
   184  	switch {
   185  	case kt.CMACSupported():
   186  		return false
   187  	case kt == KeyType_MANAGED_KEY:
   188  		return false
   189  	default:
   190  		return true
   191  	}
   192  }
   193  
   194  func (kt KeyType) ImportPublicKeySupported() bool {
   195  	switch kt {
   196  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519:
   197  		return true
   198  	}
   199  	return false
   200  }
   201  
   202  func (kt KeyType) String() string {
   203  	switch kt {
   204  	case KeyType_AES128_GCM96:
   205  		return "aes128-gcm96"
   206  	case KeyType_AES256_GCM96:
   207  		return "aes256-gcm96"
   208  	case KeyType_ChaCha20_Poly1305:
   209  		return "chacha20-poly1305"
   210  	case KeyType_ECDSA_P256:
   211  		return "ecdsa-p256"
   212  	case KeyType_ECDSA_P384:
   213  		return "ecdsa-p384"
   214  	case KeyType_ECDSA_P521:
   215  		return "ecdsa-p521"
   216  	case KeyType_ED25519:
   217  		return "ed25519"
   218  	case KeyType_RSA2048:
   219  		return "rsa-2048"
   220  	case KeyType_RSA3072:
   221  		return "rsa-3072"
   222  	case KeyType_RSA4096:
   223  		return "rsa-4096"
   224  	case KeyType_HMAC:
   225  		return "hmac"
   226  	case KeyType_MANAGED_KEY:
   227  		return "managed_key"
   228  	case KeyType_AES128_CMAC:
   229  		return "aes128-cmac"
   230  	case KeyType_AES256_CMAC:
   231  		return "aes256-cmac"
   232  	}
   233  
   234  	return "[unknown]"
   235  }
   236  
   237  type KeyData struct {
   238  	Policy       *Policy       `json:"policy"`
   239  	ArchivedKeys *archivedKeys `json:"archived_keys"`
   240  }
   241  
   242  // KeyEntry stores the key and metadata
   243  type KeyEntry struct {
   244  	// AES or some other kind that is a pure byte slice like ED25519
   245  	Key []byte `json:"key"`
   246  
   247  	// Key used for HMAC functions
   248  	HMACKey []byte `json:"hmac_key"`
   249  
   250  	// Time of creation
   251  	CreationTime time.Time `json:"time"`
   252  
   253  	EC_X *big.Int `json:"ec_x"`
   254  	EC_Y *big.Int `json:"ec_y"`
   255  	EC_D *big.Int `json:"ec_d"`
   256  
   257  	RSAKey       *rsa.PrivateKey `json:"rsa_key"`
   258  	RSAPublicKey *rsa.PublicKey  `json:"rsa_public_key"`
   259  
   260  	// The public key in an appropriate format for the type of key
   261  	FormattedPublicKey string `json:"public_key"`
   262  
   263  	// If convergent is enabled, the version (falling back to what's in the
   264  	// policy)
   265  	ConvergentVersion int `json:"convergent_version"`
   266  
   267  	// This is deprecated (but still filled) in favor of the value above which
   268  	// is more precise
   269  	DeprecatedCreationTime int64 `json:"creation_time"`
   270  
   271  	ManagedKeyUUID string `json:"managed_key_id,omitempty"`
   272  
   273  	// Key entry certificate chain. If set, leaf certificate key matches the
   274  	// KeyEntry key
   275  	CertificateChain [][]byte `json:"certificate_chain"`
   276  }
   277  
   278  func (ke *KeyEntry) IsPrivateKeyMissing() bool {
   279  	if ke.RSAKey != nil || ke.EC_D != nil || len(ke.Key) != 0 || len(ke.ManagedKeyUUID) != 0 {
   280  		return false
   281  	}
   282  
   283  	return true
   284  }
   285  
   286  // deprecatedKeyEntryMap is used to allow JSON marshal/unmarshal
   287  type deprecatedKeyEntryMap map[int]KeyEntry
   288  
   289  // MarshalJSON implements JSON marshaling
   290  func (kem deprecatedKeyEntryMap) MarshalJSON() ([]byte, error) {
   291  	intermediate := map[string]KeyEntry{}
   292  	for k, v := range kem {
   293  		intermediate[strconv.Itoa(k)] = v
   294  	}
   295  	return json.Marshal(&intermediate)
   296  }
   297  
   298  // MarshalJSON implements JSON unmarshalling
   299  func (kem deprecatedKeyEntryMap) UnmarshalJSON(data []byte) error {
   300  	intermediate := map[string]KeyEntry{}
   301  	if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
   302  		return err
   303  	}
   304  	for k, v := range intermediate {
   305  		keyval, err := strconv.Atoi(k)
   306  		if err != nil {
   307  			return err
   308  		}
   309  		kem[keyval] = v
   310  	}
   311  
   312  	return nil
   313  }
   314  
   315  // keyEntryMap is used to allow JSON marshal/unmarshal
   316  type keyEntryMap map[string]KeyEntry
   317  
   318  // PolicyConfig is used to create a new policy
   319  type PolicyConfig struct {
   320  	// The name of the policy
   321  	Name string `json:"name"`
   322  
   323  	// The type of key
   324  	Type KeyType
   325  
   326  	// Derived keys MUST provide a context and the master underlying key is
   327  	// never used.
   328  	Derived              bool
   329  	KDF                  int
   330  	ConvergentEncryption bool
   331  
   332  	// Whether the key is exportable
   333  	Exportable bool
   334  
   335  	// Whether the key is allowed to be deleted
   336  	DeletionAllowed bool
   337  
   338  	// AllowPlaintextBackup allows taking backup of the policy in plaintext
   339  	AllowPlaintextBackup bool
   340  
   341  	// VersionTemplate is used to prefix the ciphertext with information about
   342  	// the key version. It must inclide {{version}} and a delimiter between the
   343  	// version prefix and the ciphertext.
   344  	VersionTemplate string
   345  
   346  	// StoragePrefix is used to add a prefix when storing and retrieving the
   347  	// policy object.
   348  	StoragePrefix string
   349  }
   350  
   351  // NewPolicy takes a policy config and returns a Policy with those settings.
   352  func NewPolicy(config PolicyConfig) *Policy {
   353  	return &Policy{
   354  		l:                    new(sync.RWMutex),
   355  		Name:                 config.Name,
   356  		Type:                 config.Type,
   357  		Derived:              config.Derived,
   358  		KDF:                  config.KDF,
   359  		ConvergentEncryption: config.ConvergentEncryption,
   360  		ConvergentVersion:    -1,
   361  		Exportable:           config.Exportable,
   362  		DeletionAllowed:      config.DeletionAllowed,
   363  		AllowPlaintextBackup: config.AllowPlaintextBackup,
   364  		VersionTemplate:      config.VersionTemplate,
   365  		StoragePrefix:        config.StoragePrefix,
   366  	}
   367  }
   368  
   369  // LoadPolicy will load a policy from the provided storage path and set the
   370  // necessary un-exported variables. It is particularly useful when accessing a
   371  // policy without the lock manager.
   372  func LoadPolicy(ctx context.Context, s logical.Storage, path string) (*Policy, error) {
   373  	raw, err := s.Get(ctx, path)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  	if raw == nil {
   378  		return nil, nil
   379  	}
   380  
   381  	var policy Policy
   382  	err = jsonutil.DecodeJSON(raw.Value, &policy)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  
   387  	// Migrate RSA private keys to include their private counterpart. This lets
   388  	// us reference RSAPublicKey whenever we need to, without necessarily
   389  	// needing the private key handy, synchronizing the behavior with EC and
   390  	// Ed25519 key pairs.
   391  	switch policy.Type {
   392  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
   393  		for _, entry := range policy.Keys {
   394  			if entry.RSAPublicKey == nil && entry.RSAKey != nil {
   395  				entry.RSAPublicKey = entry.RSAKey.Public().(*rsa.PublicKey)
   396  			}
   397  		}
   398  	}
   399  
   400  	policy.l = new(sync.RWMutex)
   401  
   402  	return &policy, nil
   403  }
   404  
   405  // Policy is the struct used to store metadata
   406  type Policy struct {
   407  	// This is a pointer on purpose: if we are running with cache disabled we
   408  	// need to actually swap in the lock manager's lock for this policy with
   409  	// the local lock.
   410  	l *sync.RWMutex
   411  	// writeLocked allows us to implement Lock() and Unlock()
   412  	writeLocked bool
   413  	// Stores whether it's been deleted. This acts as a guard for operations
   414  	// that may write data, e.g. if one request rotates and that request is
   415  	// served after a delete.
   416  	deleted uint32
   417  
   418  	Name    string      `json:"name"`
   419  	Key     []byte      `json:"key,omitempty"`      // DEPRECATED
   420  	KeySize int         `json:"key_size,omitempty"` // For algorithms with variable key sizes
   421  	Keys    keyEntryMap `json:"keys"`
   422  
   423  	// Derived keys MUST provide a context and the master underlying key is
   424  	// never used. If convergent encryption is true, the context will be used
   425  	// as the nonce as well.
   426  	Derived              bool `json:"derived"`
   427  	KDF                  int  `json:"kdf"`
   428  	ConvergentEncryption bool `json:"convergent_encryption"`
   429  
   430  	// Whether the key is exportable
   431  	Exportable bool `json:"exportable"`
   432  
   433  	// The minimum version of the key allowed to be used for decryption
   434  	MinDecryptionVersion int `json:"min_decryption_version"`
   435  
   436  	// The minimum version of the key allowed to be used for encryption
   437  	MinEncryptionVersion int `json:"min_encryption_version"`
   438  
   439  	// The latest key version in this policy
   440  	LatestVersion int `json:"latest_version"`
   441  
   442  	// The latest key version in the archive. We never delete these, so this is
   443  	// a max.
   444  	ArchiveVersion int `json:"archive_version"`
   445  
   446  	// ArchiveMinVersion is the minimum version of the key in the archive.
   447  	ArchiveMinVersion int `json:"archive_min_version"`
   448  
   449  	// MinAvailableVersion is the minimum version of the key present. All key
   450  	// versions before this would have been deleted.
   451  	MinAvailableVersion int `json:"min_available_version"`
   452  
   453  	// Whether the key is allowed to be deleted
   454  	DeletionAllowed bool `json:"deletion_allowed"`
   455  
   456  	// The version of the convergent nonce to use
   457  	ConvergentVersion int `json:"convergent_version"`
   458  
   459  	// The type of key
   460  	Type KeyType `json:"type"`
   461  
   462  	// BackupInfo indicates the information about the backup action taken on
   463  	// this policy
   464  	BackupInfo *BackupInfo `json:"backup_info"`
   465  
   466  	// RestoreInfo indicates the information about the restore action taken on
   467  	// this policy
   468  	RestoreInfo *RestoreInfo `json:"restore_info"`
   469  
   470  	// AllowPlaintextBackup allows taking backup of the policy in plaintext
   471  	AllowPlaintextBackup bool `json:"allow_plaintext_backup"`
   472  
   473  	// VersionTemplate is used to prefix the ciphertext with information about
   474  	// the key version. It must inclide {{version}} and a delimiter between the
   475  	// version prefix and the ciphertext.
   476  	VersionTemplate string `json:"version_template"`
   477  
   478  	// StoragePrefix is used to add a prefix when storing and retrieving the
   479  	// policy object.
   480  	StoragePrefix string `json:"storage_prefix"`
   481  
   482  	// AutoRotatePeriod defines how frequently the key should automatically
   483  	// rotate. Setting this to zero disables automatic rotation for the key.
   484  	AutoRotatePeriod time.Duration `json:"auto_rotate_period"`
   485  
   486  	// versionPrefixCache stores caches of version prefix strings and the split
   487  	// version template.
   488  	versionPrefixCache sync.Map
   489  
   490  	// Imported indicates whether the key was generated by Vault or imported
   491  	// from an external source
   492  	Imported bool
   493  
   494  	// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
   495  	AllowImportedKeyRotation bool
   496  }
   497  
   498  func (p *Policy) Lock(exclusive bool) {
   499  	if exclusive {
   500  		p.l.Lock()
   501  		p.writeLocked = true
   502  	} else {
   503  		p.l.RLock()
   504  	}
   505  }
   506  
   507  func (p *Policy) Unlock() {
   508  	if p.writeLocked {
   509  		p.writeLocked = false
   510  		p.l.Unlock()
   511  	} else {
   512  		p.l.RUnlock()
   513  	}
   514  }
   515  
   516  // ArchivedKeys stores old keys. This is used to keep the key loading time sane
   517  // when there are huge numbers of rotations.
   518  type archivedKeys struct {
   519  	Keys []KeyEntry `json:"keys"`
   520  }
   521  
   522  func (p *Policy) LoadArchive(ctx context.Context, storage logical.Storage) (*archivedKeys, error) {
   523  	archive := &archivedKeys{}
   524  
   525  	raw, err := storage.Get(ctx, path.Join(p.StoragePrefix, "archive", p.Name))
   526  	if err != nil {
   527  		return nil, err
   528  	}
   529  	if raw == nil {
   530  		archive.Keys = make([]KeyEntry, 0)
   531  		return archive, nil
   532  	}
   533  
   534  	if err := jsonutil.DecodeJSON(raw.Value, archive); err != nil {
   535  		return nil, err
   536  	}
   537  
   538  	return archive, nil
   539  }
   540  
   541  func (p *Policy) storeArchive(ctx context.Context, storage logical.Storage, archive *archivedKeys) error {
   542  	// Encode the policy
   543  	buf, err := json.Marshal(archive)
   544  	if err != nil {
   545  		return err
   546  	}
   547  
   548  	// Write the policy into storage
   549  	err = storage.Put(ctx, &logical.StorageEntry{
   550  		Key:   path.Join(p.StoragePrefix, "archive", p.Name),
   551  		Value: buf,
   552  	})
   553  	if err != nil {
   554  		return err
   555  	}
   556  
   557  	return nil
   558  }
   559  
   560  // handleArchiving manages the movement of keys to and from the policy archive.
   561  // This should *ONLY* be called from Persist() since it assumes that the policy
   562  // will be persisted afterwards.
   563  func (p *Policy) handleArchiving(ctx context.Context, storage logical.Storage) error {
   564  	// We need to move keys that are no longer accessible to archivedKeys, and keys
   565  	// that now need to be accessible back here.
   566  	//
   567  	// For safety, because there isn't really a good reason to, we never delete
   568  	// keys from the archive even when we move them back.
   569  
   570  	// Check if we have the latest minimum version in the current set of keys
   571  	_, keysContainsMinimum := p.Keys[strconv.Itoa(p.MinDecryptionVersion)]
   572  
   573  	// Sanity checks
   574  	switch {
   575  	case p.MinDecryptionVersion < 1:
   576  		return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
   577  	case p.LatestVersion < 1:
   578  		return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
   579  	case !keysContainsMinimum && p.ArchiveVersion != p.LatestVersion:
   580  		return fmt.Errorf("need to move keys from archive but archive version not up-to-date")
   581  	case p.ArchiveVersion > p.LatestVersion:
   582  		return fmt.Errorf("archive version of %d is greater than the latest version %d",
   583  			p.ArchiveVersion, p.LatestVersion)
   584  	case p.MinEncryptionVersion > 0 && p.MinEncryptionVersion < p.MinDecryptionVersion:
   585  		return fmt.Errorf("minimum decryption version of %d is greater than minimum encryption version %d",
   586  			p.MinDecryptionVersion, p.MinEncryptionVersion)
   587  	case p.MinDecryptionVersion > p.LatestVersion:
   588  		return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
   589  			p.MinDecryptionVersion, p.LatestVersion)
   590  	}
   591  
   592  	archive, err := p.LoadArchive(ctx, storage)
   593  	if err != nil {
   594  		return err
   595  	}
   596  
   597  	if !keysContainsMinimum {
   598  		// Need to move keys *from* archive
   599  		for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   600  			p.Keys[strconv.Itoa(i)] = archive.Keys[i-p.MinAvailableVersion]
   601  		}
   602  
   603  		return nil
   604  	}
   605  
   606  	// Need to move keys *to* archive
   607  
   608  	// We need a size that is equivalent to the latest version (number of keys)
   609  	// but adding one since slice numbering starts at 0 and we're indexing by
   610  	// key version
   611  	if len(archive.Keys)+p.MinAvailableVersion < p.LatestVersion+1 {
   612  		// Increase the size of the archive slice
   613  		newKeys := make([]KeyEntry, p.LatestVersion-p.MinAvailableVersion+1)
   614  		copy(newKeys, archive.Keys)
   615  		archive.Keys = newKeys
   616  	}
   617  
   618  	// We are storing all keys in the archive, so we ensure that it is up to
   619  	// date up to p.LatestVersion
   620  	for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
   621  		archive.Keys[i-p.MinAvailableVersion] = p.Keys[strconv.Itoa(i)]
   622  		p.ArchiveVersion = i
   623  	}
   624  
   625  	// Trim the keys if required
   626  	if p.ArchiveMinVersion < p.MinAvailableVersion {
   627  		archive.Keys = archive.Keys[p.MinAvailableVersion-p.ArchiveMinVersion:]
   628  		p.ArchiveMinVersion = p.MinAvailableVersion
   629  	}
   630  
   631  	err = p.storeArchive(ctx, storage, archive)
   632  	if err != nil {
   633  		return err
   634  	}
   635  
   636  	// Perform deletion afterwards so that if there is an error saving we
   637  	// haven't messed with the current policy
   638  	for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
   639  		delete(p.Keys, strconv.Itoa(i))
   640  	}
   641  
   642  	return nil
   643  }
   644  
   645  func (p *Policy) Persist(ctx context.Context, storage logical.Storage) (retErr error) {
   646  	if atomic.LoadUint32(&p.deleted) == 1 {
   647  		return errors.New("key has been deleted, not persisting")
   648  	}
   649  
   650  	// Other functions will take care of restoring other values; this is just
   651  	// responsible for archiving and keys since the archive function can modify
   652  	// keys. At the moment one of the other functions calling persist will also
   653  	// roll back keys, but better safe than sorry and this doesn't happen
   654  	// enough to worry about the speed tradeoff.
   655  	priorArchiveVersion := p.ArchiveVersion
   656  	var priorKeys keyEntryMap
   657  
   658  	if p.Keys != nil {
   659  		priorKeys = keyEntryMap{}
   660  		for k, v := range p.Keys {
   661  			priorKeys[k] = v
   662  		}
   663  	}
   664  
   665  	defer func() {
   666  		if retErr != nil {
   667  			p.ArchiveVersion = priorArchiveVersion
   668  			p.Keys = priorKeys
   669  		}
   670  	}()
   671  
   672  	err := p.handleArchiving(ctx, storage)
   673  	if err != nil {
   674  		return err
   675  	}
   676  
   677  	// Encode the policy
   678  	buf, err := p.Serialize()
   679  	if err != nil {
   680  		return err
   681  	}
   682  
   683  	// Write the policy into storage
   684  	err = storage.Put(ctx, &logical.StorageEntry{
   685  		Key:   path.Join(p.StoragePrefix, "policy", p.Name),
   686  		Value: buf,
   687  	})
   688  	if err != nil {
   689  		return err
   690  	}
   691  
   692  	return nil
   693  }
   694  
   695  func (p *Policy) Serialize() ([]byte, error) {
   696  	return json.Marshal(p)
   697  }
   698  
   699  func (p *Policy) NeedsUpgrade() bool {
   700  	// Ensure we've moved from Key -> Keys
   701  	if p.Key != nil && len(p.Key) > 0 {
   702  		return true
   703  	}
   704  
   705  	// With archiving, past assumptions about the length of the keys map are no
   706  	// longer valid
   707  	if p.LatestVersion == 0 && len(p.Keys) != 0 {
   708  		return true
   709  	}
   710  
   711  	// We disallow setting the version to 0, since they start at 1 since moving
   712  	// to rotate-able keys, so update if it's set to 0
   713  	if p.MinDecryptionVersion == 0 {
   714  		return true
   715  	}
   716  
   717  	// On first load after an upgrade, copy keys to the archive
   718  	if p.ArchiveVersion == 0 {
   719  		return true
   720  	}
   721  
   722  	// Need to write the version if zero; for version 3 on we set this to -1 to
   723  	// ignore it since we store this information in each key entry
   724  	if p.ConvergentEncryption && p.ConvergentVersion == 0 {
   725  		return true
   726  	}
   727  
   728  	if p.Type.HMACSupported() {
   729  		if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
   730  			return true
   731  		}
   732  	}
   733  
   734  	return false
   735  }
   736  
   737  func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
   738  	priorKey := p.Key
   739  	priorLatestVersion := p.LatestVersion
   740  	priorMinDecryptionVersion := p.MinDecryptionVersion
   741  	priorConvergentVersion := p.ConvergentVersion
   742  	var priorKeys keyEntryMap
   743  
   744  	if p.Keys != nil {
   745  		priorKeys = keyEntryMap{}
   746  		for k, v := range p.Keys {
   747  			priorKeys[k] = v
   748  		}
   749  	}
   750  
   751  	defer func() {
   752  		if retErr != nil {
   753  			p.Key = priorKey
   754  			p.LatestVersion = priorLatestVersion
   755  			p.MinDecryptionVersion = priorMinDecryptionVersion
   756  			p.ConvergentVersion = priorConvergentVersion
   757  			p.Keys = priorKeys
   758  		}
   759  	}()
   760  
   761  	persistNeeded := false
   762  	// Ensure we've moved from Key -> Keys
   763  	if p.Key != nil && len(p.Key) > 0 {
   764  		p.MigrateKeyToKeysMap()
   765  		persistNeeded = true
   766  	}
   767  
   768  	// With archiving, past assumptions about the length of the keys map are no
   769  	// longer valid
   770  	if p.LatestVersion == 0 && len(p.Keys) != 0 {
   771  		p.LatestVersion = len(p.Keys)
   772  		persistNeeded = true
   773  	}
   774  
   775  	// We disallow setting the version to 0, since they start at 1 since moving
   776  	// to rotate-able keys, so update if it's set to 0
   777  	if p.MinDecryptionVersion == 0 {
   778  		p.MinDecryptionVersion = 1
   779  		persistNeeded = true
   780  	}
   781  
   782  	// On first load after an upgrade, copy keys to the archive
   783  	if p.ArchiveVersion == 0 {
   784  		persistNeeded = true
   785  	}
   786  
   787  	if p.ConvergentEncryption && p.ConvergentVersion == 0 {
   788  		p.ConvergentVersion = 1
   789  		persistNeeded = true
   790  	}
   791  
   792  	if p.Type.HMACSupported() {
   793  		if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
   794  			entry := p.Keys[strconv.Itoa(p.LatestVersion)]
   795  			hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
   796  			if err != nil {
   797  				return err
   798  			}
   799  			entry.HMACKey = hmacKey
   800  			p.Keys[strconv.Itoa(p.LatestVersion)] = entry
   801  			persistNeeded = true
   802  
   803  			if p.Type == KeyType_HMAC {
   804  				entry.HMACKey = entry.Key
   805  			}
   806  		}
   807  	}
   808  
   809  	if persistNeeded {
   810  		err := p.Persist(ctx, storage)
   811  		if err != nil {
   812  			return err
   813  		}
   814  	}
   815  
   816  	return nil
   817  }
   818  
   819  // GetKey is used to derive the encryption key that should be used depending
   820  // on the policy. If derivation is disabled the raw key is used and no context
   821  // is required, otherwise the KDF mode is used with the context to derive the
   822  // proper key.
   823  func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) {
   824  	// Fast-path non-derived keys
   825  	if !p.Derived {
   826  		keyEntry, err := p.safeGetKeyEntry(ver)
   827  		if err != nil {
   828  			return nil, err
   829  		}
   830  
   831  		return keyEntry.Key, nil
   832  	}
   833  
   834  	return p.DeriveKey(context, nil, ver, numBytes)
   835  }
   836  
   837  // DeriveKey is used to derive a symmetric key given a context and salt.  This does not
   838  // check the policies Derived flag, but just implements the derivation logic.  GetKey
   839  // is responsible for switching on the policy config.
   840  func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) {
   841  	if !p.Type.DerivationSupported() {
   842  		return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
   843  	}
   844  
   845  	if p.Keys == nil || p.LatestVersion == 0 {
   846  		return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
   847  	}
   848  
   849  	if ver <= 0 || ver > p.LatestVersion {
   850  		return nil, errutil.UserError{Err: "invalid key version"}
   851  	}
   852  
   853  	// Ensure a context is provided
   854  	if len(context) == 0 {
   855  		return nil, errutil.UserError{Err: "missing 'context' for key derivation; the key was created using a derived key, which means additional, per-request information must be included in order to perform operations with the key"}
   856  	}
   857  
   858  	keyEntry, err := p.safeGetKeyEntry(ver)
   859  	if err != nil {
   860  		return nil, err
   861  	}
   862  
   863  	switch p.KDF {
   864  	case Kdf_hmac_sha256_counter:
   865  		prf := kdf.HMACSHA256PRF
   866  		prfLen := kdf.HMACSHA256PRFLen
   867  		return kdf.CounterMode(prf, prfLen, keyEntry.Key, append(context, salt...), 256)
   868  
   869  	case Kdf_hkdf_sha256:
   870  		reader := hkdf.New(sha256.New, keyEntry.Key, salt, context)
   871  		derBytes := bytes.NewBuffer(nil)
   872  		derBytes.Grow(numBytes)
   873  		limReader := &io.LimitedReader{
   874  			R: reader,
   875  			N: int64(numBytes),
   876  		}
   877  
   878  		switch p.Type {
   879  		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   880  			n, err := derBytes.ReadFrom(limReader)
   881  			if err != nil {
   882  				return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
   883  			}
   884  			if n != int64(numBytes) {
   885  				return nil, errutil.InternalError{Err: fmt.Sprintf("unable to read enough derived bytes, needed %d, got %d", numBytes, n)}
   886  			}
   887  			return derBytes.Bytes(), nil
   888  
   889  		case KeyType_ED25519:
   890  			// We use the limited reader containing the derived bytes as the
   891  			// "random" input to the generation function
   892  			_, pri, err := ed25519.GenerateKey(limReader)
   893  			if err != nil {
   894  				return nil, errutil.InternalError{Err: fmt.Sprintf("error generating derived key: %v", err)}
   895  			}
   896  			return pri, nil
   897  
   898  		default:
   899  			return nil, errutil.InternalError{Err: "unsupported key type for derivation"}
   900  		}
   901  
   902  	default:
   903  		return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
   904  	}
   905  }
   906  
   907  func (p *Policy) safeGetKeyEntry(ver int) (KeyEntry, error) {
   908  	keyVerStr := strconv.Itoa(ver)
   909  	keyEntry, ok := p.Keys[keyVerStr]
   910  	if !ok {
   911  		return keyEntry, errutil.UserError{Err: "no such key version"}
   912  	}
   913  	return keyEntry, nil
   914  }
   915  
   916  func (p *Policy) convergentVersion(ver int) int {
   917  	if !p.ConvergentEncryption {
   918  		return 0
   919  	}
   920  
   921  	convergentVersion := p.ConvergentVersion
   922  	if convergentVersion == 0 {
   923  		// For some reason, not upgraded yet
   924  		convergentVersion = 1
   925  	}
   926  	currKey := p.Keys[strconv.Itoa(ver)]
   927  	if currKey.ConvergentVersion != 0 {
   928  		convergentVersion = currKey.ConvergentVersion
   929  	}
   930  
   931  	return convergentVersion
   932  }
   933  
   934  func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, error) {
   935  	return p.EncryptWithFactory(ver, context, nonce, value, nil)
   936  }
   937  
   938  func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
   939  	return p.DecryptWithFactory(context, nonce, value, nil)
   940  }
   941  
   942  func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
   943  	if !p.Type.DecryptionSupported() {
   944  		return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
   945  	}
   946  
   947  	tplParts, err := p.getTemplateParts()
   948  	if err != nil {
   949  		return "", err
   950  	}
   951  
   952  	// Verify the prefix
   953  	if !strings.HasPrefix(value, tplParts[0]) {
   954  		return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
   955  	}
   956  
   957  	splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, tplParts[0]), tplParts[1], 2)
   958  	if len(splitVerCiphertext) != 2 {
   959  		return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
   960  	}
   961  
   962  	ver, err := strconv.Atoi(splitVerCiphertext[0])
   963  	if err != nil {
   964  		return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
   965  	}
   966  
   967  	if ver == 0 {
   968  		// Compatibility mode with initial implementation, where keys start at
   969  		// zero
   970  		ver = 1
   971  	}
   972  
   973  	if ver > p.LatestVersion {
   974  		return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
   975  	}
   976  
   977  	if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
   978  		return "", errutil.UserError{Err: ErrTooOld}
   979  	}
   980  
   981  	convergentVersion := p.convergentVersion(ver)
   982  	if convergentVersion == 1 && (nonce == nil || len(nonce) == 0) {
   983  		return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
   984  	}
   985  
   986  	// Decode the base64
   987  	decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
   988  	if err != nil {
   989  		return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
   990  	}
   991  
   992  	var plain []byte
   993  
   994  	switch p.Type {
   995  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   996  		numBytes := 32
   997  		if p.Type == KeyType_AES128_GCM96 {
   998  			numBytes = 16
   999  		}
  1000  
  1001  		encKey, err := p.GetKey(context, ver, numBytes)
  1002  		if err != nil {
  1003  			return "", err
  1004  		}
  1005  
  1006  		if len(encKey) != numBytes {
  1007  			return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
  1008  		}
  1009  
  1010  		symopts := SymmetricOpts{
  1011  			Convergent:        p.ConvergentEncryption,
  1012  			ConvergentVersion: p.ConvergentVersion,
  1013  		}
  1014  		for index, rawFactory := range factories {
  1015  			if rawFactory == nil {
  1016  				continue
  1017  			}
  1018  			switch factory := rawFactory.(type) {
  1019  			case AEADFactory:
  1020  				symopts.AEADFactory = factory
  1021  			case AssociatedDataFactory:
  1022  				symopts.AdditionalData, err = factory.GetAssociatedData()
  1023  				if err != nil {
  1024  					return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
  1025  				}
  1026  			case ManagedKeyFactory:
  1027  			default:
  1028  				return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
  1029  			}
  1030  		}
  1031  
  1032  		plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
  1033  		if err != nil {
  1034  			return "", err
  1035  		}
  1036  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1037  		keyEntry, err := p.safeGetKeyEntry(ver)
  1038  		if err != nil {
  1039  			return "", err
  1040  		}
  1041  		key := keyEntry.RSAKey
  1042  		if key == nil {
  1043  			return "", errutil.InternalError{Err: fmt.Sprintf("cannot decrypt ciphertext, key version does not have a private counterpart")}
  1044  		}
  1045  		plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil)
  1046  		if err != nil {
  1047  			return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)}
  1048  		}
  1049  	case KeyType_MANAGED_KEY:
  1050  		keyEntry, err := p.safeGetKeyEntry(ver)
  1051  		if err != nil {
  1052  			return "", err
  1053  		}
  1054  		var aad []byte
  1055  		var managedKeyFactory ManagedKeyFactory
  1056  		for _, f := range factories {
  1057  			switch factory := f.(type) {
  1058  			case AssociatedDataFactory:
  1059  				aad, err = factory.GetAssociatedData()
  1060  				if err != nil {
  1061  					return "", err
  1062  				}
  1063  			case ManagedKeyFactory:
  1064  				managedKeyFactory = factory
  1065  			}
  1066  		}
  1067  
  1068  		if managedKeyFactory == nil {
  1069  			return "", errors.New("key type is managed_key, but managed key parameters were not provided")
  1070  		}
  1071  
  1072  		plain, err = p.decryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, decoded, nonce, aad)
  1073  		if err != nil {
  1074  			return "", err
  1075  		}
  1076  
  1077  	default:
  1078  		return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  1079  	}
  1080  
  1081  	return base64.StdEncoding.EncodeToString(plain), nil
  1082  }
  1083  
  1084  func (p *Policy) HMACKey(version int) ([]byte, error) {
  1085  	switch {
  1086  	case version < 0:
  1087  		return nil, fmt.Errorf("key version does not exist (cannot be negative)")
  1088  	case version > p.LatestVersion:
  1089  		return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
  1090  	}
  1091  	keyEntry, err := p.safeGetKeyEntry(version)
  1092  	if err != nil {
  1093  		return nil, err
  1094  	}
  1095  
  1096  	if p.Type == KeyType_HMAC {
  1097  		return keyEntry.Key, nil
  1098  	}
  1099  	if keyEntry.HMACKey == nil {
  1100  		return nil, fmt.Errorf("no HMAC key exists for that key version")
  1101  	}
  1102  	return keyEntry.HMACKey, nil
  1103  }
  1104  
  1105  func (p *Policy) CMACKey(version int) ([]byte, error) {
  1106  	switch {
  1107  	case version < 0:
  1108  		return nil, fmt.Errorf("key version does not exist (cannot be negative)")
  1109  	case version > p.LatestVersion:
  1110  		return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
  1111  	}
  1112  	keyEntry, err := p.safeGetKeyEntry(version)
  1113  	if err != nil {
  1114  		return nil, err
  1115  	}
  1116  
  1117  	if p.Type.CMACSupported() {
  1118  		return keyEntry.Key, nil
  1119  	}
  1120  
  1121  	return nil, fmt.Errorf("key type %s does not support CMAC operations", p.Type)
  1122  }
  1123  
  1124  func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
  1125  	return p.SignWithOptions(ver, context, input, &SigningOptions{
  1126  		HashAlgorithm: hashAlgorithm,
  1127  		Marshaling:    marshaling,
  1128  		SaltLength:    rsa.PSSSaltLengthAuto,
  1129  		SigAlgorithm:  sigAlgorithm,
  1130  	})
  1131  }
  1132  
  1133  func (p *Policy) minRSAPSSSaltLength() int {
  1134  	// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
  1135  	return rsa.PSSSaltLengthEqualsHash
  1136  }
  1137  
  1138  func (p *Policy) maxRSAPSSSaltLength(keyBitLen int, hash crypto.Hash) int {
  1139  	// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
  1140  	return (keyBitLen-1+7)/8 - 2 - hash.Size()
  1141  }
  1142  
  1143  func (p *Policy) validRSAPSSSaltLength(keyBitLen int, hash crypto.Hash, saltLength int) bool {
  1144  	return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(keyBitLen, hash)
  1145  }
  1146  
  1147  func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
  1148  	if !p.Type.SigningSupported() {
  1149  		return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
  1150  	}
  1151  
  1152  	switch {
  1153  	case ver == 0:
  1154  		ver = p.LatestVersion
  1155  	case ver < 0:
  1156  		return nil, errutil.UserError{Err: "requested version for signing is negative"}
  1157  	case ver > p.LatestVersion:
  1158  		return nil, errutil.UserError{Err: "requested version for signing is higher than the latest key version"}
  1159  	case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
  1160  		return nil, errutil.UserError{Err: "requested version for signing is less than the minimum encryption key version"}
  1161  	}
  1162  
  1163  	var sig []byte
  1164  	var pubKey []byte
  1165  	var err error
  1166  	keyParams, err := p.safeGetKeyEntry(ver)
  1167  	if err != nil {
  1168  		return nil, err
  1169  	}
  1170  
  1171  	// Before signing, check if key has its private part, if not return error
  1172  	if keyParams.IsPrivateKeyMissing() {
  1173  		return nil, errutil.UserError{Err: "requested version for signing does not contain a private part"}
  1174  	}
  1175  
  1176  	hashAlgorithm := options.HashAlgorithm
  1177  	marshaling := options.Marshaling
  1178  	saltLength := options.SaltLength
  1179  	sigAlgorithm := options.SigAlgorithm
  1180  
  1181  	switch p.Type {
  1182  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1183  		var curveBits int
  1184  		var curve elliptic.Curve
  1185  		switch p.Type {
  1186  		case KeyType_ECDSA_P384:
  1187  			curveBits = 384
  1188  			curve = elliptic.P384()
  1189  		case KeyType_ECDSA_P521:
  1190  			curveBits = 521
  1191  			curve = elliptic.P521()
  1192  		default:
  1193  			curveBits = 256
  1194  			curve = elliptic.P256()
  1195  		}
  1196  
  1197  		key := &ecdsa.PrivateKey{
  1198  			PublicKey: ecdsa.PublicKey{
  1199  				Curve: curve,
  1200  				X:     keyParams.EC_X,
  1201  				Y:     keyParams.EC_Y,
  1202  			},
  1203  			D: keyParams.EC_D,
  1204  		}
  1205  
  1206  		r, s, err := ecdsa.Sign(rand.Reader, key, input)
  1207  		if err != nil {
  1208  			return nil, err
  1209  		}
  1210  
  1211  		switch marshaling {
  1212  		case MarshalingTypeASN1:
  1213  			// This is used by openssl and X.509
  1214  			sig, err = asn1.Marshal(ecdsaSignature{
  1215  				R: r,
  1216  				S: s,
  1217  			})
  1218  			if err != nil {
  1219  				return nil, err
  1220  			}
  1221  
  1222  		case MarshalingTypeJWS:
  1223  			// This is used by JWS
  1224  
  1225  			// First we have to get the length of the curve in bytes. Although
  1226  			// we only support 256 now, we'll do this in an agnostic way so we
  1227  			// can reuse this marshaling if we support e.g. 521. Getting the
  1228  			// number of bytes without rounding up would be 65.125 so we need
  1229  			// to add one in that case.
  1230  			keyLen := curveBits / 8
  1231  			if curveBits%8 > 0 {
  1232  				keyLen++
  1233  			}
  1234  
  1235  			// Now create the output array
  1236  			sig = make([]byte, keyLen*2)
  1237  			rb := r.Bytes()
  1238  			sb := s.Bytes()
  1239  			copy(sig[keyLen-len(rb):], rb)
  1240  			copy(sig[2*keyLen-len(sb):], sb)
  1241  
  1242  		default:
  1243  			return nil, errutil.UserError{Err: "requested marshaling type is invalid"}
  1244  		}
  1245  
  1246  	case KeyType_ED25519:
  1247  		var key ed25519.PrivateKey
  1248  
  1249  		if p.Derived {
  1250  			// Derive the key that should be used
  1251  			var err error
  1252  			key, err = p.GetKey(context, ver, 32)
  1253  			if err != nil {
  1254  				return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
  1255  			}
  1256  			pubKey = key.Public().(ed25519.PublicKey)
  1257  		} else {
  1258  			key = ed25519.PrivateKey(keyParams.Key)
  1259  		}
  1260  
  1261  		// Per docs, do not pre-hash ed25519; it does two passes and performs
  1262  		// its own hashing
  1263  		sig, err = key.Sign(rand.Reader, input, crypto.Hash(0))
  1264  		if err != nil {
  1265  			return nil, err
  1266  		}
  1267  
  1268  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1269  		key := keyParams.RSAKey
  1270  
  1271  		algo, ok := CryptoHashMap[hashAlgorithm]
  1272  		if !ok {
  1273  			return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
  1274  		}
  1275  
  1276  		if sigAlgorithm == "" {
  1277  			sigAlgorithm = "pss"
  1278  		}
  1279  
  1280  		switch sigAlgorithm {
  1281  		case "pss":
  1282  			if !p.validRSAPSSSaltLength(key.N.BitLen(), algo, saltLength) {
  1283  				return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
  1284  			}
  1285  			sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
  1286  			if err != nil {
  1287  				return nil, err
  1288  			}
  1289  		case "pkcs1v15":
  1290  			sig, err = rsa.SignPKCS1v15(rand.Reader, key, algo, input)
  1291  			if err != nil {
  1292  				return nil, err
  1293  			}
  1294  		default:
  1295  			return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
  1296  		}
  1297  
  1298  	case KeyType_MANAGED_KEY:
  1299  		keyEntry, err := p.safeGetKeyEntry(ver)
  1300  		if err != nil {
  1301  			return nil, err
  1302  		}
  1303  
  1304  		sig, err = p.signWithManagedKey(options, keyEntry, input)
  1305  		if err != nil {
  1306  			return nil, err
  1307  		}
  1308  
  1309  	default:
  1310  		return nil, fmt.Errorf("unsupported key type %v", p.Type)
  1311  	}
  1312  
  1313  	// Convert to base64
  1314  	var encoded string
  1315  	switch marshaling {
  1316  	case MarshalingTypeASN1:
  1317  		encoded = base64.StdEncoding.EncodeToString(sig)
  1318  	case MarshalingTypeJWS:
  1319  		encoded = base64.RawURLEncoding.EncodeToString(sig)
  1320  	}
  1321  	res := &SigningResult{
  1322  		Signature: p.getVersionPrefix(ver) + encoded,
  1323  		PublicKey: pubKey,
  1324  	}
  1325  
  1326  	return res, nil
  1327  }
  1328  
  1329  func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
  1330  	return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
  1331  		HashAlgorithm: hashAlgorithm,
  1332  		Marshaling:    marshaling,
  1333  		SaltLength:    rsa.PSSSaltLengthAuto,
  1334  		SigAlgorithm:  sigAlgorithm,
  1335  	})
  1336  }
  1337  
  1338  func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
  1339  	if !p.Type.SigningSupported() {
  1340  		return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
  1341  	}
  1342  
  1343  	tplParts, err := p.getTemplateParts()
  1344  	if err != nil {
  1345  		return false, err
  1346  	}
  1347  
  1348  	// Verify the prefix
  1349  	if !strings.HasPrefix(sig, tplParts[0]) {
  1350  		return false, errutil.UserError{Err: "invalid signature: no prefix"}
  1351  	}
  1352  
  1353  	splitVerSig := strings.SplitN(strings.TrimPrefix(sig, tplParts[0]), tplParts[1], 2)
  1354  	if len(splitVerSig) != 2 {
  1355  		return false, errutil.UserError{Err: "invalid signature: wrong number of fields"}
  1356  	}
  1357  
  1358  	ver, err := strconv.Atoi(splitVerSig[0])
  1359  	if err != nil {
  1360  		return false, errutil.UserError{Err: "invalid signature: version number could not be decoded"}
  1361  	}
  1362  
  1363  	if ver > p.LatestVersion {
  1364  		return false, errutil.UserError{Err: "invalid signature: version is too new"}
  1365  	}
  1366  
  1367  	if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
  1368  		return false, errutil.UserError{Err: ErrTooOld}
  1369  	}
  1370  
  1371  	hashAlgorithm := options.HashAlgorithm
  1372  	marshaling := options.Marshaling
  1373  	saltLength := options.SaltLength
  1374  	sigAlgorithm := options.SigAlgorithm
  1375  
  1376  	var sigBytes []byte
  1377  	switch marshaling {
  1378  	case MarshalingTypeASN1:
  1379  		sigBytes, err = base64.StdEncoding.DecodeString(splitVerSig[1])
  1380  	case MarshalingTypeJWS:
  1381  		sigBytes, err = base64.RawURLEncoding.DecodeString(splitVerSig[1])
  1382  	default:
  1383  		return false, errutil.UserError{Err: "requested marshaling type is invalid"}
  1384  	}
  1385  	if err != nil {
  1386  		return false, errutil.UserError{Err: "invalid base64 signature value"}
  1387  	}
  1388  
  1389  	switch p.Type {
  1390  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1391  		var curve elliptic.Curve
  1392  		switch p.Type {
  1393  		case KeyType_ECDSA_P384:
  1394  			curve = elliptic.P384()
  1395  		case KeyType_ECDSA_P521:
  1396  			curve = elliptic.P521()
  1397  		default:
  1398  			curve = elliptic.P256()
  1399  		}
  1400  
  1401  		var ecdsaSig ecdsaSignature
  1402  
  1403  		switch marshaling {
  1404  		case MarshalingTypeASN1:
  1405  			rest, err := asn1.Unmarshal(sigBytes, &ecdsaSig)
  1406  			if err != nil {
  1407  				return false, errutil.UserError{Err: "supplied signature is invalid"}
  1408  			}
  1409  			if rest != nil && len(rest) != 0 {
  1410  				return false, errutil.UserError{Err: "supplied signature contains extra data"}
  1411  			}
  1412  
  1413  		case MarshalingTypeJWS:
  1414  			paramLen := len(sigBytes) / 2
  1415  			rb := sigBytes[:paramLen]
  1416  			sb := sigBytes[paramLen:]
  1417  			ecdsaSig.R = new(big.Int)
  1418  			ecdsaSig.R.SetBytes(rb)
  1419  			ecdsaSig.S = new(big.Int)
  1420  			ecdsaSig.S.SetBytes(sb)
  1421  		}
  1422  
  1423  		keyParams, err := p.safeGetKeyEntry(ver)
  1424  		if err != nil {
  1425  			return false, err
  1426  		}
  1427  		key := &ecdsa.PublicKey{
  1428  			Curve: curve,
  1429  			X:     keyParams.EC_X,
  1430  			Y:     keyParams.EC_Y,
  1431  		}
  1432  
  1433  		return ecdsa.Verify(key, input, ecdsaSig.R, ecdsaSig.S), nil
  1434  
  1435  	case KeyType_ED25519:
  1436  		var pub ed25519.PublicKey
  1437  
  1438  		if p.Derived {
  1439  			// Derive the key that should be used
  1440  			key, err := p.GetKey(context, ver, 32)
  1441  			if err != nil {
  1442  				return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
  1443  			}
  1444  			pub = ed25519.PrivateKey(key).Public().(ed25519.PublicKey)
  1445  		} else {
  1446  			keyEntry, err := p.safeGetKeyEntry(ver)
  1447  			if err != nil {
  1448  				return false, err
  1449  			}
  1450  
  1451  			raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  1452  			if err != nil {
  1453  				return false, err
  1454  			}
  1455  
  1456  			pub = ed25519.PublicKey(raw)
  1457  		}
  1458  
  1459  		return ed25519.Verify(pub, input, sigBytes), nil
  1460  
  1461  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1462  		keyEntry, err := p.safeGetKeyEntry(ver)
  1463  		if err != nil {
  1464  			return false, err
  1465  		}
  1466  
  1467  		algo, ok := CryptoHashMap[hashAlgorithm]
  1468  		if !ok {
  1469  			return false, errutil.InternalError{Err: "unsupported hash algorithm"}
  1470  		}
  1471  
  1472  		if sigAlgorithm == "" {
  1473  			sigAlgorithm = "pss"
  1474  		}
  1475  
  1476  		switch sigAlgorithm {
  1477  		case "pss":
  1478  			publicKey := keyEntry.RSAPublicKey
  1479  			if !keyEntry.IsPrivateKeyMissing() {
  1480  				publicKey = &keyEntry.RSAKey.PublicKey
  1481  			}
  1482  			if !p.validRSAPSSSaltLength(publicKey.N.BitLen(), algo, saltLength) {
  1483  				return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
  1484  			}
  1485  			err = rsa.VerifyPSS(publicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
  1486  		case "pkcs1v15":
  1487  			publicKey := keyEntry.RSAPublicKey
  1488  			if !keyEntry.IsPrivateKeyMissing() {
  1489  				publicKey = &keyEntry.RSAKey.PublicKey
  1490  			}
  1491  			err = rsa.VerifyPKCS1v15(publicKey, algo, input, sigBytes)
  1492  		default:
  1493  			return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
  1494  		}
  1495  
  1496  		return err == nil, nil
  1497  
  1498  	case KeyType_MANAGED_KEY:
  1499  		keyEntry, err := p.safeGetKeyEntry(ver)
  1500  		if err != nil {
  1501  			return false, err
  1502  		}
  1503  
  1504  		return p.verifyWithManagedKey(options, keyEntry, input, sigBytes)
  1505  
  1506  	default:
  1507  		return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  1508  	}
  1509  }
  1510  
  1511  func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte, randReader io.Reader) error {
  1512  	return p.ImportPublicOrPrivate(ctx, storage, key, true, randReader)
  1513  }
  1514  
  1515  func (p *Policy) ImportPublicOrPrivate(ctx context.Context, storage logical.Storage, key []byte, isPrivateKey bool, randReader io.Reader) error {
  1516  	now := time.Now()
  1517  	entry := KeyEntry{
  1518  		CreationTime:           now,
  1519  		DeprecatedCreationTime: now.Unix(),
  1520  	}
  1521  
  1522  	// Before we insert this entry, check if the latest version is incomplete
  1523  	// and this entry matches the current version; if so, return without
  1524  	// updating to the next version.
  1525  	if p.LatestVersion > 0 {
  1526  		latestKey := p.Keys[strconv.Itoa(p.LatestVersion)]
  1527  		if latestKey.IsPrivateKeyMissing() && isPrivateKey {
  1528  			if err := p.ImportPrivateKeyForVersion(ctx, storage, p.LatestVersion, key); err == nil {
  1529  				return nil
  1530  			}
  1531  		}
  1532  	}
  1533  
  1534  	if p.Type != KeyType_HMAC {
  1535  		hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
  1536  		if err != nil {
  1537  			return err
  1538  		}
  1539  		entry.HMACKey = hmacKey
  1540  	}
  1541  
  1542  	if p.Type == KeyType_ED25519 && p.Derived && !isPrivateKey {
  1543  		return fmt.Errorf("unable to import only public key for derived Ed25519 key: imported key should not be an Ed25519 key pair but is instead an HKDF key")
  1544  	}
  1545  
  1546  	if ((p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES128_CMAC) && len(key) != 16) ||
  1547  		((p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305 || p.Type == KeyType_AES256_CMAC) && len(key) != 32) ||
  1548  		(p.Type == KeyType_HMAC && (len(key) < HmacMinKeySize || len(key) > HmacMaxKeySize)) {
  1549  		return fmt.Errorf("invalid key size %d bytes for key type %s", len(key), p.Type)
  1550  	}
  1551  
  1552  	if p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305 || p.Type == KeyType_HMAC || p.Type == KeyType_AES128_CMAC || p.Type == KeyType_AES256_CMAC {
  1553  		entry.Key = key
  1554  		if p.Type == KeyType_HMAC {
  1555  			p.KeySize = len(key)
  1556  			entry.HMACKey = key
  1557  		}
  1558  	} else {
  1559  		var parsedKey any
  1560  		var err error
  1561  		if isPrivateKey {
  1562  			parsedKey, err = x509.ParsePKCS8PrivateKey(key)
  1563  			if err != nil {
  1564  				if strings.Contains(err.Error(), "unknown elliptic curve") {
  1565  					var edErr error
  1566  					parsedKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
  1567  					if edErr != nil {
  1568  						return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
  1569  					}
  1570  
  1571  					// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
  1572  				} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
  1573  					var rsaErr error
  1574  					parsedKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
  1575  					if rsaErr != nil {
  1576  						return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
  1577  					}
  1578  
  1579  					// Parsing as RSA-PSS in PKCS8 succeeded!
  1580  				} else {
  1581  					return fmt.Errorf("error parsing asymmetric key: %s", err)
  1582  				}
  1583  			}
  1584  		} else {
  1585  			pemBlock, _ := pem.Decode(key)
  1586  			if pemBlock == nil {
  1587  				return fmt.Errorf("error parsing public key: not in PEM format")
  1588  			}
  1589  
  1590  			parsedKey, err = x509.ParsePKIXPublicKey(pemBlock.Bytes)
  1591  			if err != nil {
  1592  				return fmt.Errorf("error parsing public key: %w", err)
  1593  			}
  1594  		}
  1595  
  1596  		err = entry.parseFromKey(p.Type, parsedKey)
  1597  		if err != nil {
  1598  			return err
  1599  		}
  1600  	}
  1601  
  1602  	p.LatestVersion += 1
  1603  
  1604  	if p.Keys == nil {
  1605  		// This is an initial key rotation when generating a new policy. We
  1606  		// don't need to call migrate here because if we've called getPolicy to
  1607  		// get the policy in the first place it will have been run.
  1608  		p.Keys = keyEntryMap{}
  1609  	}
  1610  	p.Keys[strconv.Itoa(p.LatestVersion)] = entry
  1611  
  1612  	// This ensures that with new key creations min decryption version is set
  1613  	// to 1 rather than the int default of 0, since keys start at 1 (either
  1614  	// fresh or after migration to the key map)
  1615  	if p.MinDecryptionVersion == 0 {
  1616  		p.MinDecryptionVersion = 1
  1617  	}
  1618  
  1619  	return p.Persist(ctx, storage)
  1620  }
  1621  
  1622  // Rotate rotates the policy and persists it to storage.
  1623  // If the rotation partially fails, the policy state will be restored.
  1624  func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
  1625  	priorLatestVersion := p.LatestVersion
  1626  	priorMinDecryptionVersion := p.MinDecryptionVersion
  1627  	var priorKeys keyEntryMap
  1628  
  1629  	if p.Imported && !p.AllowImportedKeyRotation {
  1630  		return fmt.Errorf("imported key %s does not allow rotation within Vault", p.Name)
  1631  	}
  1632  
  1633  	if p.Keys != nil {
  1634  		priorKeys = keyEntryMap{}
  1635  		for k, v := range p.Keys {
  1636  			priorKeys[k] = v
  1637  		}
  1638  	}
  1639  
  1640  	defer func() {
  1641  		if retErr != nil {
  1642  			p.LatestVersion = priorLatestVersion
  1643  			p.MinDecryptionVersion = priorMinDecryptionVersion
  1644  			p.Keys = priorKeys
  1645  		}
  1646  	}()
  1647  
  1648  	if err := p.RotateInMemory(randReader); err != nil {
  1649  		return err
  1650  	}
  1651  
  1652  	p.Imported = false
  1653  	return p.Persist(ctx, storage)
  1654  }
  1655  
  1656  // RotateInMemory rotates the policy but does not persist it to storage.
  1657  func (p *Policy) RotateInMemory(randReader io.Reader) (retErr error) {
  1658  	now := time.Now()
  1659  	entry := KeyEntry{
  1660  		CreationTime:           now,
  1661  		DeprecatedCreationTime: now.Unix(),
  1662  	}
  1663  
  1664  	if p.Type != KeyType_AES128_CMAC && p.Type != KeyType_AES256_CMAC && p.Type != KeyType_HMAC {
  1665  		hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
  1666  		if err != nil {
  1667  			return err
  1668  		}
  1669  		entry.HMACKey = hmacKey
  1670  	}
  1671  
  1672  	var err error
  1673  	switch p.Type {
  1674  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC, KeyType_AES128_CMAC, KeyType_AES256_CMAC:
  1675  		// Default to 256 bit key
  1676  		numBytes := 32
  1677  		if p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES128_CMAC {
  1678  			numBytes = 16
  1679  		} else if p.Type == KeyType_HMAC {
  1680  			numBytes = p.KeySize
  1681  			if numBytes < HmacMinKeySize || numBytes > HmacMaxKeySize {
  1682  				return fmt.Errorf("invalid key size for HMAC key, must be between %d and %d bytes", HmacMinKeySize, HmacMaxKeySize)
  1683  			}
  1684  		}
  1685  		newKey, err := uuid.GenerateRandomBytesWithReader(numBytes, randReader)
  1686  		if err != nil {
  1687  			return err
  1688  		}
  1689  		entry.Key = newKey
  1690  
  1691  		if p.Type == KeyType_HMAC {
  1692  			// To avoid causing problems, ensure HMACKey = Key.
  1693  			entry.HMACKey = newKey
  1694  		}
  1695  
  1696  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1697  		var curve elliptic.Curve
  1698  		switch p.Type {
  1699  		case KeyType_ECDSA_P384:
  1700  			curve = elliptic.P384()
  1701  		case KeyType_ECDSA_P521:
  1702  			curve = elliptic.P521()
  1703  		default:
  1704  			curve = elliptic.P256()
  1705  		}
  1706  
  1707  		privKey, err := ecdsa.GenerateKey(curve, rand.Reader)
  1708  		if err != nil {
  1709  			return err
  1710  		}
  1711  		entry.EC_D = privKey.D
  1712  		entry.EC_X = privKey.X
  1713  		entry.EC_Y = privKey.Y
  1714  		derBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
  1715  		if err != nil {
  1716  			return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  1717  		}
  1718  		pemBlock := &pem.Block{
  1719  			Type:  "PUBLIC KEY",
  1720  			Bytes: derBytes,
  1721  		}
  1722  		pemBytes := pem.EncodeToMemory(pemBlock)
  1723  		if pemBytes == nil || len(pemBytes) == 0 {
  1724  			return fmt.Errorf("error PEM-encoding public key")
  1725  		}
  1726  		entry.FormattedPublicKey = string(pemBytes)
  1727  
  1728  	case KeyType_ED25519:
  1729  		// Go uses a 64-byte private key for Ed25519 keys (private+public, each
  1730  		// 32-bytes long). When we do Key derivation, we still generate a 32-byte
  1731  		// random value (and compute the corresponding Ed25519 public key), but
  1732  		// use this entire 64-byte key as if it was an HKDF key. The corresponding
  1733  		// underlying public key is never returned (which is probably good, because
  1734  		// doing so would leak half of our HKDF key...), but means we cannot import
  1735  		// derived-enabled Ed25519 public key components.
  1736  		pub, pri, err := ed25519.GenerateKey(randReader)
  1737  		if err != nil {
  1738  			return err
  1739  		}
  1740  		entry.Key = pri
  1741  		entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub)
  1742  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1743  		bitSize := 2048
  1744  		if p.Type == KeyType_RSA3072 {
  1745  			bitSize = 3072
  1746  		}
  1747  		if p.Type == KeyType_RSA4096 {
  1748  			bitSize = 4096
  1749  		}
  1750  
  1751  		entry.RSAKey, err = rsa.GenerateKey(randReader, bitSize)
  1752  		if err != nil {
  1753  			return err
  1754  		}
  1755  
  1756  		entry.RSAPublicKey = entry.RSAKey.Public().(*rsa.PublicKey)
  1757  	}
  1758  
  1759  	if p.ConvergentEncryption {
  1760  		if p.ConvergentVersion == -1 || p.ConvergentVersion > 1 {
  1761  			entry.ConvergentVersion = currentConvergentVersion
  1762  		}
  1763  	}
  1764  
  1765  	p.LatestVersion += 1
  1766  
  1767  	if p.Keys == nil {
  1768  		// This is an initial key rotation when generating a new policy. We
  1769  		// don't need to call migrate here because if we've called getPolicy to
  1770  		// get the policy in the first place it will have been run.
  1771  		p.Keys = keyEntryMap{}
  1772  	}
  1773  	p.Keys[strconv.Itoa(p.LatestVersion)] = entry
  1774  
  1775  	// This ensures that with new key creations min decryption version is set
  1776  	// to 1 rather than the int default of 0, since keys start at 1 (either
  1777  	// fresh or after migration to the key map)
  1778  	if p.MinDecryptionVersion == 0 {
  1779  		p.MinDecryptionVersion = 1
  1780  	}
  1781  
  1782  	return nil
  1783  }
  1784  
  1785  func (p *Policy) MigrateKeyToKeysMap() {
  1786  	now := time.Now()
  1787  	p.Keys = keyEntryMap{
  1788  		"1": KeyEntry{
  1789  			Key:                    p.Key,
  1790  			CreationTime:           now,
  1791  			DeprecatedCreationTime: now.Unix(),
  1792  		},
  1793  	}
  1794  	p.Key = nil
  1795  }
  1796  
  1797  // Backup should be called with an exclusive lock held on the policy
  1798  func (p *Policy) Backup(ctx context.Context, storage logical.Storage) (out string, retErr error) {
  1799  	if !p.Exportable {
  1800  		return "", fmt.Errorf("exporting is disallowed on the policy")
  1801  	}
  1802  
  1803  	if !p.AllowPlaintextBackup {
  1804  		return "", fmt.Errorf("plaintext backup is disallowed on the policy")
  1805  	}
  1806  
  1807  	priorBackupInfo := p.BackupInfo
  1808  
  1809  	defer func() {
  1810  		if retErr != nil {
  1811  			p.BackupInfo = priorBackupInfo
  1812  		}
  1813  	}()
  1814  
  1815  	// Create a record of this backup operation in the policy
  1816  	p.BackupInfo = &BackupInfo{
  1817  		Time:    time.Now(),
  1818  		Version: p.LatestVersion,
  1819  	}
  1820  	err := p.Persist(ctx, storage)
  1821  	if err != nil {
  1822  		return "", errwrap.Wrapf("failed to persist policy with backup info: {{err}}", err)
  1823  	}
  1824  
  1825  	// Load the archive only after persisting the policy as the archive can get
  1826  	// adjusted while persisting the policy
  1827  	archivedKeys, err := p.LoadArchive(ctx, storage)
  1828  	if err != nil {
  1829  		return "", err
  1830  	}
  1831  
  1832  	keyData := &KeyData{
  1833  		Policy:       p,
  1834  		ArchivedKeys: archivedKeys,
  1835  	}
  1836  
  1837  	encodedBackup, err := jsonutil.EncodeJSON(keyData)
  1838  	if err != nil {
  1839  		return "", err
  1840  	}
  1841  
  1842  	return base64.StdEncoding.EncodeToString(encodedBackup), nil
  1843  }
  1844  
  1845  func (p *Policy) getTemplateParts() ([]string, error) {
  1846  	partsRaw, ok := p.versionPrefixCache.Load("template-parts")
  1847  	if ok {
  1848  		return partsRaw.([]string), nil
  1849  	}
  1850  
  1851  	template := p.VersionTemplate
  1852  	if template == "" {
  1853  		template = DefaultVersionTemplate
  1854  	}
  1855  
  1856  	tplParts := strings.Split(template, "{{version}}")
  1857  	if len(tplParts) != 2 {
  1858  		return nil, errutil.InternalError{Err: "error parsing version template"}
  1859  	}
  1860  
  1861  	p.versionPrefixCache.Store("template-parts", tplParts)
  1862  	return tplParts, nil
  1863  }
  1864  
  1865  func (p *Policy) getVersionPrefix(ver int) string {
  1866  	prefixRaw, ok := p.versionPrefixCache.Load(ver)
  1867  	if ok {
  1868  		return prefixRaw.(string)
  1869  	}
  1870  
  1871  	template := p.VersionTemplate
  1872  	if template == "" {
  1873  		template = DefaultVersionTemplate
  1874  	}
  1875  
  1876  	prefix := strings.ReplaceAll(template, "{{version}}", strconv.Itoa(ver))
  1877  	p.versionPrefixCache.Store(ver, prefix)
  1878  
  1879  	return prefix
  1880  }
  1881  
  1882  // SymmetricOpts are the arguments to symmetric operations that are "optional", e.g.
  1883  // not always used.  This improves the aesthetics of calls to those functions.
  1884  type SymmetricOpts struct {
  1885  	// Whether to use convergent encryption
  1886  	Convergent bool
  1887  	// The version of the convergent encryption scheme
  1888  	ConvergentVersion int
  1889  	// The nonce, if not randomly generated
  1890  	Nonce []byte
  1891  	// Additional data to include in AEAD authentication
  1892  	AdditionalData []byte
  1893  	// The HMAC key, for generating IVs in convergent encryption
  1894  	HMACKey []byte
  1895  	// Allows an external provider of the AEAD, for e.g. managed keys
  1896  	AEADFactory AEADFactory
  1897  }
  1898  
  1899  // Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys
  1900  func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) {
  1901  	var aead cipher.AEAD
  1902  	var err error
  1903  	nonce := opts.Nonce
  1904  
  1905  	switch p.Type {
  1906  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
  1907  		// Setup the cipher
  1908  		aesCipher, err := aes.NewCipher(encKey)
  1909  		if err != nil {
  1910  			return nil, errutil.InternalError{Err: err.Error()}
  1911  		}
  1912  
  1913  		// Setup the GCM AEAD
  1914  		gcm, err := cipher.NewGCM(aesCipher)
  1915  		if err != nil {
  1916  			return nil, errutil.InternalError{Err: err.Error()}
  1917  		}
  1918  
  1919  		aead = gcm
  1920  
  1921  	case KeyType_ChaCha20_Poly1305:
  1922  		cha, err := chacha20poly1305.New(encKey)
  1923  		if err != nil {
  1924  			return nil, errutil.InternalError{Err: err.Error()}
  1925  		}
  1926  
  1927  		aead = cha
  1928  	case KeyType_MANAGED_KEY:
  1929  		if opts.Convergent || len(opts.Nonce) != 0 {
  1930  			return nil, errutil.UserError{Err: "cannot use convergent encryption or provide a nonce to managed-key backed encryption"}
  1931  		}
  1932  		if opts.AEADFactory == nil {
  1933  			return nil, errors.New("expected AEAD factory from managed key, none provided")
  1934  		}
  1935  		aead, err = opts.AEADFactory.GetAEAD(nonce)
  1936  		if err != nil {
  1937  			return nil, err
  1938  		}
  1939  	}
  1940  
  1941  	if opts.Convergent {
  1942  		convergentVersion := p.convergentVersion(ver)
  1943  		switch convergentVersion {
  1944  		case 1:
  1945  			if len(opts.Nonce) != aead.NonceSize() {
  1946  				return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())}
  1947  			}
  1948  		case 2, 3:
  1949  			if len(opts.HMACKey) == 0 {
  1950  				return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")}
  1951  			}
  1952  			nonceHmac := hmac.New(sha256.New, opts.HMACKey)
  1953  			nonceHmac.Write(plaintext)
  1954  			nonceSum := nonceHmac.Sum(nil)
  1955  			nonce = nonceSum[:aead.NonceSize()]
  1956  		default:
  1957  			return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)}
  1958  		}
  1959  	} else if len(nonce) == 0 {
  1960  		// Compute random nonce
  1961  		nonce, err = uuid.GenerateRandomBytes(aead.NonceSize())
  1962  		if err != nil {
  1963  			return nil, errutil.InternalError{Err: err.Error()}
  1964  		}
  1965  	} else if len(nonce) != aead.NonceSize() {
  1966  		return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long but given %d bytes", aead.NonceSize(), len(nonce))}
  1967  	}
  1968  
  1969  	// Encrypt and tag with AEAD
  1970  	ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData)
  1971  
  1972  	// Place the encrypted data after the nonce
  1973  	if !opts.Convergent || p.convergentVersion(ver) > 1 {
  1974  		ciphertext = append(nonce, ciphertext...)
  1975  	}
  1976  	return ciphertext, nil
  1977  }
  1978  
  1979  // Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys
  1980  func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) {
  1981  	var aead cipher.AEAD
  1982  	var err error
  1983  	var nonce []byte
  1984  
  1985  	switch p.Type {
  1986  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
  1987  		// Setup the cipher
  1988  		aesCipher, err := aes.NewCipher(encKey)
  1989  		if err != nil {
  1990  			return nil, errutil.InternalError{Err: err.Error()}
  1991  		}
  1992  
  1993  		// Setup the GCM AEAD
  1994  		gcm, err := cipher.NewGCM(aesCipher)
  1995  		if err != nil {
  1996  			return nil, errutil.InternalError{Err: err.Error()}
  1997  		}
  1998  
  1999  		aead = gcm
  2000  
  2001  	case KeyType_ChaCha20_Poly1305:
  2002  		cha, err := chacha20poly1305.New(encKey)
  2003  		if err != nil {
  2004  			return nil, errutil.InternalError{Err: err.Error()}
  2005  		}
  2006  
  2007  		aead = cha
  2008  	case KeyType_MANAGED_KEY:
  2009  		aead, err = opts.AEADFactory.GetAEAD(nonce)
  2010  		if err != nil {
  2011  			return nil, err
  2012  		}
  2013  	}
  2014  
  2015  	if len(ciphertext) < aead.NonceSize() {
  2016  		return nil, errutil.UserError{Err: "invalid ciphertext length"}
  2017  	}
  2018  
  2019  	// Extract the nonce and ciphertext
  2020  	var trueCT []byte
  2021  	if opts.Convergent && opts.ConvergentVersion == 1 {
  2022  		trueCT = ciphertext
  2023  	} else {
  2024  		nonce = ciphertext[:aead.NonceSize()]
  2025  		trueCT = ciphertext[aead.NonceSize():]
  2026  	}
  2027  
  2028  	// Verify and Decrypt
  2029  	plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData)
  2030  	if err != nil {
  2031  		return nil, errutil.UserError{Err: err.Error()}
  2032  	}
  2033  	return plain, nil
  2034  }
  2035  
  2036  func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
  2037  	if !p.Type.EncryptionSupported() {
  2038  		return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
  2039  	}
  2040  
  2041  	// Decode the plaintext value
  2042  	plaintext, err := base64.StdEncoding.DecodeString(value)
  2043  	if err != nil {
  2044  		return "", errutil.UserError{Err: err.Error()}
  2045  	}
  2046  
  2047  	switch {
  2048  	case ver == 0:
  2049  		ver = p.LatestVersion
  2050  	case ver < 0:
  2051  		return "", errutil.UserError{Err: "requested version for encryption is negative"}
  2052  	case ver > p.LatestVersion:
  2053  		return "", errutil.UserError{Err: "requested version for encryption is higher than the latest key version"}
  2054  	case ver < p.MinEncryptionVersion:
  2055  		return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"}
  2056  	}
  2057  
  2058  	var ciphertext []byte
  2059  
  2060  	switch p.Type {
  2061  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
  2062  		hmacKey := context
  2063  
  2064  		var encKey []byte
  2065  		var deriveHMAC bool
  2066  
  2067  		encBytes := 32
  2068  		hmacBytes := 0
  2069  		convergentVersion := p.convergentVersion(ver)
  2070  		if convergentVersion > 2 {
  2071  			deriveHMAC = true
  2072  			hmacBytes = 32
  2073  			if len(nonce) > 0 {
  2074  				return "", errutil.UserError{Err: "nonce provided when not allowed"}
  2075  			}
  2076  		} else if len(nonce) > 0 && (!p.ConvergentEncryption || convergentVersion != 1) {
  2077  			return "", errutil.UserError{Err: "nonce provided when not allowed"}
  2078  		}
  2079  		if p.Type == KeyType_AES128_GCM96 {
  2080  			encBytes = 16
  2081  		}
  2082  
  2083  		key, err := p.GetKey(context, ver, encBytes+hmacBytes)
  2084  		if err != nil {
  2085  			return "", err
  2086  		}
  2087  
  2088  		if len(key) < encBytes+hmacBytes {
  2089  			return "", errutil.InternalError{Err: "could not derive key, length too small"}
  2090  		}
  2091  
  2092  		encKey = key[:encBytes]
  2093  		if len(encKey) != encBytes {
  2094  			return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
  2095  		}
  2096  		if deriveHMAC {
  2097  			hmacKey = key[encBytes:]
  2098  			if len(hmacKey) != hmacBytes {
  2099  				return "", errutil.InternalError{Err: "could not derive hmac key, length not correct"}
  2100  			}
  2101  		}
  2102  
  2103  		symopts := SymmetricOpts{
  2104  			Convergent: p.ConvergentEncryption,
  2105  			HMACKey:    hmacKey,
  2106  			Nonce:      nonce,
  2107  		}
  2108  		for index, rawFactory := range factories {
  2109  			if rawFactory == nil {
  2110  				continue
  2111  			}
  2112  			switch factory := rawFactory.(type) {
  2113  			case AEADFactory:
  2114  				symopts.AEADFactory = factory
  2115  			case AssociatedDataFactory:
  2116  				symopts.AdditionalData, err = factory.GetAssociatedData()
  2117  				if err != nil {
  2118  					return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
  2119  				}
  2120  			case ManagedKeyFactory:
  2121  			default:
  2122  				return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
  2123  			}
  2124  		}
  2125  
  2126  		ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
  2127  		if err != nil {
  2128  			return "", err
  2129  		}
  2130  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2131  		keyEntry, err := p.safeGetKeyEntry(ver)
  2132  		if err != nil {
  2133  			return "", err
  2134  		}
  2135  		var publicKey *rsa.PublicKey
  2136  		if keyEntry.RSAKey != nil {
  2137  			publicKey = &keyEntry.RSAKey.PublicKey
  2138  		} else {
  2139  			publicKey = keyEntry.RSAPublicKey
  2140  		}
  2141  		ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
  2142  		if err != nil {
  2143  			return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)}
  2144  		}
  2145  	case KeyType_MANAGED_KEY:
  2146  		keyEntry, err := p.safeGetKeyEntry(ver)
  2147  		if err != nil {
  2148  			return "", err
  2149  		}
  2150  
  2151  		var aad []byte
  2152  		var managedKeyFactory ManagedKeyFactory
  2153  		for _, f := range factories {
  2154  			switch factory := f.(type) {
  2155  			case AssociatedDataFactory:
  2156  				aad, err = factory.GetAssociatedData()
  2157  				if err != nil {
  2158  					return "", nil
  2159  				}
  2160  			case ManagedKeyFactory:
  2161  				managedKeyFactory = factory
  2162  			}
  2163  		}
  2164  
  2165  		if managedKeyFactory == nil {
  2166  			return "", errors.New("key type is managed_key, but managed key parameters were not provided")
  2167  		}
  2168  
  2169  		ciphertext, err = p.encryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, plaintext, nonce, aad)
  2170  		if err != nil {
  2171  			return "", err
  2172  		}
  2173  
  2174  	default:
  2175  		return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  2176  	}
  2177  
  2178  	// Convert to base64
  2179  	encoded := base64.StdEncoding.EncodeToString(ciphertext)
  2180  
  2181  	// Prepend some information
  2182  	encoded = p.getVersionPrefix(ver) + encoded
  2183  
  2184  	return encoded, nil
  2185  }
  2186  
  2187  func (p *Policy) KeyVersionCanBeUpdated(keyVersion int, isPrivateKey bool) error {
  2188  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2189  	if err != nil {
  2190  		return err
  2191  	}
  2192  
  2193  	if !p.Type.ImportPublicKeySupported() {
  2194  		return errors.New("provided type does not support importing key versions")
  2195  	}
  2196  
  2197  	isPrivateKeyMissing := keyEntry.IsPrivateKeyMissing()
  2198  	if isPrivateKeyMissing && !isPrivateKey {
  2199  		return errors.New("cannot add a public key to a key version that already has a public key set")
  2200  	}
  2201  
  2202  	if !isPrivateKeyMissing {
  2203  		return errors.New("private key imported, key version cannot be updated")
  2204  	}
  2205  
  2206  	return nil
  2207  }
  2208  
  2209  func (p *Policy) ImportPrivateKeyForVersion(ctx context.Context, storage logical.Storage, keyVersion int, key []byte) error {
  2210  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2211  	if err != nil {
  2212  		return err
  2213  	}
  2214  
  2215  	// Parse key
  2216  	parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key)
  2217  	if err != nil {
  2218  		if strings.Contains(err.Error(), "unknown elliptic curve") {
  2219  			var edErr error
  2220  			parsedPrivateKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
  2221  			if edErr != nil {
  2222  				return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
  2223  			}
  2224  
  2225  			// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
  2226  		} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
  2227  			var rsaErr error
  2228  			parsedPrivateKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
  2229  			if rsaErr != nil {
  2230  				return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
  2231  			}
  2232  
  2233  			// Parsing as RSA-PSS in PKCS8 succeeded!
  2234  		} else {
  2235  			return fmt.Errorf("error parsing asymmetric key: %s", err)
  2236  		}
  2237  	}
  2238  
  2239  	switch parsedPrivateKey.(type) {
  2240  	case *ecdsa.PrivateKey:
  2241  		ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey)
  2242  		pemBlock, _ := pem.Decode([]byte(keyEntry.FormattedPublicKey))
  2243  		if pemBlock == nil {
  2244  			return fmt.Errorf("failed to parse key entry public key: invalid PEM blob")
  2245  		}
  2246  		publicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
  2247  		if err != nil || publicKey == nil {
  2248  			return fmt.Errorf("failed to parse key entry public key: %v", err)
  2249  		}
  2250  		if !publicKey.(*ecdsa.PublicKey).Equal(&ecdsaKey.PublicKey) {
  2251  			return fmt.Errorf("cannot import key, key pair does not match")
  2252  		}
  2253  	case *rsa.PrivateKey:
  2254  		rsaKey := parsedPrivateKey.(*rsa.PrivateKey)
  2255  		if !rsaKey.PublicKey.Equal(keyEntry.RSAPublicKey) {
  2256  			return fmt.Errorf("cannot import key, key pair does not match")
  2257  		}
  2258  	case ed25519.PrivateKey:
  2259  		ed25519Key := parsedPrivateKey.(ed25519.PrivateKey)
  2260  		publicKey, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  2261  		if err != nil {
  2262  			return fmt.Errorf("failed to parse key entry public key: %v", err)
  2263  		}
  2264  		if !ed25519.PublicKey(publicKey).Equal(ed25519Key.Public()) {
  2265  			return fmt.Errorf("cannot import key, key pair does not match")
  2266  		}
  2267  	}
  2268  
  2269  	err = keyEntry.parseFromKey(p.Type, parsedPrivateKey)
  2270  	if err != nil {
  2271  		return err
  2272  	}
  2273  
  2274  	p.Keys[strconv.Itoa(keyVersion)] = keyEntry
  2275  
  2276  	return p.Persist(ctx, storage)
  2277  }
  2278  
  2279  func (ke *KeyEntry) parseFromKey(PolKeyType KeyType, parsedKey any) error {
  2280  	switch parsedKey.(type) {
  2281  	case *ecdsa.PrivateKey, *ecdsa.PublicKey:
  2282  		if PolKeyType != KeyType_ECDSA_P256 && PolKeyType != KeyType_ECDSA_P384 && PolKeyType != KeyType_ECDSA_P521 {
  2283  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2284  		}
  2285  
  2286  		curve := elliptic.P256()
  2287  		if PolKeyType == KeyType_ECDSA_P384 {
  2288  			curve = elliptic.P384()
  2289  		} else if PolKeyType == KeyType_ECDSA_P521 {
  2290  			curve = elliptic.P521()
  2291  		}
  2292  
  2293  		var derBytes []byte
  2294  		var err error
  2295  		ecdsaKey, ok := parsedKey.(*ecdsa.PrivateKey)
  2296  		if ok {
  2297  
  2298  			if ecdsaKey.Curve != curve {
  2299  				return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
  2300  			}
  2301  
  2302  			ke.EC_D = ecdsaKey.D
  2303  			ke.EC_X = ecdsaKey.X
  2304  			ke.EC_Y = ecdsaKey.Y
  2305  
  2306  			derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey.Public())
  2307  			if err != nil {
  2308  				return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  2309  			}
  2310  		} else {
  2311  			ecdsaKey := parsedKey.(*ecdsa.PublicKey)
  2312  
  2313  			if ecdsaKey.Curve != curve {
  2314  				return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
  2315  			}
  2316  
  2317  			ke.EC_X = ecdsaKey.X
  2318  			ke.EC_Y = ecdsaKey.Y
  2319  
  2320  			derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey)
  2321  			if err != nil {
  2322  				return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  2323  			}
  2324  		}
  2325  
  2326  		pemBlock := &pem.Block{
  2327  			Type:  "PUBLIC KEY",
  2328  			Bytes: derBytes,
  2329  		}
  2330  		pemBytes := pem.EncodeToMemory(pemBlock)
  2331  		if pemBytes == nil || len(pemBytes) == 0 {
  2332  			return fmt.Errorf("error PEM-encoding public key")
  2333  		}
  2334  		ke.FormattedPublicKey = string(pemBytes)
  2335  	case ed25519.PrivateKey, ed25519.PublicKey:
  2336  		if PolKeyType != KeyType_ED25519 {
  2337  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2338  		}
  2339  
  2340  		privateKey, ok := parsedKey.(ed25519.PrivateKey)
  2341  		if ok {
  2342  			ke.Key = privateKey
  2343  			publicKey := privateKey.Public().(ed25519.PublicKey)
  2344  			ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
  2345  		} else {
  2346  			publicKey := parsedKey.(ed25519.PublicKey)
  2347  			ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
  2348  		}
  2349  	case *rsa.PrivateKey, *rsa.PublicKey:
  2350  		if PolKeyType != KeyType_RSA2048 && PolKeyType != KeyType_RSA3072 && PolKeyType != KeyType_RSA4096 {
  2351  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2352  		}
  2353  
  2354  		keyBytes := 256
  2355  		if PolKeyType == KeyType_RSA3072 {
  2356  			keyBytes = 384
  2357  		} else if PolKeyType == KeyType_RSA4096 {
  2358  			keyBytes = 512
  2359  		}
  2360  
  2361  		rsaKey, ok := parsedKey.(*rsa.PrivateKey)
  2362  		if ok {
  2363  			if rsaKey.Size() != keyBytes {
  2364  				return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
  2365  			}
  2366  			ke.RSAKey = rsaKey
  2367  			ke.RSAPublicKey = rsaKey.Public().(*rsa.PublicKey)
  2368  		} else {
  2369  			rsaKey := parsedKey.(*rsa.PublicKey)
  2370  			if rsaKey.Size() != keyBytes {
  2371  				return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
  2372  			}
  2373  			ke.RSAPublicKey = rsaKey
  2374  		}
  2375  	default:
  2376  		return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2377  	}
  2378  
  2379  	return nil
  2380  }
  2381  
  2382  func (p *Policy) WrapKey(ver int, targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
  2383  	if !p.Type.SigningSupported() {
  2384  		return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
  2385  	}
  2386  
  2387  	switch {
  2388  	case ver == 0:
  2389  		ver = p.LatestVersion
  2390  	case ver < 0:
  2391  		return "", errutil.UserError{Err: "requested version for key wrapping is negative"}
  2392  	case ver > p.LatestVersion:
  2393  		return "", errutil.UserError{Err: "requested version for key wrapping is higher than the latest key version"}
  2394  	case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
  2395  		return "", errutil.UserError{Err: "requested version for key wrapping is less than the minimum encryption key version"}
  2396  	}
  2397  
  2398  	keyEntry, err := p.safeGetKeyEntry(ver)
  2399  	if err != nil {
  2400  		return "", err
  2401  	}
  2402  
  2403  	return keyEntry.WrapKey(targetKey, targetKeyType, hash)
  2404  }
  2405  
  2406  func (ke *KeyEntry) WrapKey(targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
  2407  	// Presently this method implements a CKM_RSA_AES_KEY_WRAP-compatible
  2408  	// wrapping interface and only works on RSA keyEntries as a result.
  2409  	if ke.RSAPublicKey == nil {
  2410  		return "", fmt.Errorf("unsupported key type in use; must be a rsa key")
  2411  	}
  2412  
  2413  	var preppedTargetKey []byte
  2414  	switch targetKeyType {
  2415  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC, KeyType_AES128_CMAC, KeyType_AES256_CMAC:
  2416  		var ok bool
  2417  		preppedTargetKey, ok = targetKey.([]byte)
  2418  		if !ok {
  2419  			return "", fmt.Errorf("failed to wrap target key for import: symmetric key not provided in byte format (%T)", targetKey)
  2420  		}
  2421  	default:
  2422  		var err error
  2423  		preppedTargetKey, err = x509.MarshalPKCS8PrivateKey(targetKey)
  2424  		if err != nil {
  2425  			return "", fmt.Errorf("failed to wrap target key for import: %w", err)
  2426  		}
  2427  	}
  2428  
  2429  	result, err := wrapTargetPKCS8ForImport(ke.RSAPublicKey, preppedTargetKey, hash)
  2430  	if err != nil {
  2431  		return result, fmt.Errorf("failed to wrap target key for import: %w", err)
  2432  	}
  2433  
  2434  	return result, nil
  2435  }
  2436  
  2437  func wrapTargetPKCS8ForImport(wrappingKey *rsa.PublicKey, preppedTargetKey []byte, hash hash.Hash) (string, error) {
  2438  	// Generate an ephemeral AES-256 key
  2439  	ephKey, err := uuid.GenerateRandomBytes(32)
  2440  	if err != nil {
  2441  		return "", fmt.Errorf("failed to generate an ephemeral AES wrapping key: %w", err)
  2442  	}
  2443  
  2444  	// Wrap ephemeral AES key with public wrapping key
  2445  	ephKeyWrapped, err := rsa.EncryptOAEP(hash, rand.Reader, wrappingKey, ephKey, []byte{} /* label */)
  2446  	if err != nil {
  2447  		return "", fmt.Errorf("failed to encrypt ephemeral wrapping key with public key: %w", err)
  2448  	}
  2449  
  2450  	// Create KWP instance for wrapping target key
  2451  	kwp, err := subtle.NewKWP(ephKey)
  2452  	if err != nil {
  2453  		return "", fmt.Errorf("failed to generate new KWP from AES key: %w", err)
  2454  	}
  2455  
  2456  	// Wrap target key with KWP
  2457  	targetKeyWrapped, err := kwp.Wrap(preppedTargetKey)
  2458  	if err != nil {
  2459  		return "", fmt.Errorf("failed to wrap target key with KWP: %w", err)
  2460  	}
  2461  
  2462  	// Combined wrapped keys into a single blob and base64 encode
  2463  	wrappedKeys := append(ephKeyWrapped, targetKeyWrapped...)
  2464  	return base64.StdEncoding.EncodeToString(wrappedKeys), nil
  2465  }
  2466  
  2467  func (p *Policy) CreateCsr(keyVersion int, csrTemplate *x509.CertificateRequest) ([]byte, error) {
  2468  	if !p.Type.SigningSupported() {
  2469  		return nil, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
  2470  	}
  2471  
  2472  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2473  	if err != nil {
  2474  		return nil, err
  2475  	}
  2476  
  2477  	if keyEntry.IsPrivateKeyMissing() {
  2478  		return nil, errutil.UserError{Err: "private key not imported for key version selected"}
  2479  	}
  2480  
  2481  	csrTemplate.Signature = nil
  2482  	csrTemplate.SignatureAlgorithm = x509.UnknownSignatureAlgorithm
  2483  
  2484  	var key crypto.Signer
  2485  	switch p.Type {
  2486  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  2487  		var curve elliptic.Curve
  2488  		switch p.Type {
  2489  		case KeyType_ECDSA_P384:
  2490  			curve = elliptic.P384()
  2491  		case KeyType_ECDSA_P521:
  2492  			curve = elliptic.P521()
  2493  		default:
  2494  			curve = elliptic.P256()
  2495  		}
  2496  
  2497  		key = &ecdsa.PrivateKey{
  2498  			PublicKey: ecdsa.PublicKey{
  2499  				Curve: curve,
  2500  				X:     keyEntry.EC_X,
  2501  				Y:     keyEntry.EC_Y,
  2502  			},
  2503  			D: keyEntry.EC_D,
  2504  		}
  2505  
  2506  	case KeyType_ED25519:
  2507  		if p.Derived {
  2508  			return nil, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
  2509  		}
  2510  		key = ed25519.PrivateKey(keyEntry.Key)
  2511  
  2512  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2513  		key = keyEntry.RSAKey
  2514  
  2515  	default:
  2516  		return nil, errutil.InternalError{Err: fmt.Sprintf("selected key type '%s' does not support signing", p.Type.String())}
  2517  	}
  2518  	csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
  2519  	if err != nil {
  2520  		return nil, fmt.Errorf("could not create the cerfificate request: %w", err)
  2521  	}
  2522  
  2523  	pemCsr := pem.EncodeToMemory(&pem.Block{
  2524  		Type:  "CERTIFICATE REQUEST",
  2525  		Bytes: csrBytes,
  2526  	})
  2527  
  2528  	return pemCsr, nil
  2529  }
  2530  
  2531  func (p *Policy) ValidateLeafCertKeyMatch(keyVersion int, certPublicKeyAlgorithm x509.PublicKeyAlgorithm, certPublicKey any) (bool, error) {
  2532  	if !p.Type.SigningSupported() {
  2533  		return false, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
  2534  	}
  2535  
  2536  	var keyTypeMatches bool
  2537  	switch p.Type {
  2538  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  2539  		if certPublicKeyAlgorithm == x509.ECDSA {
  2540  			keyTypeMatches = true
  2541  		}
  2542  	case KeyType_ED25519:
  2543  		if certPublicKeyAlgorithm == x509.Ed25519 {
  2544  			keyTypeMatches = true
  2545  		}
  2546  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2547  		if certPublicKeyAlgorithm == x509.RSA {
  2548  			keyTypeMatches = true
  2549  		}
  2550  	}
  2551  	if !keyTypeMatches {
  2552  		return false, errutil.UserError{Err: fmt.Sprintf("provided leaf certificate public key algorithm '%s' does not match the transit key type '%s'",
  2553  			certPublicKeyAlgorithm, p.Type)}
  2554  	}
  2555  
  2556  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2557  	if err != nil {
  2558  		return false, err
  2559  	}
  2560  
  2561  	switch certPublicKeyAlgorithm {
  2562  	case x509.ECDSA:
  2563  		certPublicKey := certPublicKey.(*ecdsa.PublicKey)
  2564  		var curve elliptic.Curve
  2565  		switch p.Type {
  2566  		case KeyType_ECDSA_P384:
  2567  			curve = elliptic.P384()
  2568  		case KeyType_ECDSA_P521:
  2569  			curve = elliptic.P521()
  2570  		default:
  2571  			curve = elliptic.P256()
  2572  		}
  2573  
  2574  		publicKey := &ecdsa.PublicKey{
  2575  			Curve: curve,
  2576  			X:     keyEntry.EC_X,
  2577  			Y:     keyEntry.EC_Y,
  2578  		}
  2579  
  2580  		return publicKey.Equal(certPublicKey), nil
  2581  
  2582  	case x509.Ed25519:
  2583  		if p.Derived {
  2584  			return false, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
  2585  		}
  2586  		certPublicKey := certPublicKey.(ed25519.PublicKey)
  2587  
  2588  		raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  2589  		if err != nil {
  2590  			return false, err
  2591  		}
  2592  		publicKey := ed25519.PublicKey(raw)
  2593  
  2594  		return publicKey.Equal(certPublicKey), nil
  2595  
  2596  	case x509.RSA:
  2597  		certPublicKey := certPublicKey.(*rsa.PublicKey)
  2598  		publicKey := keyEntry.RSAKey.PublicKey
  2599  		return publicKey.Equal(certPublicKey), nil
  2600  
  2601  	case x509.UnknownPublicKeyAlgorithm:
  2602  		return false, errutil.InternalError{Err: fmt.Sprint("certificate signed with an unknown algorithm")}
  2603  	}
  2604  
  2605  	return false, nil
  2606  }
  2607  
  2608  func (p *Policy) ValidateAndPersistCertificateChain(ctx context.Context, keyVersion int, certChain []*x509.Certificate, storage logical.Storage) error {
  2609  	if len(certChain) == 0 {
  2610  		return errutil.UserError{Err: "expected at least one certificate in the parsed certificate chain"}
  2611  	}
  2612  
  2613  	if certChain[0].BasicConstraintsValid && certChain[0].IsCA {
  2614  		return errutil.UserError{Err: "certificate in the first position is not a leaf certificate"}
  2615  	}
  2616  
  2617  	for _, cert := range certChain[1:] {
  2618  		if cert.BasicConstraintsValid && !cert.IsCA {
  2619  			return errutil.UserError{Err: "provided certificate chain contains more than one leaf certificate"}
  2620  		}
  2621  	}
  2622  
  2623  	valid, err := p.ValidateLeafCertKeyMatch(keyVersion, certChain[0].PublicKeyAlgorithm, certChain[0].PublicKey)
  2624  	if err != nil {
  2625  		prefixedErr := fmt.Errorf("could not validate key match between leaf certificate key and key version in transit: %w", err)
  2626  		switch err.(type) {
  2627  		case errutil.UserError:
  2628  			return errutil.UserError{Err: prefixedErr.Error()}
  2629  		default:
  2630  			return prefixedErr
  2631  		}
  2632  	}
  2633  	if !valid {
  2634  		return fmt.Errorf("leaf certificate public key does match the key version selected")
  2635  	}
  2636  
  2637  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2638  	if err != nil {
  2639  		return err
  2640  	}
  2641  
  2642  	// Convert the certificate chain to DER format
  2643  	derCertificates := make([][]byte, len(certChain))
  2644  	for i, cert := range certChain {
  2645  		derCertificates[i] = cert.Raw
  2646  	}
  2647  
  2648  	keyEntry.CertificateChain = derCertificates
  2649  
  2650  	p.Keys[strconv.Itoa(keyVersion)] = keyEntry
  2651  	return p.Persist(ctx, storage)
  2652  }