github.com/hashicorp/vault/sdk@v0.13.0/database/dbplugin/v5/middleware_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package dbplugin 5 6 import ( 7 "context" 8 "errors" 9 "net/url" 10 "reflect" 11 "testing" 12 13 "github.com/hashicorp/go-hclog" 14 "google.golang.org/grpc/codes" 15 "google.golang.org/grpc/status" 16 ) 17 18 func TestDatabaseErrorSanitizerMiddleware(t *testing.T) { 19 type testCase struct { 20 inputErr error 21 secretsFunc func() map[string]string 22 23 expectedError error 24 } 25 26 tests := map[string]testCase{ 27 "nil error": { 28 inputErr: nil, 29 expectedError: nil, 30 }, 31 "url error": { 32 inputErr: new(url.Error), 33 expectedError: errors.New("unable to parse connection url"), 34 }, 35 "nil secrets func": { 36 inputErr: errors.New("here is my password: iofsd9473tg"), 37 expectedError: errors.New("here is my password: iofsd9473tg"), 38 }, 39 "secrets with empty string": { 40 inputErr: errors.New("here is my password: iofsd9473tg"), 41 secretsFunc: secretFunc(t, "", ""), 42 expectedError: errors.New("here is my password: iofsd9473tg"), 43 }, 44 "secrets that do not match": { 45 inputErr: errors.New("here is my password: iofsd9473tg"), 46 secretsFunc: secretFunc(t, "asdf", "<redacted>"), 47 expectedError: errors.New("here is my password: iofsd9473tg"), 48 }, 49 "secrets that do match": { 50 inputErr: errors.New("here is my password: iofsd9473tg"), 51 secretsFunc: secretFunc(t, "iofsd9473tg", "<redacted>"), 52 expectedError: errors.New("here is my password: <redacted>"), 53 }, 54 "multiple secrets": { 55 inputErr: errors.New("here is my password: iofsd9473tg"), 56 secretsFunc: secretFunc(t, 57 "iofsd9473tg", "<redacted>", 58 "password", "<this was the word password>", 59 ), 60 expectedError: errors.New("here is my <this was the word password>: <redacted>"), 61 }, 62 "gRPC status error": { 63 inputErr: status.Error(codes.InvalidArgument, "an error with a password iofsd9473tg"), 64 secretsFunc: secretFunc(t, "iofsd9473tg", "<redacted>"), 65 expectedError: status.Errorf(codes.InvalidArgument, "an error with a password <redacted>"), 66 }, 67 } 68 69 for name, test := range tests { 70 t.Run(name, func(t *testing.T) { 71 db := fakeDatabase{} 72 mw := NewDatabaseErrorSanitizerMiddleware(db, test.secretsFunc) 73 74 actualErr := mw.sanitize(test.inputErr) 75 if !reflect.DeepEqual(actualErr, test.expectedError) { 76 t.Fatalf("Actual error: %s\nExpected error: %s", actualErr, test.expectedError) 77 } 78 }) 79 } 80 81 t.Run("Initialize", func(t *testing.T) { 82 db := &recordingDatabase{ 83 next: fakeDatabase{ 84 initErr: errors.New("password: iofsd9473tg with some stuff after it"), 85 }, 86 } 87 mw := DatabaseErrorSanitizerMiddleware{ 88 next: db, 89 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 90 } 91 92 expectedErr := errors.New("password: <redacted> with some stuff after it") 93 94 _, err := mw.Initialize(context.Background(), InitializeRequest{}) 95 if !reflect.DeepEqual(err, expectedErr) { 96 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 97 } 98 99 assertEquals(t, db.initializeCalls, 1) 100 assertEquals(t, db.newUserCalls, 0) 101 assertEquals(t, db.updateUserCalls, 0) 102 assertEquals(t, db.deleteUserCalls, 0) 103 assertEquals(t, db.typeCalls, 0) 104 assertEquals(t, db.closeCalls, 0) 105 }) 106 107 t.Run("NewUser", func(t *testing.T) { 108 db := &recordingDatabase{ 109 next: fakeDatabase{ 110 newUserErr: errors.New("password: iofsd9473tg with some stuff after it"), 111 }, 112 } 113 mw := DatabaseErrorSanitizerMiddleware{ 114 next: db, 115 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 116 } 117 118 expectedErr := errors.New("password: <redacted> with some stuff after it") 119 120 _, err := mw.NewUser(context.Background(), NewUserRequest{}) 121 if !reflect.DeepEqual(err, expectedErr) { 122 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 123 } 124 125 assertEquals(t, db.initializeCalls, 0) 126 assertEquals(t, db.newUserCalls, 1) 127 assertEquals(t, db.updateUserCalls, 0) 128 assertEquals(t, db.deleteUserCalls, 0) 129 assertEquals(t, db.typeCalls, 0) 130 assertEquals(t, db.closeCalls, 0) 131 }) 132 133 t.Run("UpdateUser", func(t *testing.T) { 134 db := &recordingDatabase{ 135 next: fakeDatabase{ 136 updateUserErr: errors.New("password: iofsd9473tg with some stuff after it"), 137 }, 138 } 139 mw := DatabaseErrorSanitizerMiddleware{ 140 next: db, 141 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 142 } 143 144 expectedErr := errors.New("password: <redacted> with some stuff after it") 145 146 _, err := mw.UpdateUser(context.Background(), UpdateUserRequest{}) 147 if !reflect.DeepEqual(err, expectedErr) { 148 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 149 } 150 151 assertEquals(t, db.initializeCalls, 0) 152 assertEquals(t, db.newUserCalls, 0) 153 assertEquals(t, db.updateUserCalls, 1) 154 assertEquals(t, db.deleteUserCalls, 0) 155 assertEquals(t, db.typeCalls, 0) 156 assertEquals(t, db.closeCalls, 0) 157 }) 158 159 t.Run("DeleteUser", func(t *testing.T) { 160 db := &recordingDatabase{ 161 next: fakeDatabase{ 162 deleteUserErr: errors.New("password: iofsd9473tg with some stuff after it"), 163 }, 164 } 165 mw := DatabaseErrorSanitizerMiddleware{ 166 next: db, 167 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 168 } 169 170 expectedErr := errors.New("password: <redacted> with some stuff after it") 171 172 _, err := mw.DeleteUser(context.Background(), DeleteUserRequest{}) 173 if !reflect.DeepEqual(err, expectedErr) { 174 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 175 } 176 177 assertEquals(t, db.initializeCalls, 0) 178 assertEquals(t, db.newUserCalls, 0) 179 assertEquals(t, db.updateUserCalls, 0) 180 assertEquals(t, db.deleteUserCalls, 1) 181 assertEquals(t, db.typeCalls, 0) 182 assertEquals(t, db.closeCalls, 0) 183 }) 184 185 t.Run("Type", func(t *testing.T) { 186 db := &recordingDatabase{ 187 next: fakeDatabase{ 188 typeErr: errors.New("password: iofsd9473tg with some stuff after it"), 189 }, 190 } 191 mw := DatabaseErrorSanitizerMiddleware{ 192 next: db, 193 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 194 } 195 196 expectedErr := errors.New("password: <redacted> with some stuff after it") 197 198 _, err := mw.Type() 199 if !reflect.DeepEqual(err, expectedErr) { 200 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 201 } 202 203 assertEquals(t, db.initializeCalls, 0) 204 assertEquals(t, db.newUserCalls, 0) 205 assertEquals(t, db.updateUserCalls, 0) 206 assertEquals(t, db.deleteUserCalls, 0) 207 assertEquals(t, db.typeCalls, 1) 208 assertEquals(t, db.closeCalls, 0) 209 }) 210 211 t.Run("Close", func(t *testing.T) { 212 db := &recordingDatabase{ 213 next: fakeDatabase{ 214 closeErr: errors.New("password: iofsd9473tg with some stuff after it"), 215 }, 216 } 217 mw := DatabaseErrorSanitizerMiddleware{ 218 next: db, 219 secretsFn: secretFunc(t, "iofsd9473tg", "<redacted>"), 220 } 221 222 expectedErr := errors.New("password: <redacted> with some stuff after it") 223 224 err := mw.Close() 225 if !reflect.DeepEqual(err, expectedErr) { 226 t.Fatalf("Actual err: %s\n Expected err: %s", err, expectedErr) 227 } 228 229 assertEquals(t, db.initializeCalls, 0) 230 assertEquals(t, db.newUserCalls, 0) 231 assertEquals(t, db.updateUserCalls, 0) 232 assertEquals(t, db.deleteUserCalls, 0) 233 assertEquals(t, db.typeCalls, 0) 234 assertEquals(t, db.closeCalls, 1) 235 }) 236 } 237 238 func secretFunc(t *testing.T, vals ...string) func() map[string]string { 239 t.Helper() 240 if len(vals)%2 != 0 { 241 t.Fatalf("Test configuration error: secretFunc must be called with an even number of values") 242 } 243 244 m := map[string]string{} 245 246 for i := 0; i < len(vals); i += 2 { 247 key := vals[i] 248 m[key] = vals[i+1] 249 } 250 251 return func() map[string]string { 252 return m 253 } 254 } 255 256 func TestTracingMiddleware(t *testing.T) { 257 t.Run("Initialize", func(t *testing.T) { 258 db := &recordingDatabase{} 259 logger := hclog.NewNullLogger() 260 mw := databaseTracingMiddleware{ 261 next: db, 262 logger: logger, 263 } 264 _, err := mw.Initialize(context.Background(), InitializeRequest{}) 265 if err != nil { 266 t.Fatalf("Expected no error, but got: %s", err) 267 } 268 assertEquals(t, db.initializeCalls, 1) 269 assertEquals(t, db.newUserCalls, 0) 270 assertEquals(t, db.updateUserCalls, 0) 271 assertEquals(t, db.deleteUserCalls, 0) 272 assertEquals(t, db.typeCalls, 0) 273 assertEquals(t, db.closeCalls, 0) 274 }) 275 276 t.Run("NewUser", func(t *testing.T) { 277 db := &recordingDatabase{} 278 logger := hclog.NewNullLogger() 279 mw := databaseTracingMiddleware{ 280 next: db, 281 logger: logger, 282 } 283 _, err := mw.NewUser(context.Background(), NewUserRequest{}) 284 if err != nil { 285 t.Fatalf("Expected no error, but got: %s", err) 286 } 287 assertEquals(t, db.initializeCalls, 0) 288 assertEquals(t, db.newUserCalls, 1) 289 assertEquals(t, db.updateUserCalls, 0) 290 assertEquals(t, db.deleteUserCalls, 0) 291 assertEquals(t, db.typeCalls, 0) 292 assertEquals(t, db.closeCalls, 0) 293 }) 294 295 t.Run("UpdateUser", func(t *testing.T) { 296 db := &recordingDatabase{} 297 logger := hclog.NewNullLogger() 298 mw := databaseTracingMiddleware{ 299 next: db, 300 logger: logger, 301 } 302 _, err := mw.UpdateUser(context.Background(), UpdateUserRequest{}) 303 if err != nil { 304 t.Fatalf("Expected no error, but got: %s", err) 305 } 306 assertEquals(t, db.initializeCalls, 0) 307 assertEquals(t, db.newUserCalls, 0) 308 assertEquals(t, db.updateUserCalls, 1) 309 assertEquals(t, db.deleteUserCalls, 0) 310 assertEquals(t, db.typeCalls, 0) 311 assertEquals(t, db.closeCalls, 0) 312 }) 313 314 t.Run("DeleteUser", func(t *testing.T) { 315 db := &recordingDatabase{} 316 logger := hclog.NewNullLogger() 317 mw := databaseTracingMiddleware{ 318 next: db, 319 logger: logger, 320 } 321 _, err := mw.DeleteUser(context.Background(), DeleteUserRequest{}) 322 if err != nil { 323 t.Fatalf("Expected no error, but got: %s", err) 324 } 325 assertEquals(t, db.initializeCalls, 0) 326 assertEquals(t, db.newUserCalls, 0) 327 assertEquals(t, db.updateUserCalls, 0) 328 assertEquals(t, db.deleteUserCalls, 1) 329 assertEquals(t, db.typeCalls, 0) 330 assertEquals(t, db.closeCalls, 0) 331 }) 332 333 t.Run("Type", func(t *testing.T) { 334 db := &recordingDatabase{} 335 logger := hclog.NewNullLogger() 336 mw := databaseTracingMiddleware{ 337 next: db, 338 logger: logger, 339 } 340 _, err := mw.Type() 341 if err != nil { 342 t.Fatalf("Expected no error, but got: %s", err) 343 } 344 assertEquals(t, db.initializeCalls, 0) 345 assertEquals(t, db.newUserCalls, 0) 346 assertEquals(t, db.updateUserCalls, 0) 347 assertEquals(t, db.deleteUserCalls, 0) 348 assertEquals(t, db.typeCalls, 1) 349 assertEquals(t, db.closeCalls, 0) 350 }) 351 352 t.Run("Close", func(t *testing.T) { 353 db := &recordingDatabase{} 354 logger := hclog.NewNullLogger() 355 mw := databaseTracingMiddleware{ 356 next: db, 357 logger: logger, 358 } 359 err := mw.Close() 360 if err != nil { 361 t.Fatalf("Expected no error, but got: %s", err) 362 } 363 assertEquals(t, db.initializeCalls, 0) 364 assertEquals(t, db.newUserCalls, 0) 365 assertEquals(t, db.updateUserCalls, 0) 366 assertEquals(t, db.deleteUserCalls, 0) 367 assertEquals(t, db.typeCalls, 0) 368 assertEquals(t, db.closeCalls, 1) 369 }) 370 } 371 372 func TestMetricsMiddleware(t *testing.T) { 373 t.Run("Initialize", func(t *testing.T) { 374 db := &recordingDatabase{} 375 mw := databaseMetricsMiddleware{ 376 next: db, 377 typeStr: "metrics", 378 } 379 _, err := mw.Initialize(context.Background(), InitializeRequest{}) 380 if err != nil { 381 t.Fatalf("Expected no error, but got: %s", err) 382 } 383 assertEquals(t, db.initializeCalls, 1) 384 assertEquals(t, db.newUserCalls, 0) 385 assertEquals(t, db.updateUserCalls, 0) 386 assertEquals(t, db.deleteUserCalls, 0) 387 assertEquals(t, db.typeCalls, 0) 388 assertEquals(t, db.closeCalls, 0) 389 }) 390 391 t.Run("NewUser", func(t *testing.T) { 392 db := &recordingDatabase{} 393 mw := databaseMetricsMiddleware{ 394 next: db, 395 typeStr: "metrics", 396 } 397 _, err := mw.NewUser(context.Background(), NewUserRequest{}) 398 if err != nil { 399 t.Fatalf("Expected no error, but got: %s", err) 400 } 401 assertEquals(t, db.initializeCalls, 0) 402 assertEquals(t, db.newUserCalls, 1) 403 assertEquals(t, db.updateUserCalls, 0) 404 assertEquals(t, db.deleteUserCalls, 0) 405 assertEquals(t, db.typeCalls, 0) 406 assertEquals(t, db.closeCalls, 0) 407 }) 408 409 t.Run("UpdateUser", func(t *testing.T) { 410 db := &recordingDatabase{} 411 mw := databaseMetricsMiddleware{ 412 next: db, 413 typeStr: "metrics", 414 } 415 _, err := mw.UpdateUser(context.Background(), UpdateUserRequest{}) 416 if err != nil { 417 t.Fatalf("Expected no error, but got: %s", err) 418 } 419 assertEquals(t, db.initializeCalls, 0) 420 assertEquals(t, db.newUserCalls, 0) 421 assertEquals(t, db.updateUserCalls, 1) 422 assertEquals(t, db.deleteUserCalls, 0) 423 assertEquals(t, db.typeCalls, 0) 424 assertEquals(t, db.closeCalls, 0) 425 }) 426 427 t.Run("DeleteUser", func(t *testing.T) { 428 db := &recordingDatabase{} 429 mw := databaseMetricsMiddleware{ 430 next: db, 431 typeStr: "metrics", 432 } 433 _, err := mw.DeleteUser(context.Background(), DeleteUserRequest{}) 434 if err != nil { 435 t.Fatalf("Expected no error, but got: %s", err) 436 } 437 assertEquals(t, db.initializeCalls, 0) 438 assertEquals(t, db.newUserCalls, 0) 439 assertEquals(t, db.updateUserCalls, 0) 440 assertEquals(t, db.deleteUserCalls, 1) 441 assertEquals(t, db.typeCalls, 0) 442 assertEquals(t, db.closeCalls, 0) 443 }) 444 445 t.Run("Type", func(t *testing.T) { 446 db := &recordingDatabase{} 447 mw := databaseMetricsMiddleware{ 448 next: db, 449 typeStr: "metrics", 450 } 451 _, err := mw.Type() 452 if err != nil { 453 t.Fatalf("Expected no error, but got: %s", err) 454 } 455 assertEquals(t, db.initializeCalls, 0) 456 assertEquals(t, db.newUserCalls, 0) 457 assertEquals(t, db.updateUserCalls, 0) 458 assertEquals(t, db.deleteUserCalls, 0) 459 assertEquals(t, db.typeCalls, 1) 460 assertEquals(t, db.closeCalls, 0) 461 }) 462 463 t.Run("Close", func(t *testing.T) { 464 db := &recordingDatabase{} 465 mw := databaseMetricsMiddleware{ 466 next: db, 467 typeStr: "metrics", 468 } 469 err := mw.Close() 470 if err != nil { 471 t.Fatalf("Expected no error, but got: %s", err) 472 } 473 assertEquals(t, db.initializeCalls, 0) 474 assertEquals(t, db.newUserCalls, 0) 475 assertEquals(t, db.updateUserCalls, 0) 476 assertEquals(t, db.deleteUserCalls, 0) 477 assertEquals(t, db.typeCalls, 0) 478 assertEquals(t, db.closeCalls, 1) 479 }) 480 } 481 482 func assertEquals(t *testing.T, actual, expected int) { 483 t.Helper() 484 if actual != expected { 485 t.Fatalf("Actual: %d Expected: %d", actual, expected) 486 } 487 }