github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/grpc_client_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"reflect"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
    15  	"google.golang.org/grpc"
    16  )
    17  
    18  func TestGRPCClient_Initialize(t *testing.T) {
    19  	type testCase struct {
    20  		client       proto.DatabaseClient
    21  		req          InitializeRequest
    22  		expectedResp InitializeResponse
    23  		assertErr    errorAssertion
    24  	}
    25  
    26  	tests := map[string]testCase{
    27  		"bad config": {
    28  			client: fakeClient{},
    29  			req: InitializeRequest{
    30  				Config: map[string]interface{}{
    31  					"foo": badJSONValue{},
    32  				},
    33  			},
    34  			assertErr: assertErrNotNil,
    35  		},
    36  		"database error": {
    37  			client: fakeClient{
    38  				initErr: errors.New("initialize error"),
    39  			},
    40  			req: InitializeRequest{
    41  				Config: map[string]interface{}{
    42  					"foo": "bar",
    43  				},
    44  			},
    45  			assertErr: assertErrNotNil,
    46  		},
    47  		"happy path": {
    48  			client: fakeClient{
    49  				initResp: &proto.InitializeResponse{
    50  					ConfigData: marshal(t, map[string]interface{}{
    51  						"foo": "bar",
    52  						"baz": "biz",
    53  					}),
    54  				},
    55  			},
    56  			req: InitializeRequest{
    57  				Config: map[string]interface{}{
    58  					"foo": "bar",
    59  				},
    60  			},
    61  			expectedResp: InitializeResponse{
    62  				Config: map[string]interface{}{
    63  					"foo": "bar",
    64  					"baz": "biz",
    65  				},
    66  			},
    67  			assertErr: assertErrNil,
    68  		},
    69  		"JSON number type in initialize request": {
    70  			client: fakeClient{
    71  				initResp: &proto.InitializeResponse{
    72  					ConfigData: marshal(t, map[string]interface{}{
    73  						"foo": "bar",
    74  						"max": "10",
    75  					}),
    76  				},
    77  			},
    78  			req: InitializeRequest{
    79  				Config: map[string]interface{}{
    80  					"foo": "bar",
    81  					"max": json.Number("10"),
    82  				},
    83  			},
    84  			expectedResp: InitializeResponse{
    85  				Config: map[string]interface{}{
    86  					"foo": "bar",
    87  					"max": "10",
    88  				},
    89  			},
    90  			assertErr: assertErrNil,
    91  		},
    92  	}
    93  
    94  	for name, test := range tests {
    95  		t.Run(name, func(t *testing.T) {
    96  			c := gRPCClient{
    97  				client:  test.client,
    98  				doneCtx: nil,
    99  			}
   100  
   101  			// Context doesn't need to timeout since this is just passed through
   102  			ctx := context.Background()
   103  
   104  			resp, err := c.Initialize(ctx, test.req)
   105  			test.assertErr(t, err)
   106  
   107  			if !reflect.DeepEqual(resp, test.expectedResp) {
   108  				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
   109  			}
   110  		})
   111  	}
   112  }
   113  
   114  func TestGRPCClient_NewUser(t *testing.T) {
   115  	runningCtx := context.Background()
   116  	cancelledCtx, cancel := context.WithCancel(context.Background())
   117  	cancel()
   118  
   119  	type testCase struct {
   120  		client       proto.DatabaseClient
   121  		req          NewUserRequest
   122  		doneCtx      context.Context
   123  		expectedResp NewUserResponse
   124  		assertErr    errorAssertion
   125  	}
   126  
   127  	tests := map[string]testCase{
   128  		"missing password": {
   129  			client: fakeClient{},
   130  			req: NewUserRequest{
   131  				Password:   "",
   132  				Expiration: time.Now(),
   133  			},
   134  			doneCtx:   runningCtx,
   135  			assertErr: assertErrNotNil,
   136  		},
   137  		"bad expiration": {
   138  			client: fakeClient{},
   139  			req: NewUserRequest{
   140  				Password:   "njkvcb8y934u90grsnkjl",
   141  				Expiration: invalidExpiration,
   142  			},
   143  			doneCtx:   runningCtx,
   144  			assertErr: assertErrNotNil,
   145  		},
   146  		"database error": {
   147  			client: fakeClient{
   148  				newUserErr: errors.New("new user error"),
   149  			},
   150  			req: NewUserRequest{
   151  				Password:   "njkvcb8y934u90grsnkjl",
   152  				Expiration: time.Now(),
   153  			},
   154  			doneCtx:   runningCtx,
   155  			assertErr: assertErrNotNil,
   156  		},
   157  		"plugin shut down": {
   158  			client: fakeClient{
   159  				newUserErr: errors.New("new user error"),
   160  			},
   161  			req: NewUserRequest{
   162  				Password:   "njkvcb8y934u90grsnkjl",
   163  				Expiration: time.Now(),
   164  			},
   165  			doneCtx:   cancelledCtx,
   166  			assertErr: assertErrEquals(ErrPluginShutdown),
   167  		},
   168  		"happy path": {
   169  			client: fakeClient{
   170  				newUserResp: &proto.NewUserResponse{
   171  					Username: "new_user",
   172  				},
   173  			},
   174  			req: NewUserRequest{
   175  				Password:   "njkvcb8y934u90grsnkjl",
   176  				Expiration: time.Now(),
   177  			},
   178  			doneCtx: runningCtx,
   179  			expectedResp: NewUserResponse{
   180  				Username: "new_user",
   181  			},
   182  			assertErr: assertErrNil,
   183  		},
   184  	}
   185  
   186  	for name, test := range tests {
   187  		t.Run(name, func(t *testing.T) {
   188  			c := gRPCClient{
   189  				client:  test.client,
   190  				doneCtx: test.doneCtx,
   191  			}
   192  
   193  			ctx := context.Background()
   194  
   195  			resp, err := c.NewUser(ctx, test.req)
   196  			test.assertErr(t, err)
   197  
   198  			if !reflect.DeepEqual(resp, test.expectedResp) {
   199  				t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp)
   200  			}
   201  		})
   202  	}
   203  }
   204  
   205  func TestGRPCClient_UpdateUser(t *testing.T) {
   206  	runningCtx := context.Background()
   207  	cancelledCtx, cancel := context.WithCancel(context.Background())
   208  	cancel()
   209  
   210  	type testCase struct {
   211  		client    proto.DatabaseClient
   212  		req       UpdateUserRequest
   213  		doneCtx   context.Context
   214  		assertErr errorAssertion
   215  	}
   216  
   217  	tests := map[string]testCase{
   218  		"missing username": {
   219  			client:    fakeClient{},
   220  			req:       UpdateUserRequest{},
   221  			doneCtx:   runningCtx,
   222  			assertErr: assertErrNotNil,
   223  		},
   224  		"missing changes": {
   225  			client: fakeClient{},
   226  			req: UpdateUserRequest{
   227  				Username: "user",
   228  			},
   229  			doneCtx:   runningCtx,
   230  			assertErr: assertErrNotNil,
   231  		},
   232  		"empty password": {
   233  			client: fakeClient{},
   234  			req: UpdateUserRequest{
   235  				Username: "user",
   236  				Password: &ChangePassword{
   237  					NewPassword: "",
   238  				},
   239  			},
   240  			doneCtx:   runningCtx,
   241  			assertErr: assertErrNotNil,
   242  		},
   243  		"zero expiration": {
   244  			client: fakeClient{},
   245  			req: UpdateUserRequest{
   246  				Username: "user",
   247  				Expiration: &ChangeExpiration{
   248  					NewExpiration: time.Time{},
   249  				},
   250  			},
   251  			doneCtx:   runningCtx,
   252  			assertErr: assertErrNotNil,
   253  		},
   254  		"bad expiration": {
   255  			client: fakeClient{},
   256  			req: UpdateUserRequest{
   257  				Username: "user",
   258  				Expiration: &ChangeExpiration{
   259  					NewExpiration: invalidExpiration,
   260  				},
   261  			},
   262  			doneCtx:   runningCtx,
   263  			assertErr: assertErrNotNil,
   264  		},
   265  		"database error": {
   266  			client: fakeClient{
   267  				updateUserErr: errors.New("update user error"),
   268  			},
   269  			req: UpdateUserRequest{
   270  				Username: "user",
   271  				Password: &ChangePassword{
   272  					NewPassword: "asdf",
   273  				},
   274  			},
   275  			doneCtx:   runningCtx,
   276  			assertErr: assertErrNotNil,
   277  		},
   278  		"plugin shut down": {
   279  			client: fakeClient{
   280  				updateUserErr: errors.New("update user error"),
   281  			},
   282  			req: UpdateUserRequest{
   283  				Username: "user",
   284  				Password: &ChangePassword{
   285  					NewPassword: "asdf",
   286  				},
   287  			},
   288  			doneCtx:   cancelledCtx,
   289  			assertErr: assertErrEquals(ErrPluginShutdown),
   290  		},
   291  		"happy path - change password": {
   292  			client: fakeClient{},
   293  			req: UpdateUserRequest{
   294  				Username: "user",
   295  				Password: &ChangePassword{
   296  					NewPassword: "asdf",
   297  				},
   298  			},
   299  			doneCtx:   runningCtx,
   300  			assertErr: assertErrNil,
   301  		},
   302  		"happy path - change expiration": {
   303  			client: fakeClient{},
   304  			req: UpdateUserRequest{
   305  				Username: "user",
   306  				Expiration: &ChangeExpiration{
   307  					NewExpiration: time.Now(),
   308  				},
   309  			},
   310  			doneCtx:   runningCtx,
   311  			assertErr: assertErrNil,
   312  		},
   313  	}
   314  
   315  	for name, test := range tests {
   316  		t.Run(name, func(t *testing.T) {
   317  			c := gRPCClient{
   318  				client:  test.client,
   319  				doneCtx: test.doneCtx,
   320  			}
   321  
   322  			ctx := context.Background()
   323  
   324  			_, err := c.UpdateUser(ctx, test.req)
   325  			test.assertErr(t, err)
   326  		})
   327  	}
   328  }
   329  
   330  func TestGRPCClient_DeleteUser(t *testing.T) {
   331  	runningCtx := context.Background()
   332  	cancelledCtx, cancel := context.WithCancel(context.Background())
   333  	cancel()
   334  
   335  	type testCase struct {
   336  		client    proto.DatabaseClient
   337  		req       DeleteUserRequest
   338  		doneCtx   context.Context
   339  		assertErr errorAssertion
   340  	}
   341  
   342  	tests := map[string]testCase{
   343  		"missing username": {
   344  			client:    fakeClient{},
   345  			req:       DeleteUserRequest{},
   346  			doneCtx:   runningCtx,
   347  			assertErr: assertErrNotNil,
   348  		},
   349  		"database error": {
   350  			client: fakeClient{
   351  				deleteUserErr: errors.New("delete user error'"),
   352  			},
   353  			req: DeleteUserRequest{
   354  				Username: "user",
   355  			},
   356  			doneCtx:   runningCtx,
   357  			assertErr: assertErrNotNil,
   358  		},
   359  		"plugin shut down": {
   360  			client: fakeClient{
   361  				deleteUserErr: errors.New("delete user error'"),
   362  			},
   363  			req: DeleteUserRequest{
   364  				Username: "user",
   365  			},
   366  			doneCtx:   cancelledCtx,
   367  			assertErr: assertErrEquals(ErrPluginShutdown),
   368  		},
   369  		"happy path": {
   370  			client: fakeClient{},
   371  			req: DeleteUserRequest{
   372  				Username: "user",
   373  			},
   374  			doneCtx:   runningCtx,
   375  			assertErr: assertErrNil,
   376  		},
   377  	}
   378  
   379  	for name, test := range tests {
   380  		t.Run(name, func(t *testing.T) {
   381  			c := gRPCClient{
   382  				client:  test.client,
   383  				doneCtx: test.doneCtx,
   384  			}
   385  
   386  			ctx := context.Background()
   387  
   388  			_, err := c.DeleteUser(ctx, test.req)
   389  			test.assertErr(t, err)
   390  		})
   391  	}
   392  }
   393  
   394  func TestGRPCClient_Type(t *testing.T) {
   395  	runningCtx := context.Background()
   396  	cancelledCtx, cancel := context.WithCancel(context.Background())
   397  	cancel()
   398  
   399  	type testCase struct {
   400  		client       proto.DatabaseClient
   401  		doneCtx      context.Context
   402  		expectedType string
   403  		assertErr    errorAssertion
   404  	}
   405  
   406  	tests := map[string]testCase{
   407  		"database error": {
   408  			client: fakeClient{
   409  				typeErr: errors.New("type error"),
   410  			},
   411  			doneCtx:   runningCtx,
   412  			assertErr: assertErrNotNil,
   413  		},
   414  		"plugin shut down": {
   415  			client: fakeClient{
   416  				typeErr: errors.New("type error"),
   417  			},
   418  			doneCtx:   cancelledCtx,
   419  			assertErr: assertErrEquals(ErrPluginShutdown),
   420  		},
   421  		"happy path": {
   422  			client: fakeClient{
   423  				typeResp: &proto.TypeResponse{
   424  					Type: "test type",
   425  				},
   426  			},
   427  			doneCtx:      runningCtx,
   428  			expectedType: "test type",
   429  			assertErr:    assertErrNil,
   430  		},
   431  	}
   432  
   433  	for name, test := range tests {
   434  		t.Run(name, func(t *testing.T) {
   435  			c := gRPCClient{
   436  				client:  test.client,
   437  				doneCtx: test.doneCtx,
   438  			}
   439  
   440  			dbType, err := c.Type()
   441  			test.assertErr(t, err)
   442  
   443  			if dbType != test.expectedType {
   444  				t.Fatalf("Actual type: %s Expected type: %s", dbType, test.expectedType)
   445  			}
   446  		})
   447  	}
   448  }
   449  
   450  func TestGRPCClient_Close(t *testing.T) {
   451  	runningCtx := context.Background()
   452  	cancelledCtx, cancel := context.WithCancel(context.Background())
   453  	cancel()
   454  
   455  	type testCase struct {
   456  		client    proto.DatabaseClient
   457  		doneCtx   context.Context
   458  		assertErr errorAssertion
   459  	}
   460  
   461  	tests := map[string]testCase{
   462  		"database error": {
   463  			client: fakeClient{
   464  				typeErr: errors.New("type error"),
   465  			},
   466  			doneCtx:   runningCtx,
   467  			assertErr: assertErrNotNil,
   468  		},
   469  		"plugin shut down": {
   470  			client: fakeClient{
   471  				typeErr: errors.New("type error"),
   472  			},
   473  			doneCtx:   cancelledCtx,
   474  			assertErr: assertErrEquals(ErrPluginShutdown),
   475  		},
   476  		"happy path": {
   477  			client:    fakeClient{},
   478  			doneCtx:   runningCtx,
   479  			assertErr: assertErrNil,
   480  		},
   481  	}
   482  
   483  	for name, test := range tests {
   484  		t.Run(name, func(t *testing.T) {
   485  			c := gRPCClient{
   486  				client:  test.client,
   487  				doneCtx: test.doneCtx,
   488  			}
   489  
   490  			err := c.Close()
   491  			test.assertErr(t, err)
   492  		})
   493  	}
   494  }
   495  
   496  type errorAssertion func(*testing.T, error)
   497  
   498  func assertErrNotNil(t *testing.T, err error) {
   499  	t.Helper()
   500  	if err == nil {
   501  		t.Fatalf("err expected, got nil")
   502  	}
   503  }
   504  
   505  func assertErrNil(t *testing.T, err error) {
   506  	t.Helper()
   507  	if err != nil {
   508  		t.Fatalf("no error expected, got: %s", err)
   509  	}
   510  }
   511  
   512  func assertErrEquals(expectedErr error) errorAssertion {
   513  	return func(t *testing.T, err error) {
   514  		t.Helper()
   515  		if err != expectedErr {
   516  			t.Fatalf("Actual err: %#v Expected err: %#v", err, expectedErr)
   517  		}
   518  	}
   519  }
   520  
   521  var _ proto.DatabaseClient = fakeClient{}
   522  
   523  type fakeClient struct {
   524  	initResp *proto.InitializeResponse
   525  	initErr  error
   526  
   527  	newUserResp *proto.NewUserResponse
   528  	newUserErr  error
   529  
   530  	updateUserResp *proto.UpdateUserResponse
   531  	updateUserErr  error
   532  
   533  	deleteUserResp *proto.DeleteUserResponse
   534  	deleteUserErr  error
   535  
   536  	typeResp *proto.TypeResponse
   537  	typeErr  error
   538  
   539  	closeErr error
   540  }
   541  
   542  func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) {
   543  	return f.initResp, f.initErr
   544  }
   545  
   546  func (f fakeClient) NewUser(context.Context, *proto.NewUserRequest, ...grpc.CallOption) (*proto.NewUserResponse, error) {
   547  	return f.newUserResp, f.newUserErr
   548  }
   549  
   550  func (f fakeClient) UpdateUser(context.Context, *proto.UpdateUserRequest, ...grpc.CallOption) (*proto.UpdateUserResponse, error) {
   551  	return f.updateUserResp, f.updateUserErr
   552  }
   553  
   554  func (f fakeClient) DeleteUser(context.Context, *proto.DeleteUserRequest, ...grpc.CallOption) (*proto.DeleteUserResponse, error) {
   555  	return f.deleteUserResp, f.deleteUserErr
   556  }
   557  
   558  func (f fakeClient) Type(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.TypeResponse, error) {
   559  	return f.typeResp, f.typeErr
   560  }
   561  
   562  func (f fakeClient) Close(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.Empty, error) {
   563  	return &proto.Empty{}, f.typeErr
   564  }