github.com/hashicorp/vault/sdk@v0.13.0/plugin/mock/backend.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package mock
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"os"
    10  	"testing"
    11  
    12  	"github.com/hashicorp/vault/api"
    13  	"github.com/hashicorp/vault/sdk/framework"
    14  	"github.com/hashicorp/vault/sdk/logical"
    15  )
    16  
    17  const (
    18  	MockPluginVersionEnv           = "TESTING_MOCK_VAULT_PLUGIN_VERSION"
    19  	MockPluginDefaultInternalValue = "bar"
    20  )
    21  
    22  // New returns a new backend as an interface. This func
    23  // is only necessary for builtin backend plugins.
    24  func New() (interface{}, error) {
    25  	return Backend(), nil
    26  }
    27  
    28  // Factory returns a new backend as logical.Backend.
    29  func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
    30  	b := Backend()
    31  	if err := b.Setup(ctx, conf); err != nil {
    32  		return nil, err
    33  	}
    34  	return b, nil
    35  }
    36  
    37  // FactoryType is a wrapper func that allows the Factory func to specify
    38  // the backend type for the mock backend plugin instance.
    39  func FactoryType(backendType logical.BackendType) logical.Factory {
    40  	return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
    41  		b := Backend()
    42  		b.BackendType = backendType
    43  		if err := b.Setup(ctx, conf); err != nil {
    44  			return nil, err
    45  		}
    46  		return b, nil
    47  	}
    48  }
    49  
    50  // Backend returns a private embedded struct of framework.Backend.
    51  func Backend() *backend {
    52  	var b backend
    53  	b.Backend = &framework.Backend{
    54  		Help: "",
    55  		Paths: framework.PathAppend(
    56  			errorPaths(&b),
    57  			kvPaths(&b),
    58  			[]*framework.Path{
    59  				pathInternal(&b),
    60  				pathSpecial(&b),
    61  				pathRaw(&b),
    62  				pathEnv(&b),
    63  			},
    64  		),
    65  		PathsSpecial: &logical.Paths{
    66  			Unauthenticated: []string{
    67  				"special",
    68  			},
    69  		},
    70  		Secrets:     []*framework.Secret{},
    71  		Invalidate:  b.invalidate,
    72  		BackendType: logical.TypeLogical,
    73  	}
    74  	b.internal = MockPluginDefaultInternalValue
    75  	b.RunningVersion = "v0.0.0+mock"
    76  	if version := os.Getenv(MockPluginVersionEnv); version != "" {
    77  		b.RunningVersion = version
    78  	}
    79  	return &b
    80  }
    81  
    82  type backend struct {
    83  	*framework.Backend
    84  
    85  	// internal is used to test invalidate and reloads.
    86  	internal string
    87  }
    88  
    89  func (b *backend) invalidate(ctx context.Context, key string) {
    90  	switch key {
    91  	case "internal":
    92  		b.internal = ""
    93  	}
    94  }
    95  
    96  // WriteInternalValue is a helper to set an in-memory value in the plugin,
    97  // allowing tests to later assert that the plugin either has or hasn't been
    98  // restarted.
    99  func WriteInternalValue(t *testing.T, client *api.Client, mountPath, value string) {
   100  	t.Helper()
   101  	resp, err := client.Logical().Write(fmt.Sprintf("%s/internal", mountPath), map[string]interface{}{
   102  		"value": value,
   103  	})
   104  	if err != nil {
   105  		t.Fatalf("err: %v", err)
   106  	}
   107  	if resp != nil {
   108  		t.Fatalf("bad: %v", resp)
   109  	}
   110  }
   111  
   112  // ExpectInternalValue checks the internal in-memory value.
   113  func ExpectInternalValue(t *testing.T, client *api.Client, mountPath, expected string) {
   114  	t.Helper()
   115  	expectInternalValue(t, client, mountPath, expected)
   116  }
   117  
   118  func expectInternalValue(t *testing.T, client *api.Client, mountPath, expected string) {
   119  	t.Helper()
   120  	resp, err := client.Logical().Read(fmt.Sprintf("%s/internal", mountPath))
   121  	if err != nil {
   122  		t.Fatalf("err: %v", err)
   123  	}
   124  	if resp == nil {
   125  		t.Fatalf("bad: response should not be nil")
   126  	}
   127  	if resp.Data["value"].(string) != expected {
   128  		t.Fatalf("expected %q but got %q", expected, resp.Data["value"].(string))
   129  	}
   130  }