github.com/hashicorp/vault/sdk@v0.13.0/helper/keysutil/lock_manager.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package keysutil
     5  
     6  import (
     7  	"context"
     8  	"encoding/base64"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/hashicorp/errwrap"
    17  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    18  	"github.com/hashicorp/vault/sdk/helper/locksutil"
    19  	"github.com/hashicorp/vault/sdk/logical"
    20  )
    21  
    22  const (
    23  	shared                   = false
    24  	exclusive                = true
    25  	currentConvergentVersion = 3
    26  )
    27  
    28  var errNeedExclusiveLock = errors.New("an exclusive lock is needed for this operation")
    29  
    30  // PolicyRequest holds values used when requesting a policy. Most values are
    31  // only used during an upsert.
    32  type PolicyRequest struct {
    33  	// The storage to use
    34  	Storage logical.Storage
    35  
    36  	// The name of the policy
    37  	Name string
    38  
    39  	// The key type
    40  	KeyType KeyType
    41  
    42  	// The key size for variable key size algorithms
    43  	KeySize int
    44  
    45  	// Whether it should be derived
    46  	Derived bool
    47  
    48  	// Whether to enable convergent encryption
    49  	Convergent bool
    50  
    51  	// Whether to allow export
    52  	Exportable bool
    53  
    54  	// Whether to upsert
    55  	Upsert bool
    56  
    57  	// Whether to allow plaintext backup
    58  	AllowPlaintextBackup bool
    59  
    60  	// How frequently the key should automatically rotate
    61  	AutoRotatePeriod time.Duration
    62  
    63  	// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
    64  	AllowImportedKeyRotation bool
    65  
    66  	// Indicates whether a private or public key is imported/upserted
    67  	IsPrivateKey bool
    68  
    69  	// The UUID of the managed key, if using one
    70  	ManagedKeyUUID string
    71  }
    72  
    73  type LockManager struct {
    74  	useCache bool
    75  	cache    Cache
    76  	keyLocks []*locksutil.LockEntry
    77  }
    78  
    79  func NewLockManager(useCache bool, cacheSize int) (*LockManager, error) {
    80  	// determine the type of cache to create
    81  	var cache Cache
    82  	switch {
    83  	case !useCache:
    84  	case cacheSize < 0:
    85  		return nil, errors.New("cache size must be greater or equal to zero")
    86  	case cacheSize == 0:
    87  		cache = NewTransitSyncMap()
    88  	case cacheSize > 0:
    89  		newLRUCache, err := NewTransitLRU(cacheSize)
    90  		if err != nil {
    91  			return nil, errwrap.Wrapf("failed to create cache: {{err}}", err)
    92  		}
    93  		cache = newLRUCache
    94  	}
    95  
    96  	lm := &LockManager{
    97  		useCache: useCache,
    98  		cache:    cache,
    99  		keyLocks: locksutil.CreateLocks(),
   100  	}
   101  
   102  	return lm, nil
   103  }
   104  
   105  func (lm *LockManager) GetCacheSize() int {
   106  	if !lm.useCache {
   107  		return 0
   108  	}
   109  	return lm.cache.Size()
   110  }
   111  
   112  func (lm *LockManager) GetUseCache() bool {
   113  	return lm.useCache
   114  }
   115  
   116  func (lm *LockManager) InvalidatePolicy(name string) {
   117  	if lm.useCache {
   118  		lm.cache.Delete(name)
   119  	}
   120  }
   121  
   122  func (lm *LockManager) InitCache(cacheSize int) error {
   123  	if lm.useCache {
   124  		switch {
   125  		case cacheSize < 0:
   126  			return errors.New("cache size must be greater or equal to zero")
   127  		case cacheSize == 0:
   128  			lm.cache = NewTransitSyncMap()
   129  		case cacheSize > 0:
   130  			newLRUCache, err := NewTransitLRU(cacheSize)
   131  			if err != nil {
   132  				return errwrap.Wrapf("failed to create cache: {{err}}", err)
   133  			}
   134  			lm.cache = newLRUCache
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  // RestorePolicy acquires an exclusive lock on the policy name and restores the
   141  // given policy along with the archive.
   142  func (lm *LockManager) RestorePolicy(ctx context.Context, storage logical.Storage, name, backup string, force bool) error {
   143  	backupBytes, err := base64.StdEncoding.DecodeString(backup)
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	var keyData KeyData
   149  	err = jsonutil.DecodeJSON(backupBytes, &keyData)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	// Set a different name if desired
   155  	if name != "" {
   156  		keyData.Policy.Name = name
   157  	}
   158  
   159  	name = keyData.Policy.Name
   160  
   161  	// Grab the exclusive lock as we'll be modifying disk
   162  	lock := locksutil.LockForKey(lm.keyLocks, name)
   163  	lock.Lock()
   164  	defer lock.Unlock()
   165  
   166  	var ok bool
   167  	var pRaw interface{}
   168  
   169  	// If the policy is in cache and 'force' is not specified, error out. Anywhere
   170  	// that would put it in the cache will also be protected by the mutex above,
   171  	// so we don't need to re-check the cache later.
   172  	if lm.useCache {
   173  		pRaw, ok = lm.cache.Load(name)
   174  		if ok && !force {
   175  			return fmt.Errorf("key %q already exists", name)
   176  		}
   177  	}
   178  
   179  	// Conditionally look up the policy from storage, depending on the use of
   180  	// 'force' and if the policy was found in cache.
   181  	//
   182  	// - If was not found in cache and we are not using 'force', look for it in
   183  	// storage. If found, error out.
   184  	//
   185  	// - If it was found in cache and we are using 'force', pRaw will not be nil
   186  	// and we do not look the policy up from storage
   187  	//
   188  	// - If it was found in cache and we are not using 'force', we should have
   189  	// returned above with error
   190  	var p *Policy
   191  	if pRaw == nil {
   192  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   193  		if err != nil {
   194  			return err
   195  		}
   196  		if p != nil && !force {
   197  			return fmt.Errorf("key %q already exists", name)
   198  		}
   199  	}
   200  
   201  	// If both pRaw and p above are nil and 'force' is specified, we don't need to
   202  	// grab policy locks as we have ensured it doesn't already exist, so there
   203  	// will be no races as nothing else has this pointer. If 'force' was not used,
   204  	// an error would have been returned by now if the policy already existed
   205  	if pRaw != nil {
   206  		p = pRaw.(*Policy)
   207  	}
   208  	if p != nil {
   209  		p.l.Lock()
   210  		defer p.l.Unlock()
   211  	}
   212  
   213  	// Restore the archived keys
   214  	if keyData.ArchivedKeys != nil {
   215  		err = keyData.Policy.storeArchive(ctx, storage, keyData.ArchivedKeys)
   216  		if err != nil {
   217  			return errwrap.Wrapf(fmt.Sprintf("failed to restore archived keys for key %q: {{err}}", name), err)
   218  		}
   219  	}
   220  
   221  	// Mark that policy as a restored key
   222  	keyData.Policy.RestoreInfo = &RestoreInfo{
   223  		Time:    time.Now(),
   224  		Version: keyData.Policy.LatestVersion,
   225  	}
   226  
   227  	// Restore the policy. This will also attempt to adjust the archive.
   228  	err = keyData.Policy.Persist(ctx, storage)
   229  	if err != nil {
   230  		return errwrap.Wrapf(fmt.Sprintf("failed to restore the policy %q: {{err}}", name), err)
   231  	}
   232  
   233  	keyData.Policy.l = new(sync.RWMutex)
   234  
   235  	// Update the cache to contain the restored policy
   236  	if lm.useCache {
   237  		lm.cache.Store(name, keyData.Policy)
   238  	}
   239  	return nil
   240  }
   241  
   242  func (lm *LockManager) BackupPolicy(ctx context.Context, storage logical.Storage, name string) (string, error) {
   243  	var p *Policy
   244  	var err error
   245  
   246  	// Backup writes information about when the backup took place, so we get an
   247  	// exclusive lock here
   248  	lock := locksutil.LockForKey(lm.keyLocks, name)
   249  	lock.Lock()
   250  	defer lock.Unlock()
   251  
   252  	var ok bool
   253  	var pRaw interface{}
   254  
   255  	if lm.useCache {
   256  		pRaw, ok = lm.cache.Load(name)
   257  	}
   258  	if ok {
   259  		p = pRaw.(*Policy)
   260  		p.l.Lock()
   261  		defer p.l.Unlock()
   262  	} else {
   263  		// If the policy doesn't exit in storage, error out
   264  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   265  		if err != nil {
   266  			return "", err
   267  		}
   268  		if p == nil {
   269  			return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
   270  		}
   271  	}
   272  
   273  	if atomic.LoadUint32(&p.deleted) == 1 {
   274  		return "", fmt.Errorf(fmt.Sprintf("key %q not found", name))
   275  	}
   276  
   277  	backup, err := p.Backup(ctx, storage)
   278  	if err != nil {
   279  		return "", err
   280  	}
   281  
   282  	return backup, nil
   283  }
   284  
   285  // When the function returns, if caching was disabled, the Policy's lock must
   286  // be unlocked when the caller is done (and it should not be re-locked).
   287  func (lm *LockManager) GetPolicy(ctx context.Context, req PolicyRequest, rand io.Reader) (retP *Policy, retUpserted bool, retErr error) {
   288  	var p *Policy
   289  	var err error
   290  	var ok bool
   291  	var pRaw interface{}
   292  
   293  	// Check if it's in our cache. If so, return right away.
   294  	if lm.useCache {
   295  		pRaw, ok = lm.cache.Load(req.Name)
   296  	}
   297  	if ok {
   298  		p = pRaw.(*Policy)
   299  		if atomic.LoadUint32(&p.deleted) == 1 {
   300  			return nil, false, nil
   301  		}
   302  		return p, false, nil
   303  	}
   304  
   305  	// We're not using the cache, or it wasn't found; get an exclusive lock.
   306  	// This ensures that any other process writing the actual storage will be
   307  	// finished before we load from storage.
   308  	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
   309  	lock.Lock()
   310  
   311  	// If we are using the cache, defer the lock unlock; otherwise we will
   312  	// return from here with the lock still held.
   313  	cleanup := func() {
   314  		switch {
   315  		// If using the cache we always unlock, the caller locks the policy
   316  		// themselves
   317  		case lm.useCache:
   318  			lock.Unlock()
   319  
   320  		// If not using the cache, if we aren't returning a policy the caller
   321  		// doesn't have a lock, so we must unlock
   322  		case retP == nil:
   323  			lock.Unlock()
   324  		}
   325  	}
   326  
   327  	// Check the cache again
   328  	if lm.useCache {
   329  		pRaw, ok = lm.cache.Load(req.Name)
   330  	}
   331  	if ok {
   332  		p = pRaw.(*Policy)
   333  		if atomic.LoadUint32(&p.deleted) == 1 {
   334  			cleanup()
   335  			return nil, false, nil
   336  		}
   337  		retP = p
   338  		cleanup()
   339  		return
   340  	}
   341  
   342  	// Load it from storage
   343  	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
   344  	if err != nil {
   345  		cleanup()
   346  		return nil, false, err
   347  	}
   348  	// We don't need to lock the policy as there would be no other holders of
   349  	// the pointer
   350  
   351  	if p == nil {
   352  		// This is the only place we upsert a new policy, so if upsert is not
   353  		// specified, or the lock type is wrong, unlock before returning
   354  		if !req.Upsert {
   355  			cleanup()
   356  			return nil, false, nil
   357  		}
   358  
   359  		// We create the policy here, then at the end we do a LoadOrStore. If
   360  		// it's been loaded since we last checked the cache, we return an error
   361  		// to the user to let them know that their request can't be satisfied
   362  		// because we don't know if the parameters match.
   363  
   364  		switch req.KeyType {
   365  		case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
   366  			if req.Convergent && !req.Derived {
   367  				cleanup()
   368  				return nil, false, fmt.Errorf("convergent encryption requires derivation to be enabled")
   369  			}
   370  
   371  		case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
   372  			if req.Derived || req.Convergent {
   373  				cleanup()
   374  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   375  			}
   376  
   377  		case KeyType_ED25519:
   378  			if req.Convergent {
   379  				cleanup()
   380  				return nil, false, fmt.Errorf("convergent encryption not supported for keys of type %v", req.KeyType)
   381  			}
   382  
   383  		case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
   384  			if req.Derived || req.Convergent {
   385  				cleanup()
   386  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   387  			}
   388  		case KeyType_HMAC:
   389  			if req.Derived || req.Convergent {
   390  				cleanup()
   391  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   392  			}
   393  
   394  		case KeyType_MANAGED_KEY:
   395  			if req.Derived || req.Convergent {
   396  				cleanup()
   397  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   398  			}
   399  
   400  		case KeyType_AES128_CMAC, KeyType_AES256_CMAC:
   401  			if req.Derived || req.Convergent {
   402  				cleanup()
   403  				return nil, false, fmt.Errorf("key derivation and convergent encryption not supported for keys of type %v", req.KeyType)
   404  			}
   405  
   406  		default:
   407  			cleanup()
   408  			return nil, false, fmt.Errorf("unsupported key type %v", req.KeyType)
   409  		}
   410  
   411  		p = &Policy{
   412  			l:                    new(sync.RWMutex),
   413  			Name:                 req.Name,
   414  			Type:                 req.KeyType,
   415  			Derived:              req.Derived,
   416  			Exportable:           req.Exportable,
   417  			AllowPlaintextBackup: req.AllowPlaintextBackup,
   418  			AutoRotatePeriod:     req.AutoRotatePeriod,
   419  			KeySize:              req.KeySize,
   420  		}
   421  
   422  		if req.Derived {
   423  			p.KDF = Kdf_hkdf_sha256
   424  			if req.Convergent {
   425  				p.ConvergentEncryption = true
   426  				// As of version 3 we store the version within each key, so we
   427  				// set to -1 to indicate that the value in the policy has no
   428  				// meaning. We still, for backwards compatibility, fall back to
   429  				// this value if the key doesn't have one, which means it will
   430  				// only be -1 in the case where every key version is >= 3
   431  				p.ConvergentVersion = -1
   432  			}
   433  		}
   434  
   435  		// Performs the actual persist and does setup
   436  		if p.Type == KeyType_MANAGED_KEY {
   437  			err = p.RotateManagedKey(ctx, req.Storage, req.ManagedKeyUUID)
   438  		} else {
   439  			err = p.Rotate(ctx, req.Storage, rand)
   440  		}
   441  		if err != nil {
   442  			cleanup()
   443  			return nil, false, err
   444  		}
   445  
   446  		if lm.useCache {
   447  			lm.cache.Store(req.Name, p)
   448  		} else {
   449  			p.l = &lock.RWMutex
   450  			p.writeLocked = true
   451  		}
   452  
   453  		// We don't need to worry about upgrading since it will be a new policy
   454  		retP = p
   455  		retUpserted = true
   456  		cleanup()
   457  		return
   458  	}
   459  
   460  	if p.NeedsUpgrade() {
   461  		if err := p.Upgrade(ctx, req.Storage, rand); err != nil {
   462  			cleanup()
   463  			return nil, false, err
   464  		}
   465  	}
   466  
   467  	if lm.useCache {
   468  		lm.cache.Store(req.Name, p)
   469  	} else {
   470  		p.l = &lock.RWMutex
   471  		p.writeLocked = true
   472  	}
   473  
   474  	retP = p
   475  	cleanup()
   476  	return
   477  }
   478  
   479  func (lm *LockManager) ImportPolicy(ctx context.Context, req PolicyRequest, key []byte, rand io.Reader) error {
   480  	var p *Policy
   481  	var err error
   482  	var ok bool
   483  	var pRaw interface{}
   484  
   485  	// Check if it's in our cache
   486  	if lm.useCache {
   487  		pRaw, ok = lm.cache.Load(req.Name)
   488  	}
   489  	if ok {
   490  		p = pRaw.(*Policy)
   491  		if atomic.LoadUint32(&p.deleted) == 1 {
   492  			return nil
   493  		}
   494  	}
   495  
   496  	// We're not using the cache, or it wasn't found; get an exclusive lock.
   497  	// This ensures that any other process writing the actual storage will be
   498  	// finished before we load from storage.
   499  	lock := locksutil.LockForKey(lm.keyLocks, req.Name)
   500  	lock.Lock()
   501  	defer lock.Unlock()
   502  
   503  	// Load it from storage
   504  	p, err = lm.getPolicyFromStorage(ctx, req.Storage, req.Name)
   505  	if err != nil {
   506  		return err
   507  	}
   508  
   509  	if p == nil {
   510  		p = &Policy{
   511  			l:                        new(sync.RWMutex),
   512  			Name:                     req.Name,
   513  			Type:                     req.KeyType,
   514  			Derived:                  req.Derived,
   515  			Exportable:               req.Exportable,
   516  			AllowPlaintextBackup:     req.AllowPlaintextBackup,
   517  			AutoRotatePeriod:         req.AutoRotatePeriod,
   518  			AllowImportedKeyRotation: req.AllowImportedKeyRotation,
   519  			Imported:                 true,
   520  		}
   521  	}
   522  
   523  	err = p.ImportPublicOrPrivate(ctx, req.Storage, key, req.IsPrivateKey, rand)
   524  	if err != nil {
   525  		return fmt.Errorf("error importing key: %s", err)
   526  	}
   527  
   528  	if lm.useCache {
   529  		lm.cache.Store(req.Name, p)
   530  	}
   531  
   532  	return nil
   533  }
   534  
   535  func (lm *LockManager) DeletePolicy(ctx context.Context, storage logical.Storage, name string) error {
   536  	var p *Policy
   537  	var err error
   538  	var ok bool
   539  	var pRaw interface{}
   540  
   541  	// We may be writing to disk, so grab an exclusive lock. This prevents bad
   542  	// behavior when the cache is turned off. We also lock the shared policy
   543  	// object to make sure no requests are in flight.
   544  	lock := locksutil.LockForKey(lm.keyLocks, name)
   545  	lock.Lock()
   546  	defer lock.Unlock()
   547  
   548  	if lm.useCache {
   549  		pRaw, ok = lm.cache.Load(name)
   550  	}
   551  	if ok {
   552  		p = pRaw.(*Policy)
   553  		p.l.Lock()
   554  		defer p.l.Unlock()
   555  	}
   556  
   557  	if p == nil {
   558  		p, err = lm.getPolicyFromStorage(ctx, storage, name)
   559  		if err != nil {
   560  			return err
   561  		}
   562  		if p == nil {
   563  			return fmt.Errorf("could not delete key; not found")
   564  		}
   565  	}
   566  
   567  	if !p.DeletionAllowed {
   568  		return fmt.Errorf("deletion is not allowed for this key")
   569  	}
   570  
   571  	atomic.StoreUint32(&p.deleted, 1)
   572  
   573  	if lm.useCache {
   574  		lm.cache.Delete(name)
   575  	}
   576  
   577  	err = storage.Delete(ctx, "policy/"+name)
   578  	if err != nil {
   579  		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q: {{err}}", name), err)
   580  	}
   581  
   582  	err = storage.Delete(ctx, "archive/"+name)
   583  	if err != nil {
   584  		return errwrap.Wrapf(fmt.Sprintf("error deleting key %q archive: {{err}}", name), err)
   585  	}
   586  
   587  	return nil
   588  }
   589  
   590  func (lm *LockManager) getPolicyFromStorage(ctx context.Context, storage logical.Storage, name string) (*Policy, error) {
   591  	return LoadPolicy(ctx, storage, "policy/"+name)
   592  }