github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/plugin_client_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"reflect"
    10  	"testing"
    11  	"time"
    12  
    13  	log "github.com/hashicorp/go-hclog"
    14  	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
    15  	"github.com/hashicorp/vault/sdk/helper/consts"
    16  	"github.com/hashicorp/vault/sdk/helper/pluginutil"
    17  	"github.com/hashicorp/vault/sdk/helper/wrapping"
    18  	"github.com/hashicorp/vault/sdk/logical"
    19  	"github.com/stretchr/testify/mock"
    20  	"google.golang.org/grpc"
    21  )
    22  
    23  func TestNewPluginClient(t *testing.T) {
    24  	type testCase struct {
    25  		config       pluginutil.PluginClientConfig
    26  		pluginClient pluginutil.PluginClient
    27  		expectedResp *DatabasePluginClient
    28  		expectedErr  error
    29  	}
    30  
    31  	tests := map[string]testCase{
    32  		"happy path": {
    33  			config: testPluginClientConfig(),
    34  			pluginClient: &fakePluginClient{
    35  				connResp:     nil,
    36  				dispenseResp: gRPCClient{client: fakeClient{}},
    37  				dispenseErr:  nil,
    38  			},
    39  			expectedResp: &DatabasePluginClient{
    40  				client: &fakePluginClient{
    41  					connResp:     nil,
    42  					dispenseResp: gRPCClient{client: fakeClient{}},
    43  					dispenseErr:  nil,
    44  				},
    45  				Database: gRPCClient{client: proto.NewDatabaseClient(nil), versionClient: logical.NewPluginVersionClient(nil), doneCtx: context.Context(nil)},
    46  			},
    47  			expectedErr: nil,
    48  		},
    49  		"dispense error": {
    50  			config: testPluginClientConfig(),
    51  			pluginClient: &fakePluginClient{
    52  				connResp:     nil,
    53  				dispenseResp: gRPCClient{},
    54  				dispenseErr:  errors.New("dispense error"),
    55  			},
    56  			expectedResp: nil,
    57  			expectedErr:  errors.New("dispense error"),
    58  		},
    59  		"error unsupported client type": {
    60  			config: testPluginClientConfig(),
    61  			pluginClient: &fakePluginClient{
    62  				connResp:     nil,
    63  				dispenseResp: nil,
    64  				dispenseErr:  nil,
    65  			},
    66  			expectedResp: nil,
    67  			expectedErr:  errors.New("unsupported client type"),
    68  		},
    69  	}
    70  
    71  	for name, test := range tests {
    72  		t.Run(name, func(t *testing.T) {
    73  			ctx := context.Background()
    74  
    75  			mockWrapper := new(mockRunnerUtil)
    76  			mockWrapper.On("NewPluginClient", ctx, mock.Anything).
    77  				Return(test.pluginClient, nil)
    78  			defer mockWrapper.AssertNumberOfCalls(t, "NewPluginClient", 1)
    79  
    80  			resp, err := NewPluginClient(ctx, mockWrapper, test.config)
    81  			if test.expectedErr != nil && err == nil {
    82  				t.Fatalf("err expected, got nil")
    83  			}
    84  			if test.expectedErr == nil && err != nil {
    85  				t.Fatalf("no error expected, got: %s", err)
    86  			}
    87  			if test.expectedErr == nil && !reflect.DeepEqual(resp, test.expectedResp) {
    88  				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func testPluginClientConfig() pluginutil.PluginClientConfig {
    95  	return pluginutil.PluginClientConfig{
    96  		Name:            "test-plugin",
    97  		PluginSets:      PluginSets,
    98  		PluginType:      consts.PluginTypeDatabase,
    99  		HandshakeConfig: HandshakeConfig,
   100  		Logger:          log.NewNullLogger(),
   101  		IsMetadataMode:  true,
   102  		AutoMTLS:        true,
   103  	}
   104  }
   105  
   106  var _ pluginutil.PluginClient = &fakePluginClient{}
   107  
   108  type fakePluginClient struct {
   109  	connResp grpc.ClientConnInterface
   110  
   111  	dispenseResp interface{}
   112  	dispenseErr  error
   113  }
   114  
   115  func (f *fakePluginClient) Conn() grpc.ClientConnInterface {
   116  	return nil
   117  }
   118  
   119  func (f *fakePluginClient) Reload() error {
   120  	return nil
   121  }
   122  
   123  func (f *fakePluginClient) Dispense(name string) (interface{}, error) {
   124  	return f.dispenseResp, f.dispenseErr
   125  }
   126  
   127  func (f *fakePluginClient) Ping() error {
   128  	return nil
   129  }
   130  
   131  func (f *fakePluginClient) Close() error {
   132  	return nil
   133  }
   134  
   135  var _ pluginutil.RunnerUtil = &mockRunnerUtil{}
   136  
   137  type mockRunnerUtil struct {
   138  	mock.Mock
   139  }
   140  
   141  func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) {
   142  	return "dummyversion", nil
   143  }
   144  
   145  func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config pluginutil.PluginClientConfig) (pluginutil.PluginClient, error) {
   146  	args := m.Called(ctx, config)
   147  	return args.Get(0).(pluginutil.PluginClient), args.Error(1)
   148  }
   149  
   150  func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
   151  	args := m.Called(ctx, data, ttl, jwt)
   152  	return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
   153  }
   154  
   155  func (m *mockRunnerUtil) MlockEnabled() bool {
   156  	args := m.Called()
   157  	return args.Bool(0)
   158  }
   159  
   160  func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) {
   161  	return "clusterid", nil
   162  }