github.com/hashicorp/vault/sdk@v0.11.0/helper/pluginutil/run_config_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package pluginutil
     5  
     6  import (
     7  	"context"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"os"
    11  	"os/exec"
    12  	"strconv"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/hashicorp/go-hclog"
    17  	"github.com/hashicorp/go-plugin"
    18  	"github.com/hashicorp/go-secure-stdlib/plugincontainer"
    19  	"github.com/hashicorp/vault/sdk/helper/consts"
    20  	"github.com/hashicorp/vault/sdk/helper/pluginruntimeutil"
    21  	"github.com/hashicorp/vault/sdk/helper/wrapping"
    22  	"github.com/stretchr/testify/mock"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestMakeConfig(t *testing.T) {
    27  	type testCase struct {
    28  		rc runConfig
    29  
    30  		responseWrapInfo      *wrapping.ResponseWrapInfo
    31  		responseWrapInfoErr   error
    32  		responseWrapInfoTimes int
    33  
    34  		mlockEnabled      bool
    35  		mlockEnabledTimes int
    36  
    37  		expectedConfig       *plugin.ClientConfig
    38  		expectTLSConfig      bool
    39  		expectRunnerFunc     bool
    40  		skipSecureConfig     bool
    41  		useLegacyEnvLayering bool
    42  	}
    43  
    44  	tests := map[string]testCase{
    45  		"metadata mode, not-AutoMTLS": {
    46  			rc: runConfig{
    47  				command: "echo",
    48  				args:    []string{"foo", "bar"},
    49  				sha256:  []byte("some_sha256"),
    50  				env:     []string{"initial=true"},
    51  				PluginClientConfig: PluginClientConfig{
    52  					PluginSets: map[int]plugin.PluginSet{
    53  						1: {
    54  							"bogus": nil,
    55  						},
    56  					},
    57  					HandshakeConfig: plugin.HandshakeConfig{
    58  						ProtocolVersion:  1,
    59  						MagicCookieKey:   "magic_cookie_key",
    60  						MagicCookieValue: "magic_cookie_value",
    61  					},
    62  					Logger:         hclog.NewNullLogger(),
    63  					IsMetadataMode: true,
    64  					AutoMTLS:       false,
    65  				},
    66  			},
    67  
    68  			responseWrapInfoTimes: 0,
    69  
    70  			mlockEnabled:         false,
    71  			mlockEnabledTimes:    1,
    72  			useLegacyEnvLayering: true,
    73  
    74  			expectedConfig: &plugin.ClientConfig{
    75  				HandshakeConfig: plugin.HandshakeConfig{
    76  					ProtocolVersion:  1,
    77  					MagicCookieKey:   "magic_cookie_key",
    78  					MagicCookieValue: "magic_cookie_value",
    79  				},
    80  				VersionedPlugins: map[int]plugin.PluginSet{
    81  					1: {
    82  						"bogus": nil,
    83  					},
    84  				},
    85  				Cmd: commandWithEnv(
    86  					"echo",
    87  					[]string{"foo", "bar"},
    88  					append(append([]string{
    89  						"initial=true",
    90  						fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
    91  						fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
    92  						fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
    93  					}, os.Environ()...), PluginUseLegacyEnvLayering+"=true"),
    94  				),
    95  				SecureConfig: &plugin.SecureConfig{
    96  					Checksum: []byte("some_sha256"),
    97  					// Hash is generated
    98  				},
    99  				AllowedProtocols: []plugin.Protocol{
   100  					plugin.ProtocolNetRPC,
   101  					plugin.ProtocolGRPC,
   102  				},
   103  				Logger:      hclog.NewNullLogger(),
   104  				AutoMTLS:    false,
   105  				SkipHostEnv: true,
   106  			},
   107  			expectTLSConfig: false,
   108  		},
   109  		"non-metadata mode, not-AutoMTLS": {
   110  			rc: runConfig{
   111  				command: "echo",
   112  				args:    []string{"foo", "bar"},
   113  				sha256:  []byte("some_sha256"),
   114  				env:     []string{"initial=true"},
   115  				PluginClientConfig: PluginClientConfig{
   116  					PluginSets: map[int]plugin.PluginSet{
   117  						1: {
   118  							"bogus": nil,
   119  						},
   120  					},
   121  					HandshakeConfig: plugin.HandshakeConfig{
   122  						ProtocolVersion:  1,
   123  						MagicCookieKey:   "magic_cookie_key",
   124  						MagicCookieValue: "magic_cookie_value",
   125  					},
   126  					Logger:         hclog.NewNullLogger(),
   127  					IsMetadataMode: false,
   128  					AutoMTLS:       false,
   129  				},
   130  			},
   131  
   132  			responseWrapInfo: &wrapping.ResponseWrapInfo{
   133  				Token: "testtoken",
   134  			},
   135  			responseWrapInfoTimes: 1,
   136  
   137  			mlockEnabled:      true,
   138  			mlockEnabledTimes: 1,
   139  
   140  			expectedConfig: &plugin.ClientConfig{
   141  				HandshakeConfig: plugin.HandshakeConfig{
   142  					ProtocolVersion:  1,
   143  					MagicCookieKey:   "magic_cookie_key",
   144  					MagicCookieValue: "magic_cookie_value",
   145  				},
   146  				VersionedPlugins: map[int]plugin.PluginSet{
   147  					1: {
   148  						"bogus": nil,
   149  					},
   150  				},
   151  				Cmd: commandWithEnv(
   152  					"echo",
   153  					[]string{"foo", "bar"},
   154  					append(os.Environ(), []string{
   155  						"initial=true",
   156  						fmt.Sprintf("%s=%t", PluginMlockEnabled, true),
   157  						fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
   158  						fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
   159  						fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, false),
   160  						fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, "testtoken"),
   161  					}...),
   162  				),
   163  				SecureConfig: &plugin.SecureConfig{
   164  					Checksum: []byte("some_sha256"),
   165  					// Hash is generated
   166  				},
   167  				AllowedProtocols: []plugin.Protocol{
   168  					plugin.ProtocolNetRPC,
   169  					plugin.ProtocolGRPC,
   170  				},
   171  				Logger:      hclog.NewNullLogger(),
   172  				AutoMTLS:    false,
   173  				SkipHostEnv: true,
   174  			},
   175  			expectTLSConfig: true,
   176  		},
   177  		"metadata mode, AutoMTLS": {
   178  			rc: runConfig{
   179  				command: "echo",
   180  				args:    []string{"foo", "bar"},
   181  				sha256:  []byte("some_sha256"),
   182  				env:     []string{"initial=true"},
   183  				PluginClientConfig: PluginClientConfig{
   184  					PluginSets: map[int]plugin.PluginSet{
   185  						1: {
   186  							"bogus": nil,
   187  						},
   188  					},
   189  					HandshakeConfig: plugin.HandshakeConfig{
   190  						ProtocolVersion:  1,
   191  						MagicCookieKey:   "magic_cookie_key",
   192  						MagicCookieValue: "magic_cookie_value",
   193  					},
   194  					Logger:         hclog.NewNullLogger(),
   195  					IsMetadataMode: true,
   196  					AutoMTLS:       true,
   197  				},
   198  			},
   199  
   200  			responseWrapInfoTimes: 0,
   201  
   202  			mlockEnabled:      false,
   203  			mlockEnabledTimes: 1,
   204  
   205  			expectedConfig: &plugin.ClientConfig{
   206  				HandshakeConfig: plugin.HandshakeConfig{
   207  					ProtocolVersion:  1,
   208  					MagicCookieKey:   "magic_cookie_key",
   209  					MagicCookieValue: "magic_cookie_value",
   210  				},
   211  				VersionedPlugins: map[int]plugin.PluginSet{
   212  					1: {
   213  						"bogus": nil,
   214  					},
   215  				},
   216  				Cmd: commandWithEnv(
   217  					"echo",
   218  					[]string{"foo", "bar"},
   219  					append(os.Environ(), []string{
   220  						"initial=true",
   221  						fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
   222  						fmt.Sprintf("%s=%t", PluginMetadataModeEnv, true),
   223  						fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
   224  					}...),
   225  				),
   226  				SecureConfig: &plugin.SecureConfig{
   227  					Checksum: []byte("some_sha256"),
   228  					// Hash is generated
   229  				},
   230  				AllowedProtocols: []plugin.Protocol{
   231  					plugin.ProtocolNetRPC,
   232  					plugin.ProtocolGRPC,
   233  				},
   234  				Logger:      hclog.NewNullLogger(),
   235  				AutoMTLS:    true,
   236  				SkipHostEnv: true,
   237  			},
   238  			expectTLSConfig: false,
   239  		},
   240  		"not-metadata mode, AutoMTLS": {
   241  			rc: runConfig{
   242  				command: "echo",
   243  				args:    []string{"foo", "bar"},
   244  				sha256:  []byte("some_sha256"),
   245  				env:     []string{"initial=true"},
   246  				PluginClientConfig: PluginClientConfig{
   247  					PluginSets: map[int]plugin.PluginSet{
   248  						1: {
   249  							"bogus": nil,
   250  						},
   251  					},
   252  					HandshakeConfig: plugin.HandshakeConfig{
   253  						ProtocolVersion:  1,
   254  						MagicCookieKey:   "magic_cookie_key",
   255  						MagicCookieValue: "magic_cookie_value",
   256  					},
   257  					Logger:         hclog.NewNullLogger(),
   258  					IsMetadataMode: false,
   259  					AutoMTLS:       true,
   260  				},
   261  			},
   262  
   263  			responseWrapInfoTimes: 0,
   264  
   265  			mlockEnabled:      false,
   266  			mlockEnabledTimes: 1,
   267  
   268  			expectedConfig: &plugin.ClientConfig{
   269  				HandshakeConfig: plugin.HandshakeConfig{
   270  					ProtocolVersion:  1,
   271  					MagicCookieKey:   "magic_cookie_key",
   272  					MagicCookieValue: "magic_cookie_value",
   273  				},
   274  				VersionedPlugins: map[int]plugin.PluginSet{
   275  					1: {
   276  						"bogus": nil,
   277  					},
   278  				},
   279  				Cmd: commandWithEnv(
   280  					"echo",
   281  					[]string{"foo", "bar"},
   282  					append(os.Environ(), []string{
   283  						"initial=true",
   284  						fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
   285  						fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
   286  						fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
   287  					}...),
   288  				),
   289  				SecureConfig: &plugin.SecureConfig{
   290  					Checksum: []byte("some_sha256"),
   291  					// Hash is generated
   292  				},
   293  				AllowedProtocols: []plugin.Protocol{
   294  					plugin.ProtocolNetRPC,
   295  					plugin.ProtocolGRPC,
   296  				},
   297  				Logger:      hclog.NewNullLogger(),
   298  				AutoMTLS:    true,
   299  				SkipHostEnv: true,
   300  			},
   301  			expectTLSConfig: false,
   302  		},
   303  		"image set": {
   304  			rc: runConfig{
   305  				command:  "echo",
   306  				args:     []string{"foo", "bar"},
   307  				sha256:   []byte("some_sha256"),
   308  				env:      []string{"initial=true"},
   309  				image:    "some-image",
   310  				imageTag: "0.1.0",
   311  				PluginClientConfig: PluginClientConfig{
   312  					PluginSets: map[int]plugin.PluginSet{
   313  						1: {
   314  							"bogus": nil,
   315  						},
   316  					},
   317  					HandshakeConfig: plugin.HandshakeConfig{
   318  						ProtocolVersion:  1,
   319  						MagicCookieKey:   "magic_cookie_key",
   320  						MagicCookieValue: "magic_cookie_value",
   321  					},
   322  					Logger:         hclog.NewNullLogger(),
   323  					IsMetadataMode: false,
   324  					AutoMTLS:       true,
   325  				},
   326  			},
   327  
   328  			responseWrapInfoTimes: 0,
   329  
   330  			mlockEnabled:      false,
   331  			mlockEnabledTimes: 2,
   332  
   333  			expectedConfig: &plugin.ClientConfig{
   334  				HandshakeConfig: plugin.HandshakeConfig{
   335  					ProtocolVersion:  1,
   336  					MagicCookieKey:   "magic_cookie_key",
   337  					MagicCookieValue: "magic_cookie_value",
   338  				},
   339  				VersionedPlugins: map[int]plugin.PluginSet{
   340  					1: {
   341  						"bogus": nil,
   342  					},
   343  				},
   344  				Cmd:          nil,
   345  				SecureConfig: nil,
   346  				AllowedProtocols: []plugin.Protocol{
   347  					plugin.ProtocolNetRPC,
   348  					plugin.ProtocolGRPC,
   349  				},
   350  				Logger:              hclog.NewNullLogger(),
   351  				AutoMTLS:            true,
   352  				SkipHostEnv:         true,
   353  				GRPCBrokerMultiplex: true,
   354  				UnixSocketConfig: &plugin.UnixSocketConfig{
   355  					Group: strconv.Itoa(os.Getgid()),
   356  				},
   357  			},
   358  			expectTLSConfig:  false,
   359  			expectRunnerFunc: true,
   360  			skipSecureConfig: true,
   361  		},
   362  	}
   363  
   364  	for name, test := range tests {
   365  		t.Run(name, func(t *testing.T) {
   366  			mockWrapper := new(mockRunnerUtil)
   367  			mockWrapper.On("ResponseWrapData", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
   368  				Return(test.responseWrapInfo, test.responseWrapInfoErr)
   369  			mockWrapper.On("MlockEnabled").
   370  				Return(test.mlockEnabled)
   371  			test.rc.Wrapper = mockWrapper
   372  			defer mockWrapper.AssertNumberOfCalls(t, "ResponseWrapData", test.responseWrapInfoTimes)
   373  			defer mockWrapper.AssertNumberOfCalls(t, "MlockEnabled", test.mlockEnabledTimes)
   374  
   375  			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   376  			defer cancel()
   377  
   378  			if test.useLegacyEnvLayering {
   379  				t.Setenv(PluginUseLegacyEnvLayering, "true")
   380  			}
   381  
   382  			config, err := test.rc.makeConfig(ctx)
   383  			if err != nil {
   384  				t.Fatalf("no error expected, got: %s", err)
   385  			}
   386  
   387  			// The following fields are generated, so we just need to check for existence, not specific value
   388  			// The value must be nilled out before performing a DeepEqual check
   389  			if !test.skipSecureConfig {
   390  				hsh := config.SecureConfig.Hash
   391  				if hsh == nil {
   392  					t.Fatalf("Missing SecureConfig.Hash")
   393  				}
   394  				config.SecureConfig.Hash = nil
   395  			}
   396  
   397  			if test.expectTLSConfig && config.TLSConfig == nil {
   398  				t.Fatalf("TLS config expected, got nil")
   399  			}
   400  			if !test.expectTLSConfig && config.TLSConfig != nil {
   401  				t.Fatalf("no TLS config expected, got: %#v", config.TLSConfig)
   402  			}
   403  			config.TLSConfig = nil
   404  
   405  			if test.expectRunnerFunc != (config.RunnerFunc != nil) {
   406  				t.Fatalf("expected RunnerFunc: %v, actual: %v", test.expectRunnerFunc, config.RunnerFunc != nil)
   407  			}
   408  			config.RunnerFunc = nil
   409  
   410  			require.Equal(t, test.expectedConfig, config)
   411  		})
   412  	}
   413  }
   414  
   415  func commandWithEnv(cmd string, args []string, env []string) *exec.Cmd {
   416  	c := exec.Command(cmd, args...)
   417  	c.Env = env
   418  	return c
   419  }
   420  
   421  var _ RunnerUtil = &mockRunnerUtil{}
   422  
   423  type mockRunnerUtil struct {
   424  	mock.Mock
   425  }
   426  
   427  func (m *mockRunnerUtil) VaultVersion(ctx context.Context) (string, error) {
   428  	return "dummyversion", nil
   429  }
   430  
   431  func (m *mockRunnerUtil) NewPluginClient(ctx context.Context, config PluginClientConfig) (PluginClient, error) {
   432  	args := m.Called(ctx, config)
   433  	return args.Get(0).(PluginClient), args.Error(1)
   434  }
   435  
   436  func (m *mockRunnerUtil) ResponseWrapData(ctx context.Context, data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) {
   437  	args := m.Called(ctx, data, ttl, jwt)
   438  	return args.Get(0).(*wrapping.ResponseWrapInfo), args.Error(1)
   439  }
   440  
   441  func (m *mockRunnerUtil) MlockEnabled() bool {
   442  	args := m.Called()
   443  	return args.Bool(0)
   444  }
   445  
   446  func (m *mockRunnerUtil) ClusterID(ctx context.Context) (string, error) {
   447  	return "1234", nil
   448  }
   449  
   450  func TestContainerConfig(t *testing.T) {
   451  	dummySHA, err := hex.DecodeString("abc123")
   452  	if err != nil {
   453  		t.Fatal(err)
   454  	}
   455  	myPID := strconv.Itoa(os.Getpid())
   456  	for name, tc := range map[string]struct {
   457  		rc       runConfig
   458  		expected plugincontainer.Config
   459  	}{
   460  		"image set, no runtime": {
   461  			rc: runConfig{
   462  				command:  "echo",
   463  				args:     []string{"foo", "bar"},
   464  				sha256:   dummySHA,
   465  				env:      []string{"initial=true"},
   466  				image:    "some-image",
   467  				imageTag: "0.1.0",
   468  				PluginClientConfig: PluginClientConfig{
   469  					PluginSets: map[int]plugin.PluginSet{
   470  						1: {
   471  							"bogus": nil,
   472  						},
   473  					},
   474  					HandshakeConfig: plugin.HandshakeConfig{
   475  						ProtocolVersion:  1,
   476  						MagicCookieKey:   "magic_cookie_key",
   477  						MagicCookieValue: "magic_cookie_value",
   478  					},
   479  					Logger:     hclog.NewNullLogger(),
   480  					AutoMTLS:   true,
   481  					Name:       "some-plugin",
   482  					PluginType: consts.PluginTypeCredential,
   483  					Version:    "v0.1.0",
   484  				},
   485  			},
   486  			expected: plugincontainer.Config{
   487  				Image:      "some-image",
   488  				Tag:        "0.1.0",
   489  				SHA256:     "abc123",
   490  				Entrypoint: []string{"echo"},
   491  				Args:       []string{"foo", "bar"},
   492  				Env: []string{
   493  					"initial=true",
   494  					fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
   495  					fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
   496  					fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
   497  				},
   498  				Labels: map[string]string{
   499  					labelVaultPID:           myPID,
   500  					labelVaultClusterID:     "1234",
   501  					labelVaultPluginName:    "some-plugin",
   502  					labelVaultPluginType:    "auth",
   503  					labelVaultPluginVersion: "v0.1.0",
   504  				},
   505  				Runtime:  consts.DefaultContainerPluginOCIRuntime,
   506  				GroupAdd: os.Getgid(),
   507  			},
   508  		},
   509  		"image set, with runtime": {
   510  			rc: runConfig{
   511  				sha256:   dummySHA,
   512  				image:    "some-image",
   513  				imageTag: "0.1.0",
   514  				runtimeConfig: &pluginruntimeutil.PluginRuntimeConfig{
   515  					OCIRuntime:   "some-oci-runtime",
   516  					CgroupParent: "/cgroup/parent",
   517  					CPU:          1000,
   518  					Memory:       2000,
   519  				},
   520  				PluginClientConfig: PluginClientConfig{
   521  					PluginSets: map[int]plugin.PluginSet{
   522  						1: {
   523  							"bogus": nil,
   524  						},
   525  					},
   526  					HandshakeConfig: plugin.HandshakeConfig{
   527  						ProtocolVersion:  1,
   528  						MagicCookieKey:   "magic_cookie_key",
   529  						MagicCookieValue: "magic_cookie_value",
   530  					},
   531  					Logger:     hclog.NewNullLogger(),
   532  					AutoMTLS:   true,
   533  					Name:       "some-plugin",
   534  					PluginType: consts.PluginTypeCredential,
   535  					Version:    "v0.1.0",
   536  				},
   537  			},
   538  			expected: plugincontainer.Config{
   539  				Image:  "some-image",
   540  				Tag:    "0.1.0",
   541  				SHA256: "abc123",
   542  				Env: []string{
   543  					fmt.Sprintf("%s=%s", PluginVaultVersionEnv, "dummyversion"),
   544  					fmt.Sprintf("%s=%t", PluginMetadataModeEnv, false),
   545  					fmt.Sprintf("%s=%t", PluginAutoMTLSEnv, true),
   546  				},
   547  				Labels: map[string]string{
   548  					labelVaultPID:           myPID,
   549  					labelVaultClusterID:     "1234",
   550  					labelVaultPluginName:    "some-plugin",
   551  					labelVaultPluginType:    "auth",
   552  					labelVaultPluginVersion: "v0.1.0",
   553  				},
   554  				Runtime:      "some-oci-runtime",
   555  				GroupAdd:     os.Getgid(),
   556  				CgroupParent: "/cgroup/parent",
   557  				NanoCpus:     1000,
   558  				Memory:       2000,
   559  			},
   560  		},
   561  	} {
   562  		t.Run(name, func(t *testing.T) {
   563  			mockWrapper := new(mockRunnerUtil)
   564  			mockWrapper.On("ResponseWrapData", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
   565  				Return(nil, nil)
   566  			mockWrapper.On("MlockEnabled").
   567  				Return(false)
   568  			tc.rc.Wrapper = mockWrapper
   569  			cmd, _, err := tc.rc.generateCmd(context.Background())
   570  			if err != nil {
   571  				t.Fatal(err)
   572  			}
   573  			cfg, err := tc.rc.containerConfig(context.Background(), cmd.Env)
   574  			require.NoError(t, err)
   575  			require.Equal(t, tc.expected, *cfg)
   576  		})
   577  	}
   578  }