github.com/hashicorp/vault/sdk@v0.13.0/physical/inmem/inmem.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package inmem
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	"github.com/armon/go-radix"
    18  	log "github.com/hashicorp/go-hclog"
    19  	"github.com/hashicorp/vault/sdk/physical"
    20  	uberAtomic "go.uber.org/atomic"
    21  )
    22  
    23  // Verify interfaces are satisfied
    24  var (
    25  	_ physical.Backend                   = (*InmemBackend)(nil)
    26  	_ physical.MountTableLimitingBackend = (*InmemBackend)(nil)
    27  	_ physical.HABackend                 = (*InmemHABackend)(nil)
    28  	_ physical.HABackend                 = (*TransactionalInmemHABackend)(nil)
    29  	_ physical.Lock                      = (*InmemLock)(nil)
    30  	_ physical.Transactional             = (*TransactionalInmemBackend)(nil)
    31  	_ physical.Transactional             = (*TransactionalInmemHABackend)(nil)
    32  	_ physical.TransactionalLimits       = (*TransactionalInmemBackend)(nil)
    33  )
    34  
    35  var (
    36  	PutDisabledError      = errors.New("put operations disabled in inmem backend")
    37  	GetDisabledError      = errors.New("get operations disabled in inmem backend")
    38  	DeleteDisabledError   = errors.New("delete operations disabled in inmem backend")
    39  	ListDisabledError     = errors.New("list operations disabled in inmem backend")
    40  	GetInTxnDisabledError = errors.New("get operations inside transactions are disabled in inmem backend")
    41  )
    42  
    43  // InmemBackend is an in-memory only physical backend. It is useful
    44  // for testing and development situations where the data is not
    45  // expected to be durable.
    46  type InmemBackend struct {
    47  	sync.RWMutex
    48  	root         *radix.Tree
    49  	permitPool   *physical.PermitPool
    50  	logger       log.Logger
    51  	failGet      *uint32
    52  	failPut      *uint32
    53  	failDelete   *uint32
    54  	failList     *uint32
    55  	failGetInTxn *uint32
    56  	logOps       bool
    57  	maxValueSize int
    58  	writeLatency time.Duration
    59  
    60  	mountTablePaths map[string]struct{}
    61  }
    62  
    63  type TransactionalInmemBackend struct {
    64  	InmemBackend
    65  
    66  	// Using Uber atomic because our SemGrep rules don't like the old pointer
    67  	// trick we used above any more even though it's fine. The newer sync/atomic
    68  	// types are almost the same, but lack ways to initialize them cleanly in New*
    69  	// functions so sticking with what SemGrep likes for now.
    70  	maxBatchEntries *uberAtomic.Int32
    71  	maxBatchSize    *uberAtomic.Int32
    72  
    73  	largestBatchLen  *uberAtomic.Uint64
    74  	largestBatchSize *uberAtomic.Uint64
    75  
    76  	transactionCompleteCh chan *txnCommitRequest
    77  }
    78  
    79  // NewInmem constructs a new in-memory backend
    80  func NewInmem(conf map[string]string, logger log.Logger) (physical.Backend, error) {
    81  	maxValueSize := 0
    82  	maxValueSizeStr, ok := conf["max_value_size"]
    83  	if ok {
    84  		var err error
    85  		maxValueSize, err = strconv.Atoi(maxValueSizeStr)
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  	}
    90  
    91  	return &InmemBackend{
    92  		root:         radix.New(),
    93  		permitPool:   physical.NewPermitPool(physical.DefaultParallelOperations),
    94  		logger:       logger,
    95  		failGet:      new(uint32),
    96  		failPut:      new(uint32),
    97  		failDelete:   new(uint32),
    98  		failList:     new(uint32),
    99  		failGetInTxn: new(uint32),
   100  		logOps:       os.Getenv("VAULT_INMEM_LOG_ALL_OPS") != "",
   101  		maxValueSize: maxValueSize,
   102  	}, nil
   103  }
   104  
   105  // Basically for now just creates a permit pool of size 1 so only one operation
   106  // can run at a time
   107  func NewTransactionalInmem(conf map[string]string, logger log.Logger) (physical.Backend, error) {
   108  	maxValueSize := 0
   109  	maxValueSizeStr, ok := conf["max_value_size"]
   110  	if ok {
   111  		var err error
   112  		maxValueSize, err = strconv.Atoi(maxValueSizeStr)
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  	}
   117  
   118  	return &TransactionalInmemBackend{
   119  		InmemBackend: InmemBackend{
   120  			root:         radix.New(),
   121  			permitPool:   physical.NewPermitPool(1),
   122  			logger:       logger,
   123  			failGet:      new(uint32),
   124  			failPut:      new(uint32),
   125  			failDelete:   new(uint32),
   126  			failList:     new(uint32),
   127  			failGetInTxn: new(uint32),
   128  			logOps:       os.Getenv("VAULT_INMEM_LOG_ALL_OPS") != "",
   129  			maxValueSize: maxValueSize,
   130  		},
   131  
   132  		maxBatchEntries:  uberAtomic.NewInt32(64),
   133  		maxBatchSize:     uberAtomic.NewInt32(128 * 1024),
   134  		largestBatchLen:  uberAtomic.NewUint64(0),
   135  		largestBatchSize: uberAtomic.NewUint64(0),
   136  	}, nil
   137  }
   138  
   139  // SetWriteLatency add a sleep to each Put/Delete operation (and each op in a
   140  // transaction for a TransactionalInmemBackend). It's not so much to simulate
   141  // real disk latency as much as to make the go runtime schedule things more like
   142  // a real disk where concurrent write operations are more likely to interleave
   143  // as each one blocks on disk IO. Set to 0 to disable again (the default).
   144  func (i *InmemBackend) SetWriteLatency(latency time.Duration) {
   145  	i.Lock()
   146  	defer i.Unlock()
   147  	i.writeLatency = latency
   148  }
   149  
   150  // Put is used to insert or update an entry
   151  func (i *InmemBackend) Put(ctx context.Context, entry *physical.Entry) error {
   152  	i.permitPool.Acquire()
   153  	defer i.permitPool.Release()
   154  
   155  	i.Lock()
   156  	defer i.Unlock()
   157  
   158  	return i.PutInternal(ctx, entry)
   159  }
   160  
   161  func (i *InmemBackend) PutInternal(ctx context.Context, entry *physical.Entry) error {
   162  	if i.logOps {
   163  		i.logger.Trace("put", "key", entry.Key)
   164  	}
   165  	if atomic.LoadUint32(i.failPut) != 0 {
   166  		return PutDisabledError
   167  	}
   168  
   169  	select {
   170  	case <-ctx.Done():
   171  		return ctx.Err()
   172  	default:
   173  	}
   174  
   175  	if i.maxValueSize > 0 && len(entry.Value) > i.maxValueSize {
   176  		return fmt.Errorf("%s", physical.ErrValueTooLarge)
   177  	}
   178  
   179  	i.root.Insert(entry.Key, entry.Value)
   180  	if i.writeLatency > 0 {
   181  		time.Sleep(i.writeLatency)
   182  	}
   183  	return nil
   184  }
   185  
   186  func (i *InmemBackend) FailPut(fail bool) {
   187  	var val uint32
   188  	if fail {
   189  		val = 1
   190  	}
   191  	atomic.StoreUint32(i.failPut, val)
   192  }
   193  
   194  // Get is used to fetch an entry
   195  func (i *InmemBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
   196  	i.permitPool.Acquire()
   197  	defer i.permitPool.Release()
   198  
   199  	i.RLock()
   200  	defer i.RUnlock()
   201  
   202  	return i.GetInternal(ctx, key)
   203  }
   204  
   205  func (i *InmemBackend) GetInternal(ctx context.Context, key string) (*physical.Entry, error) {
   206  	if i.logOps {
   207  		i.logger.Trace("get", "key", key)
   208  	}
   209  	if atomic.LoadUint32(i.failGet) != 0 {
   210  		return nil, GetDisabledError
   211  	}
   212  
   213  	select {
   214  	case <-ctx.Done():
   215  		return nil, ctx.Err()
   216  	default:
   217  	}
   218  
   219  	if raw, ok := i.root.Get(key); ok {
   220  		return &physical.Entry{
   221  			Key:   key,
   222  			Value: raw.([]byte),
   223  		}, nil
   224  	}
   225  	return nil, nil
   226  }
   227  
   228  func (i *InmemBackend) FailGet(fail bool) {
   229  	var val uint32
   230  	if fail {
   231  		val = 1
   232  	}
   233  	atomic.StoreUint32(i.failGet, val)
   234  }
   235  
   236  func (i *InmemBackend) FailGetInTxn(fail bool) {
   237  	var val uint32
   238  	if fail {
   239  		val = 1
   240  	}
   241  	atomic.StoreUint32(i.failGetInTxn, val)
   242  }
   243  
   244  // Delete is used to permanently delete an entry
   245  func (i *InmemBackend) Delete(ctx context.Context, key string) error {
   246  	i.permitPool.Acquire()
   247  	defer i.permitPool.Release()
   248  
   249  	i.Lock()
   250  	defer i.Unlock()
   251  
   252  	return i.DeleteInternal(ctx, key)
   253  }
   254  
   255  func (i *InmemBackend) DeleteInternal(ctx context.Context, key string) error {
   256  	if i.logOps {
   257  		i.logger.Trace("delete", "key", key)
   258  	}
   259  	if atomic.LoadUint32(i.failDelete) != 0 {
   260  		return DeleteDisabledError
   261  	}
   262  	select {
   263  	case <-ctx.Done():
   264  		return ctx.Err()
   265  	default:
   266  	}
   267  
   268  	i.root.Delete(key)
   269  	if i.writeLatency > 0 {
   270  		time.Sleep(i.writeLatency)
   271  	}
   272  	return nil
   273  }
   274  
   275  func (i *InmemBackend) FailDelete(fail bool) {
   276  	var val uint32
   277  	if fail {
   278  		val = 1
   279  	}
   280  	atomic.StoreUint32(i.failDelete, val)
   281  }
   282  
   283  // List is used to list all the keys under a given
   284  // prefix, up to the next prefix.
   285  func (i *InmemBackend) List(ctx context.Context, prefix string) ([]string, error) {
   286  	i.permitPool.Acquire()
   287  	defer i.permitPool.Release()
   288  
   289  	i.RLock()
   290  	defer i.RUnlock()
   291  
   292  	return i.ListInternal(ctx, prefix)
   293  }
   294  
   295  func (i *InmemBackend) ListInternal(ctx context.Context, prefix string) ([]string, error) {
   296  	if i.logOps {
   297  		i.logger.Trace("list", "prefix", prefix)
   298  	}
   299  	if atomic.LoadUint32(i.failList) != 0 {
   300  		return nil, ListDisabledError
   301  	}
   302  
   303  	var out []string
   304  	seen := make(map[string]interface{})
   305  	walkFn := func(s string, v interface{}) bool {
   306  		trimmed := strings.TrimPrefix(s, prefix)
   307  		sep := strings.Index(trimmed, "/")
   308  		if sep == -1 {
   309  			out = append(out, trimmed)
   310  		} else {
   311  			trimmed = trimmed[:sep+1]
   312  			if _, ok := seen[trimmed]; !ok {
   313  				out = append(out, trimmed)
   314  				seen[trimmed] = struct{}{}
   315  			}
   316  		}
   317  		return false
   318  	}
   319  	i.root.WalkPrefix(prefix, walkFn)
   320  
   321  	select {
   322  	case <-ctx.Done():
   323  		return nil, ctx.Err()
   324  	default:
   325  	}
   326  
   327  	return out, nil
   328  }
   329  
   330  func (i *InmemBackend) FailList(fail bool) {
   331  	var val uint32
   332  	if fail {
   333  		val = 1
   334  	}
   335  	atomic.StoreUint32(i.failList, val)
   336  }
   337  
   338  // RegisterMountTablePath implements physical.MountTableLimitingBackend
   339  func (i *InmemBackend) RegisterMountTablePath(path string) {
   340  	if i.mountTablePaths == nil {
   341  		i.mountTablePaths = make(map[string]struct{})
   342  	}
   343  	i.mountTablePaths[path] = struct{}{}
   344  }
   345  
   346  // GetMountTablePaths returns any paths registered as mount table or namespace
   347  // metadata paths. It's intended for testing.
   348  func (i *InmemBackend) GetMountTablePaths() []string {
   349  	var paths []string
   350  	for path := range i.mountTablePaths {
   351  		paths = append(paths, path)
   352  	}
   353  	return paths
   354  }
   355  
   356  // Transaction implements the transaction interface
   357  func (t *TransactionalInmemBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
   358  	t.permitPool.Acquire()
   359  	defer t.permitPool.Release()
   360  
   361  	t.Lock()
   362  	defer t.Unlock()
   363  
   364  	failGetInTxn := atomic.LoadUint32(t.failGetInTxn)
   365  	size := uint64(0)
   366  	for _, t := range txns {
   367  		// We use 2x key length to match the logic in WALBackend.persistWALs.
   368  		size += uint64(2*len(t.Entry.Key) + len(t.Entry.Value))
   369  		if t.Operation == physical.GetOperation && failGetInTxn != 0 {
   370  			return GetInTxnDisabledError
   371  		}
   372  	}
   373  
   374  	if size > t.largestBatchSize.Load() {
   375  		t.largestBatchSize.Store(size)
   376  	}
   377  	if len(txns) > int(t.largestBatchLen.Load()) {
   378  		t.largestBatchLen.Store(uint64(len(txns)))
   379  	}
   380  
   381  	err := physical.GenericTransactionHandler(ctx, t, txns)
   382  
   383  	// If we have a transactionCompleteCh set, we block on it before returning.
   384  	if t.transactionCompleteCh != nil {
   385  		req := &txnCommitRequest{
   386  			txns: txns,
   387  			ch:   make(chan struct{}),
   388  		}
   389  		t.transactionCompleteCh <- req
   390  		<-req.ch
   391  	}
   392  	return err
   393  }
   394  
   395  func (t *TransactionalInmemBackend) SetMaxBatchEntries(entries int) {
   396  	t.maxBatchEntries.Store(int32(entries))
   397  }
   398  
   399  func (t *TransactionalInmemBackend) SetMaxBatchSize(entries int) {
   400  	t.maxBatchSize.Store(int32(entries))
   401  }
   402  
   403  func (t *TransactionalInmemBackend) TransactionLimits() (int, int) {
   404  	return int(t.maxBatchEntries.Load()), int(t.maxBatchSize.Load())
   405  }
   406  
   407  func (t *TransactionalInmemBackend) BatchStats() (maxEntries uint64, maxSize uint64) {
   408  	return t.largestBatchLen.Load(), t.largestBatchSize.Load()
   409  }
   410  
   411  // TxnCommitChan returns a channel that allows deterministic control of when
   412  // transactions are executed. Each time `Transaction` is called on the backend,
   413  // a txnCommitRequest is sent on the chan returned and then Transaction will
   414  // block until Done is called on that request object. This allows tests to
   415  // deterministically wait until a persist is actually in progress, as well as
   416  // control when the persist completes. The returned chan is buffered with a
   417  // length of 5 which should be enough to ensure that test code doesn't deadlock
   418  // in normal operation since we typically only have one outstanding transaction
   419  // at at time.
   420  func (t *TransactionalInmemBackend) TxnCommitChan() <-chan *txnCommitRequest {
   421  	t.Lock()
   422  	defer t.Unlock()
   423  
   424  	ch := make(chan *txnCommitRequest, 5)
   425  	t.transactionCompleteCh = ch
   426  
   427  	return ch
   428  }
   429  
   430  type txnCommitRequest struct {
   431  	txns []*physical.TxnEntry
   432  	ch   chan struct{}
   433  }
   434  
   435  func (r *txnCommitRequest) Commit() {
   436  	close(r.ch)
   437  }