github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/conversions_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"fmt"
     8  	"reflect"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  	"unicode"
    13  
    14  	"github.com/hashicorp/vault/sdk/database/dbplugin/v5/proto"
    15  	"google.golang.org/protobuf/types/known/structpb"
    16  	"google.golang.org/protobuf/types/known/timestamppb"
    17  )
    18  
    19  func TestConversionsHaveAllFields(t *testing.T) {
    20  	t.Run("initReqToProto", func(t *testing.T) {
    21  		req := InitializeRequest{
    22  			Config: map[string]interface{}{
    23  				"foo": map[string]interface{}{
    24  					"bar": "baz",
    25  				},
    26  			},
    27  			VerifyConnection: true,
    28  		}
    29  
    30  		protoReq, err := initReqToProto(req)
    31  		if err != nil {
    32  			t.Fatalf("Failed to convert request to proto request: %s", err)
    33  		}
    34  
    35  		values := getAllGetterValues(protoReq)
    36  		if len(values) == 0 {
    37  			// Probably a test failure - the protos used in these tests should have Get functions on them
    38  			t.Fatalf("No values found from Get functions!")
    39  		}
    40  
    41  		for _, gtr := range values {
    42  			err := assertAllFieldsSet(fmt.Sprintf("InitializeRequest.%s", gtr.name), gtr.value)
    43  			if err != nil {
    44  				t.Fatalf("%s", err)
    45  			}
    46  		}
    47  	})
    48  
    49  	t.Run("newUserReqToProto", func(t *testing.T) {
    50  		req := NewUserRequest{
    51  			UsernameConfig: UsernameMetadata{
    52  				DisplayName: "dispName",
    53  				RoleName:    "roleName",
    54  			},
    55  			Statements: Statements{
    56  				Commands: []string{
    57  					"statement",
    58  				},
    59  			},
    60  			RollbackStatements: Statements{
    61  				Commands: []string{
    62  					"rollback_statement",
    63  				},
    64  			},
    65  			CredentialType: CredentialTypeRSAPrivateKey,
    66  			PublicKey:      []byte("-----BEGIN PUBLIC KEY-----"),
    67  			Password:       "password",
    68  			Subject:        "subject",
    69  			Expiration:     time.Now(),
    70  		}
    71  
    72  		protoReq, err := newUserReqToProto(req)
    73  		if err != nil {
    74  			t.Fatalf("Failed to convert request to proto request: %s", err)
    75  		}
    76  
    77  		values := getAllGetterValues(protoReq)
    78  		if len(values) == 0 {
    79  			// Probably a test failure - the protos used in these tests should have Get functions on them
    80  			t.Fatalf("No values found from Get functions!")
    81  		}
    82  
    83  		for _, gtr := range values {
    84  			err := assertAllFieldsSet(fmt.Sprintf("NewUserRequest.%s", gtr.name), gtr.value)
    85  			if err != nil {
    86  				t.Fatalf("%s", err)
    87  			}
    88  		}
    89  	})
    90  
    91  	t.Run("updateUserReqToProto", func(t *testing.T) {
    92  		req := UpdateUserRequest{
    93  			Username:       "username",
    94  			CredentialType: CredentialTypeRSAPrivateKey,
    95  			Password: &ChangePassword{
    96  				NewPassword: "newpassword",
    97  				Statements: Statements{
    98  					Commands: []string{
    99  						"statement",
   100  					},
   101  				},
   102  			},
   103  			PublicKey: &ChangePublicKey{
   104  				NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
   105  				Statements: Statements{
   106  					Commands: []string{
   107  						"statement",
   108  					},
   109  				},
   110  			},
   111  			Expiration: &ChangeExpiration{
   112  				NewExpiration: time.Now(),
   113  				Statements: Statements{
   114  					Commands: []string{
   115  						"statement",
   116  					},
   117  				},
   118  			},
   119  		}
   120  
   121  		protoReq, err := updateUserReqToProto(req)
   122  		if err != nil {
   123  			t.Fatalf("Failed to convert request to proto request: %s", err)
   124  		}
   125  
   126  		values := getAllGetterValues(protoReq)
   127  		if len(values) == 0 {
   128  			// Probably a test failure - the protos used in these tests should have Get functions on them
   129  			t.Fatalf("No values found from Get functions!")
   130  		}
   131  
   132  		for _, gtr := range values {
   133  			err := assertAllFieldsSet(fmt.Sprintf("UpdateUserRequest.%s", gtr.name), gtr.value)
   134  			if err != nil {
   135  				t.Fatalf("%s", err)
   136  			}
   137  		}
   138  	})
   139  
   140  	t.Run("deleteUserReqToProto", func(t *testing.T) {
   141  		req := DeleteUserRequest{
   142  			Username: "username",
   143  			Statements: Statements{
   144  				Commands: []string{
   145  					"statement",
   146  				},
   147  			},
   148  		}
   149  
   150  		protoReq, err := deleteUserReqToProto(req)
   151  		if err != nil {
   152  			t.Fatalf("Failed to convert request to proto request: %s", err)
   153  		}
   154  
   155  		values := getAllGetterValues(protoReq)
   156  		if len(values) == 0 {
   157  			// Probably a test failure - the protos used in these tests should have Get functions on them
   158  			t.Fatalf("No values found from Get functions!")
   159  		}
   160  
   161  		for _, gtr := range values {
   162  			err := assertAllFieldsSet(fmt.Sprintf("DeleteUserRequest.%s", gtr.name), gtr.value)
   163  			if err != nil {
   164  				t.Fatalf("%s", err)
   165  			}
   166  		}
   167  	})
   168  
   169  	t.Run("getUpdateUserRequest", func(t *testing.T) {
   170  		req := &proto.UpdateUserRequest{
   171  			Username:       "username",
   172  			CredentialType: int32(CredentialTypeRSAPrivateKey),
   173  			Password: &proto.ChangePassword{
   174  				NewPassword: "newpass",
   175  				Statements: &proto.Statements{
   176  					Commands: []string{
   177  						"statement",
   178  					},
   179  				},
   180  			},
   181  			PublicKey: &proto.ChangePublicKey{
   182  				NewPublicKey: []byte("-----BEGIN PUBLIC KEY-----"),
   183  				Statements: &proto.Statements{
   184  					Commands: []string{
   185  						"statement",
   186  					},
   187  				},
   188  			},
   189  			Expiration: &proto.ChangeExpiration{
   190  				NewExpiration: timestamppb.Now(),
   191  				Statements: &proto.Statements{
   192  					Commands: []string{
   193  						"statement",
   194  					},
   195  				},
   196  			},
   197  		}
   198  
   199  		protoReq, err := getUpdateUserRequest(req)
   200  		if err != nil {
   201  			t.Fatalf("Failed to convert request to proto request: %s", err)
   202  		}
   203  
   204  		err = assertAllFieldsSet("proto.UpdateUserRequest", protoReq)
   205  		if err != nil {
   206  			t.Fatalf("%s", err)
   207  		}
   208  	})
   209  }
   210  
   211  type getter struct {
   212  	name  string
   213  	value interface{}
   214  }
   215  
   216  func getAllGetterValues(value interface{}) (values []getter) {
   217  	typ := reflect.TypeOf(value)
   218  	val := reflect.ValueOf(value)
   219  	for i := 0; i < typ.NumMethod(); i++ {
   220  		method := typ.Method(i)
   221  		if !strings.HasPrefix(method.Name, "Get") {
   222  			continue
   223  		}
   224  		valMethod := val.Method(i)
   225  		resp := valMethod.Call(nil)
   226  		getVal := resp[0].Interface()
   227  		gtr := getter{
   228  			name:  strings.TrimPrefix(method.Name, "Get"),
   229  			value: getVal,
   230  		}
   231  		values = append(values, gtr)
   232  	}
   233  	return values
   234  }
   235  
   236  // Ensures the assertion works properly
   237  func TestAssertAllFieldsSet(t *testing.T) {
   238  	type testCase struct {
   239  		value     interface{}
   240  		expectErr bool
   241  	}
   242  
   243  	tests := map[string]testCase{
   244  		"zero int": {
   245  			value:     0,
   246  			expectErr: true,
   247  		},
   248  		"non-zero int": {
   249  			value:     1,
   250  			expectErr: false,
   251  		},
   252  		"zero float64": {
   253  			value:     0.0,
   254  			expectErr: true,
   255  		},
   256  		"non-zero float64": {
   257  			value:     1.0,
   258  			expectErr: false,
   259  		},
   260  		"empty string": {
   261  			value:     "",
   262  			expectErr: true,
   263  		},
   264  		"true boolean": {
   265  			value:     true,
   266  			expectErr: false,
   267  		},
   268  		"false boolean": { // False is an exception to the "is zero" rule
   269  			value:     false,
   270  			expectErr: false,
   271  		},
   272  		"blank struct": {
   273  			value:     struct{}{},
   274  			expectErr: true,
   275  		},
   276  		"non-blank but empty struct": {
   277  			value: struct {
   278  				str string
   279  			}{
   280  				str: "",
   281  			},
   282  			expectErr: true,
   283  		},
   284  		"non-empty string": {
   285  			value:     "foo",
   286  			expectErr: false,
   287  		},
   288  		"non-empty struct": {
   289  			value: struct {
   290  				str string
   291  			}{
   292  				str: "foo",
   293  			},
   294  			expectErr: false,
   295  		},
   296  		"empty nested struct": {
   297  			value: struct {
   298  				Str       string
   299  				Substruct struct {
   300  					Substr string
   301  				}
   302  			}{
   303  				Str: "foo",
   304  				Substruct: struct {
   305  					Substr string
   306  				}{}, // Empty sub-field
   307  			},
   308  			expectErr: true,
   309  		},
   310  		"filled nested struct": {
   311  			value: struct {
   312  				str       string
   313  				substruct struct {
   314  					substr string
   315  				}
   316  			}{
   317  				str: "foo",
   318  				substruct: struct {
   319  					substr string
   320  				}{
   321  					substr: "sub-foo",
   322  				},
   323  			},
   324  			expectErr: false,
   325  		},
   326  		"nil map": {
   327  			value:     map[string]string(nil),
   328  			expectErr: true,
   329  		},
   330  		"empty map": {
   331  			value:     map[string]string{},
   332  			expectErr: true,
   333  		},
   334  		"filled map": {
   335  			value: map[string]string{
   336  				"foo": "bar",
   337  				"int": "42",
   338  			},
   339  			expectErr: false,
   340  		},
   341  		"map with empty string value": {
   342  			value: map[string]string{
   343  				"foo": "",
   344  			},
   345  			expectErr: true,
   346  		},
   347  		"nested map with empty string value": {
   348  			value: map[string]interface{}{
   349  				"bar": "baz",
   350  				"foo": map[string]interface{}{
   351  					"subfoo": "",
   352  				},
   353  			},
   354  			expectErr: true,
   355  		},
   356  		"nil slice": {
   357  			value:     []string(nil),
   358  			expectErr: true,
   359  		},
   360  		"empty slice": {
   361  			value:     []string{},
   362  			expectErr: true,
   363  		},
   364  		"filled slice": {
   365  			value: []string{
   366  				"foo",
   367  			},
   368  			expectErr: false,
   369  		},
   370  		"slice with empty string value": {
   371  			value: []string{
   372  				"",
   373  			},
   374  			expectErr: true,
   375  		},
   376  		"empty structpb": {
   377  			value:     newStructPb(t, map[string]interface{}{}),
   378  			expectErr: true,
   379  		},
   380  		"filled structpb": {
   381  			value: newStructPb(t, map[string]interface{}{
   382  				"foo": "bar",
   383  				"int": 42,
   384  			}),
   385  			expectErr: false,
   386  		},
   387  
   388  		"pointer to zero int": {
   389  			value:     intPtr(0),
   390  			expectErr: true,
   391  		},
   392  		"pointer to non-zero int": {
   393  			value:     intPtr(1),
   394  			expectErr: false,
   395  		},
   396  		"pointer to zero float64": {
   397  			value:     float64Ptr(0.0),
   398  			expectErr: true,
   399  		},
   400  		"pointer to non-zero float64": {
   401  			value:     float64Ptr(1.0),
   402  			expectErr: false,
   403  		},
   404  		"pointer to nil string": {
   405  			value:     new(string),
   406  			expectErr: true,
   407  		},
   408  		"pointer to non-nil string": {
   409  			value:     strPtr("foo"),
   410  			expectErr: false,
   411  		},
   412  	}
   413  
   414  	for name, test := range tests {
   415  		t.Run(name, func(t *testing.T) {
   416  			err := assertAllFieldsSet("", test.value)
   417  			if test.expectErr && err == nil {
   418  				t.Fatalf("err expected, got nil")
   419  			}
   420  			if !test.expectErr && err != nil {
   421  				t.Fatalf("no error expected, got: %s", err)
   422  			}
   423  		})
   424  	}
   425  }
   426  
   427  func assertAllFieldsSet(name string, val interface{}) error {
   428  	if val == nil {
   429  		return fmt.Errorf("value is nil")
   430  	}
   431  
   432  	rVal := reflect.ValueOf(val)
   433  	return assertAllFieldsSetValue(name, rVal)
   434  }
   435  
   436  func assertAllFieldsSetValue(name string, rVal reflect.Value) error {
   437  	// All booleans are allowed - we don't have a way of differentiating between
   438  	// and intentional false and a missing false
   439  	if rVal.Kind() == reflect.Bool {
   440  		return nil
   441  	}
   442  
   443  	// Primitives fall through here
   444  	if rVal.IsZero() {
   445  		return fmt.Errorf("%s is zero", name)
   446  	}
   447  
   448  	switch rVal.Kind() {
   449  	case reflect.Ptr, reflect.Interface:
   450  		return assertAllFieldsSetValue(name, rVal.Elem())
   451  	case reflect.Struct:
   452  		return assertAllFieldsSetStruct(name, rVal)
   453  	case reflect.Map:
   454  		if rVal.Len() == 0 {
   455  			return fmt.Errorf("%s (map type) is empty", name)
   456  		}
   457  
   458  		iter := rVal.MapRange()
   459  		for iter.Next() {
   460  			k := iter.Key()
   461  			v := iter.Value()
   462  
   463  			err := assertAllFieldsSetValue(fmt.Sprintf("%s[%s]", name, k), v)
   464  			if err != nil {
   465  				return err
   466  			}
   467  		}
   468  	case reflect.Slice:
   469  		if rVal.Len() == 0 {
   470  			return fmt.Errorf("%s (slice type) is empty", name)
   471  		}
   472  		for i := 0; i < rVal.Len(); i++ {
   473  			sliceVal := rVal.Index(i)
   474  			err := assertAllFieldsSetValue(fmt.Sprintf("%s[%d]", name, i), sliceVal)
   475  			if err != nil {
   476  				return err
   477  			}
   478  		}
   479  	}
   480  	return nil
   481  }
   482  
   483  func assertAllFieldsSetStruct(name string, rVal reflect.Value) error {
   484  	switch rVal.Type() {
   485  	case reflect.TypeOf(timestamppb.Timestamp{}):
   486  		ts := rVal.Interface().(timestamppb.Timestamp)
   487  		if ts.AsTime().IsZero() {
   488  			return fmt.Errorf("%s is zero", name)
   489  		}
   490  		return nil
   491  	default:
   492  		for i := 0; i < rVal.NumField(); i++ {
   493  			field := rVal.Field(i)
   494  			fieldName := rVal.Type().Field(i)
   495  
   496  			// Skip fields that aren't exported
   497  			if unicode.IsLower([]rune(fieldName.Name)[0]) {
   498  				continue
   499  			}
   500  
   501  			err := assertAllFieldsSetValue(fmt.Sprintf("%s.%s", name, fieldName.Name), field)
   502  			if err != nil {
   503  				return err
   504  			}
   505  		}
   506  		return nil
   507  	}
   508  }
   509  
   510  func intPtr(i int) *int {
   511  	return &i
   512  }
   513  
   514  func float64Ptr(f float64) *float64 {
   515  	return &f
   516  }
   517  
   518  func strPtr(str string) *string {
   519  	return &str
   520  }
   521  
   522  func newStructPb(t *testing.T, m map[string]interface{}) *structpb.Struct {
   523  	t.Helper()
   524  
   525  	s, err := structpb.NewStruct(m)
   526  	if err != nil {
   527  		t.Fatalf("Failed to convert map to struct: %s", err)
   528  	}
   529  	return s
   530  }