github.com/thiagoyeds/go-cloud@v0.26.0/secrets/awskms/kms_test.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package awskms
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"os"
    22  	"testing"
    23  
    24  	kmsv2 "github.com/aws/aws-sdk-go-v2/service/kms"
    25  	"github.com/aws/aws-sdk-go/aws/awserr"
    26  	"github.com/aws/aws-sdk-go/aws/session"
    27  	"github.com/aws/aws-sdk-go/service/kms"
    28  	"github.com/aws/smithy-go"
    29  	"gocloud.dev/internal/testing/setup"
    30  	"gocloud.dev/secrets"
    31  	"gocloud.dev/secrets/driver"
    32  	"gocloud.dev/secrets/drivertest"
    33  )
    34  
    35  const (
    36  	keyID1 = "alias/test-secrets"
    37  	keyID2 = "alias/test-secrets2"
    38  	region = "us-east-2"
    39  )
    40  
    41  type harness struct {
    42  	useV2    bool
    43  	client   *kms.KMS
    44  	clientV2 *kmsv2.Client
    45  	close    func()
    46  }
    47  
    48  func (h *harness) MakeDriver(ctx context.Context) (driver.Keeper, driver.Keeper, error) {
    49  	return &keeper{useV2: h.useV2, keyID: keyID1, client: h.client, clientV2: h.clientV2}, &keeper{useV2: h.useV2, keyID: keyID2, client: h.client, clientV2: h.clientV2}, nil
    50  }
    51  
    52  func (h *harness) Close() {
    53  	h.close()
    54  }
    55  
    56  func newHarness(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
    57  	sess, _, done, _ := setup.NewAWSSession(ctx, t, region)
    58  	return &harness{
    59  		useV2:  false,
    60  		client: kms.New(sess),
    61  		close:  done,
    62  	}, nil
    63  }
    64  
    65  func newHarnessV2(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
    66  	cfg, _, done, _ := setup.NewAWSv2Config(ctx, t, region)
    67  	return &harness{
    68  		useV2:    true,
    69  		clientV2: kmsv2.NewFromConfig(cfg),
    70  		close:    done,
    71  	}, nil
    72  }
    73  
    74  func TestConformance(t *testing.T) {
    75  	drivertest.RunConformanceTests(t, newHarness, []drivertest.AsTest{verifyAs{v2: false}})
    76  }
    77  
    78  func TestConformanceV2(t *testing.T) {
    79  	drivertest.RunConformanceTests(t, newHarnessV2, []drivertest.AsTest{verifyAs{v2: true}})
    80  }
    81  
    82  type verifyAs struct {
    83  	v2 bool
    84  }
    85  
    86  func (v verifyAs) Name() string {
    87  	return "verify As function"
    88  }
    89  
    90  func (v verifyAs) ErrorCheck(k *secrets.Keeper, err error) error {
    91  	var code string
    92  	if v.v2 {
    93  		var e smithy.APIError
    94  		if !k.ErrorAs(err, &e) {
    95  			return errors.New("Keeper.ErrorAs failed")
    96  		}
    97  		code = e.ErrorCode()
    98  	} else {
    99  		var e awserr.Error
   100  		if !k.ErrorAs(err, &e) {
   101  			return errors.New("Keeper.ErrorAs failed")
   102  		}
   103  		code = e.Code()
   104  	}
   105  	if code != kms.ErrCodeInvalidCiphertextException {
   106  		return fmt.Errorf("got %q, want %q", code, kms.ErrCodeInvalidCiphertextException)
   107  	}
   108  	return nil
   109  }
   110  
   111  // KMS-specific tests.
   112  
   113  func TestNoSessionProvidedError(t *testing.T) {
   114  	if _, err := Dial(nil); err == nil {
   115  		t.Error("got nil, want no AWS session provided")
   116  	}
   117  }
   118  
   119  func TestNoConnectionError(t *testing.T) {
   120  	prevAccessKey := os.Getenv("AWS_ACCESS_KEY")
   121  	prevSecretKey := os.Getenv("AWS_SECRET_KEY")
   122  	prevRegion := os.Getenv("AWS_REGION")
   123  	os.Setenv("AWS_ACCESS_KEY", "myaccesskey")
   124  	os.Setenv("AWS_SECRET_KEY", "mysecretkey")
   125  	os.Setenv("AWS_REGION", "us-east-1")
   126  	defer func() {
   127  		os.Setenv("AWS_ACCESS_KEY", prevAccessKey)
   128  		os.Setenv("AWS_SECRET_KEY", prevSecretKey)
   129  		os.Setenv("AWS_REGION", prevRegion)
   130  	}()
   131  	sess, err := session.NewSession()
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  
   136  	client, err := Dial(sess)
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  	keeper := OpenKeeper(client, keyID1, nil)
   141  	defer keeper.Close()
   142  
   143  	if _, err := keeper.Encrypt(context.Background(), []byte("test")); err == nil {
   144  		t.Error("got nil, want UnrecognizedClientException")
   145  	}
   146  }
   147  
   148  func TestOpenKeeper(t *testing.T) {
   149  	tests := []struct {
   150  		URL     string
   151  		WantErr bool
   152  	}{
   153  		// OK, by alias.
   154  		{"awskms://alias/my-key", false},
   155  		// OK, by ARN with empty Host.
   156  		{"awskms:///arn:aws:kms:us-east-1:932528106278:alias/gocloud-test", false},
   157  		// OK, by ARN with empty Host.
   158  		{"awskms:///arn:aws:kms:us-east-1:932528106278:key/8be0dcc5-da0a-4164-a99f-649015e344b5", false},
   159  		// OK, overriding region.
   160  		{"awskms://alias/my-key?region=us-west1", false},
   161  		// OK, using V1.
   162  		{"awskms://alias/my-key?awssdk=v1", false},
   163  		// OK, using V2.
   164  		{"awskms://alias/my-key?awssdk=v2", false},
   165  		// Unknown parameter.
   166  		{"awskms://alias/my-key?param=value", true},
   167  	}
   168  
   169  	ctx := context.Background()
   170  	for _, test := range tests {
   171  		keeper, err := secrets.OpenKeeper(ctx, test.URL)
   172  		if (err != nil) != test.WantErr {
   173  			t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
   174  		}
   175  		if err == nil {
   176  			if err = keeper.Close(); err != nil {
   177  				t.Errorf("%s: got error during close: %v", test.URL, err)
   178  			}
   179  		}
   180  	}
   181  }