github.com/hashicorp/vault/sdk@v0.13.0/physical/inmem/transactions_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package inmem
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"reflect"
    10  	"sort"
    11  	"testing"
    12  
    13  	radix "github.com/armon/go-radix"
    14  	log "github.com/hashicorp/go-hclog"
    15  	"github.com/hashicorp/vault/sdk/helper/logging"
    16  	"github.com/hashicorp/vault/sdk/physical"
    17  )
    18  
    19  type faultyPseudo struct {
    20  	underlying  InmemBackend
    21  	faultyPaths map[string]struct{}
    22  }
    23  
    24  func (f *faultyPseudo) Get(ctx context.Context, key string) (*physical.Entry, error) {
    25  	return f.underlying.Get(context.Background(), key)
    26  }
    27  
    28  func (f *faultyPseudo) Put(ctx context.Context, entry *physical.Entry) error {
    29  	return f.underlying.Put(context.Background(), entry)
    30  }
    31  
    32  func (f *faultyPseudo) Delete(ctx context.Context, key string) error {
    33  	return f.underlying.Delete(context.Background(), key)
    34  }
    35  
    36  func (f *faultyPseudo) GetInternal(ctx context.Context, key string) (*physical.Entry, error) {
    37  	if _, ok := f.faultyPaths[key]; ok {
    38  		return nil, fmt.Errorf("fault")
    39  	}
    40  	return f.underlying.GetInternal(context.Background(), key)
    41  }
    42  
    43  func (f *faultyPseudo) PutInternal(ctx context.Context, entry *physical.Entry) error {
    44  	if _, ok := f.faultyPaths[entry.Key]; ok {
    45  		return fmt.Errorf("fault")
    46  	}
    47  	return f.underlying.PutInternal(context.Background(), entry)
    48  }
    49  
    50  func (f *faultyPseudo) DeleteInternal(ctx context.Context, key string) error {
    51  	if _, ok := f.faultyPaths[key]; ok {
    52  		return fmt.Errorf("fault")
    53  	}
    54  	return f.underlying.DeleteInternal(context.Background(), key)
    55  }
    56  
    57  func (f *faultyPseudo) List(ctx context.Context, prefix string) ([]string, error) {
    58  	return f.underlying.List(context.Background(), prefix)
    59  }
    60  
    61  func (f *faultyPseudo) Transaction(ctx context.Context, txns []*physical.TxnEntry) error {
    62  	f.underlying.permitPool.Acquire()
    63  	defer f.underlying.permitPool.Release()
    64  
    65  	f.underlying.Lock()
    66  	defer f.underlying.Unlock()
    67  
    68  	return physical.GenericTransactionHandler(ctx, f, txns)
    69  }
    70  
    71  func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo {
    72  	out := &faultyPseudo{
    73  		underlying: InmemBackend{
    74  			root:       radix.New(),
    75  			permitPool: physical.NewPermitPool(1),
    76  			logger:     logger.Named("storage.inmembackend"),
    77  			failGet:    new(uint32),
    78  			failPut:    new(uint32),
    79  			failDelete: new(uint32),
    80  			failList:   new(uint32),
    81  		},
    82  		faultyPaths: make(map[string]struct{}, len(faultyPaths)),
    83  	}
    84  	for _, v := range faultyPaths {
    85  		out.faultyPaths[v] = struct{}{}
    86  	}
    87  	return out
    88  }
    89  
    90  func TestPseudo_Basic(t *testing.T) {
    91  	logger := logging.NewVaultLogger(log.Debug)
    92  	p := newFaultyPseudo(logger, nil)
    93  	physical.ExerciseBackend(t, p)
    94  	physical.ExerciseBackend_ListPrefix(t, p)
    95  }
    96  
    97  func TestPseudo_SuccessfulTransaction(t *testing.T) {
    98  	logger := logging.NewVaultLogger(log.Debug)
    99  	p := newFaultyPseudo(logger, nil)
   100  
   101  	physical.ExerciseTransactionalBackend(t, p)
   102  }
   103  
   104  func TestPseudo_FailedTransaction(t *testing.T) {
   105  	logger := logging.NewVaultLogger(log.Debug)
   106  	p := newFaultyPseudo(logger, []string{"zip"})
   107  
   108  	txns := physical.SetupTestingTransactions(t, p)
   109  	if err := p.Transaction(context.Background(), txns); err == nil {
   110  		t.Fatal("expected error during transaction")
   111  	}
   112  
   113  	keys, err := p.List(context.Background(), "")
   114  	if err != nil {
   115  		t.Fatal(err)
   116  	}
   117  
   118  	expected := []string{"foo", "zip", "deleteme", "deleteme2"}
   119  
   120  	sort.Strings(keys)
   121  	sort.Strings(expected)
   122  	if !reflect.DeepEqual(keys, expected) {
   123  		t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys)
   124  	}
   125  
   126  	entry, err := p.Get(context.Background(), "foo")
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	if entry == nil {
   131  		t.Fatal("got nil entry")
   132  	}
   133  	if entry.Value == nil {
   134  		t.Fatal("got nil value")
   135  	}
   136  	if string(entry.Value) != "bar" {
   137  		t.Fatal("values did not rollback correctly")
   138  	}
   139  
   140  	entry, err = p.Get(context.Background(), "zip")
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  	if entry == nil {
   145  		t.Fatal("got nil entry")
   146  	}
   147  	if entry.Value == nil {
   148  		t.Fatal("got nil value")
   149  	}
   150  	if string(entry.Value) != "zap" {
   151  		t.Fatal("values did not rollback correctly")
   152  	}
   153  }