github.com/yorinasub17/go-cloud@v0.27.40/secrets/awskms/kms_test.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package awskms
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"net/url"
    22  	"os"
    23  	"testing"
    24  
    25  	kmsv2 "github.com/aws/aws-sdk-go-v2/service/kms"
    26  	"github.com/aws/aws-sdk-go/aws/awserr"
    27  	"github.com/aws/aws-sdk-go/aws/session"
    28  	"github.com/aws/aws-sdk-go/service/kms"
    29  	"github.com/aws/smithy-go"
    30  	"github.com/google/go-cmp/cmp"
    31  	"gocloud.dev/internal/testing/setup"
    32  	"gocloud.dev/secrets"
    33  	"gocloud.dev/secrets/driver"
    34  	"gocloud.dev/secrets/drivertest"
    35  )
    36  
    37  const (
    38  	keyID1 = "alias/test-secrets"
    39  	keyID2 = "alias/test-secrets2"
    40  	region = "us-east-2"
    41  )
    42  
    43  type harness struct {
    44  	useV2    bool
    45  	client   *kms.KMS
    46  	clientV2 *kmsv2.Client
    47  	close    func()
    48  }
    49  
    50  func (h *harness) MakeDriver(ctx context.Context) (driver.Keeper, driver.Keeper, error) {
    51  	return &keeper{useV2: h.useV2, keyID: keyID1, client: h.client, clientV2: h.clientV2}, &keeper{useV2: h.useV2, keyID: keyID2, client: h.client, clientV2: h.clientV2}, nil
    52  }
    53  
    54  func (h *harness) Close() {
    55  	h.close()
    56  }
    57  
    58  func newHarness(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
    59  	sess, _, done, _ := setup.NewAWSSession(ctx, t, region)
    60  	return &harness{
    61  		useV2:  false,
    62  		client: kms.New(sess),
    63  		close:  done,
    64  	}, nil
    65  }
    66  
    67  func newHarnessV2(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
    68  	cfg, _, done, _ := setup.NewAWSv2Config(ctx, t, region)
    69  	return &harness{
    70  		useV2:    true,
    71  		clientV2: kmsv2.NewFromConfig(cfg),
    72  		close:    done,
    73  	}, nil
    74  }
    75  
    76  func TestConformance(t *testing.T) {
    77  	drivertest.RunConformanceTests(t, newHarness, []drivertest.AsTest{verifyAs{v2: false}})
    78  }
    79  
    80  func TestConformanceV2(t *testing.T) {
    81  	drivertest.RunConformanceTests(t, newHarnessV2, []drivertest.AsTest{verifyAs{v2: true}})
    82  }
    83  
    84  type verifyAs struct {
    85  	v2 bool
    86  }
    87  
    88  func (v verifyAs) Name() string {
    89  	return "verify As function"
    90  }
    91  
    92  func (v verifyAs) ErrorCheck(k *secrets.Keeper, err error) error {
    93  	var code string
    94  	if v.v2 {
    95  		var e smithy.APIError
    96  		if !k.ErrorAs(err, &e) {
    97  			return errors.New("Keeper.ErrorAs failed")
    98  		}
    99  		code = e.ErrorCode()
   100  	} else {
   101  		var e awserr.Error
   102  		if !k.ErrorAs(err, &e) {
   103  			return errors.New("Keeper.ErrorAs failed")
   104  		}
   105  		code = e.Code()
   106  	}
   107  	if code != kms.ErrCodeInvalidCiphertextException {
   108  		return fmt.Errorf("got %q, want %q", code, kms.ErrCodeInvalidCiphertextException)
   109  	}
   110  	return nil
   111  }
   112  
   113  // KMS-specific tests.
   114  
   115  func TestNoSessionProvidedError(t *testing.T) {
   116  	if _, err := Dial(nil); err == nil {
   117  		t.Error("got nil, want no AWS session provided")
   118  	}
   119  }
   120  
   121  func TestNoConnectionError(t *testing.T) {
   122  	prevAccessKey := os.Getenv("AWS_ACCESS_KEY")
   123  	prevSecretKey := os.Getenv("AWS_SECRET_KEY")
   124  	prevRegion := os.Getenv("AWS_REGION")
   125  	os.Setenv("AWS_ACCESS_KEY", "myaccesskey")
   126  	os.Setenv("AWS_SECRET_KEY", "mysecretkey")
   127  	os.Setenv("AWS_REGION", "us-east-1")
   128  	defer func() {
   129  		os.Setenv("AWS_ACCESS_KEY", prevAccessKey)
   130  		os.Setenv("AWS_SECRET_KEY", prevSecretKey)
   131  		os.Setenv("AWS_REGION", prevRegion)
   132  	}()
   133  	sess, err := session.NewSession()
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  
   138  	client, err := Dial(sess)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  	keeper := OpenKeeper(client, keyID1, nil)
   143  	defer keeper.Close()
   144  
   145  	if _, err := keeper.Encrypt(context.Background(), []byte("test")); err == nil {
   146  		t.Error("got nil, want UnrecognizedClientException")
   147  	}
   148  }
   149  
   150  func TestEncryptionContext(t *testing.T) {
   151  	tests := []struct {
   152  		Existing map[string]string
   153  		URL      string
   154  		WantErr  bool
   155  		Want     map[string]string
   156  	}{
   157  		// None before or after.
   158  		{nil, "http://foo", false, nil},
   159  		// New parameter.
   160  		{nil, "http://foo?context_foo=bar", false, map[string]string{"foo": "bar"}},
   161  		// 2 new parameters.
   162  		{nil, "http://foo?context_foo=bar&context_abc=baz", false, map[string]string{"foo": "bar", "abc": "baz"}},
   163  		// Multiple values.
   164  		{nil, "http://foo?context_foo=bar&context_foo=baz", true, nil},
   165  		// Existing, no new.
   166  		{map[string]string{"foo": "bar"}, "http://foo", false, map[string]string{"foo": "bar"}},
   167  		// No-conflict merge.
   168  		{map[string]string{"foo": "bar"}, "http://foo?context_abc=baz", false, map[string]string{"foo": "bar", "abc": "baz"}},
   169  		// Overwrite merge.
   170  		{map[string]string{"foo": "bar"}, "http://foo?context_foo=baz", false, map[string]string{"foo": "baz"}},
   171  	}
   172  	for _, test := range tests {
   173  		t.Run(fmt.Sprintf("existing %v URL %v", test.Existing, test.URL), func(t *testing.T) {
   174  			opts := KeeperOptions{
   175  				EncryptionContext: test.Existing,
   176  			}
   177  			u, err := url.Parse(test.URL)
   178  			if err != nil {
   179  				t.Fatal(err)
   180  			}
   181  			err = addEncryptionContextFromURLParams(&opts, u.Query())
   182  			if (err != nil) != test.WantErr {
   183  				t.Fatalf("got err %v, want error? %v", err, test.WantErr)
   184  			}
   185  			if diff := cmp.Diff(opts.EncryptionContext, test.Want); diff != "" {
   186  				t.Errorf("diff %v", diff)
   187  			}
   188  		})
   189  	}
   190  }
   191  
   192  func TestOpenKeeper(t *testing.T) {
   193  	tests := []struct {
   194  		URL     string
   195  		WantErr bool
   196  	}{
   197  		// OK, by alias.
   198  		{"awskms://alias/my-key", false},
   199  		// OK, by ARN with empty Host.
   200  		{"awskms:///arn:aws:kms:us-east-1:932528106278:alias/gocloud-test", false},
   201  		// OK, by ARN with empty Host.
   202  		{"awskms:///arn:aws:kms:us-east-1:932528106278:key/8be0dcc5-da0a-4164-a99f-649015e344b5", false},
   203  		// OK, overriding region.
   204  		{"awskms://alias/my-key?region=us-west1", false},
   205  		// OK, using V1.
   206  		{"awskms://alias/my-key?awssdk=v1", false},
   207  		// OK, using V2.
   208  		{"awskms://alias/my-key?awssdk=v2", false},
   209  		// OK, adding EncryptionContext.
   210  		{"awskms://alias/my-key?context_abc=foo&context_def=bar", false},
   211  		// Multiple values for an EncryptionContext.
   212  		{"awskms://alias/my-key?context_abc=foo&context_abc=bar", true},
   213  		// Unknown parameter.
   214  		{"awskms://alias/my-key?param=value", true},
   215  	}
   216  
   217  	ctx := context.Background()
   218  	for _, test := range tests {
   219  		keeper, err := secrets.OpenKeeper(ctx, test.URL)
   220  		if (err != nil) != test.WantErr {
   221  			t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
   222  		}
   223  		if err == nil {
   224  			if err = keeper.Close(); err != nil {
   225  				t.Errorf("%s: got error during close: %v", test.URL, err)
   226  			}
   227  		}
   228  	}
   229  }