github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/conversions_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "fmt" 8 "reflect" 9 "strings" 10 "testing" 11 "time" 12 "unicode" 13 14 "github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto" 15 "google.golang.org/protobuf/types/known/structpb" 16 "google.golang.org/protobuf/types/known/timestamppb" 17 ) 18 19 func TestConversionsHaveAllFields(t *testing.T) { 20 t.Run("initReqToProto", func(t *testing.T) { 21 req := InitializeRequest{ 22 Config: map[string]interface{}{ 23 "foo": map[string]interface{}{ 24 "bar": "baz", 25 }, 26 }, 27 VerifyConnection: true, 28 } 29 30 protoReq, err := initReqToProto(req) 31 if err != nil { 32 t.Fatalf("Failed to convert request to proto request: %s", err) 33 } 34 35 values := getAllGetterValues(protoReq) 36 if len(values) == 0 { 37 // Probably a test failure - the protos used in these tests should have Get functions on them 38 t.Fatalf("No values found from Get functions!") 39 } 40 41 for _, gtr := range values { 42 err := assertAllFieldsSet(fmt.Sprintf("InitializeRequest.%s", gtr.name), gtr.value) 43 if err != nil { 44 t.Fatalf("%s", err) 45 } 46 } 47 }) 48 49 t.Run("newUserReqToProto", func(t *testing.T) { 50 req := NewUserRequest{ 51 UsernameConfig: UsernameMetadata{ 52 DisplayName: "dispName", 53 RoleName: "roleName", 54 }, 55 Statements: Statements{ 56 Commands: []string{ 57 "statement", 58 }, 59 }, 60 RollbackStatements: Statements{ 61 Commands: []string{ 62 "rollback_statement", 63 }, 64 }, 65 CredentialType: CredentialTypeRSAPrivateKey, 66 PublicKey: []byte("-----BEGIN PUBLIC KEY-----"), 67 Password: "password", 68 Subject: "subject", 69 Expiration: time.Now(), 70 } 71 72 protoReq, err := newUserReqToProto(req) 73 if err != nil { 74 t.Fatalf("Failed to convert request to proto request: %s", err) 75 } 76 77 values := getAllGetterValues(protoReq) 78 if len(values) == 0 { 79 // Probably a test failure - the protos used in these tests should have Get functions on them 80 t.Fatalf("No values found from Get functions!") 81 } 82 83 for _, gtr := range values { 84 err := assertAllFieldsSet(fmt.Sprintf("NewUserRequest.%s", gtr.name), gtr.value) 85 if err != nil { 86 t.Fatalf("%s", err) 87 } 88 } 89 }) 90 91 t.Run("updateUserReqToProto", func(t *testing.T) { 92 req := UpdateUserRequest{ 93 Username: "username", 94 CredentialType: CredentialTypeRSAPrivateKey, 95 Password: &ChangePassword{ 96 NewPassword: "newpassword", 97 Statements: Statements{ 98 Commands: []string{ 99 "statement", 100 }, 101 }, 102 }, 103 PublicKey: &ChangePublicKey{ 104 NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"), 105 Statements: Statements{ 106 Commands: []string{ 107 "statement", 108 }, 109 }, 110 }, 111 Expiration: &ChangeExpiration{ 112 NewExpiration: time.Now(), 113 Statements: Statements{ 114 Commands: []string{ 115 "statement", 116 }, 117 }, 118 }, 119 } 120 121 protoReq, err := updateUserReqToProto(req) 122 if err != nil { 123 t.Fatalf("Failed to convert request to proto request: %s", err) 124 } 125 126 values := getAllGetterValues(protoReq) 127 if len(values) == 0 { 128 // Probably a test failure - the protos used in these tests should have Get functions on them 129 t.Fatalf("No values found from Get functions!") 130 } 131 132 for _, gtr := range values { 133 err := assertAllFieldsSet(fmt.Sprintf("UpdateUserRequest.%s", gtr.name), gtr.value) 134 if err != nil { 135 t.Fatalf("%s", err) 136 } 137 } 138 }) 139 140 t.Run("deleteUserReqToProto", func(t *testing.T) { 141 req := DeleteUserRequest{ 142 Username: "username", 143 Statements: Statements{ 144 Commands: []string{ 145 "statement", 146 }, 147 }, 148 } 149 150 protoReq, err := deleteUserReqToProto(req) 151 if err != nil { 152 t.Fatalf("Failed to convert request to proto request: %s", err) 153 } 154 155 values := getAllGetterValues(protoReq) 156 if len(values) == 0 { 157 // Probably a test failure - the protos used in these tests should have Get functions on them 158 t.Fatalf("No values found from Get functions!") 159 } 160 161 for _, gtr := range values { 162 err := assertAllFieldsSet(fmt.Sprintf("DeleteUserRequest.%s", gtr.name), gtr.value) 163 if err != nil { 164 t.Fatalf("%s", err) 165 } 166 } 167 }) 168 169 t.Run("getUpdateUserRequest", func(t *testing.T) { 170 req := &proto.UpdateUserRequest{ 171 Username: "username", 172 CredentialType: int32(CredentialTypeRSAPrivateKey), 173 Password: &proto.ChangePassword{ 174 NewPassword: "newpass", 175 Statements: &proto.Statements{ 176 Commands: []string{ 177 "statement", 178 }, 179 }, 180 }, 181 PublicKey: &proto.ChangePublicKey{ 182 NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"), 183 Statements: &proto.Statements{ 184 Commands: []string{ 185 "statement", 186 }, 187 }, 188 }, 189 Expiration: &proto.ChangeExpiration{ 190 NewExpiration: timestamppb.Now(), 191 Statements: &proto.Statements{ 192 Commands: []string{ 193 "statement", 194 }, 195 }, 196 }, 197 } 198 199 protoReq, err := getUpdateUserRequest(req) 200 if err != nil { 201 t.Fatalf("Failed to convert request to proto request: %s", err) 202 } 203 204 err = assertAllFieldsSet("proto.UpdateUserRequest", protoReq) 205 if err != nil { 206 t.Fatalf("%s", err) 207 } 208 }) 209 } 210 211 type getter struct { 212 name string 213 value interface{} 214 } 215 216 func getAllGetterValues(value interface{}) (values []getter) { 217 typ := reflect.TypeOf(value) 218 val := reflect.ValueOf(value) 219 for i := 0; i < typ.NumMethod(); i++ { 220 method := typ.Method(i) 221 if !strings.HasPrefix(method.Name, "Get") { 222 continue 223 } 224 valMethod := val.Method(i) 225 resp := valMethod.Call(nil) 226 getVal := resp[0].Interface() 227 gtr := getter{ 228 name: strings.TrimPrefix(method.Name, "Get"), 229 value: getVal, 230 } 231 values = append(values, gtr) 232 } 233 return values 234 } 235 236 // Ensures the assertion works properly 237 func TestAssertAllFieldsSet(t *testing.T) { 238 type testCase struct { 239 value interface{} 240 expectErr bool 241 } 242 243 tests := map[string]testCase{ 244 "zero int": { 245 value: 0, 246 expectErr: true, 247 }, 248 "non-zero int": { 249 value: 1, 250 expectErr: false, 251 }, 252 "zero float64": { 253 value: 0.0, 254 expectErr: true, 255 }, 256 "non-zero float64": { 257 value: 1.0, 258 expectErr: false, 259 }, 260 "empty string": { 261 value: "", 262 expectErr: true, 263 }, 264 "true boolean": { 265 value: true, 266 expectErr: false, 267 }, 268 "false boolean": { // False is an exception to the "is zero" rule 269 value: false, 270 expectErr: false, 271 }, 272 "blank struct": { 273 value: struct{}{}, 274 expectErr: true, 275 }, 276 "non-blank but empty struct": { 277 value: struct { 278 str string 279 }{ 280 str: "", 281 }, 282 expectErr: true, 283 }, 284 "non-empty string": { 285 value: "foo", 286 expectErr: false, 287 }, 288 "non-empty struct": { 289 value: struct { 290 str string 291 }{ 292 str: "foo", 293 }, 294 expectErr: false, 295 }, 296 "empty nested struct": { 297 value: struct { 298 Str string 299 Substruct struct { 300 Substr string 301 } 302 }{ 303 Str: "foo", 304 Substruct: struct { 305 Substr string 306 }{}, // Empty sub-field 307 }, 308 expectErr: true, 309 }, 310 "filled nested struct": { 311 value: struct { 312 str string 313 substruct struct { 314 substr string 315 } 316 }{ 317 str: "foo", 318 substruct: struct { 319 substr string 320 }{ 321 substr: "sub-foo", 322 }, 323 }, 324 expectErr: false, 325 }, 326 "nil map": { 327 value: map[string]string(nil), 328 expectErr: true, 329 }, 330 "empty map": { 331 value: map[string]string{}, 332 expectErr: true, 333 }, 334 "filled map": { 335 value: map[string]string{ 336 "foo": "bar", 337 "int": "42", 338 }, 339 expectErr: false, 340 }, 341 "map with empty string value": { 342 value: map[string]string{ 343 "foo": "", 344 }, 345 expectErr: true, 346 }, 347 "nested map with empty string value": { 348 value: map[string]interface{}{ 349 "bar": "baz", 350 "foo": map[string]interface{}{ 351 "subfoo": "", 352 }, 353 }, 354 expectErr: true, 355 }, 356 "nil slice": { 357 value: []string(nil), 358 expectErr: true, 359 }, 360 "empty slice": { 361 value: []string{}, 362 expectErr: true, 363 }, 364 "filled slice": { 365 value: []string{ 366 "foo", 367 }, 368 expectErr: false, 369 }, 370 "slice with empty string value": { 371 value: []string{ 372 "", 373 }, 374 expectErr: true, 375 }, 376 "empty structpb": { 377 value: newStructPb(t, map[string]interface{}{}), 378 expectErr: true, 379 }, 380 "filled structpb": { 381 value: newStructPb(t, map[string]interface{}{ 382 "foo": "bar", 383 "int": 42, 384 }), 385 expectErr: false, 386 }, 387 388 "pointer to zero int": { 389 value: intPtr(0), 390 expectErr: true, 391 }, 392 "pointer to non-zero int": { 393 value: intPtr(1), 394 expectErr: false, 395 }, 396 "pointer to zero float64": { 397 value: float64Ptr(0.0), 398 expectErr: true, 399 }, 400 "pointer to non-zero float64": { 401 value: float64Ptr(1.0), 402 expectErr: false, 403 }, 404 "pointer to nil string": { 405 value: new(string), 406 expectErr: true, 407 }, 408 "pointer to non-nil string": { 409 value: strPtr("foo"), 410 expectErr: false, 411 }, 412 } 413 414 for name, test := range tests { 415 t.Run(name, func(t *testing.T) { 416 err := assertAllFieldsSet("", test.value) 417 if test.expectErr && err == nil { 418 t.Fatalf("err expected, got nil") 419 } 420 if !test.expectErr && err != nil { 421 t.Fatalf("no error expected, got: %s", err) 422 } 423 }) 424 } 425 } 426 427 func assertAllFieldsSet(name string, val interface{}) error { 428 if val == nil { 429 return fmt.Errorf("value is nil") 430 } 431 432 rVal := reflect.ValueOf(val) 433 return assertAllFieldsSetValue(name, rVal) 434 } 435 436 func assertAllFieldsSetValue(name string, rVal reflect.Value) error { 437 // All booleans are allowed - we don't have a way of differentiating between 438 // and intentional false and a missing false 439 if rVal.Kind() == reflect.Bool { 440 return nil 441 } 442 443 // Primitives fall through here 444 if rVal.IsZero() { 445 return fmt.Errorf("%s is zero", name) 446 } 447 448 switch rVal.Kind() { 449 case reflect.Ptr, reflect.Interface: 450 return assertAllFieldsSetValue(name, rVal.Elem()) 451 case reflect.Struct: 452 return assertAllFieldsSetStruct(name, rVal) 453 case reflect.Map: 454 if rVal.Len() == 0 { 455 return fmt.Errorf("%s (map type) is empty", name) 456 } 457 458 iter := rVal.MapRange() 459 for iter.Next() { 460 k := iter.Key() 461 v := iter.Value() 462 463 err := assertAllFieldsSetValue(fmt.Sprintf("%s[%s]", name, k), v) 464 if err != nil { 465 return err 466 } 467 } 468 case reflect.Slice: 469 if rVal.Len() == 0 { 470 return fmt.Errorf("%s (slice type) is empty", name) 471 } 472 for i := 0; i < rVal.Len(); i++ { 473 sliceVal := rVal.Index(i) 474 err := assertAllFieldsSetValue(fmt.Sprintf("%s[%d]", name, i), sliceVal) 475 if err != nil { 476 return err 477 } 478 } 479 } 480 return nil 481 } 482 483 func assertAllFieldsSetStruct(name string, rVal reflect.Value) error { 484 switch rVal.Type() { 485 case reflect.TypeOf(timestamppb.Timestamp{}): 486 ts := rVal.Interface().(timestamppb.Timestamp) 487 if ts.AsTime().IsZero() { 488 return fmt.Errorf("%s is zero", name) 489 } 490 return nil 491 default: 492 for i := 0; i < rVal.NumField(); i++ { 493 field := rVal.Field(i) 494 fieldName := rVal.Type().Field(i) 495 496 // Skip fields that aren't exported 497 if unicode.IsLower([]rune(fieldName.Name)[0]) { 498 continue 499 } 500 501 err := assertAllFieldsSetValue(fmt.Sprintf("%s.%s", name, fieldName.Name), field) 502 if err != nil { 503 return err 504 } 505 } 506 return nil 507 } 508 } 509 510 func intPtr(i int) *int { 511 return &i 512 } 513 514 func float64Ptr(f float64) *float64 { 515 return &f 516 } 517 518 func strPtr(str string) *string { 519 return &str 520 } 521 522 func newStructPb(t *testing.T, m map[string]interface{}) *structpb.Struct { 523 t.Helper() 524 525 s, err := structpb.NewStruct(m) 526 if err != nil { 527 t.Fatalf("Failed to convert map to struct: %s", err) 528 } 529 return s 530 }