github.com/hashicorp/vault/sdk@v0.11.0/helper/pluginutil/run_config_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package pluginutil 5 6 import ( 7 "context" 8 "encoding/hex" 9 "fmt" 10 "os" 11 "os/exec" 12 "strconv" 13 "testing" 14 "time" 15 16 "github.com/hashicorp/go-hclog" 17 "github.com/hashicorp/go-plugin" 18 "github.com/hashicorp/go-secure-stdlib/plugincontainer" 19 "github.com/hashicorp/vault/sdk/helper/consts" 20 "github.com/hashicorp/vault/sdk/helper/pluginruntimeutil" 21 "github.com/hashicorp/vault/sdk/helper/wrapping" 22 "github.com/stretchr/testify/mock" 23 "github.com/stretchr/testify/require" 24 ) 25 26 func TestMakeConfig(t *testing.T) { 27 type testCase struct { 28 rc runConfig 29 30 responseWrapInfo *wrapping.ResponseWrapInfo 31 responseWrapInfoErr error 32 responseWrapInfoTimes int 33 34 mlockEnabled bool 35 mlockEnabledTimes int 36 37 expectedConfig *plugin.ClientConfig 38 expectTLSConfig bool 39 expectRunnerFunc bool 40 skipSecureConfig bool 41 useLegacyEnvLayering bool 42 } 43 44 tests := map[string]testCase{ 45 "metadata mode, not-AutoMTLS": { 46 rc: runConfig{ 47 command: "echo", 48 args: []string{"foo", "bar"}, 49 sha256: []byte("some_sha256"), 50 env: []string{"initial=true"}, 51 PluginClientConfig: PluginClientConfig{ 52 PluginSets: map[int]plugin.PluginSet{ 53 1: { 54 "bogus": nil, 55 }, 56 }, 57 HandshakeConfig: plugin.HandshakeConfig{ 58 ProtocolVersion: 1, 59 MagicCookieKey: "magic_cookie_key", 60 MagicCookieValue: "magic_cookie_value", 61 }, 62 Logger: hclog.NewNullLogger(), 63 IsMetadataMode: true, 64 AutoMTLS: false, 65 }, 66 }, 67 68 responseWrapInfoTimes: 0, 69 70 mlockEnabled: false, 71 mlockEnabledTimes: 1, 72 useLegacyEnvLayering: true, 73 74 expectedConfig: &plugin.ClientConfig{ 75 HandshakeConfig: plugin.HandshakeConfig{ 76 ProtocolVersion: 1, 77 MagicCookieKey: "magic_cookie_key", 78 MagicCookieValue: "magic_cookie_value", 79 }, 80 VersionedPlugins: map[int]plugin.PluginSet{ 81 1: { 82 "bogus": nil, 83 }, 84 }, 85 Cmd: commandWithEnv( 86 "echo", 87 []string{"foo", "bar"}, 88 append(append([]string{ 89 "initial=true", 90 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 91 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), 92 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), 93 }, os.Environ()...), PluginUseLegacyEnvLayering+"=true"), 94 ), 95 SecureConfig: &plugin.SecureConfig{ 96 Checksum: []byte("some_sha256"), 97 // Hash is generated 98 }, 99 AllowedProtocols: []plugin.Protocol{ 100 plugin.ProtocolNetRPC, 101 plugin.ProtocolGRPC, 102 }, 103 Logger: hclog.NewNullLogger(), 104 AutoMTLS: false, 105 SkipHostEnv: true, 106 }, 107 expectTLSConfig: false, 108 }, 109 "non-metadata mode, not-AutoMTLS": { 110 rc: runConfig{ 111 command: "echo", 112 args: []string{"foo", "bar"}, 113 sha256: []byte("some_sha256"), 114 env: []string{"initial=true"}, 115 PluginClientConfig: PluginClientConfig{ 116 PluginSets: map[int]plugin.PluginSet{ 117 1: { 118 "bogus": nil, 119 }, 120 }, 121 HandshakeConfig: plugin.HandshakeConfig{ 122 ProtocolVersion: 1, 123 MagicCookieKey: "magic_cookie_key", 124 MagicCookieValue: "magic_cookie_value", 125 }, 126 Logger: hclog.NewNullLogger(), 127 IsMetadataMode: false, 128 AutoMTLS: false, 129 }, 130 }, 131 132 responseWrapInfo: &wrapping.ResponseWrapInfo{ 133 Token: "testtoken", 134 }, 135 responseWrapInfoTimes: 1, 136 137 mlockEnabled: true, 138 mlockEnabledTimes: 1, 139 140 expectedConfig: &plugin.ClientConfig{ 141 HandshakeConfig: plugin.HandshakeConfig{ 142 ProtocolVersion: 1, 143 MagicCookieKey: "magic_cookie_key", 144 MagicCookieValue: "magic_cookie_value", 145 }, 146 VersionedPlugins: map[int]plugin.PluginSet{ 147 1: { 148 "bogus": nil, 149 }, 150 }, 151 Cmd: commandWithEnv( 152 "echo", 153 []string{"foo", "bar"}, 154 append(os.Environ(), []string{ 155 "initial=true", 156 fmt.Sprintf("%s=%t", PluginMlockEnabled, true), 157 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 158 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), 159 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false), 160 fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"), 161 }...), 162 ), 163 SecureConfig: &plugin.SecureConfig{ 164 Checksum: []byte("some_sha256"), 165 // Hash is generated 166 }, 167 AllowedProtocols: []plugin.Protocol{ 168 plugin.ProtocolNetRPC, 169 plugin.ProtocolGRPC, 170 }, 171 Logger: hclog.NewNullLogger(), 172 AutoMTLS: false, 173 SkipHostEnv: true, 174 }, 175 expectTLSConfig: true, 176 }, 177 "metadata mode, AutoMTLS": { 178 rc: runConfig{ 179 command: "echo", 180 args: []string{"foo", "bar"}, 181 sha256: []byte("some_sha256"), 182 env: []string{"initial=true"}, 183 PluginClientConfig: PluginClientConfig{ 184 PluginSets: map[int]plugin.PluginSet{ 185 1: { 186 "bogus": nil, 187 }, 188 }, 189 HandshakeConfig: plugin.HandshakeConfig{ 190 ProtocolVersion: 1, 191 MagicCookieKey: "magic_cookie_key", 192 MagicCookieValue: "magic_cookie_value", 193 }, 194 Logger: hclog.NewNullLogger(), 195 IsMetadataMode: true, 196 AutoMTLS: true, 197 }, 198 }, 199 200 responseWrapInfoTimes: 0, 201 202 mlockEnabled: false, 203 mlockEnabledTimes: 1, 204 205 expectedConfig: &plugin.ClientConfig{ 206 HandshakeConfig: plugin.HandshakeConfig{ 207 ProtocolVersion: 1, 208 MagicCookieKey: "magic_cookie_key", 209 MagicCookieValue: "magic_cookie_value", 210 }, 211 VersionedPlugins: map[int]plugin.PluginSet{ 212 1: { 213 "bogus": nil, 214 }, 215 }, 216 Cmd: commandWithEnv( 217 "echo", 218 []string{"foo", "bar"}, 219 append(os.Environ(), []string{ 220 "initial=true", 221 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 222 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true), 223 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), 224 }...), 225 ), 226 SecureConfig: &plugin.SecureConfig{ 227 Checksum: []byte("some_sha256"), 228 // Hash is generated 229 }, 230 AllowedProtocols: []plugin.Protocol{ 231 plugin.ProtocolNetRPC, 232 plugin.ProtocolGRPC, 233 }, 234 Logger: hclog.NewNullLogger(), 235 AutoMTLS: true, 236 SkipHostEnv: true, 237 }, 238 expectTLSConfig: false, 239 }, 240 "not-metadata mode, AutoMTLS": { 241 rc: runConfig{ 242 command: "echo", 243 args: []string{"foo", "bar"}, 244 sha256: []byte("some_sha256"), 245 env: []string{"initial=true"}, 246 PluginClientConfig: PluginClientConfig{ 247 PluginSets: map[int]plugin.PluginSet{ 248 1: { 249 "bogus": nil, 250 }, 251 }, 252 HandshakeConfig: plugin.HandshakeConfig{ 253 ProtocolVersion: 1, 254 MagicCookieKey: "magic_cookie_key", 255 MagicCookieValue: "magic_cookie_value", 256 }, 257 Logger: hclog.NewNullLogger(), 258 IsMetadataMode: false, 259 AutoMTLS: true, 260 }, 261 }, 262 263 responseWrapInfoTimes: 0, 264 265 mlockEnabled: false, 266 mlockEnabledTimes: 1, 267 268 expectedConfig: &plugin.ClientConfig{ 269 HandshakeConfig: plugin.HandshakeConfig{ 270 ProtocolVersion: 1, 271 MagicCookieKey: "magic_cookie_key", 272 MagicCookieValue: "magic_cookie_value", 273 }, 274 VersionedPlugins: map[int]plugin.PluginSet{ 275 1: { 276 "bogus": nil, 277 }, 278 }, 279 Cmd: commandWithEnv( 280 "echo", 281 []string{"foo", "bar"}, 282 append(os.Environ(), []string{ 283 "initial=true", 284 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 285 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), 286 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), 287 }...), 288 ), 289 SecureConfig: &plugin.SecureConfig{ 290 Checksum: []byte("some_sha256"), 291 // Hash is generated 292 }, 293 AllowedProtocols: []plugin.Protocol{ 294 plugin.ProtocolNetRPC, 295 plugin.ProtocolGRPC, 296 }, 297 Logger: hclog.NewNullLogger(), 298 AutoMTLS: true, 299 SkipHostEnv: true, 300 }, 301 expectTLSConfig: false, 302 }, 303 "image set": { 304 rc: runConfig{ 305 command: "echo", 306 args: []string{"foo", "bar"}, 307 sha256: []byte("some_sha256"), 308 env: []string{"initial=true"}, 309 image: "some-image", 310 imageTag: "0.1.0", 311 PluginClientConfig: PluginClientConfig{ 312 PluginSets: map[int]plugin.PluginSet{ 313 1: { 314 "bogus": nil, 315 }, 316 }, 317 HandshakeConfig: plugin.HandshakeConfig{ 318 ProtocolVersion: 1, 319 MagicCookieKey: "magic_cookie_key", 320 MagicCookieValue: "magic_cookie_value", 321 }, 322 Logger: hclog.NewNullLogger(), 323 IsMetadataMode: false, 324 AutoMTLS: true, 325 }, 326 }, 327 328 responseWrapInfoTimes: 0, 329 330 mlockEnabled: false, 331 mlockEnabledTimes: 2, 332 333 expectedConfig: &plugin.ClientConfig{ 334 HandshakeConfig: plugin.HandshakeConfig{ 335 ProtocolVersion: 1, 336 MagicCookieKey: "magic_cookie_key", 337 MagicCookieValue: "magic_cookie_value", 338 }, 339 VersionedPlugins: map[int]plugin.PluginSet{ 340 1: { 341 "bogus": nil, 342 }, 343 }, 344 Cmd: nil, 345 SecureConfig: nil, 346 AllowedProtocols: []plugin.Protocol{ 347 plugin.ProtocolNetRPC, 348 plugin.ProtocolGRPC, 349 }, 350 Logger: hclog.NewNullLogger(), 351 AutoMTLS: true, 352 SkipHostEnv: true, 353 GRPCBrokerMultiplex: true, 354 UnixSocketConfig: &plugin.UnixSocketConfig{ 355 Group: strconv.Itoa(os.Getgid()), 356 }, 357 }, 358 expectTLSConfig: false, 359 expectRunnerFunc: true, 360 skipSecureConfig: true, 361 }, 362 } 363 364 for name, test := range tests { 365 t.Run(name, func(t *testing.T) { 366 mockWrapper := new(mockRunnerUtil) 367 mockWrapper.On("ResponseWrapData", mock.Anything, mock.Anything, mock.Anything, mock.Anything). 368 Return(test.responseWrapInfo, test.responseWrapInfoErr) 369 mockWrapper.On("MlockEnabled"). 370 Return(test.mlockEnabled) 371 test.rc.Wrapper = mockWrapper 372 defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes) 373 defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes) 374 375 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 376 defer cancel() 377 378 if test.useLegacyEnvLayering { 379 t.Setenv(PluginUseLegacyEnvLayering, "true") 380 } 381 382 config, err := test.rc.makeConfig(ctx) 383 if err != nil { 384 t.Fatalf("no error expected, got: %s", err) 385 } 386 387 // The following fields are generated, so we just need to check for existence, not specific value 388 // The value must be nilled out before performing a DeepEqual check 389 if !test.skipSecureConfig { 390 hsh := config.SecureConfig.Hash 391 if hsh == nil { 392 t.Fatalf("Missing SecureConfig.Hash") 393 } 394 config.SecureConfig.Hash = nil 395 } 396 397 if test.expectTLSConfig && config.TLSConfig == nil { 398 t.Fatalf("TLS config expected, got nil") 399 } 400 if !test.expectTLSConfig && config.TLSConfig != nil { 401 t.Fatalf("no TLS config expected, got: %#v", config.TLSConfig) 402 } 403 config.TLSConfig = nil 404 405 if test.expectRunnerFunc != (config.RunnerFunc != nil) { 406 t.Fatalf("expected RunnerFunc: %v, actual: %v", test.expectRunnerFunc, config.RunnerFunc != nil) 407 } 408 config.RunnerFunc = nil 409 410 require.Equal(t, test.expectedConfig, config) 411 }) 412 } 413 } 414 415 func commandWithEnv(cmd string, args []string, env []string) *exec.Cmd { 416 c := exec.Command(cmd, args...) 417 c.Env = env 418 return c 419 } 420 421 var _ RunnerUtil = &mockRunnerUtil{} 422 423 type mockRunnerUtil struct { 424 mock.Mock 425 } 426 427 func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) { 428 return "dummyversion", nil 429 } 430 431 func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) { 432 args := m.Called(ctx, config) 433 return args.Get(0).(PluginClient), args.Error(1) 434 } 435 436 func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { 437 args := m.Called(ctx, data, ttl, jwt) 438 return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1) 439 } 440 441 func (m *mockRunnerUtil) MlockEnabled() bool { 442 args := m.Called() 443 return args.Bool(0) 444 } 445 446 func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) { 447 return "1234", nil 448 } 449 450 func TestContainerConfig(t *testing.T) { 451 dummySHA, err := hex.DecodeString("abc123") 452 if err != nil { 453 t.Fatal(err) 454 } 455 myPID := strconv.Itoa(os.Getpid()) 456 for name, tc := range map[string]struct { 457 rc runConfig 458 expected plugincontainer.Config 459 }{ 460 "image set, no runtime": { 461 rc: runConfig{ 462 command: "echo", 463 args: []string{"foo", "bar"}, 464 sha256: dummySHA, 465 env: []string{"initial=true"}, 466 image: "some-image", 467 imageTag: "0.1.0", 468 PluginClientConfig: PluginClientConfig{ 469 PluginSets: map[int]plugin.PluginSet{ 470 1: { 471 "bogus": nil, 472 }, 473 }, 474 HandshakeConfig: plugin.HandshakeConfig{ 475 ProtocolVersion: 1, 476 MagicCookieKey: "magic_cookie_key", 477 MagicCookieValue: "magic_cookie_value", 478 }, 479 Logger: hclog.NewNullLogger(), 480 AutoMTLS: true, 481 Name: "some-plugin", 482 PluginType: consts.PluginTypeCredential, 483 Version: "v0.1.0", 484 }, 485 }, 486 expected: plugincontainer.Config{ 487 Image: "some-image", 488 Tag: "0.1.0", 489 SHA256: "abc123", 490 Entrypoint: []string{"echo"}, 491 Args: []string{"foo", "bar"}, 492 Env: []string{ 493 "initial=true", 494 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 495 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), 496 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), 497 }, 498 Labels: map[string]string{ 499 labelVaultPID: myPID, 500 labelVaultClusterID: "1234", 501 labelVaultPluginName: "some-plugin", 502 labelVaultPluginType: "auth", 503 labelVaultPluginVersion: "v0.1.0", 504 }, 505 Runtime: consts.DefaultContainerPluginOCIRuntime, 506 GroupAdd: os.Getgid(), 507 }, 508 }, 509 "image set, with runtime": { 510 rc: runConfig{ 511 sha256: dummySHA, 512 image: "some-image", 513 imageTag: "0.1.0", 514 runtimeConfig: &pluginruntimeutil.PluginRuntimeConfig{ 515 OCIRuntime: "some-oci-runtime", 516 CgroupParent: "/cgroup/parent", 517 CPU: 1000, 518 Memory: 2000, 519 }, 520 PluginClientConfig: PluginClientConfig{ 521 PluginSets: map[int]plugin.PluginSet{ 522 1: { 523 "bogus": nil, 524 }, 525 }, 526 HandshakeConfig: plugin.HandshakeConfig{ 527 ProtocolVersion: 1, 528 MagicCookieKey: "magic_cookie_key", 529 MagicCookieValue: "magic_cookie_value", 530 }, 531 Logger: hclog.NewNullLogger(), 532 AutoMTLS: true, 533 Name: "some-plugin", 534 PluginType: consts.PluginTypeCredential, 535 Version: "v0.1.0", 536 }, 537 }, 538 expected: plugincontainer.Config{ 539 Image: "some-image", 540 Tag: "0.1.0", 541 SHA256: "abc123", 542 Env: []string{ 543 fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"), 544 fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false), 545 fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true), 546 }, 547 Labels: map[string]string{ 548 labelVaultPID: myPID, 549 labelVaultClusterID: "1234", 550 labelVaultPluginName: "some-plugin", 551 labelVaultPluginType: "auth", 552 labelVaultPluginVersion: "v0.1.0", 553 }, 554 Runtime: "some-oci-runtime", 555 GroupAdd: os.Getgid(), 556 CgroupParent: "/cgroup/parent", 557 NanoCpus: 1000, 558 Memory: 2000, 559 }, 560 }, 561 } { 562 t.Run(name, func(t *testing.T) { 563 mockWrapper := new(mockRunnerUtil) 564 mockWrapper.On("ResponseWrapData", mock.Anything, mock.Anything, mock.Anything, mock.Anything). 565 Return(nil, nil) 566 mockWrapper.On("MlockEnabled"). 567 Return(false) 568 tc.rc.Wrapper = mockWrapper 569 cmd, _, err := tc.rc.generateCmd(context.Background()) 570 if err != nil { 571 t.Fatal(err) 572 } 573 cfg, err := tc.rc.containerConfig(context.Background(), cmd.Env) 574 require.NoError(t, err) 575 require.Equal(t, tc.expected, *cfg) 576 }) 577 } 578 }