github.com/hashicorp/vault/sdk@v0.11.0/helper/keysutil/policy.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto"
    10  	"crypto/aes"
    11  	"crypto/cipher"
    12  	"crypto/ecdsa"
    13  	"crypto/elliptic"
    14  	"crypto/hmac"
    15  	"crypto/rand"
    16  	"crypto/rsa"
    17  	"crypto/sha256"
    18  	"crypto/x509"
    19  	"encoding/asn1"
    20  	"encoding/base64"
    21  	"encoding/json"
    22  	"encoding/pem"
    23  	"errors"
    24  	"fmt"
    25  	"hash"
    26  	"io"
    27  	"math/big"
    28  	"path"
    29  	"strconv"
    30  	"strings"
    31  	"sync"
    32  	"sync/atomic"
    33  	"time"
    34  
    35  	"golang.org/x/crypto/chacha20poly1305"
    36  	"golang.org/x/crypto/ed25519"
    37  	"golang.org/x/crypto/hkdf"
    38  
    39  	"github.com/hashicorp/errwrap"
    40  	"github.com/hashicorp/go-uuid"
    41  	"github.com/hashicorp/vault/sdk/helper/errutil"
    42  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    43  	"github.com/hashicorp/vault/sdk/helper/kdf"
    44  	"github.com/hashicorp/vault/sdk/logical"
    45  
    46  	"github.com/google/tink/go/kwp/subtle"
    47  )
    48  
    49  // Careful with iota; don't put anything before it in this const block because
    50  // we need the default of zero to be the old-style KDF
    51  const (
    52  	Kdf_hmac_sha256_counter = iota // built-in helper
    53  	Kdf_hkdf_sha256                // golang.org/x/crypto/hkdf
    54  
    55  	HmacMinKeySize = 256 / 8
    56  	HmacMaxKeySize = 4096 / 8
    57  )
    58  
    59  // Or this one...we need the default of zero to be the original AES256-GCM96
    60  const (
    61  	KeyType_AES256_GCM96 = iota
    62  	KeyType_ECDSA_P256
    63  	KeyType_ED25519
    64  	KeyType_RSA2048
    65  	KeyType_RSA4096
    66  	KeyType_ChaCha20_Poly1305
    67  	KeyType_ECDSA_P384
    68  	KeyType_ECDSA_P521
    69  	KeyType_AES128_GCM96
    70  	KeyType_RSA3072
    71  	KeyType_MANAGED_KEY
    72  	KeyType_HMAC
    73  )
    74  
    75  const (
    76  	// ErrTooOld is returned whtn the ciphertext or signatures's key version is
    77  	// too old.
    78  	ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
    79  
    80  	// DefaultVersionTemplate is used when no version template is provided.
    81  	DefaultVersionTemplate = "vault:v{{version}}:"
    82  )
    83  
    84  type AEADFactory interface {
    85  	GetAEAD(iv []byte) (cipher.AEAD, error)
    86  }
    87  
    88  type AssociatedDataFactory interface {
    89  	GetAssociatedData() ([]byte, error)
    90  }
    91  
    92  type ManagedKeyFactory interface {
    93  	GetManagedKeyParameters() ManagedKeyParameters
    94  }
    95  
    96  type RestoreInfo struct {
    97  	Time    time.Time `json:"time"`
    98  	Version int       `json:"version"`
    99  }
   100  
   101  type BackupInfo struct {
   102  	Time    time.Time `json:"time"`
   103  	Version int       `json:"version"`
   104  }
   105  
   106  type SigningOptions struct {
   107  	HashAlgorithm    HashType
   108  	Marshaling       MarshalingType
   109  	SaltLength       int
   110  	SigAlgorithm     string
   111  	ManagedKeyParams ManagedKeyParameters
   112  }
   113  
   114  type SigningResult struct {
   115  	Signature string
   116  	PublicKey []byte
   117  }
   118  
   119  type ecdsaSignature struct {
   120  	R, S *big.Int
   121  }
   122  
   123  type KeyType int
   124  
   125  func (kt KeyType) EncryptionSupported() bool {
   126  	switch kt {
   127  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   128  		return true
   129  	}
   130  	return false
   131  }
   132  
   133  func (kt KeyType) DecryptionSupported() bool {
   134  	switch kt {
   135  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   136  		return true
   137  	}
   138  	return false
   139  }
   140  
   141  func (kt KeyType) SigningSupported() bool {
   142  	switch kt {
   143  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   144  		return true
   145  	}
   146  	return false
   147  }
   148  
   149  func (kt KeyType) HashSignatureInput() bool {
   150  	switch kt {
   151  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
   152  		return true
   153  	}
   154  	return false
   155  }
   156  
   157  func (kt KeyType) DerivationSupported() bool {
   158  	switch kt {
   159  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_ED25519:
   160  		return true
   161  	}
   162  	return false
   163  }
   164  
   165  func (kt KeyType) AssociatedDataSupported() bool {
   166  	switch kt {
   167  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_MANAGED_KEY:
   168  		return true
   169  	}
   170  	return false
   171  }
   172  
   173  func (kt KeyType) ImportPublicKeySupported() bool {
   174  	switch kt {
   175  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519:
   176  		return true
   177  	}
   178  	return false
   179  }
   180  
   181  func (kt KeyType) String() string {
   182  	switch kt {
   183  	case KeyType_AES128_GCM96:
   184  		return "aes128-gcm96"
   185  	case KeyType_AES256_GCM96:
   186  		return "aes256-gcm96"
   187  	case KeyType_ChaCha20_Poly1305:
   188  		return "chacha20-poly1305"
   189  	case KeyType_ECDSA_P256:
   190  		return "ecdsa-p256"
   191  	case KeyType_ECDSA_P384:
   192  		return "ecdsa-p384"
   193  	case KeyType_ECDSA_P521:
   194  		return "ecdsa-p521"
   195  	case KeyType_ED25519:
   196  		return "ed25519"
   197  	case KeyType_RSA2048:
   198  		return "rsa-2048"
   199  	case KeyType_RSA3072:
   200  		return "rsa-3072"
   201  	case KeyType_RSA4096:
   202  		return "rsa-4096"
   203  	case KeyType_HMAC:
   204  		return "hmac"
   205  	case KeyType_MANAGED_KEY:
   206  		return "managed_key"
   207  	}
   208  
   209  	return "[unknown]"
   210  }
   211  
   212  type KeyData struct {
   213  	Policy       *Policy       `json:"policy"`
   214  	ArchivedKeys *archivedKeys `json:"archived_keys"`
   215  }
   216  
   217  // KeyEntry stores the key and metadata
   218  type KeyEntry struct {
   219  	// AES or some other kind that is a pure byte slice like ED25519
   220  	Key []byte `json:"key"`
   221  
   222  	// Key used for HMAC functions
   223  	HMACKey []byte `json:"hmac_key"`
   224  
   225  	// Time of creation
   226  	CreationTime time.Time `json:"time"`
   227  
   228  	EC_X *big.Int `json:"ec_x"`
   229  	EC_Y *big.Int `json:"ec_y"`
   230  	EC_D *big.Int `json:"ec_d"`
   231  
   232  	RSAKey       *rsa.PrivateKey `json:"rsa_key"`
   233  	RSAPublicKey *rsa.PublicKey  `json:"rsa_public_key"`
   234  
   235  	// The public key in an appropriate format for the type of key
   236  	FormattedPublicKey string `json:"public_key"`
   237  
   238  	// If convergent is enabled, the version (falling back to what's in the
   239  	// policy)
   240  	ConvergentVersion int `json:"convergent_version"`
   241  
   242  	// This is deprecated (but still filled) in favor of the value above which
   243  	// is more precise
   244  	DeprecatedCreationTime int64 `json:"creation_time"`
   245  
   246  	ManagedKeyUUID string `json:"managed_key_id,omitempty"`
   247  
   248  	// Key entry certificate chain. If set, leaf certificate key matches the
   249  	// KeyEntry key
   250  	CertificateChain [][]byte `json:"certificate_chain"`
   251  }
   252  
   253  func (ke *KeyEntry) IsPrivateKeyMissing() bool {
   254  	if ke.RSAKey != nil || ke.EC_D != nil || len(ke.Key) != 0 || len(ke.ManagedKeyUUID) != 0 {
   255  		return false
   256  	}
   257  
   258  	return true
   259  }
   260  
   261  // deprecatedKeyEntryMap is used to allow JSON marshal/unmarshal
   262  type deprecatedKeyEntryMap map[int]KeyEntry
   263  
   264  // MarshalJSON implements JSON marshaling
   265  func (kem deprecatedKeyEntryMap) MarshalJSON() ([]byte, error) {
   266  	intermediate := map[string]KeyEntry{}
   267  	for k, v := range kem {
   268  		intermediate[strconv.Itoa(k)] = v
   269  	}
   270  	return json.Marshal(&intermediate)
   271  }
   272  
   273  // MarshalJSON implements JSON unmarshalling
   274  func (kem deprecatedKeyEntryMap) UnmarshalJSON(data []byte) error {
   275  	intermediate := map[string]KeyEntry{}
   276  	if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
   277  		return err
   278  	}
   279  	for k, v := range intermediate {
   280  		keyval, err := strconv.Atoi(k)
   281  		if err != nil {
   282  			return err
   283  		}
   284  		kem[keyval] = v
   285  	}
   286  
   287  	return nil
   288  }
   289  
   290  // keyEntryMap is used to allow JSON marshal/unmarshal
   291  type keyEntryMap map[string]KeyEntry
   292  
   293  // PolicyConfig is used to create a new policy
   294  type PolicyConfig struct {
   295  	// The name of the policy
   296  	Name string `json:"name"`
   297  
   298  	// The type of key
   299  	Type KeyType
   300  
   301  	// Derived keys MUST provide a context and the master underlying key is
   302  	// never used.
   303  	Derived              bool
   304  	KDF                  int
   305  	ConvergentEncryption bool
   306  
   307  	// Whether the key is exportable
   308  	Exportable bool
   309  
   310  	// Whether the key is allowed to be deleted
   311  	DeletionAllowed bool
   312  
   313  	// AllowPlaintextBackup allows taking backup of the policy in plaintext
   314  	AllowPlaintextBackup bool
   315  
   316  	// VersionTemplate is used to prefix the ciphertext with information about
   317  	// the key version. It must inclide {{version}} and a delimiter between the
   318  	// version prefix and the ciphertext.
   319  	VersionTemplate string
   320  
   321  	// StoragePrefix is used to add a prefix when storing and retrieving the
   322  	// policy object.
   323  	StoragePrefix string
   324  }
   325  
   326  // NewPolicy takes a policy config and returns a Policy with those settings.
   327  func NewPolicy(config PolicyConfig) *Policy {
   328  	return &Policy{
   329  		l:                    new(sync.RWMutex),
   330  		Name:                 config.Name,
   331  		Type:                 config.Type,
   332  		Derived:              config.Derived,
   333  		KDF:                  config.KDF,
   334  		ConvergentEncryption: config.ConvergentEncryption,
   335  		ConvergentVersion:    -1,
   336  		Exportable:           config.Exportable,
   337  		DeletionAllowed:      config.DeletionAllowed,
   338  		AllowPlaintextBackup: config.AllowPlaintextBackup,
   339  		VersionTemplate:      config.VersionTemplate,
   340  		StoragePrefix:        config.StoragePrefix,
   341  	}
   342  }
   343  
   344  // LoadPolicy will load a policy from the provided storage path and set the
   345  // necessary un-exported variables. It is particularly useful when accessing a
   346  // policy without the lock manager.
   347  func LoadPolicy(ctx context.Context, s logical.Storage, path string) (*Policy, error) {
   348  	raw, err := s.Get(ctx, path)
   349  	if err != nil {
   350  		return nil, err
   351  	}
   352  	if raw == nil {
   353  		return nil, nil
   354  	}
   355  
   356  	var policy Policy
   357  	err = jsonutil.DecodeJSON(raw.Value, &policy)
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  
   362  	// Migrate RSA private keys to include their private counterpart. This lets
   363  	// us reference RSAPublicKey whenever we need to, without necessarily
   364  	// needing the private key handy, synchronizing the behavior with EC and
   365  	// Ed25519 key pairs.
   366  	switch policy.Type {
   367  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
   368  		for _, entry := range policy.Keys {
   369  			if entry.RSAPublicKey == nil && entry.RSAKey != nil {
   370  				entry.RSAPublicKey = entry.RSAKey.Public().(*rsa.PublicKey)
   371  			}
   372  		}
   373  	}
   374  
   375  	policy.l = new(sync.RWMutex)
   376  
   377  	return &policy, nil
   378  }
   379  
   380  // Policy is the struct used to store metadata
   381  type Policy struct {
   382  	// This is a pointer on purpose: if we are running with cache disabled we
   383  	// need to actually swap in the lock manager's lock for this policy with
   384  	// the local lock.
   385  	l *sync.RWMutex
   386  	// writeLocked allows us to implement Lock() and Unlock()
   387  	writeLocked bool
   388  	// Stores whether it's been deleted. This acts as a guard for operations
   389  	// that may write data, e.g. if one request rotates and that request is
   390  	// served after a delete.
   391  	deleted uint32
   392  
   393  	Name    string      `json:"name"`
   394  	Key     []byte      `json:"key,omitempty"`      // DEPRECATED
   395  	KeySize int         `json:"key_size,omitempty"` // For algorithms with variable key sizes
   396  	Keys    keyEntryMap `json:"keys"`
   397  
   398  	// Derived keys MUST provide a context and the master underlying key is
   399  	// never used. If convergent encryption is true, the context will be used
   400  	// as the nonce as well.
   401  	Derived              bool `json:"derived"`
   402  	KDF                  int  `json:"kdf"`
   403  	ConvergentEncryption bool `json:"convergent_encryption"`
   404  
   405  	// Whether the key is exportable
   406  	Exportable bool `json:"exportable"`
   407  
   408  	// The minimum version of the key allowed to be used for decryption
   409  	MinDecryptionVersion int `json:"min_decryption_version"`
   410  
   411  	// The minimum version of the key allowed to be used for encryption
   412  	MinEncryptionVersion int `json:"min_encryption_version"`
   413  
   414  	// The latest key version in this policy
   415  	LatestVersion int `json:"latest_version"`
   416  
   417  	// The latest key version in the archive. We never delete these, so this is
   418  	// a max.
   419  	ArchiveVersion int `json:"archive_version"`
   420  
   421  	// ArchiveMinVersion is the minimum version of the key in the archive.
   422  	ArchiveMinVersion int `json:"archive_min_version"`
   423  
   424  	// MinAvailableVersion is the minimum version of the key present. All key
   425  	// versions before this would have been deleted.
   426  	MinAvailableVersion int `json:"min_available_version"`
   427  
   428  	// Whether the key is allowed to be deleted
   429  	DeletionAllowed bool `json:"deletion_allowed"`
   430  
   431  	// The version of the convergent nonce to use
   432  	ConvergentVersion int `json:"convergent_version"`
   433  
   434  	// The type of key
   435  	Type KeyType `json:"type"`
   436  
   437  	// BackupInfo indicates the information about the backup action taken on
   438  	// this policy
   439  	BackupInfo *BackupInfo `json:"backup_info"`
   440  
   441  	// RestoreInfo indicates the information about the restore action taken on
   442  	// this policy
   443  	RestoreInfo *RestoreInfo `json:"restore_info"`
   444  
   445  	// AllowPlaintextBackup allows taking backup of the policy in plaintext
   446  	AllowPlaintextBackup bool `json:"allow_plaintext_backup"`
   447  
   448  	// VersionTemplate is used to prefix the ciphertext with information about
   449  	// the key version. It must inclide {{version}} and a delimiter between the
   450  	// version prefix and the ciphertext.
   451  	VersionTemplate string `json:"version_template"`
   452  
   453  	// StoragePrefix is used to add a prefix when storing and retrieving the
   454  	// policy object.
   455  	StoragePrefix string `json:"storage_prefix"`
   456  
   457  	// AutoRotatePeriod defines how frequently the key should automatically
   458  	// rotate. Setting this to zero disables automatic rotation for the key.
   459  	AutoRotatePeriod time.Duration `json:"auto_rotate_period"`
   460  
   461  	// versionPrefixCache stores caches of version prefix strings and the split
   462  	// version template.
   463  	versionPrefixCache sync.Map
   464  
   465  	// Imported indicates whether the key was generated by Vault or imported
   466  	// from an external source
   467  	Imported bool
   468  
   469  	// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
   470  	AllowImportedKeyRotation bool
   471  }
   472  
   473  func (p *Policy) Lock(exclusive bool) {
   474  	if exclusive {
   475  		p.l.Lock()
   476  		p.writeLocked = true
   477  	} else {
   478  		p.l.RLock()
   479  	}
   480  }
   481  
   482  func (p *Policy) Unlock() {
   483  	if p.writeLocked {
   484  		p.writeLocked = false
   485  		p.l.Unlock()
   486  	} else {
   487  		p.l.RUnlock()
   488  	}
   489  }
   490  
   491  // ArchivedKeys stores old keys. This is used to keep the key loading time sane
   492  // when there are huge numbers of rotations.
   493  type archivedKeys struct {
   494  	Keys []KeyEntry `json:"keys"`
   495  }
   496  
   497  func (p *Policy) LoadArchive(ctx context.Context, storage logical.Storage) (*archivedKeys, error) {
   498  	archive := &archivedKeys{}
   499  
   500  	raw, err := storage.Get(ctx, path.Join(p.StoragePrefix, "archive", p.Name))
   501  	if err != nil {
   502  		return nil, err
   503  	}
   504  	if raw == nil {
   505  		archive.Keys = make([]KeyEntry, 0)
   506  		return archive, nil
   507  	}
   508  
   509  	if err := jsonutil.DecodeJSON(raw.Value, archive); err != nil {
   510  		return nil, err
   511  	}
   512  
   513  	return archive, nil
   514  }
   515  
   516  func (p *Policy) storeArchive(ctx context.Context, storage logical.Storage, archive *archivedKeys) error {
   517  	// Encode the policy
   518  	buf, err := json.Marshal(archive)
   519  	if err != nil {
   520  		return err
   521  	}
   522  
   523  	// Write the policy into storage
   524  	err = storage.Put(ctx, &logical.StorageEntry{
   525  		Key:   path.Join(p.StoragePrefix, "archive", p.Name),
   526  		Value: buf,
   527  	})
   528  	if err != nil {
   529  		return err
   530  	}
   531  
   532  	return nil
   533  }
   534  
   535  // handleArchiving manages the movement of keys to and from the policy archive.
   536  // This should *ONLY* be called from Persist() since it assumes that the policy
   537  // will be persisted afterwards.
   538  func (p *Policy) handleArchiving(ctx context.Context, storage logical.Storage) error {
   539  	// We need to move keys that are no longer accessible to archivedKeys, and keys
   540  	// that now need to be accessible back here.
   541  	//
   542  	// For safety, because there isn't really a good reason to, we never delete
   543  	// keys from the archive even when we move them back.
   544  
   545  	// Check if we have the latest minimum version in the current set of keys
   546  	_, keysContainsMinimum := p.Keys[strconv.Itoa(p.MinDecryptionVersion)]
   547  
   548  	// Sanity checks
   549  	switch {
   550  	case p.MinDecryptionVersion < 1:
   551  		return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
   552  	case p.LatestVersion < 1:
   553  		return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
   554  	case !keysContainsMinimum && p.ArchiveVersion != p.LatestVersion:
   555  		return fmt.Errorf("need to move keys from archive but archive version not up-to-date")
   556  	case p.ArchiveVersion > p.LatestVersion:
   557  		return fmt.Errorf("archive version of %d is greater than the latest version %d",
   558  			p.ArchiveVersion, p.LatestVersion)
   559  	case p.MinEncryptionVersion > 0 && p.MinEncryptionVersion < p.MinDecryptionVersion:
   560  		return fmt.Errorf("minimum decryption version of %d is greater than minimum encryption version %d",
   561  			p.MinDecryptionVersion, p.MinEncryptionVersion)
   562  	case p.MinDecryptionVersion > p.LatestVersion:
   563  		return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
   564  			p.MinDecryptionVersion, p.LatestVersion)
   565  	}
   566  
   567  	archive, err := p.LoadArchive(ctx, storage)
   568  	if err != nil {
   569  		return err
   570  	}
   571  
   572  	if !keysContainsMinimum {
   573  		// Need to move keys *from* archive
   574  		for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
   575  			p.Keys[strconv.Itoa(i)] = archive.Keys[i-p.MinAvailableVersion]
   576  		}
   577  
   578  		return nil
   579  	}
   580  
   581  	// Need to move keys *to* archive
   582  
   583  	// We need a size that is equivalent to the latest version (number of keys)
   584  	// but adding one since slice numbering starts at 0 and we're indexing by
   585  	// key version
   586  	if len(archive.Keys)+p.MinAvailableVersion < p.LatestVersion+1 {
   587  		// Increase the size of the archive slice
   588  		newKeys := make([]KeyEntry, p.LatestVersion-p.MinAvailableVersion+1)
   589  		copy(newKeys, archive.Keys)
   590  		archive.Keys = newKeys
   591  	}
   592  
   593  	// We are storing all keys in the archive, so we ensure that it is up to
   594  	// date up to p.LatestVersion
   595  	for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
   596  		archive.Keys[i-p.MinAvailableVersion] = p.Keys[strconv.Itoa(i)]
   597  		p.ArchiveVersion = i
   598  	}
   599  
   600  	// Trim the keys if required
   601  	if p.ArchiveMinVersion < p.MinAvailableVersion {
   602  		archive.Keys = archive.Keys[p.MinAvailableVersion-p.ArchiveMinVersion:]
   603  		p.ArchiveMinVersion = p.MinAvailableVersion
   604  	}
   605  
   606  	err = p.storeArchive(ctx, storage, archive)
   607  	if err != nil {
   608  		return err
   609  	}
   610  
   611  	// Perform deletion afterwards so that if there is an error saving we
   612  	// haven't messed with the current policy
   613  	for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
   614  		delete(p.Keys, strconv.Itoa(i))
   615  	}
   616  
   617  	return nil
   618  }
   619  
   620  func (p *Policy) Persist(ctx context.Context, storage logical.Storage) (retErr error) {
   621  	if atomic.LoadUint32(&p.deleted) == 1 {
   622  		return errors.New("key has been deleted, not persisting")
   623  	}
   624  
   625  	// Other functions will take care of restoring other values; this is just
   626  	// responsible for archiving and keys since the archive function can modify
   627  	// keys. At the moment one of the other functions calling persist will also
   628  	// roll back keys, but better safe than sorry and this doesn't happen
   629  	// enough to worry about the speed tradeoff.
   630  	priorArchiveVersion := p.ArchiveVersion
   631  	var priorKeys keyEntryMap
   632  
   633  	if p.Keys != nil {
   634  		priorKeys = keyEntryMap{}
   635  		for k, v := range p.Keys {
   636  			priorKeys[k] = v
   637  		}
   638  	}
   639  
   640  	defer func() {
   641  		if retErr != nil {
   642  			p.ArchiveVersion = priorArchiveVersion
   643  			p.Keys = priorKeys
   644  		}
   645  	}()
   646  
   647  	err := p.handleArchiving(ctx, storage)
   648  	if err != nil {
   649  		return err
   650  	}
   651  
   652  	// Encode the policy
   653  	buf, err := p.Serialize()
   654  	if err != nil {
   655  		return err
   656  	}
   657  
   658  	// Write the policy into storage
   659  	err = storage.Put(ctx, &logical.StorageEntry{
   660  		Key:   path.Join(p.StoragePrefix, "policy", p.Name),
   661  		Value: buf,
   662  	})
   663  	if err != nil {
   664  		return err
   665  	}
   666  
   667  	return nil
   668  }
   669  
   670  func (p *Policy) Serialize() ([]byte, error) {
   671  	return json.Marshal(p)
   672  }
   673  
   674  func (p *Policy) NeedsUpgrade() bool {
   675  	// Ensure we've moved from Key -> Keys
   676  	if p.Key != nil && len(p.Key) > 0 {
   677  		return true
   678  	}
   679  
   680  	// With archiving, past assumptions about the length of the keys map are no
   681  	// longer valid
   682  	if p.LatestVersion == 0 && len(p.Keys) != 0 {
   683  		return true
   684  	}
   685  
   686  	// We disallow setting the version to 0, since they start at 1 since moving
   687  	// to rotate-able keys, so update if it's set to 0
   688  	if p.MinDecryptionVersion == 0 {
   689  		return true
   690  	}
   691  
   692  	// On first load after an upgrade, copy keys to the archive
   693  	if p.ArchiveVersion == 0 {
   694  		return true
   695  	}
   696  
   697  	// Need to write the version if zero; for version 3 on we set this to -1 to
   698  	// ignore it since we store this information in each key entry
   699  	if p.ConvergentEncryption && p.ConvergentVersion == 0 {
   700  		return true
   701  	}
   702  
   703  	if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
   704  		return true
   705  	}
   706  
   707  	return false
   708  }
   709  
   710  func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
   711  	priorKey := p.Key
   712  	priorLatestVersion := p.LatestVersion
   713  	priorMinDecryptionVersion := p.MinDecryptionVersion
   714  	priorConvergentVersion := p.ConvergentVersion
   715  	var priorKeys keyEntryMap
   716  
   717  	if p.Keys != nil {
   718  		priorKeys = keyEntryMap{}
   719  		for k, v := range p.Keys {
   720  			priorKeys[k] = v
   721  		}
   722  	}
   723  
   724  	defer func() {
   725  		if retErr != nil {
   726  			p.Key = priorKey
   727  			p.LatestVersion = priorLatestVersion
   728  			p.MinDecryptionVersion = priorMinDecryptionVersion
   729  			p.ConvergentVersion = priorConvergentVersion
   730  			p.Keys = priorKeys
   731  		}
   732  	}()
   733  
   734  	persistNeeded := false
   735  	// Ensure we've moved from Key -> Keys
   736  	if p.Key != nil && len(p.Key) > 0 {
   737  		p.MigrateKeyToKeysMap()
   738  		persistNeeded = true
   739  	}
   740  
   741  	// With archiving, past assumptions about the length of the keys map are no
   742  	// longer valid
   743  	if p.LatestVersion == 0 && len(p.Keys) != 0 {
   744  		p.LatestVersion = len(p.Keys)
   745  		persistNeeded = true
   746  	}
   747  
   748  	// We disallow setting the version to 0, since they start at 1 since moving
   749  	// to rotate-able keys, so update if it's set to 0
   750  	if p.MinDecryptionVersion == 0 {
   751  		p.MinDecryptionVersion = 1
   752  		persistNeeded = true
   753  	}
   754  
   755  	// On first load after an upgrade, copy keys to the archive
   756  	if p.ArchiveVersion == 0 {
   757  		persistNeeded = true
   758  	}
   759  
   760  	if p.ConvergentEncryption && p.ConvergentVersion == 0 {
   761  		p.ConvergentVersion = 1
   762  		persistNeeded = true
   763  	}
   764  
   765  	if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
   766  		entry := p.Keys[strconv.Itoa(p.LatestVersion)]
   767  		hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
   768  		if err != nil {
   769  			return err
   770  		}
   771  		entry.HMACKey = hmacKey
   772  		p.Keys[strconv.Itoa(p.LatestVersion)] = entry
   773  		persistNeeded = true
   774  
   775  		if p.Type == KeyType_HMAC {
   776  			entry.HMACKey = entry.Key
   777  		}
   778  	}
   779  
   780  	if persistNeeded {
   781  		err := p.Persist(ctx, storage)
   782  		if err != nil {
   783  			return err
   784  		}
   785  	}
   786  
   787  	return nil
   788  }
   789  
   790  // GetKey is used to derive the encryption key that should be used depending
   791  // on the policy. If derivation is disabled the raw key is used and no context
   792  // is required, otherwise the KDF mode is used with the context to derive the
   793  // proper key.
   794  func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) {
   795  	// Fast-path non-derived keys
   796  	if !p.Derived {
   797  		keyEntry, err := p.safeGetKeyEntry(ver)
   798  		if err != nil {
   799  			return nil, err
   800  		}
   801  
   802  		return keyEntry.Key, nil
   803  	}
   804  
   805  	return p.DeriveKey(context, nil, ver, numBytes)
   806  }
   807  
   808  // DeriveKey is used to derive a symmetric key given a context and salt.  This does not
   809  // check the policies Derived flag, but just implements the derivation logic.  GetKey
   810  // is responsible for switching on the policy config.
   811  func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) {
   812  	if !p.Type.DerivationSupported() {
   813  		return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
   814  	}
   815  
   816  	if p.Keys == nil || p.LatestVersion == 0 {
   817  		return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
   818  	}
   819  
   820  	if ver <= 0 || ver > p.LatestVersion {
   821  		return nil, errutil.UserError{Err: "invalid key version"}
   822  	}
   823  
   824  	// Ensure a context is provided
   825  	if len(context) == 0 {
   826  		return nil, errutil.UserError{Err: "missing 'context' for key derivation; the key was created using a derived key, which means additional, per-request information must be included in order to perform operations with the key"}
   827  	}
   828  
   829  	keyEntry, err := p.safeGetKeyEntry(ver)
   830  	if err != nil {
   831  		return nil, err
   832  	}
   833  
   834  	switch p.KDF {
   835  	case Kdf_hmac_sha256_counter:
   836  		prf := kdf.HMACSHA256PRF
   837  		prfLen := kdf.HMACSHA256PRFLen
   838  		return kdf.CounterMode(prf, prfLen, keyEntry.Key, append(context, salt...), 256)
   839  
   840  	case Kdf_hkdf_sha256:
   841  		reader := hkdf.New(sha256.New, keyEntry.Key, salt, context)
   842  		derBytes := bytes.NewBuffer(nil)
   843  		derBytes.Grow(numBytes)
   844  		limReader := &io.LimitedReader{
   845  			R: reader,
   846  			N: int64(numBytes),
   847  		}
   848  
   849  		switch p.Type {
   850  		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   851  			n, err := derBytes.ReadFrom(limReader)
   852  			if err != nil {
   853  				return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
   854  			}
   855  			if n != int64(numBytes) {
   856  				return nil, errutil.InternalError{Err: fmt.Sprintf("unable to read enough derived bytes, needed %d, got %d", numBytes, n)}
   857  			}
   858  			return derBytes.Bytes(), nil
   859  
   860  		case KeyType_ED25519:
   861  			// We use the limited reader containing the derived bytes as the
   862  			// "random" input to the generation function
   863  			_, pri, err := ed25519.GenerateKey(limReader)
   864  			if err != nil {
   865  				return nil, errutil.InternalError{Err: fmt.Sprintf("error generating derived key: %v", err)}
   866  			}
   867  			return pri, nil
   868  
   869  		default:
   870  			return nil, errutil.InternalError{Err: "unsupported key type for derivation"}
   871  		}
   872  
   873  	default:
   874  		return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
   875  	}
   876  }
   877  
   878  func (p *Policy) safeGetKeyEntry(ver int) (KeyEntry, error) {
   879  	keyVerStr := strconv.Itoa(ver)
   880  	keyEntry, ok := p.Keys[keyVerStr]
   881  	if !ok {
   882  		return keyEntry, errutil.UserError{Err: "no such key version"}
   883  	}
   884  	return keyEntry, nil
   885  }
   886  
   887  func (p *Policy) convergentVersion(ver int) int {
   888  	if !p.ConvergentEncryption {
   889  		return 0
   890  	}
   891  
   892  	convergentVersion := p.ConvergentVersion
   893  	if convergentVersion == 0 {
   894  		// For some reason, not upgraded yet
   895  		convergentVersion = 1
   896  	}
   897  	currKey := p.Keys[strconv.Itoa(ver)]
   898  	if currKey.ConvergentVersion != 0 {
   899  		convergentVersion = currKey.ConvergentVersion
   900  	}
   901  
   902  	return convergentVersion
   903  }
   904  
   905  func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, error) {
   906  	return p.EncryptWithFactory(ver, context, nonce, value, nil)
   907  }
   908  
   909  func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
   910  	return p.DecryptWithFactory(context, nonce, value, nil)
   911  }
   912  
   913  func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
   914  	if !p.Type.DecryptionSupported() {
   915  		return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
   916  	}
   917  
   918  	tplParts, err := p.getTemplateParts()
   919  	if err != nil {
   920  		return "", err
   921  	}
   922  
   923  	// Verify the prefix
   924  	if !strings.HasPrefix(value, tplParts[0]) {
   925  		return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
   926  	}
   927  
   928  	splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, tplParts[0]), tplParts[1], 2)
   929  	if len(splitVerCiphertext) != 2 {
   930  		return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
   931  	}
   932  
   933  	ver, err := strconv.Atoi(splitVerCiphertext[0])
   934  	if err != nil {
   935  		return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
   936  	}
   937  
   938  	if ver == 0 {
   939  		// Compatibility mode with initial implementation, where keys start at
   940  		// zero
   941  		ver = 1
   942  	}
   943  
   944  	if ver > p.LatestVersion {
   945  		return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
   946  	}
   947  
   948  	if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
   949  		return "", errutil.UserError{Err: ErrTooOld}
   950  	}
   951  
   952  	convergentVersion := p.convergentVersion(ver)
   953  	if convergentVersion == 1 && (nonce == nil || len(nonce) == 0) {
   954  		return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
   955  	}
   956  
   957  	// Decode the base64
   958  	decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
   959  	if err != nil {
   960  		return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
   961  	}
   962  
   963  	var plain []byte
   964  
   965  	switch p.Type {
   966  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   967  		numBytes := 32
   968  		if p.Type == KeyType_AES128_GCM96 {
   969  			numBytes = 16
   970  		}
   971  
   972  		encKey, err := p.GetKey(context, ver, numBytes)
   973  		if err != nil {
   974  			return "", err
   975  		}
   976  
   977  		if len(encKey) != numBytes {
   978  			return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
   979  		}
   980  
   981  		symopts := SymmetricOpts{
   982  			Convergent:        p.ConvergentEncryption,
   983  			ConvergentVersion: p.ConvergentVersion,
   984  		}
   985  		for index, rawFactory := range factories {
   986  			if rawFactory == nil {
   987  				continue
   988  			}
   989  			switch factory := rawFactory.(type) {
   990  			case AEADFactory:
   991  				symopts.AEADFactory = factory
   992  			case AssociatedDataFactory:
   993  				symopts.AdditionalData, err = factory.GetAssociatedData()
   994  				if err != nil {
   995  					return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
   996  				}
   997  			case ManagedKeyFactory:
   998  			default:
   999  				return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
  1000  			}
  1001  		}
  1002  
  1003  		plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
  1004  		if err != nil {
  1005  			return "", err
  1006  		}
  1007  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1008  		keyEntry, err := p.safeGetKeyEntry(ver)
  1009  		if err != nil {
  1010  			return "", err
  1011  		}
  1012  		key := keyEntry.RSAKey
  1013  		if key == nil {
  1014  			return "", errutil.InternalError{Err: fmt.Sprintf("cannot decrypt ciphertext, key version does not have a private counterpart")}
  1015  		}
  1016  		plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil)
  1017  		if err != nil {
  1018  			return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)}
  1019  		}
  1020  	case KeyType_MANAGED_KEY:
  1021  		keyEntry, err := p.safeGetKeyEntry(ver)
  1022  		if err != nil {
  1023  			return "", err
  1024  		}
  1025  		var aad []byte
  1026  		var managedKeyFactory ManagedKeyFactory
  1027  		for _, f := range factories {
  1028  			switch factory := f.(type) {
  1029  			case AssociatedDataFactory:
  1030  				aad, err = factory.GetAssociatedData()
  1031  				if err != nil {
  1032  					return "", err
  1033  				}
  1034  			case ManagedKeyFactory:
  1035  				managedKeyFactory = factory
  1036  			}
  1037  		}
  1038  
  1039  		if managedKeyFactory == nil {
  1040  			return "", errors.New("key type is managed_key, but managed key parameters were not provided")
  1041  		}
  1042  
  1043  		plain, err = p.decryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, decoded, nonce, aad)
  1044  		if err != nil {
  1045  			return "", err
  1046  		}
  1047  
  1048  	default:
  1049  		return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  1050  	}
  1051  
  1052  	return base64.StdEncoding.EncodeToString(plain), nil
  1053  }
  1054  
  1055  func (p *Policy) HMACKey(version int) ([]byte, error) {
  1056  	switch {
  1057  	case version < 0:
  1058  		return nil, fmt.Errorf("key version does not exist (cannot be negative)")
  1059  	case version > p.LatestVersion:
  1060  		return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
  1061  	}
  1062  	keyEntry, err := p.safeGetKeyEntry(version)
  1063  	if err != nil {
  1064  		return nil, err
  1065  	}
  1066  
  1067  	if p.Type == KeyType_HMAC {
  1068  		return keyEntry.Key, nil
  1069  	}
  1070  	if keyEntry.HMACKey == nil {
  1071  		return nil, fmt.Errorf("no HMAC key exists for that key version")
  1072  	}
  1073  	return keyEntry.HMACKey, nil
  1074  }
  1075  
  1076  func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
  1077  	return p.SignWithOptions(ver, context, input, &SigningOptions{
  1078  		HashAlgorithm: hashAlgorithm,
  1079  		Marshaling:    marshaling,
  1080  		SaltLength:    rsa.PSSSaltLengthAuto,
  1081  		SigAlgorithm:  sigAlgorithm,
  1082  	})
  1083  }
  1084  
  1085  func (p *Policy) minRSAPSSSaltLength() int {
  1086  	// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
  1087  	return rsa.PSSSaltLengthEqualsHash
  1088  }
  1089  
  1090  func (p *Policy) maxRSAPSSSaltLength(keyBitLen int, hash crypto.Hash) int {
  1091  	// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
  1092  	return (keyBitLen-1+7)/8 - 2 - hash.Size()
  1093  }
  1094  
  1095  func (p *Policy) validRSAPSSSaltLength(keyBitLen int, hash crypto.Hash, saltLength int) bool {
  1096  	return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(keyBitLen, hash)
  1097  }
  1098  
  1099  func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
  1100  	if !p.Type.SigningSupported() {
  1101  		return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
  1102  	}
  1103  
  1104  	switch {
  1105  	case ver == 0:
  1106  		ver = p.LatestVersion
  1107  	case ver < 0:
  1108  		return nil, errutil.UserError{Err: "requested version for signing is negative"}
  1109  	case ver > p.LatestVersion:
  1110  		return nil, errutil.UserError{Err: "requested version for signing is higher than the latest key version"}
  1111  	case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
  1112  		return nil, errutil.UserError{Err: "requested version for signing is less than the minimum encryption key version"}
  1113  	}
  1114  
  1115  	var sig []byte
  1116  	var pubKey []byte
  1117  	var err error
  1118  	keyParams, err := p.safeGetKeyEntry(ver)
  1119  	if err != nil {
  1120  		return nil, err
  1121  	}
  1122  
  1123  	// Before signing, check if key has its private part, if not return error
  1124  	if keyParams.IsPrivateKeyMissing() {
  1125  		return nil, errutil.UserError{Err: "requested version for signing does not contain a private part"}
  1126  	}
  1127  
  1128  	hashAlgorithm := options.HashAlgorithm
  1129  	marshaling := options.Marshaling
  1130  	saltLength := options.SaltLength
  1131  	sigAlgorithm := options.SigAlgorithm
  1132  
  1133  	switch p.Type {
  1134  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1135  		var curveBits int
  1136  		var curve elliptic.Curve
  1137  		switch p.Type {
  1138  		case KeyType_ECDSA_P384:
  1139  			curveBits = 384
  1140  			curve = elliptic.P384()
  1141  		case KeyType_ECDSA_P521:
  1142  			curveBits = 521
  1143  			curve = elliptic.P521()
  1144  		default:
  1145  			curveBits = 256
  1146  			curve = elliptic.P256()
  1147  		}
  1148  
  1149  		key := &ecdsa.PrivateKey{
  1150  			PublicKey: ecdsa.PublicKey{
  1151  				Curve: curve,
  1152  				X:     keyParams.EC_X,
  1153  				Y:     keyParams.EC_Y,
  1154  			},
  1155  			D: keyParams.EC_D,
  1156  		}
  1157  
  1158  		r, s, err := ecdsa.Sign(rand.Reader, key, input)
  1159  		if err != nil {
  1160  			return nil, err
  1161  		}
  1162  
  1163  		switch marshaling {
  1164  		case MarshalingTypeASN1:
  1165  			// This is used by openssl and X.509
  1166  			sig, err = asn1.Marshal(ecdsaSignature{
  1167  				R: r,
  1168  				S: s,
  1169  			})
  1170  			if err != nil {
  1171  				return nil, err
  1172  			}
  1173  
  1174  		case MarshalingTypeJWS:
  1175  			// This is used by JWS
  1176  
  1177  			// First we have to get the length of the curve in bytes. Although
  1178  			// we only support 256 now, we'll do this in an agnostic way so we
  1179  			// can reuse this marshaling if we support e.g. 521. Getting the
  1180  			// number of bytes without rounding up would be 65.125 so we need
  1181  			// to add one in that case.
  1182  			keyLen := curveBits / 8
  1183  			if curveBits%8 > 0 {
  1184  				keyLen++
  1185  			}
  1186  
  1187  			// Now create the output array
  1188  			sig = make([]byte, keyLen*2)
  1189  			rb := r.Bytes()
  1190  			sb := s.Bytes()
  1191  			copy(sig[keyLen-len(rb):], rb)
  1192  			copy(sig[2*keyLen-len(sb):], sb)
  1193  
  1194  		default:
  1195  			return nil, errutil.UserError{Err: "requested marshaling type is invalid"}
  1196  		}
  1197  
  1198  	case KeyType_ED25519:
  1199  		var key ed25519.PrivateKey
  1200  
  1201  		if p.Derived {
  1202  			// Derive the key that should be used
  1203  			var err error
  1204  			key, err = p.GetKey(context, ver, 32)
  1205  			if err != nil {
  1206  				return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
  1207  			}
  1208  			pubKey = key.Public().(ed25519.PublicKey)
  1209  		} else {
  1210  			key = ed25519.PrivateKey(keyParams.Key)
  1211  		}
  1212  
  1213  		// Per docs, do not pre-hash ed25519; it does two passes and performs
  1214  		// its own hashing
  1215  		sig, err = key.Sign(rand.Reader, input, crypto.Hash(0))
  1216  		if err != nil {
  1217  			return nil, err
  1218  		}
  1219  
  1220  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1221  		key := keyParams.RSAKey
  1222  
  1223  		algo, ok := CryptoHashMap[hashAlgorithm]
  1224  		if !ok {
  1225  			return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
  1226  		}
  1227  
  1228  		if sigAlgorithm == "" {
  1229  			sigAlgorithm = "pss"
  1230  		}
  1231  
  1232  		switch sigAlgorithm {
  1233  		case "pss":
  1234  			if !p.validRSAPSSSaltLength(key.N.BitLen(), algo, saltLength) {
  1235  				return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
  1236  			}
  1237  			sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
  1238  			if err != nil {
  1239  				return nil, err
  1240  			}
  1241  		case "pkcs1v15":
  1242  			sig, err = rsa.SignPKCS1v15(rand.Reader, key, algo, input)
  1243  			if err != nil {
  1244  				return nil, err
  1245  			}
  1246  		default:
  1247  			return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
  1248  		}
  1249  
  1250  	case KeyType_MANAGED_KEY:
  1251  		keyEntry, err := p.safeGetKeyEntry(ver)
  1252  		if err != nil {
  1253  			return nil, err
  1254  		}
  1255  
  1256  		sig, err = p.signWithManagedKey(options, keyEntry, input)
  1257  		if err != nil {
  1258  			return nil, err
  1259  		}
  1260  
  1261  	default:
  1262  		return nil, fmt.Errorf("unsupported key type %v", p.Type)
  1263  	}
  1264  
  1265  	// Convert to base64
  1266  	var encoded string
  1267  	switch marshaling {
  1268  	case MarshalingTypeASN1:
  1269  		encoded = base64.StdEncoding.EncodeToString(sig)
  1270  	case MarshalingTypeJWS:
  1271  		encoded = base64.RawURLEncoding.EncodeToString(sig)
  1272  	}
  1273  	res := &SigningResult{
  1274  		Signature: p.getVersionPrefix(ver) + encoded,
  1275  		PublicKey: pubKey,
  1276  	}
  1277  
  1278  	return res, nil
  1279  }
  1280  
  1281  func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
  1282  	return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
  1283  		HashAlgorithm: hashAlgorithm,
  1284  		Marshaling:    marshaling,
  1285  		SaltLength:    rsa.PSSSaltLengthAuto,
  1286  		SigAlgorithm:  sigAlgorithm,
  1287  	})
  1288  }
  1289  
  1290  func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
  1291  	if !p.Type.SigningSupported() {
  1292  		return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
  1293  	}
  1294  
  1295  	tplParts, err := p.getTemplateParts()
  1296  	if err != nil {
  1297  		return false, err
  1298  	}
  1299  
  1300  	// Verify the prefix
  1301  	if !strings.HasPrefix(sig, tplParts[0]) {
  1302  		return false, errutil.UserError{Err: "invalid signature: no prefix"}
  1303  	}
  1304  
  1305  	splitVerSig := strings.SplitN(strings.TrimPrefix(sig, tplParts[0]), tplParts[1], 2)
  1306  	if len(splitVerSig) != 2 {
  1307  		return false, errutil.UserError{Err: "invalid signature: wrong number of fields"}
  1308  	}
  1309  
  1310  	ver, err := strconv.Atoi(splitVerSig[0])
  1311  	if err != nil {
  1312  		return false, errutil.UserError{Err: "invalid signature: version number could not be decoded"}
  1313  	}
  1314  
  1315  	if ver > p.LatestVersion {
  1316  		return false, errutil.UserError{Err: "invalid signature: version is too new"}
  1317  	}
  1318  
  1319  	if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
  1320  		return false, errutil.UserError{Err: ErrTooOld}
  1321  	}
  1322  
  1323  	hashAlgorithm := options.HashAlgorithm
  1324  	marshaling := options.Marshaling
  1325  	saltLength := options.SaltLength
  1326  	sigAlgorithm := options.SigAlgorithm
  1327  
  1328  	var sigBytes []byte
  1329  	switch marshaling {
  1330  	case MarshalingTypeASN1:
  1331  		sigBytes, err = base64.StdEncoding.DecodeString(splitVerSig[1])
  1332  	case MarshalingTypeJWS:
  1333  		sigBytes, err = base64.RawURLEncoding.DecodeString(splitVerSig[1])
  1334  	default:
  1335  		return false, errutil.UserError{Err: "requested marshaling type is invalid"}
  1336  	}
  1337  	if err != nil {
  1338  		return false, errutil.UserError{Err: "invalid base64 signature value"}
  1339  	}
  1340  
  1341  	switch p.Type {
  1342  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1343  		var curve elliptic.Curve
  1344  		switch p.Type {
  1345  		case KeyType_ECDSA_P384:
  1346  			curve = elliptic.P384()
  1347  		case KeyType_ECDSA_P521:
  1348  			curve = elliptic.P521()
  1349  		default:
  1350  			curve = elliptic.P256()
  1351  		}
  1352  
  1353  		var ecdsaSig ecdsaSignature
  1354  
  1355  		switch marshaling {
  1356  		case MarshalingTypeASN1:
  1357  			rest, err := asn1.Unmarshal(sigBytes, &ecdsaSig)
  1358  			if err != nil {
  1359  				return false, errutil.UserError{Err: "supplied signature is invalid"}
  1360  			}
  1361  			if rest != nil && len(rest) != 0 {
  1362  				return false, errutil.UserError{Err: "supplied signature contains extra data"}
  1363  			}
  1364  
  1365  		case MarshalingTypeJWS:
  1366  			paramLen := len(sigBytes) / 2
  1367  			rb := sigBytes[:paramLen]
  1368  			sb := sigBytes[paramLen:]
  1369  			ecdsaSig.R = new(big.Int)
  1370  			ecdsaSig.R.SetBytes(rb)
  1371  			ecdsaSig.S = new(big.Int)
  1372  			ecdsaSig.S.SetBytes(sb)
  1373  		}
  1374  
  1375  		keyParams, err := p.safeGetKeyEntry(ver)
  1376  		if err != nil {
  1377  			return false, err
  1378  		}
  1379  		key := &ecdsa.PublicKey{
  1380  			Curve: curve,
  1381  			X:     keyParams.EC_X,
  1382  			Y:     keyParams.EC_Y,
  1383  		}
  1384  
  1385  		return ecdsa.Verify(key, input, ecdsaSig.R, ecdsaSig.S), nil
  1386  
  1387  	case KeyType_ED25519:
  1388  		var pub ed25519.PublicKey
  1389  
  1390  		if p.Derived {
  1391  			// Derive the key that should be used
  1392  			key, err := p.GetKey(context, ver, 32)
  1393  			if err != nil {
  1394  				return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
  1395  			}
  1396  			pub = ed25519.PrivateKey(key).Public().(ed25519.PublicKey)
  1397  		} else {
  1398  			keyEntry, err := p.safeGetKeyEntry(ver)
  1399  			if err != nil {
  1400  				return false, err
  1401  			}
  1402  
  1403  			raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  1404  			if err != nil {
  1405  				return false, err
  1406  			}
  1407  
  1408  			pub = ed25519.PublicKey(raw)
  1409  		}
  1410  
  1411  		return ed25519.Verify(pub, input, sigBytes), nil
  1412  
  1413  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1414  		keyEntry, err := p.safeGetKeyEntry(ver)
  1415  		if err != nil {
  1416  			return false, err
  1417  		}
  1418  
  1419  		algo, ok := CryptoHashMap[hashAlgorithm]
  1420  		if !ok {
  1421  			return false, errutil.InternalError{Err: "unsupported hash algorithm"}
  1422  		}
  1423  
  1424  		if sigAlgorithm == "" {
  1425  			sigAlgorithm = "pss"
  1426  		}
  1427  
  1428  		switch sigAlgorithm {
  1429  		case "pss":
  1430  			publicKey := keyEntry.RSAPublicKey
  1431  			if !keyEntry.IsPrivateKeyMissing() {
  1432  				publicKey = &keyEntry.RSAKey.PublicKey
  1433  			}
  1434  			if !p.validRSAPSSSaltLength(publicKey.N.BitLen(), algo, saltLength) {
  1435  				return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
  1436  			}
  1437  			err = rsa.VerifyPSS(publicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
  1438  		case "pkcs1v15":
  1439  			publicKey := keyEntry.RSAPublicKey
  1440  			if !keyEntry.IsPrivateKeyMissing() {
  1441  				publicKey = &keyEntry.RSAKey.PublicKey
  1442  			}
  1443  			err = rsa.VerifyPKCS1v15(publicKey, algo, input, sigBytes)
  1444  		default:
  1445  			return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
  1446  		}
  1447  
  1448  		return err == nil, nil
  1449  
  1450  	case KeyType_MANAGED_KEY:
  1451  		keyEntry, err := p.safeGetKeyEntry(ver)
  1452  		if err != nil {
  1453  			return false, err
  1454  		}
  1455  
  1456  		return p.verifyWithManagedKey(options, keyEntry, input, sigBytes)
  1457  
  1458  	default:
  1459  		return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  1460  	}
  1461  }
  1462  
  1463  func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte, randReader io.Reader) error {
  1464  	return p.ImportPublicOrPrivate(ctx, storage, key, true, randReader)
  1465  }
  1466  
  1467  func (p *Policy) ImportPublicOrPrivate(ctx context.Context, storage logical.Storage, key []byte, isPrivateKey bool, randReader io.Reader) error {
  1468  	now := time.Now()
  1469  	entry := KeyEntry{
  1470  		CreationTime:           now,
  1471  		DeprecatedCreationTime: now.Unix(),
  1472  	}
  1473  
  1474  	// Before we insert this entry, check if the latest version is incomplete
  1475  	// and this entry matches the current version; if so, return without
  1476  	// updating to the next version.
  1477  	if p.LatestVersion > 0 {
  1478  		latestKey := p.Keys[strconv.Itoa(p.LatestVersion)]
  1479  		if latestKey.IsPrivateKeyMissing() && isPrivateKey {
  1480  			if err := p.ImportPrivateKeyForVersion(ctx, storage, p.LatestVersion, key); err == nil {
  1481  				return nil
  1482  			}
  1483  		}
  1484  	}
  1485  
  1486  	if p.Type != KeyType_HMAC {
  1487  		hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
  1488  		if err != nil {
  1489  			return err
  1490  		}
  1491  		entry.HMACKey = hmacKey
  1492  	}
  1493  
  1494  	if p.Type == KeyType_ED25519 && p.Derived && !isPrivateKey {
  1495  		return fmt.Errorf("unable to import only public key for derived Ed25519 key: imported key should not be an Ed25519 key pair but is instead an HKDF key")
  1496  	}
  1497  
  1498  	if (p.Type == KeyType_AES128_GCM96 && len(key) != 16) ||
  1499  		((p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305) && len(key) != 32) ||
  1500  		(p.Type == KeyType_HMAC && (len(key) < HmacMinKeySize || len(key) > HmacMaxKeySize)) {
  1501  		return fmt.Errorf("invalid key size %d bytes for key type %s", len(key), p.Type)
  1502  	}
  1503  
  1504  	if p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305 || p.Type == KeyType_HMAC {
  1505  		entry.Key = key
  1506  		if p.Type == KeyType_HMAC {
  1507  			p.KeySize = len(key)
  1508  			entry.HMACKey = key
  1509  		}
  1510  	} else {
  1511  		var parsedKey any
  1512  		var err error
  1513  		if isPrivateKey {
  1514  			parsedKey, err = x509.ParsePKCS8PrivateKey(key)
  1515  			if err != nil {
  1516  				if strings.Contains(err.Error(), "unknown elliptic curve") {
  1517  					var edErr error
  1518  					parsedKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
  1519  					if edErr != nil {
  1520  						return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
  1521  					}
  1522  
  1523  					// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
  1524  				} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
  1525  					var rsaErr error
  1526  					parsedKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
  1527  					if rsaErr != nil {
  1528  						return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
  1529  					}
  1530  
  1531  					// Parsing as RSA-PSS in PKCS8 succeeded!
  1532  				} else {
  1533  					return fmt.Errorf("error parsing asymmetric key: %s", err)
  1534  				}
  1535  			}
  1536  		} else {
  1537  			pemBlock, _ := pem.Decode(key)
  1538  			if pemBlock == nil {
  1539  				return fmt.Errorf("error parsing public key: not in PEM format")
  1540  			}
  1541  
  1542  			parsedKey, err = x509.ParsePKIXPublicKey(pemBlock.Bytes)
  1543  			if err != nil {
  1544  				return fmt.Errorf("error parsing public key: %w", err)
  1545  			}
  1546  		}
  1547  
  1548  		err = entry.parseFromKey(p.Type, parsedKey)
  1549  		if err != nil {
  1550  			return err
  1551  		}
  1552  	}
  1553  
  1554  	p.LatestVersion += 1
  1555  
  1556  	if p.Keys == nil {
  1557  		// This is an initial key rotation when generating a new policy. We
  1558  		// don't need to call migrate here because if we've called getPolicy to
  1559  		// get the policy in the first place it will have been run.
  1560  		p.Keys = keyEntryMap{}
  1561  	}
  1562  	p.Keys[strconv.Itoa(p.LatestVersion)] = entry
  1563  
  1564  	// This ensures that with new key creations min decryption version is set
  1565  	// to 1 rather than the int default of 0, since keys start at 1 (either
  1566  	// fresh or after migration to the key map)
  1567  	if p.MinDecryptionVersion == 0 {
  1568  		p.MinDecryptionVersion = 1
  1569  	}
  1570  
  1571  	return p.Persist(ctx, storage)
  1572  }
  1573  
  1574  // Rotate rotates the policy and persists it to storage.
  1575  // If the rotation partially fails, the policy state will be restored.
  1576  func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
  1577  	priorLatestVersion := p.LatestVersion
  1578  	priorMinDecryptionVersion := p.MinDecryptionVersion
  1579  	var priorKeys keyEntryMap
  1580  
  1581  	if p.Imported && !p.AllowImportedKeyRotation {
  1582  		return fmt.Errorf("imported key %s does not allow rotation within Vault", p.Name)
  1583  	}
  1584  
  1585  	if p.Keys != nil {
  1586  		priorKeys = keyEntryMap{}
  1587  		for k, v := range p.Keys {
  1588  			priorKeys[k] = v
  1589  		}
  1590  	}
  1591  
  1592  	defer func() {
  1593  		if retErr != nil {
  1594  			p.LatestVersion = priorLatestVersion
  1595  			p.MinDecryptionVersion = priorMinDecryptionVersion
  1596  			p.Keys = priorKeys
  1597  		}
  1598  	}()
  1599  
  1600  	if err := p.RotateInMemory(randReader); err != nil {
  1601  		return err
  1602  	}
  1603  
  1604  	p.Imported = false
  1605  	return p.Persist(ctx, storage)
  1606  }
  1607  
  1608  // RotateInMemory rotates the policy but does not persist it to storage.
  1609  func (p *Policy) RotateInMemory(randReader io.Reader) (retErr error) {
  1610  	now := time.Now()
  1611  	entry := KeyEntry{
  1612  		CreationTime:           now,
  1613  		DeprecatedCreationTime: now.Unix(),
  1614  	}
  1615  
  1616  	hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
  1617  	if err != nil {
  1618  		return err
  1619  	}
  1620  	entry.HMACKey = hmacKey
  1621  
  1622  	switch p.Type {
  1623  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC:
  1624  		// Default to 256 bit key
  1625  		numBytes := 32
  1626  		if p.Type == KeyType_AES128_GCM96 {
  1627  			numBytes = 16
  1628  		} else if p.Type == KeyType_HMAC {
  1629  			numBytes = p.KeySize
  1630  			if numBytes < HmacMinKeySize || numBytes > HmacMaxKeySize {
  1631  				return fmt.Errorf("invalid key size for HMAC key, must be between %d and %d bytes", HmacMinKeySize, HmacMaxKeySize)
  1632  			}
  1633  		}
  1634  		newKey, err := uuid.GenerateRandomBytesWithReader(numBytes, randReader)
  1635  		if err != nil {
  1636  			return err
  1637  		}
  1638  		entry.Key = newKey
  1639  
  1640  		if p.Type == KeyType_HMAC {
  1641  			// To avoid causing problems, ensure HMACKey = Key.
  1642  			entry.HMACKey = newKey
  1643  		}
  1644  
  1645  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  1646  		var curve elliptic.Curve
  1647  		switch p.Type {
  1648  		case KeyType_ECDSA_P384:
  1649  			curve = elliptic.P384()
  1650  		case KeyType_ECDSA_P521:
  1651  			curve = elliptic.P521()
  1652  		default:
  1653  			curve = elliptic.P256()
  1654  		}
  1655  
  1656  		privKey, err := ecdsa.GenerateKey(curve, rand.Reader)
  1657  		if err != nil {
  1658  			return err
  1659  		}
  1660  		entry.EC_D = privKey.D
  1661  		entry.EC_X = privKey.X
  1662  		entry.EC_Y = privKey.Y
  1663  		derBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
  1664  		if err != nil {
  1665  			return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  1666  		}
  1667  		pemBlock := &pem.Block{
  1668  			Type:  "PUBLIC KEY",
  1669  			Bytes: derBytes,
  1670  		}
  1671  		pemBytes := pem.EncodeToMemory(pemBlock)
  1672  		if pemBytes == nil || len(pemBytes) == 0 {
  1673  			return fmt.Errorf("error PEM-encoding public key")
  1674  		}
  1675  		entry.FormattedPublicKey = string(pemBytes)
  1676  
  1677  	case KeyType_ED25519:
  1678  		// Go uses a 64-byte private key for Ed25519 keys (private+public, each
  1679  		// 32-bytes long). When we do Key derivation, we still generate a 32-byte
  1680  		// random value (and compute the corresponding Ed25519 public key), but
  1681  		// use this entire 64-byte key as if it was an HKDF key. The corresponding
  1682  		// underlying public key is never returned (which is probably good, because
  1683  		// doing so would leak half of our HKDF key...), but means we cannot import
  1684  		// derived-enabled Ed25519 public key components.
  1685  		pub, pri, err := ed25519.GenerateKey(randReader)
  1686  		if err != nil {
  1687  			return err
  1688  		}
  1689  		entry.Key = pri
  1690  		entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub)
  1691  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  1692  		bitSize := 2048
  1693  		if p.Type == KeyType_RSA3072 {
  1694  			bitSize = 3072
  1695  		}
  1696  		if p.Type == KeyType_RSA4096 {
  1697  			bitSize = 4096
  1698  		}
  1699  
  1700  		entry.RSAKey, err = rsa.GenerateKey(randReader, bitSize)
  1701  		if err != nil {
  1702  			return err
  1703  		}
  1704  	}
  1705  
  1706  	if p.ConvergentEncryption {
  1707  		if p.ConvergentVersion == -1 || p.ConvergentVersion > 1 {
  1708  			entry.ConvergentVersion = currentConvergentVersion
  1709  		}
  1710  	}
  1711  
  1712  	p.LatestVersion += 1
  1713  
  1714  	if p.Keys == nil {
  1715  		// This is an initial key rotation when generating a new policy. We
  1716  		// don't need to call migrate here because if we've called getPolicy to
  1717  		// get the policy in the first place it will have been run.
  1718  		p.Keys = keyEntryMap{}
  1719  	}
  1720  	p.Keys[strconv.Itoa(p.LatestVersion)] = entry
  1721  
  1722  	// This ensures that with new key creations min decryption version is set
  1723  	// to 1 rather than the int default of 0, since keys start at 1 (either
  1724  	// fresh or after migration to the key map)
  1725  	if p.MinDecryptionVersion == 0 {
  1726  		p.MinDecryptionVersion = 1
  1727  	}
  1728  
  1729  	return nil
  1730  }
  1731  
  1732  func (p *Policy) MigrateKeyToKeysMap() {
  1733  	now := time.Now()
  1734  	p.Keys = keyEntryMap{
  1735  		"1": KeyEntry{
  1736  			Key:                    p.Key,
  1737  			CreationTime:           now,
  1738  			DeprecatedCreationTime: now.Unix(),
  1739  		},
  1740  	}
  1741  	p.Key = nil
  1742  }
  1743  
  1744  // Backup should be called with an exclusive lock held on the policy
  1745  func (p *Policy) Backup(ctx context.Context, storage logical.Storage) (out string, retErr error) {
  1746  	if !p.Exportable {
  1747  		return "", fmt.Errorf("exporting is disallowed on the policy")
  1748  	}
  1749  
  1750  	if !p.AllowPlaintextBackup {
  1751  		return "", fmt.Errorf("plaintext backup is disallowed on the policy")
  1752  	}
  1753  
  1754  	priorBackupInfo := p.BackupInfo
  1755  
  1756  	defer func() {
  1757  		if retErr != nil {
  1758  			p.BackupInfo = priorBackupInfo
  1759  		}
  1760  	}()
  1761  
  1762  	// Create a record of this backup operation in the policy
  1763  	p.BackupInfo = &BackupInfo{
  1764  		Time:    time.Now(),
  1765  		Version: p.LatestVersion,
  1766  	}
  1767  	err := p.Persist(ctx, storage)
  1768  	if err != nil {
  1769  		return "", errwrap.Wrapf("failed to persist policy with backup info: {{err}}", err)
  1770  	}
  1771  
  1772  	// Load the archive only after persisting the policy as the archive can get
  1773  	// adjusted while persisting the policy
  1774  	archivedKeys, err := p.LoadArchive(ctx, storage)
  1775  	if err != nil {
  1776  		return "", err
  1777  	}
  1778  
  1779  	keyData := &KeyData{
  1780  		Policy:       p,
  1781  		ArchivedKeys: archivedKeys,
  1782  	}
  1783  
  1784  	encodedBackup, err := jsonutil.EncodeJSON(keyData)
  1785  	if err != nil {
  1786  		return "", err
  1787  	}
  1788  
  1789  	return base64.StdEncoding.EncodeToString(encodedBackup), nil
  1790  }
  1791  
  1792  func (p *Policy) getTemplateParts() ([]string, error) {
  1793  	partsRaw, ok := p.versionPrefixCache.Load("template-parts")
  1794  	if ok {
  1795  		return partsRaw.([]string), nil
  1796  	}
  1797  
  1798  	template := p.VersionTemplate
  1799  	if template == "" {
  1800  		template = DefaultVersionTemplate
  1801  	}
  1802  
  1803  	tplParts := strings.Split(template, "{{version}}")
  1804  	if len(tplParts) != 2 {
  1805  		return nil, errutil.InternalError{Err: "error parsing version template"}
  1806  	}
  1807  
  1808  	p.versionPrefixCache.Store("template-parts", tplParts)
  1809  	return tplParts, nil
  1810  }
  1811  
  1812  func (p *Policy) getVersionPrefix(ver int) string {
  1813  	prefixRaw, ok := p.versionPrefixCache.Load(ver)
  1814  	if ok {
  1815  		return prefixRaw.(string)
  1816  	}
  1817  
  1818  	template := p.VersionTemplate
  1819  	if template == "" {
  1820  		template = DefaultVersionTemplate
  1821  	}
  1822  
  1823  	prefix := strings.ReplaceAll(template, "{{version}}", strconv.Itoa(ver))
  1824  	p.versionPrefixCache.Store(ver, prefix)
  1825  
  1826  	return prefix
  1827  }
  1828  
  1829  // SymmetricOpts are the arguments to symmetric operations that are "optional", e.g.
  1830  // not always used.  This improves the aesthetics of calls to those functions.
  1831  type SymmetricOpts struct {
  1832  	// Whether to use convergent encryption
  1833  	Convergent bool
  1834  	// The version of the convergent encryption scheme
  1835  	ConvergentVersion int
  1836  	// The nonce, if not randomly generated
  1837  	Nonce []byte
  1838  	// Additional data to include in AEAD authentication
  1839  	AdditionalData []byte
  1840  	// The HMAC key, for generating IVs in convergent encryption
  1841  	HMACKey []byte
  1842  	// Allows an external provider of the AEAD, for e.g. managed keys
  1843  	AEADFactory AEADFactory
  1844  }
  1845  
  1846  // Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys
  1847  func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) {
  1848  	var aead cipher.AEAD
  1849  	var err error
  1850  	nonce := opts.Nonce
  1851  
  1852  	switch p.Type {
  1853  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
  1854  		// Setup the cipher
  1855  		aesCipher, err := aes.NewCipher(encKey)
  1856  		if err != nil {
  1857  			return nil, errutil.InternalError{Err: err.Error()}
  1858  		}
  1859  
  1860  		// Setup the GCM AEAD
  1861  		gcm, err := cipher.NewGCM(aesCipher)
  1862  		if err != nil {
  1863  			return nil, errutil.InternalError{Err: err.Error()}
  1864  		}
  1865  
  1866  		aead = gcm
  1867  
  1868  	case KeyType_ChaCha20_Poly1305:
  1869  		cha, err := chacha20poly1305.New(encKey)
  1870  		if err != nil {
  1871  			return nil, errutil.InternalError{Err: err.Error()}
  1872  		}
  1873  
  1874  		aead = cha
  1875  	case KeyType_MANAGED_KEY:
  1876  		if opts.Convergent || len(opts.Nonce) != 0 {
  1877  			return nil, errutil.UserError{Err: "cannot use convergent encryption or provide a nonce to managed-key backed encryption"}
  1878  		}
  1879  		if opts.AEADFactory == nil {
  1880  			return nil, errors.New("expected AEAD factory from managed key, none provided")
  1881  		}
  1882  		aead, err = opts.AEADFactory.GetAEAD(nonce)
  1883  		if err != nil {
  1884  			return nil, err
  1885  		}
  1886  	}
  1887  
  1888  	if opts.Convergent {
  1889  		convergentVersion := p.convergentVersion(ver)
  1890  		switch convergentVersion {
  1891  		case 1:
  1892  			if len(opts.Nonce) != aead.NonceSize() {
  1893  				return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())}
  1894  			}
  1895  		case 2, 3:
  1896  			if len(opts.HMACKey) == 0 {
  1897  				return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")}
  1898  			}
  1899  			nonceHmac := hmac.New(sha256.New, opts.HMACKey)
  1900  			nonceHmac.Write(plaintext)
  1901  			nonceSum := nonceHmac.Sum(nil)
  1902  			nonce = nonceSum[:aead.NonceSize()]
  1903  		default:
  1904  			return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)}
  1905  		}
  1906  	} else if len(nonce) == 0 {
  1907  		// Compute random nonce
  1908  		nonce, err = uuid.GenerateRandomBytes(aead.NonceSize())
  1909  		if err != nil {
  1910  			return nil, errutil.InternalError{Err: err.Error()}
  1911  		}
  1912  	} else if len(nonce) != aead.NonceSize() {
  1913  		return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long but given %d bytes", aead.NonceSize(), len(nonce))}
  1914  	}
  1915  
  1916  	// Encrypt and tag with AEAD
  1917  	ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData)
  1918  
  1919  	// Place the encrypted data after the nonce
  1920  	if !opts.Convergent || p.convergentVersion(ver) > 1 {
  1921  		ciphertext = append(nonce, ciphertext...)
  1922  	}
  1923  	return ciphertext, nil
  1924  }
  1925  
  1926  // Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys
  1927  func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) {
  1928  	var aead cipher.AEAD
  1929  	var err error
  1930  	var nonce []byte
  1931  
  1932  	switch p.Type {
  1933  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
  1934  		// Setup the cipher
  1935  		aesCipher, err := aes.NewCipher(encKey)
  1936  		if err != nil {
  1937  			return nil, errutil.InternalError{Err: err.Error()}
  1938  		}
  1939  
  1940  		// Setup the GCM AEAD
  1941  		gcm, err := cipher.NewGCM(aesCipher)
  1942  		if err != nil {
  1943  			return nil, errutil.InternalError{Err: err.Error()}
  1944  		}
  1945  
  1946  		aead = gcm
  1947  
  1948  	case KeyType_ChaCha20_Poly1305:
  1949  		cha, err := chacha20poly1305.New(encKey)
  1950  		if err != nil {
  1951  			return nil, errutil.InternalError{Err: err.Error()}
  1952  		}
  1953  
  1954  		aead = cha
  1955  	case KeyType_MANAGED_KEY:
  1956  		aead, err = opts.AEADFactory.GetAEAD(nonce)
  1957  		if err != nil {
  1958  			return nil, err
  1959  		}
  1960  	}
  1961  
  1962  	if len(ciphertext) < aead.NonceSize() {
  1963  		return nil, errutil.UserError{Err: "invalid ciphertext length"}
  1964  	}
  1965  
  1966  	// Extract the nonce and ciphertext
  1967  	var trueCT []byte
  1968  	if opts.Convergent && opts.ConvergentVersion == 1 {
  1969  		trueCT = ciphertext
  1970  	} else {
  1971  		nonce = ciphertext[:aead.NonceSize()]
  1972  		trueCT = ciphertext[aead.NonceSize():]
  1973  	}
  1974  
  1975  	// Verify and Decrypt
  1976  	plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData)
  1977  	if err != nil {
  1978  		return nil, errutil.UserError{Err: err.Error()}
  1979  	}
  1980  	return plain, nil
  1981  }
  1982  
  1983  func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
  1984  	if !p.Type.EncryptionSupported() {
  1985  		return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
  1986  	}
  1987  
  1988  	// Decode the plaintext value
  1989  	plaintext, err := base64.StdEncoding.DecodeString(value)
  1990  	if err != nil {
  1991  		return "", errutil.UserError{Err: err.Error()}
  1992  	}
  1993  
  1994  	switch {
  1995  	case ver == 0:
  1996  		ver = p.LatestVersion
  1997  	case ver < 0:
  1998  		return "", errutil.UserError{Err: "requested version for encryption is negative"}
  1999  	case ver > p.LatestVersion:
  2000  		return "", errutil.UserError{Err: "requested version for encryption is higher than the latest key version"}
  2001  	case ver < p.MinEncryptionVersion:
  2002  		return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"}
  2003  	}
  2004  
  2005  	var ciphertext []byte
  2006  
  2007  	switch p.Type {
  2008  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
  2009  		hmacKey := context
  2010  
  2011  		var encKey []byte
  2012  		var deriveHMAC bool
  2013  
  2014  		encBytes := 32
  2015  		hmacBytes := 0
  2016  		convergentVersion := p.convergentVersion(ver)
  2017  		if convergentVersion > 2 {
  2018  			deriveHMAC = true
  2019  			hmacBytes = 32
  2020  			if len(nonce) > 0 {
  2021  				return "", errutil.UserError{Err: "nonce provided when not allowed"}
  2022  			}
  2023  		} else if len(nonce) > 0 && (!p.ConvergentEncryption || convergentVersion != 1) {
  2024  			return "", errutil.UserError{Err: "nonce provided when not allowed"}
  2025  		}
  2026  		if p.Type == KeyType_AES128_GCM96 {
  2027  			encBytes = 16
  2028  		}
  2029  
  2030  		key, err := p.GetKey(context, ver, encBytes+hmacBytes)
  2031  		if err != nil {
  2032  			return "", err
  2033  		}
  2034  
  2035  		if len(key) < encBytes+hmacBytes {
  2036  			return "", errutil.InternalError{Err: "could not derive key, length too small"}
  2037  		}
  2038  
  2039  		encKey = key[:encBytes]
  2040  		if len(encKey) != encBytes {
  2041  			return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
  2042  		}
  2043  		if deriveHMAC {
  2044  			hmacKey = key[encBytes:]
  2045  			if len(hmacKey) != hmacBytes {
  2046  				return "", errutil.InternalError{Err: "could not derive hmac key, length not correct"}
  2047  			}
  2048  		}
  2049  
  2050  		symopts := SymmetricOpts{
  2051  			Convergent: p.ConvergentEncryption,
  2052  			HMACKey:    hmacKey,
  2053  			Nonce:      nonce,
  2054  		}
  2055  		for index, rawFactory := range factories {
  2056  			if rawFactory == nil {
  2057  				continue
  2058  			}
  2059  			switch factory := rawFactory.(type) {
  2060  			case AEADFactory:
  2061  				symopts.AEADFactory = factory
  2062  			case AssociatedDataFactory:
  2063  				symopts.AdditionalData, err = factory.GetAssociatedData()
  2064  				if err != nil {
  2065  					return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
  2066  				}
  2067  			case ManagedKeyFactory:
  2068  			default:
  2069  				return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
  2070  			}
  2071  		}
  2072  
  2073  		ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
  2074  		if err != nil {
  2075  			return "", err
  2076  		}
  2077  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2078  		keyEntry, err := p.safeGetKeyEntry(ver)
  2079  		if err != nil {
  2080  			return "", err
  2081  		}
  2082  		var publicKey *rsa.PublicKey
  2083  		if keyEntry.RSAKey != nil {
  2084  			publicKey = &keyEntry.RSAKey.PublicKey
  2085  		} else {
  2086  			publicKey = keyEntry.RSAPublicKey
  2087  		}
  2088  		ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
  2089  		if err != nil {
  2090  			return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)}
  2091  		}
  2092  	case KeyType_MANAGED_KEY:
  2093  		keyEntry, err := p.safeGetKeyEntry(ver)
  2094  		if err != nil {
  2095  			return "", err
  2096  		}
  2097  
  2098  		var aad []byte
  2099  		var managedKeyFactory ManagedKeyFactory
  2100  		for _, f := range factories {
  2101  			switch factory := f.(type) {
  2102  			case AssociatedDataFactory:
  2103  				aad, err = factory.GetAssociatedData()
  2104  				if err != nil {
  2105  					return "", nil
  2106  				}
  2107  			case ManagedKeyFactory:
  2108  				managedKeyFactory = factory
  2109  			}
  2110  		}
  2111  
  2112  		if managedKeyFactory == nil {
  2113  			return "", errors.New("key type is managed_key, but managed key parameters were not provided")
  2114  		}
  2115  
  2116  		ciphertext, err = p.encryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, plaintext, nonce, aad)
  2117  		if err != nil {
  2118  			return "", err
  2119  		}
  2120  
  2121  	default:
  2122  		return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
  2123  	}
  2124  
  2125  	// Convert to base64
  2126  	encoded := base64.StdEncoding.EncodeToString(ciphertext)
  2127  
  2128  	// Prepend some information
  2129  	encoded = p.getVersionPrefix(ver) + encoded
  2130  
  2131  	return encoded, nil
  2132  }
  2133  
  2134  func (p *Policy) KeyVersionCanBeUpdated(keyVersion int, isPrivateKey bool) error {
  2135  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2136  	if err != nil {
  2137  		return err
  2138  	}
  2139  
  2140  	if !p.Type.ImportPublicKeySupported() {
  2141  		return errors.New("provided type does not support importing key versions")
  2142  	}
  2143  
  2144  	isPrivateKeyMissing := keyEntry.IsPrivateKeyMissing()
  2145  	if isPrivateKeyMissing && !isPrivateKey {
  2146  		return errors.New("cannot add a public key to a key version that already has a public key set")
  2147  	}
  2148  
  2149  	if !isPrivateKeyMissing {
  2150  		return errors.New("private key imported, key version cannot be updated")
  2151  	}
  2152  
  2153  	return nil
  2154  }
  2155  
  2156  func (p *Policy) ImportPrivateKeyForVersion(ctx context.Context, storage logical.Storage, keyVersion int, key []byte) error {
  2157  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2158  	if err != nil {
  2159  		return err
  2160  	}
  2161  
  2162  	// Parse key
  2163  	parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key)
  2164  	if err != nil {
  2165  		if strings.Contains(err.Error(), "unknown elliptic curve") {
  2166  			var edErr error
  2167  			parsedPrivateKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
  2168  			if edErr != nil {
  2169  				return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
  2170  			}
  2171  
  2172  			// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
  2173  		} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
  2174  			var rsaErr error
  2175  			parsedPrivateKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
  2176  			if rsaErr != nil {
  2177  				return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
  2178  			}
  2179  
  2180  			// Parsing as RSA-PSS in PKCS8 succeeded!
  2181  		} else {
  2182  			return fmt.Errorf("error parsing asymmetric key: %s", err)
  2183  		}
  2184  	}
  2185  
  2186  	switch parsedPrivateKey.(type) {
  2187  	case *ecdsa.PrivateKey:
  2188  		ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey)
  2189  		pemBlock, _ := pem.Decode([]byte(keyEntry.FormattedPublicKey))
  2190  		if pemBlock == nil {
  2191  			return fmt.Errorf("failed to parse key entry public key: invalid PEM blob")
  2192  		}
  2193  		publicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
  2194  		if err != nil || publicKey == nil {
  2195  			return fmt.Errorf("failed to parse key entry public key: %v", err)
  2196  		}
  2197  		if !publicKey.(*ecdsa.PublicKey).Equal(&ecdsaKey.PublicKey) {
  2198  			return fmt.Errorf("cannot import key, key pair does not match")
  2199  		}
  2200  	case *rsa.PrivateKey:
  2201  		rsaKey := parsedPrivateKey.(*rsa.PrivateKey)
  2202  		if !rsaKey.PublicKey.Equal(keyEntry.RSAPublicKey) {
  2203  			return fmt.Errorf("cannot import key, key pair does not match")
  2204  		}
  2205  	case ed25519.PrivateKey:
  2206  		ed25519Key := parsedPrivateKey.(ed25519.PrivateKey)
  2207  		publicKey, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  2208  		if err != nil {
  2209  			return fmt.Errorf("failed to parse key entry public key: %v", err)
  2210  		}
  2211  		if !ed25519.PublicKey(publicKey).Equal(ed25519Key.Public()) {
  2212  			return fmt.Errorf("cannot import key, key pair does not match")
  2213  		}
  2214  	}
  2215  
  2216  	err = keyEntry.parseFromKey(p.Type, parsedPrivateKey)
  2217  	if err != nil {
  2218  		return err
  2219  	}
  2220  
  2221  	p.Keys[strconv.Itoa(keyVersion)] = keyEntry
  2222  
  2223  	return p.Persist(ctx, storage)
  2224  }
  2225  
  2226  func (ke *KeyEntry) parseFromKey(PolKeyType KeyType, parsedKey any) error {
  2227  	switch parsedKey.(type) {
  2228  	case *ecdsa.PrivateKey, *ecdsa.PublicKey:
  2229  		if PolKeyType != KeyType_ECDSA_P256 && PolKeyType != KeyType_ECDSA_P384 && PolKeyType != KeyType_ECDSA_P521 {
  2230  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2231  		}
  2232  
  2233  		curve := elliptic.P256()
  2234  		if PolKeyType == KeyType_ECDSA_P384 {
  2235  			curve = elliptic.P384()
  2236  		} else if PolKeyType == KeyType_ECDSA_P521 {
  2237  			curve = elliptic.P521()
  2238  		}
  2239  
  2240  		var derBytes []byte
  2241  		var err error
  2242  		ecdsaKey, ok := parsedKey.(*ecdsa.PrivateKey)
  2243  		if ok {
  2244  
  2245  			if ecdsaKey.Curve != curve {
  2246  				return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
  2247  			}
  2248  
  2249  			ke.EC_D = ecdsaKey.D
  2250  			ke.EC_X = ecdsaKey.X
  2251  			ke.EC_Y = ecdsaKey.Y
  2252  
  2253  			derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey.Public())
  2254  			if err != nil {
  2255  				return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  2256  			}
  2257  		} else {
  2258  			ecdsaKey := parsedKey.(*ecdsa.PublicKey)
  2259  
  2260  			if ecdsaKey.Curve != curve {
  2261  				return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
  2262  			}
  2263  
  2264  			ke.EC_X = ecdsaKey.X
  2265  			ke.EC_Y = ecdsaKey.Y
  2266  
  2267  			derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey)
  2268  			if err != nil {
  2269  				return errwrap.Wrapf("error marshaling public key: {{err}}", err)
  2270  			}
  2271  		}
  2272  
  2273  		pemBlock := &pem.Block{
  2274  			Type:  "PUBLIC KEY",
  2275  			Bytes: derBytes,
  2276  		}
  2277  		pemBytes := pem.EncodeToMemory(pemBlock)
  2278  		if pemBytes == nil || len(pemBytes) == 0 {
  2279  			return fmt.Errorf("error PEM-encoding public key")
  2280  		}
  2281  		ke.FormattedPublicKey = string(pemBytes)
  2282  	case ed25519.PrivateKey, ed25519.PublicKey:
  2283  		if PolKeyType != KeyType_ED25519 {
  2284  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2285  		}
  2286  
  2287  		privateKey, ok := parsedKey.(ed25519.PrivateKey)
  2288  		if ok {
  2289  			ke.Key = privateKey
  2290  			publicKey := privateKey.Public().(ed25519.PublicKey)
  2291  			ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
  2292  		} else {
  2293  			publicKey := parsedKey.(ed25519.PublicKey)
  2294  			ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
  2295  		}
  2296  	case *rsa.PrivateKey, *rsa.PublicKey:
  2297  		if PolKeyType != KeyType_RSA2048 && PolKeyType != KeyType_RSA3072 && PolKeyType != KeyType_RSA4096 {
  2298  			return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2299  		}
  2300  
  2301  		keyBytes := 256
  2302  		if PolKeyType == KeyType_RSA3072 {
  2303  			keyBytes = 384
  2304  		} else if PolKeyType == KeyType_RSA4096 {
  2305  			keyBytes = 512
  2306  		}
  2307  
  2308  		rsaKey, ok := parsedKey.(*rsa.PrivateKey)
  2309  		if ok {
  2310  			if rsaKey.Size() != keyBytes {
  2311  				return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
  2312  			}
  2313  			ke.RSAKey = rsaKey
  2314  			ke.RSAPublicKey = rsaKey.Public().(*rsa.PublicKey)
  2315  		} else {
  2316  			rsaKey := parsedKey.(*rsa.PublicKey)
  2317  			if rsaKey.Size() != keyBytes {
  2318  				return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
  2319  			}
  2320  			ke.RSAPublicKey = rsaKey
  2321  		}
  2322  	default:
  2323  		return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
  2324  	}
  2325  
  2326  	return nil
  2327  }
  2328  
  2329  func (p *Policy) WrapKey(ver int, targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
  2330  	if !p.Type.SigningSupported() {
  2331  		return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
  2332  	}
  2333  
  2334  	switch {
  2335  	case ver == 0:
  2336  		ver = p.LatestVersion
  2337  	case ver < 0:
  2338  		return "", errutil.UserError{Err: "requested version for key wrapping is negative"}
  2339  	case ver > p.LatestVersion:
  2340  		return "", errutil.UserError{Err: "requested version for key wrapping is higher than the latest key version"}
  2341  	case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
  2342  		return "", errutil.UserError{Err: "requested version for key wrapping is less than the minimum encryption key version"}
  2343  	}
  2344  
  2345  	keyEntry, err := p.safeGetKeyEntry(ver)
  2346  	if err != nil {
  2347  		return "", err
  2348  	}
  2349  
  2350  	return keyEntry.WrapKey(targetKey, targetKeyType, hash)
  2351  }
  2352  
  2353  func (ke *KeyEntry) WrapKey(targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
  2354  	// Presently this method implements a CKM_RSA_AES_KEY_WRAP-compatible
  2355  	// wrapping interface and only works on RSA keyEntries as a result.
  2356  	if ke.RSAPublicKey == nil {
  2357  		return "", fmt.Errorf("unsupported key type in use; must be a rsa key")
  2358  	}
  2359  
  2360  	var preppedTargetKey []byte
  2361  	switch targetKeyType {
  2362  	case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC:
  2363  		var ok bool
  2364  		preppedTargetKey, ok = targetKey.([]byte)
  2365  		if !ok {
  2366  			return "", fmt.Errorf("failed to wrap target key for import: symmetric key not provided in byte format (%T)", targetKey)
  2367  		}
  2368  	default:
  2369  		var err error
  2370  		preppedTargetKey, err = x509.MarshalPKCS8PrivateKey(targetKey)
  2371  		if err != nil {
  2372  			return "", fmt.Errorf("failed to wrap target key for import: %w", err)
  2373  		}
  2374  	}
  2375  
  2376  	result, err := wrapTargetPKCS8ForImport(ke.RSAPublicKey, preppedTargetKey, hash)
  2377  	if err != nil {
  2378  		return result, fmt.Errorf("failed to wrap target key for import: %w", err)
  2379  	}
  2380  
  2381  	return result, nil
  2382  }
  2383  
  2384  func wrapTargetPKCS8ForImport(wrappingKey *rsa.PublicKey, preppedTargetKey []byte, hash hash.Hash) (string, error) {
  2385  	// Generate an ephemeral AES-256 key
  2386  	ephKey, err := uuid.GenerateRandomBytes(32)
  2387  	if err != nil {
  2388  		return "", fmt.Errorf("failed to generate an ephemeral AES wrapping key: %w", err)
  2389  	}
  2390  
  2391  	// Wrap ephemeral AES key with public wrapping key
  2392  	ephKeyWrapped, err := rsa.EncryptOAEP(hash, rand.Reader, wrappingKey, ephKey, []byte{} /* label */)
  2393  	if err != nil {
  2394  		return "", fmt.Errorf("failed to encrypt ephemeral wrapping key with public key: %w", err)
  2395  	}
  2396  
  2397  	// Create KWP instance for wrapping target key
  2398  	kwp, err := subtle.NewKWP(ephKey)
  2399  	if err != nil {
  2400  		return "", fmt.Errorf("failed to generate new KWP from AES key: %w", err)
  2401  	}
  2402  
  2403  	// Wrap target key with KWP
  2404  	targetKeyWrapped, err := kwp.Wrap(preppedTargetKey)
  2405  	if err != nil {
  2406  		return "", fmt.Errorf("failed to wrap target key with KWP: %w", err)
  2407  	}
  2408  
  2409  	// Combined wrapped keys into a single blob and base64 encode
  2410  	wrappedKeys := append(ephKeyWrapped, targetKeyWrapped...)
  2411  	return base64.StdEncoding.EncodeToString(wrappedKeys), nil
  2412  }
  2413  
  2414  func (p *Policy) CreateCsr(keyVersion int, csrTemplate *x509.CertificateRequest) ([]byte, error) {
  2415  	if !p.Type.SigningSupported() {
  2416  		return nil, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
  2417  	}
  2418  
  2419  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2420  	if err != nil {
  2421  		return nil, err
  2422  	}
  2423  
  2424  	if keyEntry.IsPrivateKeyMissing() {
  2425  		return nil, errutil.UserError{Err: "private key not imported for key version selected"}
  2426  	}
  2427  
  2428  	csrTemplate.Signature = nil
  2429  	csrTemplate.SignatureAlgorithm = x509.UnknownSignatureAlgorithm
  2430  
  2431  	var key crypto.Signer
  2432  	switch p.Type {
  2433  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  2434  		var curve elliptic.Curve
  2435  		switch p.Type {
  2436  		case KeyType_ECDSA_P384:
  2437  			curve = elliptic.P384()
  2438  		case KeyType_ECDSA_P521:
  2439  			curve = elliptic.P521()
  2440  		default:
  2441  			curve = elliptic.P256()
  2442  		}
  2443  
  2444  		key = &ecdsa.PrivateKey{
  2445  			PublicKey: ecdsa.PublicKey{
  2446  				Curve: curve,
  2447  				X:     keyEntry.EC_X,
  2448  				Y:     keyEntry.EC_Y,
  2449  			},
  2450  			D: keyEntry.EC_D,
  2451  		}
  2452  
  2453  	case KeyType_ED25519:
  2454  		if p.Derived {
  2455  			return nil, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
  2456  		}
  2457  		key = ed25519.PrivateKey(keyEntry.Key)
  2458  
  2459  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2460  		key = keyEntry.RSAKey
  2461  
  2462  	default:
  2463  		return nil, errutil.InternalError{Err: fmt.Sprintf("selected key type '%s' does not support signing", p.Type.String())}
  2464  	}
  2465  	csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
  2466  	if err != nil {
  2467  		return nil, fmt.Errorf("could not create the cerfificate request: %w", err)
  2468  	}
  2469  
  2470  	pemCsr := pem.EncodeToMemory(&pem.Block{
  2471  		Type:  "CERTIFICATE REQUEST",
  2472  		Bytes: csrBytes,
  2473  	})
  2474  
  2475  	return pemCsr, nil
  2476  }
  2477  
  2478  func (p *Policy) ValidateLeafCertKeyMatch(keyVersion int, certPublicKeyAlgorithm x509.PublicKeyAlgorithm, certPublicKey any) (bool, error) {
  2479  	if !p.Type.SigningSupported() {
  2480  		return false, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
  2481  	}
  2482  
  2483  	var keyTypeMatches bool
  2484  	switch p.Type {
  2485  	case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
  2486  		if certPublicKeyAlgorithm == x509.ECDSA {
  2487  			keyTypeMatches = true
  2488  		}
  2489  	case KeyType_ED25519:
  2490  		if certPublicKeyAlgorithm == x509.Ed25519 {
  2491  			keyTypeMatches = true
  2492  		}
  2493  	case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
  2494  		if certPublicKeyAlgorithm == x509.RSA {
  2495  			keyTypeMatches = true
  2496  		}
  2497  	}
  2498  	if !keyTypeMatches {
  2499  		return false, errutil.UserError{Err: fmt.Sprintf("provided leaf certificate public key algorithm '%s' does not match the transit key type '%s'",
  2500  			certPublicKeyAlgorithm, p.Type)}
  2501  	}
  2502  
  2503  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2504  	if err != nil {
  2505  		return false, err
  2506  	}
  2507  
  2508  	switch certPublicKeyAlgorithm {
  2509  	case x509.ECDSA:
  2510  		certPublicKey := certPublicKey.(*ecdsa.PublicKey)
  2511  		var curve elliptic.Curve
  2512  		switch p.Type {
  2513  		case KeyType_ECDSA_P384:
  2514  			curve = elliptic.P384()
  2515  		case KeyType_ECDSA_P521:
  2516  			curve = elliptic.P521()
  2517  		default:
  2518  			curve = elliptic.P256()
  2519  		}
  2520  
  2521  		publicKey := &ecdsa.PublicKey{
  2522  			Curve: curve,
  2523  			X:     keyEntry.EC_X,
  2524  			Y:     keyEntry.EC_Y,
  2525  		}
  2526  
  2527  		return publicKey.Equal(certPublicKey), nil
  2528  
  2529  	case x509.Ed25519:
  2530  		if p.Derived {
  2531  			return false, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
  2532  		}
  2533  		certPublicKey := certPublicKey.(ed25519.PublicKey)
  2534  
  2535  		raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
  2536  		if err != nil {
  2537  			return false, err
  2538  		}
  2539  		publicKey := ed25519.PublicKey(raw)
  2540  
  2541  		return publicKey.Equal(certPublicKey), nil
  2542  
  2543  	case x509.RSA:
  2544  		certPublicKey := certPublicKey.(*rsa.PublicKey)
  2545  		publicKey := keyEntry.RSAKey.PublicKey
  2546  		return publicKey.Equal(certPublicKey), nil
  2547  
  2548  	case x509.UnknownPublicKeyAlgorithm:
  2549  		return false, errutil.InternalError{Err: fmt.Sprint("certificate signed with an unknown algorithm")}
  2550  	}
  2551  
  2552  	return false, nil
  2553  }
  2554  
  2555  func (p *Policy) ValidateAndPersistCertificateChain(ctx context.Context, keyVersion int, certChain []*x509.Certificate, storage logical.Storage) error {
  2556  	if len(certChain) == 0 {
  2557  		return errutil.UserError{Err: "expected at least one certificate in the parsed certificate chain"}
  2558  	}
  2559  
  2560  	if certChain[0].BasicConstraintsValid && certChain[0].IsCA {
  2561  		return errutil.UserError{Err: "certificate in the first position is not a leaf certificate"}
  2562  	}
  2563  
  2564  	for _, cert := range certChain[1:] {
  2565  		if cert.BasicConstraintsValid && !cert.IsCA {
  2566  			return errutil.UserError{Err: "provided certificate chain contains more than one leaf certificate"}
  2567  		}
  2568  	}
  2569  
  2570  	valid, err := p.ValidateLeafCertKeyMatch(keyVersion, certChain[0].PublicKeyAlgorithm, certChain[0].PublicKey)
  2571  	if err != nil {
  2572  		prefixedErr := fmt.Errorf("could not validate key match between leaf certificate key and key version in transit: %w", err)
  2573  		switch err.(type) {
  2574  		case errutil.UserError:
  2575  			return errutil.UserError{Err: prefixedErr.Error()}
  2576  		default:
  2577  			return prefixedErr
  2578  		}
  2579  	}
  2580  	if !valid {
  2581  		return fmt.Errorf("leaf certificate public key does match the key version selected")
  2582  	}
  2583  
  2584  	keyEntry, err := p.safeGetKeyEntry(keyVersion)
  2585  	if err != nil {
  2586  		return err
  2587  	}
  2588  
  2589  	// Convert the certificate chain to DER format
  2590  	derCertificates := make([][]byte, len(certChain))
  2591  	for i, cert := range certChain {
  2592  		derCertificates[i] = cert.Raw
  2593  	}
  2594  
  2595  	keyEntry.CertificateChain = derCertificates
  2596  
  2597  	p.Keys[strconv.Itoa(keyVersion)] = keyEntry
  2598  	return p.Persist(ctx, storage)
  2599  }