github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/plugin_client_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "context" 8 "errors" 9 "reflect" 10 "testing" 11 "time" 12 13 log "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" 15 "github.com/hashicorp/vault/sdk/helper/consts" 16 "github.com/hashicorp/vault/sdk/helper/pluginutil" 17 "github.com/hashicorp/vault/sdk/helper/wrapping" 18 "github.com/hashicorp/vault/sdk/logical" 19 "github.com/stretchr/testify/mock" 20 "google.golang.org/grpc" 21 ) 22 23 func TestNewPluginClient(t *testing.T) { 24 type testCase struct { 25 config pluginutil.PluginClientConfig 26 pluginClient pluginutil.PluginClient 27 expectedResp *DatabasePluginClient 28 expectedErr error 29 } 30 31 tests := map[string]testCase{ 32 "happy path": { 33 config: testPluginClientConfig(), 34 pluginClient: &fakePluginClient{ 35 connResp: nil, 36 dispenseResp: gRPCClient{client: fakeClient{}}, 37 dispenseErr: nil, 38 }, 39 expectedResp: &DatabasePluginClient{ 40 client: &fakePluginClient{ 41 connResp: nil, 42 dispenseResp: gRPCClient{client: fakeClient{}}, 43 dispenseErr: nil, 44 }, 45 Database: gRPCClient{client: proto.NewDatabaseClient(nil), versionClient: logical.NewPluginVersionClient(nil), doneCtx: context.Context(nil)}, 46 }, 47 expectedErr: nil, 48 }, 49 "dispense error": { 50 config: testPluginClientConfig(), 51 pluginClient: &fakePluginClient{ 52 connResp: nil, 53 dispenseResp: gRPCClient{}, 54 dispenseErr: errors.New("dispense error"), 55 }, 56 expectedResp: nil, 57 expectedErr: errors.New("dispense error"), 58 }, 59 "error unsupported client type": { 60 config: testPluginClientConfig(), 61 pluginClient: &fakePluginClient{ 62 connResp: nil, 63 dispenseResp: nil, 64 dispenseErr: nil, 65 }, 66 expectedResp: nil, 67 expectedErr: errors.New("unsupported client type"), 68 }, 69 } 70 71 for name, test := range tests { 72 t.Run(name, func(t *testing.T) { 73 ctx := context.Background() 74 75 mockWrapper := new(mockRunnerUtil) 76 mockWrapper.On("NewPluginClient", ctx, mock.Anything). 77 Return(test.pluginClient, nil) 78 defer mockWrapper.AssertNumberOfCalls(t, "NewPluginClient", 1) 79 80 resp, err := NewPluginClient(ctx, mockWrapper, test.config) 81 if test.expectedErr != nil && err == nil { 82 t.Fatalf("err expected, got nil") 83 } 84 if test.expectedErr == nil && err != nil { 85 t.Fatalf("no error expected, got: %s", err) 86 } 87 if test.expectedErr == nil && !reflect.DeepEqual(resp, test.expectedResp) { 88 t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) 89 } 90 }) 91 } 92 } 93 94 func testPluginClientConfig() pluginutil.PluginClientConfig { 95 return pluginutil.PluginClientConfig{ 96 Name: "test-plugin", 97 PluginSets: PluginSets, 98 PluginType: consts.PluginTypeDatabase, 99 HandshakeConfig: HandshakeConfig, 100 Logger: log.NewNullLogger(), 101 IsMetadataMode: true, 102 AutoMTLS: true, 103 } 104 } 105 106 var _ pluginutil.PluginClient = &fakePluginClient{} 107 108 type fakePluginClient struct { 109 connResp grpc.ClientConnInterface 110 111 dispenseResp interface{} 112 dispenseErr error 113 } 114 115 func (f *fakePluginClient) Conn() grpc.ClientConnInterface { 116 return nil 117 } 118 119 func (f *fakePluginClient) Reload() error { 120 return nil 121 } 122 123 func (f *fakePluginClient) Dispense(name string) (interface{}, error) { 124 return f.dispenseResp, f.dispenseErr 125 } 126 127 func (f *fakePluginClient) Ping() error { 128 return nil 129 } 130 131 func (f *fakePluginClient) Close() error { 132 return nil 133 } 134 135 var _ pluginutil.RunnerUtil = &mockRunnerUtil{} 136 137 type mockRunnerUtil struct { 138 mock.Mock 139 } 140 141 func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) { 142 return "dummyversion", nil 143 } 144 145 func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) { 146 args := m.Called(ctx, config) 147 return args.Get(0).(pluginutil.PluginClient), args.Error(1) 148 } 149 150 func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { 151 args := m.Called(ctx, data, ttl, jwt) 152 return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1) 153 } 154 155 func (m *mockRunnerUtil) MlockEnabled() bool { 156 args := m.Called() 157 return args.Bool(0) 158 } 159 160 func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) { 161 return "clusterid", nil 162 }