github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/grpc_client_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "context" 8 "encoding/json" 9 "errors" 10 "reflect" 11 "testing" 12 "time" 13 14 "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" 15 "google.golang.org/grpc" 16 ) 17 18 func TestGRPCClient_Initialize(t *testing.T) { 19 type testCase struct { 20 client proto.DatabaseClient 21 req InitializeRequest 22 expectedResp InitializeResponse 23 assertErr errorAssertion 24 } 25 26 tests := map[string]testCase{ 27 "bad config": { 28 client: fakeClient{}, 29 req: InitializeRequest{ 30 Config: map[string]interface{}{ 31 "foo": badJSONValue{}, 32 }, 33 }, 34 assertErr: assertErrNotNil, 35 }, 36 "database error": { 37 client: fakeClient{ 38 initErr: errors.New("initialize error"), 39 }, 40 req: InitializeRequest{ 41 Config: map[string]interface{}{ 42 "foo": "bar", 43 }, 44 }, 45 assertErr: assertErrNotNil, 46 }, 47 "happy path": { 48 client: fakeClient{ 49 initResp: &proto.InitializeResponse{ 50 ConfigData: marshal(t, map[string]interface{}{ 51 "foo": "bar", 52 "baz": "biz", 53 }), 54 }, 55 }, 56 req: InitializeRequest{ 57 Config: map[string]interface{}{ 58 "foo": "bar", 59 }, 60 }, 61 expectedResp: InitializeResponse{ 62 Config: map[string]interface{}{ 63 "foo": "bar", 64 "baz": "biz", 65 }, 66 }, 67 assertErr: assertErrNil, 68 }, 69 "JSON number type in initialize request": { 70 client: fakeClient{ 71 initResp: &proto.InitializeResponse{ 72 ConfigData: marshal(t, map[string]interface{}{ 73 "foo": "bar", 74 "max": "10", 75 }), 76 }, 77 }, 78 req: InitializeRequest{ 79 Config: map[string]interface{}{ 80 "foo": "bar", 81 "max": json.Number("10"), 82 }, 83 }, 84 expectedResp: InitializeResponse{ 85 Config: map[string]interface{}{ 86 "foo": "bar", 87 "max": "10", 88 }, 89 }, 90 assertErr: assertErrNil, 91 }, 92 } 93 94 for name, test := range tests { 95 t.Run(name, func(t *testing.T) { 96 c := gRPCClient{ 97 client: test.client, 98 doneCtx: nil, 99 } 100 101 // Context doesn't need to timeout since this is just passed through 102 ctx := context.Background() 103 104 resp, err := c.Initialize(ctx, test.req) 105 test.assertErr(t, err) 106 107 if !reflect.DeepEqual(resp, test.expectedResp) { 108 t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) 109 } 110 }) 111 } 112 } 113 114 func TestGRPCClient_NewUser(t *testing.T) { 115 runningCtx := context.Background() 116 cancelledCtx, cancel := context.WithCancel(context.Background()) 117 cancel() 118 119 type testCase struct { 120 client proto.DatabaseClient 121 req NewUserRequest 122 doneCtx context.Context 123 expectedResp NewUserResponse 124 assertErr errorAssertion 125 } 126 127 tests := map[string]testCase{ 128 "missing password": { 129 client: fakeClient{}, 130 req: NewUserRequest{ 131 Password: "", 132 Expiration: time.Now(), 133 }, 134 doneCtx: runningCtx, 135 assertErr: assertErrNotNil, 136 }, 137 "bad expiration": { 138 client: fakeClient{}, 139 req: NewUserRequest{ 140 Password: "njkvcb8y934u90grsnkjl", 141 Expiration: invalidExpiration, 142 }, 143 doneCtx: runningCtx, 144 assertErr: assertErrNotNil, 145 }, 146 "database error": { 147 client: fakeClient{ 148 newUserErr: errors.New("new user error"), 149 }, 150 req: NewUserRequest{ 151 Password: "njkvcb8y934u90grsnkjl", 152 Expiration: time.Now(), 153 }, 154 doneCtx: runningCtx, 155 assertErr: assertErrNotNil, 156 }, 157 "plugin shut down": { 158 client: fakeClient{ 159 newUserErr: errors.New("new user error"), 160 }, 161 req: NewUserRequest{ 162 Password: "njkvcb8y934u90grsnkjl", 163 Expiration: time.Now(), 164 }, 165 doneCtx: cancelledCtx, 166 assertErr: assertErrEquals(ErrPluginShutdown), 167 }, 168 "happy path": { 169 client: fakeClient{ 170 newUserResp: &proto.NewUserResponse{ 171 Username: "new_user", 172 }, 173 }, 174 req: NewUserRequest{ 175 Password: "njkvcb8y934u90grsnkjl", 176 Expiration: time.Now(), 177 }, 178 doneCtx: runningCtx, 179 expectedResp: NewUserResponse{ 180 Username: "new_user", 181 }, 182 assertErr: assertErrNil, 183 }, 184 } 185 186 for name, test := range tests { 187 t.Run(name, func(t *testing.T) { 188 c := gRPCClient{ 189 client: test.client, 190 doneCtx: test.doneCtx, 191 } 192 193 ctx := context.Background() 194 195 resp, err := c.NewUser(ctx, test.req) 196 test.assertErr(t, err) 197 198 if !reflect.DeepEqual(resp, test.expectedResp) { 199 t.Fatalf("Actual response: %#v\nExpected response: %#v", resp, test.expectedResp) 200 } 201 }) 202 } 203 } 204 205 func TestGRPCClient_UpdateUser(t *testing.T) { 206 runningCtx := context.Background() 207 cancelledCtx, cancel := context.WithCancel(context.Background()) 208 cancel() 209 210 type testCase struct { 211 client proto.DatabaseClient 212 req UpdateUserRequest 213 doneCtx context.Context 214 assertErr errorAssertion 215 } 216 217 tests := map[string]testCase{ 218 "missing username": { 219 client: fakeClient{}, 220 req: UpdateUserRequest{}, 221 doneCtx: runningCtx, 222 assertErr: assertErrNotNil, 223 }, 224 "missing changes": { 225 client: fakeClient{}, 226 req: UpdateUserRequest{ 227 Username: "user", 228 }, 229 doneCtx: runningCtx, 230 assertErr: assertErrNotNil, 231 }, 232 "empty password": { 233 client: fakeClient{}, 234 req: UpdateUserRequest{ 235 Username: "user", 236 Password: &ChangePassword{ 237 NewPassword: "", 238 }, 239 }, 240 doneCtx: runningCtx, 241 assertErr: assertErrNotNil, 242 }, 243 "zero expiration": { 244 client: fakeClient{}, 245 req: UpdateUserRequest{ 246 Username: "user", 247 Expiration: &ChangeExpiration{ 248 NewExpiration: time.Time{}, 249 }, 250 }, 251 doneCtx: runningCtx, 252 assertErr: assertErrNotNil, 253 }, 254 "bad expiration": { 255 client: fakeClient{}, 256 req: UpdateUserRequest{ 257 Username: "user", 258 Expiration: &ChangeExpiration{ 259 NewExpiration: invalidExpiration, 260 }, 261 }, 262 doneCtx: runningCtx, 263 assertErr: assertErrNotNil, 264 }, 265 "database error": { 266 client: fakeClient{ 267 updateUserErr: errors.New("update user error"), 268 }, 269 req: UpdateUserRequest{ 270 Username: "user", 271 Password: &ChangePassword{ 272 NewPassword: "asdf", 273 }, 274 }, 275 doneCtx: runningCtx, 276 assertErr: assertErrNotNil, 277 }, 278 "plugin shut down": { 279 client: fakeClient{ 280 updateUserErr: errors.New("update user error"), 281 }, 282 req: UpdateUserRequest{ 283 Username: "user", 284 Password: &ChangePassword{ 285 NewPassword: "asdf", 286 }, 287 }, 288 doneCtx: cancelledCtx, 289 assertErr: assertErrEquals(ErrPluginShutdown), 290 }, 291 "happy path - change password": { 292 client: fakeClient{}, 293 req: UpdateUserRequest{ 294 Username: "user", 295 Password: &ChangePassword{ 296 NewPassword: "asdf", 297 }, 298 }, 299 doneCtx: runningCtx, 300 assertErr: assertErrNil, 301 }, 302 "happy path - change expiration": { 303 client: fakeClient{}, 304 req: UpdateUserRequest{ 305 Username: "user", 306 Expiration: &ChangeExpiration{ 307 NewExpiration: time.Now(), 308 }, 309 }, 310 doneCtx: runningCtx, 311 assertErr: assertErrNil, 312 }, 313 } 314 315 for name, test := range tests { 316 t.Run(name, func(t *testing.T) { 317 c := gRPCClient{ 318 client: test.client, 319 doneCtx: test.doneCtx, 320 } 321 322 ctx := context.Background() 323 324 _, err := c.UpdateUser(ctx, test.req) 325 test.assertErr(t, err) 326 }) 327 } 328 } 329 330 func TestGRPCClient_DeleteUser(t *testing.T) { 331 runningCtx := context.Background() 332 cancelledCtx, cancel := context.WithCancel(context.Background()) 333 cancel() 334 335 type testCase struct { 336 client proto.DatabaseClient 337 req DeleteUserRequest 338 doneCtx context.Context 339 assertErr errorAssertion 340 } 341 342 tests := map[string]testCase{ 343 "missing username": { 344 client: fakeClient{}, 345 req: DeleteUserRequest{}, 346 doneCtx: runningCtx, 347 assertErr: assertErrNotNil, 348 }, 349 "database error": { 350 client: fakeClient{ 351 deleteUserErr: errors.New("delete user error'"), 352 }, 353 req: DeleteUserRequest{ 354 Username: "user", 355 }, 356 doneCtx: runningCtx, 357 assertErr: assertErrNotNil, 358 }, 359 "plugin shut down": { 360 client: fakeClient{ 361 deleteUserErr: errors.New("delete user error'"), 362 }, 363 req: DeleteUserRequest{ 364 Username: "user", 365 }, 366 doneCtx: cancelledCtx, 367 assertErr: assertErrEquals(ErrPluginShutdown), 368 }, 369 "happy path": { 370 client: fakeClient{}, 371 req: DeleteUserRequest{ 372 Username: "user", 373 }, 374 doneCtx: runningCtx, 375 assertErr: assertErrNil, 376 }, 377 } 378 379 for name, test := range tests { 380 t.Run(name, func(t *testing.T) { 381 c := gRPCClient{ 382 client: test.client, 383 doneCtx: test.doneCtx, 384 } 385 386 ctx := context.Background() 387 388 _, err := c.DeleteUser(ctx, test.req) 389 test.assertErr(t, err) 390 }) 391 } 392 } 393 394 func TestGRPCClient_Type(t *testing.T) { 395 runningCtx := context.Background() 396 cancelledCtx, cancel := context.WithCancel(context.Background()) 397 cancel() 398 399 type testCase struct { 400 client proto.DatabaseClient 401 doneCtx context.Context 402 expectedType string 403 assertErr errorAssertion 404 } 405 406 tests := map[string]testCase{ 407 "database error": { 408 client: fakeClient{ 409 typeErr: errors.New("type error"), 410 }, 411 doneCtx: runningCtx, 412 assertErr: assertErrNotNil, 413 }, 414 "plugin shut down": { 415 client: fakeClient{ 416 typeErr: errors.New("type error"), 417 }, 418 doneCtx: cancelledCtx, 419 assertErr: assertErrEquals(ErrPluginShutdown), 420 }, 421 "happy path": { 422 client: fakeClient{ 423 typeResp: &proto.TypeResponse{ 424 Type: "test type", 425 }, 426 }, 427 doneCtx: runningCtx, 428 expectedType: "test type", 429 assertErr: assertErrNil, 430 }, 431 } 432 433 for name, test := range tests { 434 t.Run(name, func(t *testing.T) { 435 c := gRPCClient{ 436 client: test.client, 437 doneCtx: test.doneCtx, 438 } 439 440 dbType, err := c.Type() 441 test.assertErr(t, err) 442 443 if dbType != test.expectedType { 444 t.Fatalf("Actual type: %s Expected type: %s", dbType, test.expectedType) 445 } 446 }) 447 } 448 } 449 450 func TestGRPCClient_Close(t *testing.T) { 451 runningCtx := context.Background() 452 cancelledCtx, cancel := context.WithCancel(context.Background()) 453 cancel() 454 455 type testCase struct { 456 client proto.DatabaseClient 457 doneCtx context.Context 458 assertErr errorAssertion 459 } 460 461 tests := map[string]testCase{ 462 "database error": { 463 client: fakeClient{ 464 typeErr: errors.New("type error"), 465 }, 466 doneCtx: runningCtx, 467 assertErr: assertErrNotNil, 468 }, 469 "plugin shut down": { 470 client: fakeClient{ 471 typeErr: errors.New("type error"), 472 }, 473 doneCtx: cancelledCtx, 474 assertErr: assertErrEquals(ErrPluginShutdown), 475 }, 476 "happy path": { 477 client: fakeClient{}, 478 doneCtx: runningCtx, 479 assertErr: assertErrNil, 480 }, 481 } 482 483 for name, test := range tests { 484 t.Run(name, func(t *testing.T) { 485 c := gRPCClient{ 486 client: test.client, 487 doneCtx: test.doneCtx, 488 } 489 490 err := c.Close() 491 test.assertErr(t, err) 492 }) 493 } 494 } 495 496 type errorAssertion func(*testing.T, error) 497 498 func assertErrNotNil(t *testing.T, err error) { 499 t.Helper() 500 if err == nil { 501 t.Fatalf("err expected, got nil") 502 } 503 } 504 505 func assertErrNil(t *testing.T, err error) { 506 t.Helper() 507 if err != nil { 508 t.Fatalf("no error expected, got: %s", err) 509 } 510 } 511 512 func assertErrEquals(expectedErr error) errorAssertion { 513 return func(t *testing.T, err error) { 514 t.Helper() 515 if err != expectedErr { 516 t.Fatalf("Actual err: %#v Expected err: %#v", err, expectedErr) 517 } 518 } 519 } 520 521 var _ proto.DatabaseClient = fakeClient{} 522 523 type fakeClient struct { 524 initResp *proto.InitializeResponse 525 initErr error 526 527 newUserResp *proto.NewUserResponse 528 newUserErr error 529 530 updateUserResp *proto.UpdateUserResponse 531 updateUserErr error 532 533 deleteUserResp *proto.DeleteUserResponse 534 deleteUserErr error 535 536 typeResp *proto.TypeResponse 537 typeErr error 538 539 closeErr error 540 } 541 542 func (f fakeClient) Initialize(context.Context, *proto.InitializeRequest, ...grpc.CallOption) (*proto.InitializeResponse, error) { 543 return f.initResp, f.initErr 544 } 545 546 func (f fakeClient) NewUser(context.Context, *proto.NewUserRequest, ...grpc.CallOption) (*proto.NewUserResponse, error) { 547 return f.newUserResp, f.newUserErr 548 } 549 550 func (f fakeClient) UpdateUser(context.Context, *proto.UpdateUserRequest, ...grpc.CallOption) (*proto.UpdateUserResponse, error) { 551 return f.updateUserResp, f.updateUserErr 552 } 553 554 func (f fakeClient) DeleteUser(context.Context, *proto.DeleteUserRequest, ...grpc.CallOption) (*proto.DeleteUserResponse, error) { 555 return f.deleteUserResp, f.deleteUserErr 556 } 557 558 func (f fakeClient) Type(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.TypeResponse, error) { 559 return f.typeResp, f.typeErr 560 } 561 562 func (f fakeClient) Close(context.Context, *proto.Empty, ...grpc.CallOption) (*proto.Empty, error) { 563 return &proto.Empty{}, f.typeErr 564 }