github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/middleware_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package dbplugin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"net/url"
    10  	"reflect"
    11  	"testing"
    12  
    13  	"github.com/hashicorp/go-hclog"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/status"
    16  )
    17  
    18  func TestDatabaseErrorSanitizerMiddleware(t *testing.T) {
    19  	type testCase struct {
    20  		inputErr    error
    21  		secretsFunc func() map[string]string
    22  
    23  		expectedError error
    24  	}
    25  
    26  	tests := map[string]testCase{
    27  		"nil error": {
    28  			inputErr:      nil,
    29  			expectedError: nil,
    30  		},
    31  		"url error": {
    32  			inputErr:      new(url.Error),
    33  			expectedError: errors.New("unable to parse connection url"),
    34  		},
    35  		"nil secrets func": {
    36  			inputErr:      errors.New("here is my password: iofsd9473tg"),
    37  			expectedError: errors.New("here is my password: iofsd9473tg"),
    38  		},
    39  		"secrets with empty string": {
    40  			inputErr:      errors.New("here is my password: iofsd9473tg"),
    41  			secretsFunc:   secretFunc(t, "", ""),
    42  			expectedError: errors.New("here is my password: iofsd9473tg"),
    43  		},
    44  		"secrets that do not match": {
    45  			inputErr:      errors.New("here is my password: iofsd9473tg"),
    46  			secretsFunc:   secretFunc(t, "asdf", "<redacted>"),
    47  			expectedError: errors.New("here is my password: iofsd9473tg"),
    48  		},
    49  		"secrets that do match": {
    50  			inputErr:      errors.New("here is my password: iofsd9473tg"),
    51  			secretsFunc:   secretFunc(t, "iofsd9473tg", "<redacted>"),
    52  			expectedError: errors.New("here is my password: <redacted>"),
    53  		},
    54  		"multiple secrets": {
    55  			inputErr: errors.New("here is my password: iofsd9473tg"),
    56  			secretsFunc: secretFunc(t,
    57  				"iofsd9473tg", "<redacted>",
    58  				"password", "<this was the word password>",
    59  			),
    60  			expectedError: errors.New("here is my <this was the word password>: <redacted>"),
    61  		},
    62  		"gRPC status error": {
    63  			inputErr:      status.Error(codes.InvalidArgument, "an error with a password iofsd9473tg"),
    64  			secretsFunc:   secretFunc(t, "iofsd9473tg", "<redacted>"),
    65  			expectedError: status.Errorf(codes.InvalidArgument, "an error with a password <redacted>"),
    66  		},
    67  	}
    68  
    69  	for name, test := range tests {
    70  		t.Run(name, func(t *testing.T) {
    71  			db := fakeDatabase{}
    72  			mw := NewDatabaseErrorSanitizerMiddleware(db, test.secretsFunc)
    73  
    74  			actualErr := mw.sanitize(test.inputErr)
    75  			if !reflect.DeepEqual(actualErr, test.expectedError) {
    76  				t.Fatalf("Actual error: %s\nExpected error: %s", actualErr, test.expectedError)
    77  			}
    78  		})
    79  	}
    80  
    81  	t.Run("Initialize", func(t *testing.T) {
    82  		db := &recordingDatabase{
    83  			next: fakeDatabase{
    84  				initErr: errors.New("password: iofsd9473tg with some stuff after it"),
    85  			},
    86  		}
    87  		mw := DatabaseErrorSanitizerMiddleware{
    88  			next:      db,
    89  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
    90  		}
    91  
    92  		expectedErr := errors.New("password: <redacted> with some stuff after it")
    93  
    94  		_, err := mw.Initialize(context.Background(), InitializeRequest{})
    95  		if !reflect.DeepEqual(err, expectedErr) {
    96  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
    97  		}
    98  
    99  		assertEquals(t, db.initializeCalls, 1)
   100  		assertEquals(t, db.newUserCalls, 0)
   101  		assertEquals(t, db.updateUserCalls, 0)
   102  		assertEquals(t, db.deleteUserCalls, 0)
   103  		assertEquals(t, db.typeCalls, 0)
   104  		assertEquals(t, db.closeCalls, 0)
   105  	})
   106  
   107  	t.Run("NewUser", func(t *testing.T) {
   108  		db := &recordingDatabase{
   109  			next: fakeDatabase{
   110  				newUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
   111  			},
   112  		}
   113  		mw := DatabaseErrorSanitizerMiddleware{
   114  			next:      db,
   115  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
   116  		}
   117  
   118  		expectedErr := errors.New("password: <redacted> with some stuff after it")
   119  
   120  		_, err := mw.NewUser(context.Background(), NewUserRequest{})
   121  		if !reflect.DeepEqual(err, expectedErr) {
   122  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
   123  		}
   124  
   125  		assertEquals(t, db.initializeCalls, 0)
   126  		assertEquals(t, db.newUserCalls, 1)
   127  		assertEquals(t, db.updateUserCalls, 0)
   128  		assertEquals(t, db.deleteUserCalls, 0)
   129  		assertEquals(t, db.typeCalls, 0)
   130  		assertEquals(t, db.closeCalls, 0)
   131  	})
   132  
   133  	t.Run("UpdateUser", func(t *testing.T) {
   134  		db := &recordingDatabase{
   135  			next: fakeDatabase{
   136  				updateUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
   137  			},
   138  		}
   139  		mw := DatabaseErrorSanitizerMiddleware{
   140  			next:      db,
   141  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
   142  		}
   143  
   144  		expectedErr := errors.New("password: <redacted> with some stuff after it")
   145  
   146  		_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
   147  		if !reflect.DeepEqual(err, expectedErr) {
   148  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
   149  		}
   150  
   151  		assertEquals(t, db.initializeCalls, 0)
   152  		assertEquals(t, db.newUserCalls, 0)
   153  		assertEquals(t, db.updateUserCalls, 1)
   154  		assertEquals(t, db.deleteUserCalls, 0)
   155  		assertEquals(t, db.typeCalls, 0)
   156  		assertEquals(t, db.closeCalls, 0)
   157  	})
   158  
   159  	t.Run("DeleteUser", func(t *testing.T) {
   160  		db := &recordingDatabase{
   161  			next: fakeDatabase{
   162  				deleteUserErr: errors.New("password: iofsd9473tg with some stuff after it"),
   163  			},
   164  		}
   165  		mw := DatabaseErrorSanitizerMiddleware{
   166  			next:      db,
   167  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
   168  		}
   169  
   170  		expectedErr := errors.New("password: <redacted> with some stuff after it")
   171  
   172  		_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
   173  		if !reflect.DeepEqual(err, expectedErr) {
   174  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
   175  		}
   176  
   177  		assertEquals(t, db.initializeCalls, 0)
   178  		assertEquals(t, db.newUserCalls, 0)
   179  		assertEquals(t, db.updateUserCalls, 0)
   180  		assertEquals(t, db.deleteUserCalls, 1)
   181  		assertEquals(t, db.typeCalls, 0)
   182  		assertEquals(t, db.closeCalls, 0)
   183  	})
   184  
   185  	t.Run("Type", func(t *testing.T) {
   186  		db := &recordingDatabase{
   187  			next: fakeDatabase{
   188  				typeErr: errors.New("password: iofsd9473tg with some stuff after it"),
   189  			},
   190  		}
   191  		mw := DatabaseErrorSanitizerMiddleware{
   192  			next:      db,
   193  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
   194  		}
   195  
   196  		expectedErr := errors.New("password: <redacted> with some stuff after it")
   197  
   198  		_, err := mw.Type()
   199  		if !reflect.DeepEqual(err, expectedErr) {
   200  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
   201  		}
   202  
   203  		assertEquals(t, db.initializeCalls, 0)
   204  		assertEquals(t, db.newUserCalls, 0)
   205  		assertEquals(t, db.updateUserCalls, 0)
   206  		assertEquals(t, db.deleteUserCalls, 0)
   207  		assertEquals(t, db.typeCalls, 1)
   208  		assertEquals(t, db.closeCalls, 0)
   209  	})
   210  
   211  	t.Run("Close", func(t *testing.T) {
   212  		db := &recordingDatabase{
   213  			next: fakeDatabase{
   214  				closeErr: errors.New("password: iofsd9473tg with some stuff after it"),
   215  			},
   216  		}
   217  		mw := DatabaseErrorSanitizerMiddleware{
   218  			next:      db,
   219  			secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"),
   220  		}
   221  
   222  		expectedErr := errors.New("password: <redacted> with some stuff after it")
   223  
   224  		err := mw.Close()
   225  		if !reflect.DeepEqual(err, expectedErr) {
   226  			t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr)
   227  		}
   228  
   229  		assertEquals(t, db.initializeCalls, 0)
   230  		assertEquals(t, db.newUserCalls, 0)
   231  		assertEquals(t, db.updateUserCalls, 0)
   232  		assertEquals(t, db.deleteUserCalls, 0)
   233  		assertEquals(t, db.typeCalls, 0)
   234  		assertEquals(t, db.closeCalls, 1)
   235  	})
   236  }
   237  
   238  func secretFunc(t *testing.T, vals ...string) func() map[string]string {
   239  	t.Helper()
   240  	if len(vals)%2 != 0 {
   241  		t.Fatalf("Test configuration error: secretFunc must be called with an even number of values")
   242  	}
   243  
   244  	m := map[string]string{}
   245  
   246  	for i := 0; i < len(vals); i += 2 {
   247  		key := vals[i]
   248  		m[key] = vals[i+1]
   249  	}
   250  
   251  	return func() map[string]string {
   252  		return m
   253  	}
   254  }
   255  
   256  func TestTracingMiddleware(t *testing.T) {
   257  	t.Run("Initialize", func(t *testing.T) {
   258  		db := &recordingDatabase{}
   259  		logger := hclog.NewNullLogger()
   260  		mw := databaseTracingMiddleware{
   261  			next:   db,
   262  			logger: logger,
   263  		}
   264  		_, err := mw.Initialize(context.Background(), InitializeRequest{})
   265  		if err != nil {
   266  			t.Fatalf("Expected no error, but got: %s", err)
   267  		}
   268  		assertEquals(t, db.initializeCalls, 1)
   269  		assertEquals(t, db.newUserCalls, 0)
   270  		assertEquals(t, db.updateUserCalls, 0)
   271  		assertEquals(t, db.deleteUserCalls, 0)
   272  		assertEquals(t, db.typeCalls, 0)
   273  		assertEquals(t, db.closeCalls, 0)
   274  	})
   275  
   276  	t.Run("NewUser", func(t *testing.T) {
   277  		db := &recordingDatabase{}
   278  		logger := hclog.NewNullLogger()
   279  		mw := databaseTracingMiddleware{
   280  			next:   db,
   281  			logger: logger,
   282  		}
   283  		_, err := mw.NewUser(context.Background(), NewUserRequest{})
   284  		if err != nil {
   285  			t.Fatalf("Expected no error, but got: %s", err)
   286  		}
   287  		assertEquals(t, db.initializeCalls, 0)
   288  		assertEquals(t, db.newUserCalls, 1)
   289  		assertEquals(t, db.updateUserCalls, 0)
   290  		assertEquals(t, db.deleteUserCalls, 0)
   291  		assertEquals(t, db.typeCalls, 0)
   292  		assertEquals(t, db.closeCalls, 0)
   293  	})
   294  
   295  	t.Run("UpdateUser", func(t *testing.T) {
   296  		db := &recordingDatabase{}
   297  		logger := hclog.NewNullLogger()
   298  		mw := databaseTracingMiddleware{
   299  			next:   db,
   300  			logger: logger,
   301  		}
   302  		_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
   303  		if err != nil {
   304  			t.Fatalf("Expected no error, but got: %s", err)
   305  		}
   306  		assertEquals(t, db.initializeCalls, 0)
   307  		assertEquals(t, db.newUserCalls, 0)
   308  		assertEquals(t, db.updateUserCalls, 1)
   309  		assertEquals(t, db.deleteUserCalls, 0)
   310  		assertEquals(t, db.typeCalls, 0)
   311  		assertEquals(t, db.closeCalls, 0)
   312  	})
   313  
   314  	t.Run("DeleteUser", func(t *testing.T) {
   315  		db := &recordingDatabase{}
   316  		logger := hclog.NewNullLogger()
   317  		mw := databaseTracingMiddleware{
   318  			next:   db,
   319  			logger: logger,
   320  		}
   321  		_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
   322  		if err != nil {
   323  			t.Fatalf("Expected no error, but got: %s", err)
   324  		}
   325  		assertEquals(t, db.initializeCalls, 0)
   326  		assertEquals(t, db.newUserCalls, 0)
   327  		assertEquals(t, db.updateUserCalls, 0)
   328  		assertEquals(t, db.deleteUserCalls, 1)
   329  		assertEquals(t, db.typeCalls, 0)
   330  		assertEquals(t, db.closeCalls, 0)
   331  	})
   332  
   333  	t.Run("Type", func(t *testing.T) {
   334  		db := &recordingDatabase{}
   335  		logger := hclog.NewNullLogger()
   336  		mw := databaseTracingMiddleware{
   337  			next:   db,
   338  			logger: logger,
   339  		}
   340  		_, err := mw.Type()
   341  		if err != nil {
   342  			t.Fatalf("Expected no error, but got: %s", err)
   343  		}
   344  		assertEquals(t, db.initializeCalls, 0)
   345  		assertEquals(t, db.newUserCalls, 0)
   346  		assertEquals(t, db.updateUserCalls, 0)
   347  		assertEquals(t, db.deleteUserCalls, 0)
   348  		assertEquals(t, db.typeCalls, 1)
   349  		assertEquals(t, db.closeCalls, 0)
   350  	})
   351  
   352  	t.Run("Close", func(t *testing.T) {
   353  		db := &recordingDatabase{}
   354  		logger := hclog.NewNullLogger()
   355  		mw := databaseTracingMiddleware{
   356  			next:   db,
   357  			logger: logger,
   358  		}
   359  		err := mw.Close()
   360  		if err != nil {
   361  			t.Fatalf("Expected no error, but got: %s", err)
   362  		}
   363  		assertEquals(t, db.initializeCalls, 0)
   364  		assertEquals(t, db.newUserCalls, 0)
   365  		assertEquals(t, db.updateUserCalls, 0)
   366  		assertEquals(t, db.deleteUserCalls, 0)
   367  		assertEquals(t, db.typeCalls, 0)
   368  		assertEquals(t, db.closeCalls, 1)
   369  	})
   370  }
   371  
   372  func TestMetricsMiddleware(t *testing.T) {
   373  	t.Run("Initialize", func(t *testing.T) {
   374  		db := &recordingDatabase{}
   375  		mw := databaseMetricsMiddleware{
   376  			next:    db,
   377  			typeStr: "metrics",
   378  		}
   379  		_, err := mw.Initialize(context.Background(), InitializeRequest{})
   380  		if err != nil {
   381  			t.Fatalf("Expected no error, but got: %s", err)
   382  		}
   383  		assertEquals(t, db.initializeCalls, 1)
   384  		assertEquals(t, db.newUserCalls, 0)
   385  		assertEquals(t, db.updateUserCalls, 0)
   386  		assertEquals(t, db.deleteUserCalls, 0)
   387  		assertEquals(t, db.typeCalls, 0)
   388  		assertEquals(t, db.closeCalls, 0)
   389  	})
   390  
   391  	t.Run("NewUser", func(t *testing.T) {
   392  		db := &recordingDatabase{}
   393  		mw := databaseMetricsMiddleware{
   394  			next:    db,
   395  			typeStr: "metrics",
   396  		}
   397  		_, err := mw.NewUser(context.Background(), NewUserRequest{})
   398  		if err != nil {
   399  			t.Fatalf("Expected no error, but got: %s", err)
   400  		}
   401  		assertEquals(t, db.initializeCalls, 0)
   402  		assertEquals(t, db.newUserCalls, 1)
   403  		assertEquals(t, db.updateUserCalls, 0)
   404  		assertEquals(t, db.deleteUserCalls, 0)
   405  		assertEquals(t, db.typeCalls, 0)
   406  		assertEquals(t, db.closeCalls, 0)
   407  	})
   408  
   409  	t.Run("UpdateUser", func(t *testing.T) {
   410  		db := &recordingDatabase{}
   411  		mw := databaseMetricsMiddleware{
   412  			next:    db,
   413  			typeStr: "metrics",
   414  		}
   415  		_, err := mw.UpdateUser(context.Background(), UpdateUserRequest{})
   416  		if err != nil {
   417  			t.Fatalf("Expected no error, but got: %s", err)
   418  		}
   419  		assertEquals(t, db.initializeCalls, 0)
   420  		assertEquals(t, db.newUserCalls, 0)
   421  		assertEquals(t, db.updateUserCalls, 1)
   422  		assertEquals(t, db.deleteUserCalls, 0)
   423  		assertEquals(t, db.typeCalls, 0)
   424  		assertEquals(t, db.closeCalls, 0)
   425  	})
   426  
   427  	t.Run("DeleteUser", func(t *testing.T) {
   428  		db := &recordingDatabase{}
   429  		mw := databaseMetricsMiddleware{
   430  			next:    db,
   431  			typeStr: "metrics",
   432  		}
   433  		_, err := mw.DeleteUser(context.Background(), DeleteUserRequest{})
   434  		if err != nil {
   435  			t.Fatalf("Expected no error, but got: %s", err)
   436  		}
   437  		assertEquals(t, db.initializeCalls, 0)
   438  		assertEquals(t, db.newUserCalls, 0)
   439  		assertEquals(t, db.updateUserCalls, 0)
   440  		assertEquals(t, db.deleteUserCalls, 1)
   441  		assertEquals(t, db.typeCalls, 0)
   442  		assertEquals(t, db.closeCalls, 0)
   443  	})
   444  
   445  	t.Run("Type", func(t *testing.T) {
   446  		db := &recordingDatabase{}
   447  		mw := databaseMetricsMiddleware{
   448  			next:    db,
   449  			typeStr: "metrics",
   450  		}
   451  		_, err := mw.Type()
   452  		if err != nil {
   453  			t.Fatalf("Expected no error, but got: %s", err)
   454  		}
   455  		assertEquals(t, db.initializeCalls, 0)
   456  		assertEquals(t, db.newUserCalls, 0)
   457  		assertEquals(t, db.updateUserCalls, 0)
   458  		assertEquals(t, db.deleteUserCalls, 0)
   459  		assertEquals(t, db.typeCalls, 1)
   460  		assertEquals(t, db.closeCalls, 0)
   461  	})
   462  
   463  	t.Run("Close", func(t *testing.T) {
   464  		db := &recordingDatabase{}
   465  		mw := databaseMetricsMiddleware{
   466  			next:    db,
   467  			typeStr: "metrics",
   468  		}
   469  		err := mw.Close()
   470  		if err != nil {
   471  			t.Fatalf("Expected no error, but got: %s", err)
   472  		}
   473  		assertEquals(t, db.initializeCalls, 0)
   474  		assertEquals(t, db.newUserCalls, 0)
   475  		assertEquals(t, db.updateUserCalls, 0)
   476  		assertEquals(t, db.deleteUserCalls, 0)
   477  		assertEquals(t, db.typeCalls, 0)
   478  		assertEquals(t, db.closeCalls, 1)
   479  	})
   480  }
   481  
   482  func assertEquals(t *testing.T, actual, expected int) {
   483  	t.Helper()
   484  	if actual != expected {
   485  		t.Fatalf("Actual: %d Expected: %d", actual, expected)
   486  	}
   487  }