github.com/hashicorp/vault/sdk@v0.11.0/helper/keysutil/lock_manager.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"context"
     8  	"encoding/base64"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/hashicorp/errwrap"
    17  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    18  	"github.com/hashicorp/vault/sdk/helper/locksutil"
    19  	"github.com/hashicorp/vault/sdk/logical"
    20  )
    21  
    22  const (
    23  	shared                   = false
    24  	exclusive                = true
    25  	currentConvergentVersion = 3
    26  )
    27  
    28  var errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
    29  
    30  // PolicyRequest holds values used when requesting a policy. Most values are
    31  // only used during an upsert.
    32  type PolicyRequest struct {
    33  	// The storage to use
    34  	Storage logical.Storage
    35  
    36  	// The name of the policy
    37  	Name string
    38  
    39  	// The key type
    40  	KeyType KeyType
    41  
    42  	// The key size for variable key size algorithms
    43  	KeySize int
    44  
    45  	// Whether it should be derived
    46  	Derived bool
    47  
    48  	// Whether to enable convergent encryption
    49  	Convergent bool
    50  
    51  	// Whether to allow export
    52  	Exportable bool
    53  
    54  	// Whether to upsert
    55  	Upsert bool
    56  
    57  	// Whether to allow plaintext backup
    58  	AllowPlaintextBackup bool
    59  
    60  	// How frequently the key should automatically rotate
    61  	AutoRotatePeriod time.Duration
    62  
    63  	// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
    64  	AllowImportedKeyRotation bool
    65  
    66  	// Indicates whether a private or public key is imported/upserted
    67  	IsPrivateKey bool
    68  
    69  	// The UUID of the managed key, if using one
    70  	ManagedKeyUUID string
    71  }
    72  
    73  type LockManager struct {
    74  	useCache bool
    75  	cache    Cache
    76  	keyLocks []*locksutil.LockEntry
    77  }
    78  
    79  func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) {
    80  	// determine the type of cache to create
    81  	var cache Cache
    82  	switch {
    83  	case !useCache:
    84  	case cacheSize < 0:
    85  		return nil, errors.New("cache size must be greater or equal to zero")
    86  	case cacheSize == 0:
    87  		cache = NewTransitSyncMap()
    88  	case cacheSize > 0:
    89  		newLRUCache, err := NewTransitLRU(cacheSize)
    90  		if err != nil {
    91  			return nil, errwrap.Wrapf("failed to create cache: {{err}}", err)
    92  		}
    93  		cache = newLRUCache
    94  	}
    95  
    96  	lm := &LockManager{
    97  		useCache: useCache,
    98  		cache:    cache,
    99  		keyLocks: locksutil.CreateLocks(),
   100  	}
   101  
   102  	return lm, nil
   103  }
   104  
   105  func (lm *LockManager) GetCacheSize() int {
   106  	if !lm.useCache {
   107  		return 0
   108  	}
   109  	return lm.cache.Size()
   110  }
   111  
   112  func (lm *LockManager) GetUseCache() bool {
   113  	return lm.useCache
   114  }
   115  
   116  func (lm *LockManager) InvalidatePolicy(name string) {
   117  	if lm.useCache {
   118  		lm.cache.Delete(name)
   119  	}
   120  }
   121  
   122  func (lm *LockManager) InitCache(cacheSize int) error {
   123  	if lm.useCache {
   124  		switch {
   125  		case cacheSize < 0:
   126  			return errors.New("cache size must be greater or equal to zero")
   127  		case cacheSize == 0:
   128  			lm.cache = NewTransitSyncMap()
   129  		case cacheSize > 0:
   130  			newLRUCache, err := NewTransitLRU(cacheSize)
   131  			if err != nil {
   132  				return errwrap.Wrapf("failed to create cache: {{err}}", err)
   133  			}
   134  			lm.cache = newLRUCache
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  // RestorePolicy acquires an exclusive lock on the policy name and restores the
   141  // given policy along with the archive.
   142  func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
   143  	backupBytes, err := base64.StdEncoding.DecodeString(backup)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	var keyData KeyData
   149  	err = jsonutil.DecodeJSON(backupBytes, &keyData)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	// Set a different name if desired
   155  	if name != "" {
   156  		keyData.Policy.Name = name
   157  	}
   158  
   159  	name = keyData.Policy.Name
   160  
   161  	// Grab the exclusive lock as we'll be modifying disk
   162  	lock := locksutil.LockForKey(lm.keyLocks, name)
   163  	lock.Lock()
   164  	defer lock.Unlock()
   165  
   166  	var ok bool
   167  	var pRaw interface{}
   168  
   169  	// If the policy is in cache and 'force' is not specified, error out. Anywhere
   170  	// that would put it in the cache will also be protected by the mutex above,
   171  	// so we don't need to re-check the cache later.
   172  	if lm.useCache {
   173  		pRaw, ok = lm.cache.Load(name)
   174  		if ok && !force {
   175  			return fmt.Errorf("key %q already exists", name)
   176  		}
   177  	}
   178  
   179  	// Conditionally look up the policy from storage, depending on the use of
   180  	// 'force' and if the policy was found in cache.
   181  	//
   182  	// - If was not found in cache and we are not using 'force', look for it in
   183  	// storage. If found, error out.
   184  	//
   185  	// - If it was found in cache and we are using 'force', pRaw will not be nil
   186  	// and we do not look the policy up from storage
   187  	//
   188  	// - If it was found in cache and we are not using 'force', we should have
   189  	// returned above with error
   190  	var p *Policy
   191  	if pRaw == nil {
   192  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   193  		if err != nil {
   194  			return err
   195  		}
   196  		if p != nil && !force {
   197  			return fmt.Errorf("key %q already exists", name)
   198  		}
   199  	}
   200  
   201  	// If both pRaw and p above are nil and 'force' is specified, we don't need to
   202  	// grab policy locks as we have ensured it doesn't already exist, so there
   203  	// will be no races as nothing else has this pointer. If 'force' was not used,
   204  	// an error would have been returned by now if the policy already existed
   205  	if pRaw != nil {
   206  		p = pRaw.(*Policy)
   207  	}
   208  	if p != nil {
   209  		p.l.Lock()
   210  		defer p.l.Unlock()
   211  	}
   212  
   213  	// Restore the archived keys
   214  	if keyData.ArchivedKeys != nil {
   215  		err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
   216  		if err != nil {
   217  			return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
   218  		}
   219  	}
   220  
   221  	// Mark that policy as a restored key
   222  	keyData.Policy.RestoreInfo = &RestoreInfo{
   223  		Time:    time.Now(),
   224  		Version: keyData.Policy.LatestVersion,
   225  	}
   226  
   227  	// Restore the policy. This will also attempt to adjust the archive.
   228  	err = keyData.Policy.Persist(ctx, storage)
   229  	if err != nil {
   230  		return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
   231  	}
   232  
   233  	keyData.Policy.l = new(sync.RWMutex)
   234  
   235  	// Update the cache to contain the restored policy
   236  	if lm.useCache {
   237  		lm.cache.Store(name, keyData.Policy)
   238  	}
   239  	return nil
   240  }
   241  
   242  func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
   243  	var p *Policy
   244  	var err error
   245  
   246  	// Backup writes information about when the backup took place, so we get an
   247  	// exclusive lock here
   248  	lock := locksutil.LockForKey(lm.keyLocks, name)
   249  	lock.Lock()
   250  	defer lock.Unlock()
   251  
   252  	var ok bool
   253  	var pRaw interface{}
   254  
   255  	if lm.useCache {
   256  		pRaw, ok = lm.cache.Load(name)
   257  	}
   258  	if ok {
   259  		p = pRaw.(*Policy)
   260  		p.l.Lock()
   261  		defer p.l.Unlock()
   262  	} else {
   263  		// If the policy doesn't exit in storage, error out
   264  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   265  		if err != nil {
   266  			return "", err
   267  		}
   268  		if p == nil {
   269  			return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
   270  		}
   271  	}
   272  
   273  	if atomic.LoadUint32(&p.deleted) == 1 {
   274  		return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
   275  	}
   276  
   277  	backup, err := p.Backup(ctx, storage)
   278  	if err != nil {
   279  		return "", err
   280  	}
   281  
   282  	return backup, nil
   283  }
   284  
   285  // When the function returns, if caching was disabled, the Policy's lock must
   286  // be unlocked when the caller is done (and it should not be re-locked).
   287  func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest, rand io.Reader) (retP *Policy, retUpserted bool, retErr error) {
   288  	var p *Policy
   289  	var err error
   290  	var ok bool
   291  	var pRaw interface{}
   292  
   293  	// Check if it's in our cache. If so, return right away.
   294  	if lm.useCache {
   295  		pRaw, ok = lm.cache.Load(req.Name)
   296  	}
   297  	if ok {
   298  		p = pRaw.(*Policy)
   299  		if atomic.LoadUint32(&p.deleted) == 1 {
   300  			return nil, false, nil
   301  		}
   302  		return p, false, nil
   303  	}
   304  
   305  	// We're not using the cache, or it wasn't found; get an exclusive lock.
   306  	// This ensures that any other process writing the actual storage will be
   307  	// finished before we load from storage.
   308  	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
   309  	lock.Lock()
   310  
   311  	// If we are using the cache, defer the lock unlock; otherwise we will
   312  	// return from here with the lock still held.
   313  	cleanup := func() {
   314  		switch {
   315  		// If using the cache we always unlock, the caller locks the policy
   316  		// themselves
   317  		case lm.useCache:
   318  			lock.Unlock()
   319  
   320  		// If not using the cache, if we aren't returning a policy the caller
   321  		// doesn't have a lock, so we must unlock
   322  		case retP == nil:
   323  			lock.Unlock()
   324  		}
   325  	}
   326  
   327  	// Check the cache again
   328  	if lm.useCache {
   329  		pRaw, ok = lm.cache.Load(req.Name)
   330  	}
   331  	if ok {
   332  		p = pRaw.(*Policy)
   333  		if atomic.LoadUint32(&p.deleted) == 1 {
   334  			cleanup()
   335  			return nil, false, nil
   336  		}
   337  		retP = p
   338  		cleanup()
   339  		return
   340  	}
   341  
   342  	// Load it from storage
   343  	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
   344  	if err != nil {
   345  		cleanup()
   346  		return nil, false, err
   347  	}
   348  	// We don't need to lock the policy as there would be no other holders of
   349  	// the pointer
   350  
   351  	if p == nil {
   352  		// This is the only place we upsert a new policy, so if upsert is not
   353  		// specified, or the lock type is wrong, unlock before returning
   354  		if !req.Upsert {
   355  			cleanup()
   356  			return nil, false, nil
   357  		}
   358  
   359  		// We create the policy here, then at the end we do a LoadOrStore. If
   360  		// it's been loaded since we last checked the cache, we return an error
   361  		// to the user to let them know that their request can't be satisfied
   362  		// because we don't know if the parameters match.
   363  
   364  		switch req.KeyType {
   365  		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   366  			if req.Convergent && !req.Derived {
   367  				cleanup()
   368  				return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
   369  			}
   370  
   371  		case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
   372  			if req.Derived || req.Convergent {
   373  				cleanup()
   374  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   375  			}
   376  
   377  		case KeyType_ED25519:
   378  			if req.Convergent {
   379  				cleanup()
   380  				return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType)
   381  			}
   382  
   383  		case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
   384  			if req.Derived || req.Convergent {
   385  				cleanup()
   386  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   387  			}
   388  		case KeyType_HMAC:
   389  			if req.Derived || req.Convergent {
   390  				cleanup()
   391  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   392  			}
   393  
   394  		case KeyType_MANAGED_KEY:
   395  			if req.Derived || req.Convergent {
   396  				cleanup()
   397  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   398  			}
   399  
   400  		default:
   401  			cleanup()
   402  			return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
   403  		}
   404  
   405  		p = &Policy{
   406  			l:                    new(sync.RWMutex),
   407  			Name:                 req.Name,
   408  			Type:                 req.KeyType,
   409  			Derived:              req.Derived,
   410  			Exportable:           req.Exportable,
   411  			AllowPlaintextBackup: req.AllowPlaintextBackup,
   412  			AutoRotatePeriod:     req.AutoRotatePeriod,
   413  			KeySize:              req.KeySize,
   414  		}
   415  
   416  		if req.Derived {
   417  			p.KDF = Kdf_hkdf_sha256
   418  			if req.Convergent {
   419  				p.ConvergentEncryption = true
   420  				// As of version 3 we store the version within each key, so we
   421  				// set to -1 to indicate that the value in the policy has no
   422  				// meaning. We still, for backwards compatibility, fall back to
   423  				// this value if the key doesn't have one, which means it will
   424  				// only be -1 in the case where every key version is >= 3
   425  				p.ConvergentVersion = -1
   426  			}
   427  		}
   428  
   429  		// Performs the actual persist and does setup
   430  		if p.Type == KeyType_MANAGED_KEY {
   431  			err = p.RotateManagedKey(ctx, req.Storage, req.ManagedKeyUUID)
   432  		} else {
   433  			err = p.Rotate(ctx, req.Storage, rand)
   434  		}
   435  		if err != nil {
   436  			cleanup()
   437  			return nil, false, err
   438  		}
   439  
   440  		if lm.useCache {
   441  			lm.cache.Store(req.Name, p)
   442  		} else {
   443  			p.l = &lock.RWMutex
   444  			p.writeLocked = true
   445  		}
   446  
   447  		// We don't need to worry about upgrading since it will be a new policy
   448  		retP = p
   449  		retUpserted = true
   450  		cleanup()
   451  		return
   452  	}
   453  
   454  	if p.NeedsUpgrade() {
   455  		if err := p.Upgrade(ctx, req.Storage, rand); err != nil {
   456  			cleanup()
   457  			return nil, false, err
   458  		}
   459  	}
   460  
   461  	if lm.useCache {
   462  		lm.cache.Store(req.Name, p)
   463  	} else {
   464  		p.l = &lock.RWMutex
   465  		p.writeLocked = true
   466  	}
   467  
   468  	retP = p
   469  	cleanup()
   470  	return
   471  }
   472  
   473  func (lm *LockManager) ImportPolicy(ctx context.Context, req PolicyRequest, key []byte, rand io.Reader) error {
   474  	var p *Policy
   475  	var err error
   476  	var ok bool
   477  	var pRaw interface{}
   478  
   479  	// Check if it's in our cache
   480  	if lm.useCache {
   481  		pRaw, ok = lm.cache.Load(req.Name)
   482  	}
   483  	if ok {
   484  		p = pRaw.(*Policy)
   485  		if atomic.LoadUint32(&p.deleted) == 1 {
   486  			return nil
   487  		}
   488  	}
   489  
   490  	// We're not using the cache, or it wasn't found; get an exclusive lock.
   491  	// This ensures that any other process writing the actual storage will be
   492  	// finished before we load from storage.
   493  	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
   494  	lock.Lock()
   495  	defer lock.Unlock()
   496  
   497  	// Load it from storage
   498  	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
   499  	if err != nil {
   500  		return err
   501  	}
   502  
   503  	if p == nil {
   504  		p = &Policy{
   505  			l:                        new(sync.RWMutex),
   506  			Name:                     req.Name,
   507  			Type:                     req.KeyType,
   508  			Derived:                  req.Derived,
   509  			Exportable:               req.Exportable,
   510  			AllowPlaintextBackup:     req.AllowPlaintextBackup,
   511  			AutoRotatePeriod:         req.AutoRotatePeriod,
   512  			AllowImportedKeyRotation: req.AllowImportedKeyRotation,
   513  			Imported:                 true,
   514  		}
   515  	}
   516  
   517  	err = p.ImportPublicOrPrivate(ctx, req.Storage, key, req.IsPrivateKey, rand)
   518  	if err != nil {
   519  		return fmt.Errorf("error importing key: %s", err)
   520  	}
   521  
   522  	if lm.useCache {
   523  		lm.cache.Store(req.Name, p)
   524  	}
   525  
   526  	return nil
   527  }
   528  
   529  func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error {
   530  	var p *Policy
   531  	var err error
   532  	var ok bool
   533  	var pRaw interface{}
   534  
   535  	// We may be writing to disk, so grab an exclusive lock. This prevents bad
   536  	// behavior when the cache is turned off. We also lock the shared policy
   537  	// object to make sure no requests are in flight.
   538  	lock := locksutil.LockForKey(lm.keyLocks, name)
   539  	lock.Lock()
   540  	defer lock.Unlock()
   541  
   542  	if lm.useCache {
   543  		pRaw, ok = lm.cache.Load(name)
   544  	}
   545  	if ok {
   546  		p = pRaw.(*Policy)
   547  		p.l.Lock()
   548  		defer p.l.Unlock()
   549  	}
   550  
   551  	if p == nil {
   552  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   553  		if err != nil {
   554  			return err
   555  		}
   556  		if p == nil {
   557  			return fmt.Errorf("could not delete key; not found")
   558  		}
   559  	}
   560  
   561  	if !p.DeletionAllowed {
   562  		return fmt.Errorf("deletion is not allowed for this key")
   563  	}
   564  
   565  	atomic.StoreUint32(&p.deleted, 1)
   566  
   567  	if lm.useCache {
   568  		lm.cache.Delete(name)
   569  	}
   570  
   571  	err = storage.Delete(ctx, "policy/"+name)
   572  	if err != nil {
   573  		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err)
   574  	}
   575  
   576  	err = storage.Delete(ctx, "archive/"+name)
   577  	if err != nil {
   578  		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err)
   579  	}
   580  
   581  	return nil
   582  }
   583  
   584  func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) {
   585  	return LoadPolicy(ctx, storage, "policy/"+name)
   586  }