github.com/myhau/pulumi/pkg/v3@v3.70.2-0.20221116134521-f2775972e587/resource/stack/secrets_test.go (about)

     1  package stack
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/pulumi/pulumi/pkg/v3/secrets"
    13  	"github.com/pulumi/pulumi/sdk/v3/go/common/encoding"
    14  	"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
    15  	"github.com/pulumi/pulumi/sdk/v3/go/common/resource/config"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  type testSecretsManager struct {
    21  	encryptCalls int
    22  	decryptCalls int
    23  }
    24  
    25  func (t *testSecretsManager) Type() string { return "test" }
    26  
    27  func (t *testSecretsManager) State() interface{} { return nil }
    28  
    29  func (t *testSecretsManager) Encrypter() (config.Encrypter, error) {
    30  	return t, nil
    31  }
    32  
    33  func (t *testSecretsManager) Decrypter() (config.Decrypter, error) {
    34  	return t, nil
    35  }
    36  
    37  func (t *testSecretsManager) EncryptValue(
    38  	ctx context.Context, plaintext string) (string, error) {
    39  	t.encryptCalls++
    40  	return fmt.Sprintf("%v:%v", t.encryptCalls, plaintext), nil
    41  }
    42  
    43  func (t *testSecretsManager) DecryptValue(
    44  	ctx context.Context, ciphertext string) (string, error) {
    45  	t.decryptCalls++
    46  	i := strings.Index(ciphertext, ":")
    47  	if i == -1 {
    48  		return "", errors.New("invalid ciphertext format")
    49  	}
    50  	return ciphertext[i+1:], nil
    51  }
    52  
    53  func (t *testSecretsManager) BulkDecrypt(
    54  	ctx context.Context, ciphertexts []string) (map[string]string, error) {
    55  	return config.DefaultBulkDecrypt(ctx, t, ciphertexts)
    56  }
    57  
    58  func deserializeProperty(v interface{}, dec config.Decrypter) (resource.PropertyValue, error) {
    59  	b, err := json.Marshal(v)
    60  	if err != nil {
    61  		return resource.PropertyValue{}, err
    62  	}
    63  	if err := json.Unmarshal(b, &v); err != nil {
    64  		return resource.PropertyValue{}, err
    65  	}
    66  	return DeserializePropertyValue(v, dec, config.NewPanicCrypter())
    67  }
    68  
    69  func TestCachingCrypter(t *testing.T) {
    70  	t.Parallel()
    71  
    72  	sm := &testSecretsManager{}
    73  	csm := NewCachingSecretsManager(sm)
    74  
    75  	foo1 := resource.MakeSecret(resource.NewStringProperty("foo"))
    76  	foo2 := resource.MakeSecret(resource.NewStringProperty("foo"))
    77  	bar := resource.MakeSecret(resource.NewStringProperty("bar"))
    78  
    79  	enc, err := csm.Encrypter()
    80  	assert.NoError(t, err)
    81  
    82  	// Serialize the first copy of "foo". Encrypt should be called once, as this value has not yet been encrypted.
    83  	foo1Ser, err := SerializePropertyValue(foo1, enc, false /* showSecrets */)
    84  	assert.NoError(t, err)
    85  	assert.Equal(t, 1, sm.encryptCalls)
    86  
    87  	// Serialize the second copy of "foo". Because this is a different secret instance, Encrypt should be called
    88  	// a second time even though the plaintext is the same as the last value we encrypted.
    89  	foo2Ser, err := SerializePropertyValue(foo2, enc, false /* showSecrets */)
    90  	assert.NoError(t, err)
    91  	assert.Equal(t, 2, sm.encryptCalls)
    92  	assert.NotEqual(t, foo1Ser, foo2Ser)
    93  
    94  	// Serialize "bar". Encrypt should be called once, as this value has not yet been encrypted.
    95  	barSer, err := SerializePropertyValue(bar, enc, false /* showSecrets */)
    96  	assert.NoError(t, err)
    97  	assert.Equal(t, 3, sm.encryptCalls)
    98  
    99  	// Serialize the first copy of "foo" again. Encrypt should not be called, as this value has already been
   100  	// encrypted.
   101  	foo1Ser2, err := SerializePropertyValue(foo1, enc, false /* showSecrets */)
   102  	assert.NoError(t, err)
   103  	assert.Equal(t, 3, sm.encryptCalls)
   104  	assert.Equal(t, foo1Ser, foo1Ser2)
   105  
   106  	// Serialize the second copy of "foo" again. Encrypt should not be called, as this value has already been
   107  	// encrypted.
   108  	foo2Ser2, err := SerializePropertyValue(foo2, enc, false /* showSecrets */)
   109  	assert.NoError(t, err)
   110  	assert.Equal(t, 3, sm.encryptCalls)
   111  	assert.Equal(t, foo2Ser, foo2Ser2)
   112  
   113  	// Serialize "bar" again. Encrypt should not be called, as this value has already been encrypted.
   114  	barSer2, err := SerializePropertyValue(bar, enc, false /* showSecrets */)
   115  	assert.NoError(t, err)
   116  	assert.Equal(t, 3, sm.encryptCalls)
   117  	assert.Equal(t, barSer, barSer2)
   118  
   119  	dec, err := csm.Decrypter()
   120  	assert.NoError(t, err)
   121  
   122  	// Decrypt foo1Ser. Decrypt should be called.
   123  	foo1Dec, err := deserializeProperty(foo1Ser, dec)
   124  	assert.NoError(t, err)
   125  	assert.True(t, foo1.DeepEquals(foo1Dec))
   126  	assert.Equal(t, 1, sm.decryptCalls)
   127  
   128  	// Decrypt foo2Ser. Decrypt should be called.
   129  	foo2Dec, err := deserializeProperty(foo2Ser, dec)
   130  	assert.NoError(t, err)
   131  	assert.True(t, foo2.DeepEquals(foo2Dec))
   132  	assert.Equal(t, 2, sm.decryptCalls)
   133  
   134  	// Decrypt barSer. Decrypt should be called.
   135  	barDec, err := deserializeProperty(barSer, dec)
   136  	assert.NoError(t, err)
   137  	assert.True(t, bar.DeepEquals(barDec))
   138  	assert.Equal(t, 3, sm.decryptCalls)
   139  
   140  	// Create a new CachingSecretsManager and re-run the decrypts. Each decrypt should insert the plain- and
   141  	// ciphertext into the cache with the associated secret.
   142  	csm = NewCachingSecretsManager(sm)
   143  
   144  	dec, err = csm.Decrypter()
   145  	assert.NoError(t, err)
   146  
   147  	// Decrypt foo1Ser. Decrypt should be called.
   148  	foo1Dec, err = deserializeProperty(foo1Ser, dec)
   149  	assert.NoError(t, err)
   150  	assert.True(t, foo1.DeepEquals(foo1Dec))
   151  	assert.Equal(t, 4, sm.decryptCalls)
   152  
   153  	// Decrypt foo2Ser. Decrypt should be called.
   154  	foo2Dec, err = deserializeProperty(foo2Ser, dec)
   155  	assert.NoError(t, err)
   156  	assert.True(t, foo2.DeepEquals(foo2Dec))
   157  	assert.Equal(t, 5, sm.decryptCalls)
   158  
   159  	// Decrypt barSer. Decrypt should be called.
   160  	barDec, err = deserializeProperty(barSer, dec)
   161  	assert.NoError(t, err)
   162  	assert.True(t, bar.DeepEquals(barDec))
   163  	assert.Equal(t, 6, sm.decryptCalls)
   164  
   165  	enc, err = csm.Encrypter()
   166  	assert.NoError(t, err)
   167  
   168  	// Serialize the first copy of "foo" again. Encrypt should not be called, as this value has already been
   169  	// cached by the earlier calls to Decrypt.
   170  	foo1Ser2, err = SerializePropertyValue(foo1Dec, enc, false /* showSecrets */)
   171  	assert.NoError(t, err)
   172  	assert.Equal(t, 3, sm.encryptCalls)
   173  	assert.Equal(t, foo1Ser, foo1Ser2)
   174  
   175  	// Serialize the second copy of "foo" again. Encrypt should not be called, as this value has already been
   176  	// cached by the earlier calls to Decrypt.
   177  	foo2Ser2, err = SerializePropertyValue(foo2Dec, enc, false /* showSecrets */)
   178  	assert.NoError(t, err)
   179  	assert.Equal(t, 3, sm.encryptCalls)
   180  	assert.Equal(t, foo2Ser, foo2Ser2)
   181  
   182  	// Serialize "bar" again. Encrypt should not be called, as this value has already been cached by the
   183  	// earlier calls to Decrypt.
   184  	barSer2, err = SerializePropertyValue(barDec, enc, false /* showSecrets */)
   185  	assert.NoError(t, err)
   186  	assert.Equal(t, 3, sm.encryptCalls)
   187  	assert.Equal(t, barSer, barSer2)
   188  }
   189  
   190  type mapTestSecretsProvider struct {
   191  	m *mapTestSecretsManager
   192  }
   193  
   194  func (p *mapTestSecretsProvider) OfType(ty string, state json.RawMessage) (secrets.Manager, error) {
   195  	m, err := DefaultSecretsProvider.OfType(ty, state)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	p.m = &mapTestSecretsManager{sm: m}
   200  	return p.m, nil
   201  }
   202  
   203  type mapTestSecretsManager struct {
   204  	sm secrets.Manager
   205  
   206  	d *mapTestDecrypter
   207  }
   208  
   209  func (t *mapTestSecretsManager) Type() string { return t.sm.Type() }
   210  
   211  func (t *mapTestSecretsManager) State() interface{} { return t.sm.State() }
   212  
   213  func (t *mapTestSecretsManager) Encrypter() (config.Encrypter, error) {
   214  	return t.sm.Encrypter()
   215  }
   216  
   217  func (t *mapTestSecretsManager) Decrypter() (config.Decrypter, error) {
   218  	d, err := t.sm.Decrypter()
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	t.d = &mapTestDecrypter{d: d}
   223  	return t.d, nil
   224  }
   225  
   226  type mapTestDecrypter struct {
   227  	d config.Decrypter
   228  
   229  	decryptCalls     int
   230  	bulkDecryptCalls int
   231  }
   232  
   233  func (t *mapTestDecrypter) DecryptValue(
   234  	ctx context.Context, ciphertext string) (string, error) {
   235  	t.decryptCalls++
   236  	return t.d.DecryptValue(ctx, ciphertext)
   237  }
   238  
   239  func (t *mapTestDecrypter) BulkDecrypt(
   240  	ctx context.Context, ciphertexts []string) (map[string]string, error) {
   241  	t.bulkDecryptCalls++
   242  	return config.DefaultBulkDecrypt(ctx, t.d, ciphertexts)
   243  }
   244  
   245  func TestMapCrypter(t *testing.T) {
   246  	t.Parallel()
   247  
   248  	ctx := context.Background()
   249  
   250  	bytes, err := ioutil.ReadFile("testdata/checkpoint-secrets.json")
   251  	require.NoError(t, err)
   252  
   253  	chk, err := UnmarshalVersionedCheckpointToLatestCheckpoint(encoding.JSON, bytes)
   254  	require.NoError(t, err)
   255  
   256  	var prov mapTestSecretsProvider
   257  
   258  	_, err = DeserializeDeploymentV3(ctx, *chk.Latest, &prov)
   259  	require.NoError(t, err)
   260  
   261  	d := prov.m.d
   262  	assert.Equal(t, 1, d.bulkDecryptCalls)
   263  	assert.Equal(t, 0, d.decryptCalls)
   264  }