github.com/hashicorp/vault/sdk@v0.13.0/physical/inmem/transactions_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package inmem 5 6 import ( 7 "context" 8 "fmt" 9 "reflect" 10 "sort" 11 "testing" 12 13 radix "github.com/armon/go-radix" 14 log "github.com/hashicorp/go-hclog" 15 "github.com/hashicorp/vault/sdk/helper/logging" 16 "github.com/hashicorp/vault/sdk/physical" 17 ) 18 19 type faultyPseudo struct { 20 underlying InmemBackend 21 faultyPaths map[string]struct{} 22 } 23 24 func (f *faultyPseudo) Get(ctx context.Context, key string) (*physical.Entry, error) { 25 return f.underlying.Get(context.Background(), key) 26 } 27 28 func (f *faultyPseudo) Put(ctx context.Context, entry *physical.Entry) error { 29 return f.underlying.Put(context.Background(), entry) 30 } 31 32 func (f *faultyPseudo) Delete(ctx context.Context, key string) error { 33 return f.underlying.Delete(context.Background(), key) 34 } 35 36 func (f *faultyPseudo) GetInternal(ctx context.Context, key string) (*physical.Entry, error) { 37 if _, ok := f.faultyPaths[key]; ok { 38 return nil, fmt.Errorf("fault") 39 } 40 return f.underlying.GetInternal(context.Background(), key) 41 } 42 43 func (f *faultyPseudo) PutInternal(ctx context.Context, entry *physical.Entry) error { 44 if _, ok := f.faultyPaths[entry.Key]; ok { 45 return fmt.Errorf("fault") 46 } 47 return f.underlying.PutInternal(context.Background(), entry) 48 } 49 50 func (f *faultyPseudo) DeleteInternal(ctx context.Context, key string) error { 51 if _, ok := f.faultyPaths[key]; ok { 52 return fmt.Errorf("fault") 53 } 54 return f.underlying.DeleteInternal(context.Background(), key) 55 } 56 57 func (f *faultyPseudo) List(ctx context.Context, prefix string) ([]string, error) { 58 return f.underlying.List(context.Background(), prefix) 59 } 60 61 func (f *faultyPseudo) Transaction(ctx context.Context, txns []*physical.TxnEntry) error { 62 f.underlying.permitPool.Acquire() 63 defer f.underlying.permitPool.Release() 64 65 f.underlying.Lock() 66 defer f.underlying.Unlock() 67 68 return physical.GenericTransactionHandler(ctx, f, txns) 69 } 70 71 func newFaultyPseudo(logger log.Logger, faultyPaths []string) *faultyPseudo { 72 out := &faultyPseudo{ 73 underlying: InmemBackend{ 74 root: radix.New(), 75 permitPool: physical.NewPermitPool(1), 76 logger: logger.Named("storage.inmembackend"), 77 failGet: new(uint32), 78 failPut: new(uint32), 79 failDelete: new(uint32), 80 failList: new(uint32), 81 }, 82 faultyPaths: make(map[string]struct{}, len(faultyPaths)), 83 } 84 for _, v := range faultyPaths { 85 out.faultyPaths[v] = struct{}{} 86 } 87 return out 88 } 89 90 func TestPseudo_Basic(t *testing.T) { 91 logger := logging.NewVaultLogger(log.Debug) 92 p := newFaultyPseudo(logger, nil) 93 physical.ExerciseBackend(t, p) 94 physical.ExerciseBackend_ListPrefix(t, p) 95 } 96 97 func TestPseudo_SuccessfulTransaction(t *testing.T) { 98 logger := logging.NewVaultLogger(log.Debug) 99 p := newFaultyPseudo(logger, nil) 100 101 physical.ExerciseTransactionalBackend(t, p) 102 } 103 104 func TestPseudo_FailedTransaction(t *testing.T) { 105 logger := logging.NewVaultLogger(log.Debug) 106 p := newFaultyPseudo(logger, []string{"zip"}) 107 108 txns := physical.SetupTestingTransactions(t, p) 109 if err := p.Transaction(context.Background(), txns); err == nil { 110 t.Fatal("expected error during transaction") 111 } 112 113 keys, err := p.List(context.Background(), "") 114 if err != nil { 115 t.Fatal(err) 116 } 117 118 expected := []string{"foo", "zip", "deleteme", "deleteme2"} 119 120 sort.Strings(keys) 121 sort.Strings(expected) 122 if !reflect.DeepEqual(keys, expected) { 123 t.Fatalf("mismatch: expected\n%#v\ngot\n%#v\n", expected, keys) 124 } 125 126 entry, err := p.Get(context.Background(), "foo") 127 if err != nil { 128 t.Fatal(err) 129 } 130 if entry == nil { 131 t.Fatal("got nil entry") 132 } 133 if entry.Value == nil { 134 t.Fatal("got nil value") 135 } 136 if string(entry.Value) != "bar" { 137 t.Fatal("values did not rollback correctly") 138 } 139 140 entry, err = p.Get(context.Background(), "zip") 141 if err != nil { 142 t.Fatal(err) 143 } 144 if entry == nil { 145 t.Fatal("got nil entry") 146 } 147 if entry.Value == nil { 148 t.Fatal("got nil value") 149 } 150 if string(entry.Value) != "zap" { 151 t.Fatal("values did not rollback correctly") 152 } 153 }