github.com/thiagoyeds/go-cloud@v0.26.0/runtimevar/awssecretsmanager/awssecretsmanager.go (about)

     1  // Copyright 2020 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awssecretsmanager provides a runtimevar implementation with variables
    16  // read from AWS Secrets Manager (https://aws.amazon.com/secrets-manager)
    17  // Use OpenVariable to construct a *runtimevar.Variable.
    18  //
    19  // URLs
    20  //
    21  // For runtimevar.OpenVariable, awssecretsmanager registers for the scheme "awssecretsmanager".
    22  // The default URL opener will use an AWS session with the default credentials
    23  // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
    24  // for more details.
    25  // To customize the URL opener, or for more details on the URL format,
    26  // see URLOpener.
    27  // See https://gocloud.dev/concepts/urls/ for background information.
    28  //
    29  // As
    30  //
    31  // awssecretsmanager exposes the following types for As:
    32  //  - Snapshot: (V1) *secretsmanager.GetSecretValueOutput, *secretsmanager.DescribeSecretOutput, (V2) *secretsmanagerv2.GetSecretValueOutput, *secretsmanagerv2.DescribeSecretOutput
    33  //  - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError
    34  package awssecretsmanager // import "gocloud.dev/runtimevar/awssecretsmanager"
    35  
    36  import (
    37  	"context"
    38  	"errors"
    39  	"fmt"
    40  	"net/url"
    41  	"path"
    42  	"strings"
    43  	"sync"
    44  	"time"
    45  
    46  	awsv2 "github.com/aws/aws-sdk-go-v2/aws"
    47  	secretsmanagerv2 "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
    48  	"github.com/aws/aws-sdk-go/aws"
    49  	"github.com/aws/aws-sdk-go/aws/awserr"
    50  	"github.com/aws/aws-sdk-go/aws/client"
    51  	"github.com/aws/aws-sdk-go/aws/request"
    52  	"github.com/aws/aws-sdk-go/service/secretsmanager"
    53  	"github.com/aws/smithy-go"
    54  	"github.com/google/wire"
    55  	gcaws "gocloud.dev/aws"
    56  	"gocloud.dev/gcerrors"
    57  	"gocloud.dev/runtimevar"
    58  	"gocloud.dev/runtimevar/driver"
    59  )
    60  
    61  func init() {
    62  	runtimevar.DefaultURLMux().RegisterVariable(Scheme, new(lazySessionOpener))
    63  }
    64  
    65  // Set holds Wire providers for this package.
    66  var Set = wire.NewSet(
    67  	wire.Struct(new(URLOpener), "ConfigProvider"),
    68  )
    69  
    70  // URLOpener opens AWS Secrets Manager URLs like "awssecretsmanager://my-secret-var-name".
    71  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
    72  //
    73  // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
    74  // or anything else to accept the default.
    75  //
    76  // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
    77  // for overriding the aws.Session from the URL.
    78  // For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
    79  //
    80  // In addition, the following URL parameters are supported:
    81  //   - decoder: The decoder to use. Defaults to URLOpener.Decoder, or
    82  //       runtimevar.BytesDecoder if URLOpener.Decoder is nil.
    83  //       See runtimevar.DecoderByName for supported values.
    84  type URLOpener struct {
    85  	// UseV2 indicates whether the AWS SDK V2 should be used.
    86  	UseV2 bool
    87  
    88  	// ConfigProvider must be set to a non-nil value if UseV2 is false.
    89  	ConfigProvider client.ConfigProvider
    90  
    91  	// Decoder specifies the decoder to use if one is not specified in the URL.
    92  	// Defaults to runtimevar.BytesDecoder.
    93  	Decoder *runtimevar.Decoder
    94  
    95  	// Options specifies the options to pass to New.
    96  	Options Options
    97  }
    98  
    99  // lazySessionOpener obtains the AWS session from the environment on the first
   100  // call to OpenVariableURL.
   101  type lazySessionOpener struct {
   102  	init   sync.Once
   103  	opener *URLOpener
   104  	err    error
   105  }
   106  
   107  func (o *lazySessionOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   108  	if gcaws.UseV2(u.Query()) {
   109  		opener := &URLOpener{UseV2: true}
   110  		return opener.OpenVariableURL(ctx, u)
   111  	}
   112  	o.init.Do(func() {
   113  		sess, err := gcaws.NewDefaultSession()
   114  		if err != nil {
   115  			o.err = err
   116  			return
   117  		}
   118  		o.opener = &URLOpener{
   119  			ConfigProvider: sess,
   120  		}
   121  	})
   122  	if o.err != nil {
   123  		return nil, fmt.Errorf("open variable %v: %v", u, o.err)
   124  	}
   125  	return o.opener.OpenVariableURL(ctx, u)
   126  }
   127  
   128  // Scheme is the URL scheme awssecretsmanager registers its URLOpener under on runtimevar.DefaultMux.
   129  const Scheme = "awssecretsmanager"
   130  
   131  // OpenVariableURL opens the variable at the URL's path. See the package doc
   132  // for more details.
   133  func (o *URLOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   134  	q := u.Query()
   135  
   136  	decoderName := q.Get("decoder")
   137  	q.Del("decoder")
   138  	decoder, err := runtimevar.DecoderByName(ctx, decoderName, o.Decoder)
   139  	if err != nil {
   140  		return nil, fmt.Errorf("open variable %v: invalid decoder: %v", u, err)
   141  	}
   142  
   143  	if o.UseV2 {
   144  		cfg, err := gcaws.V2ConfigFromURLParams(ctx, q)
   145  		if err != nil {
   146  			return nil, fmt.Errorf("open variable %v: %v", u, err)
   147  		}
   148  		return OpenVariableV2(secretsmanagerv2.NewFromConfig(cfg), path.Join(u.Host, u.Path), decoder, &o.Options)
   149  	}
   150  	configProvider := &gcaws.ConfigOverrider{
   151  		Base: o.ConfigProvider,
   152  	}
   153  	overrideCfg, err := gcaws.ConfigFromURLParams(q)
   154  	if err != nil {
   155  		return nil, fmt.Errorf("open variable %v: %v", u, err)
   156  	}
   157  
   158  	configProvider.Configs = append(configProvider.Configs, overrideCfg)
   159  
   160  	return OpenVariable(configProvider, path.Join(u.Host, u.Path), decoder, &o.Options)
   161  }
   162  
   163  // Options sets options.
   164  type Options struct {
   165  	// WaitDuration controls the rate at which AWS Secrets Manager is polled.
   166  	// Defaults to 30 seconds.
   167  	WaitDuration time.Duration
   168  }
   169  
   170  // OpenVariable constructs a *runtimevar.Variable backed by the variable name in AWS Secrets Manager.
   171  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
   172  // Secrets Manager returns raw bytes; provide a decoder to decode the raw bytes
   173  // into the appropriate type for runtimevar.Snapshot.Value.
   174  // See the runtimevar package documentation for examples of decoders.
   175  func OpenVariable(sess client.ConfigProvider, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   176  	return runtimevar.New(newWatcher(false, sess, nil, name, decoder, opts)), nil
   177  }
   178  
   179  // OpenVariableV2 constructs a *runtimevar.Variable backed by the variable name in AWS Secrets Manager,
   180  // using AWS SDK V2.
   181  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
   182  // Secrets Manager returns raw bytes; provide a decoder to decode the raw bytes
   183  // into the appropriate type for runtimevar.Snapshot.Value.
   184  // See the runtimevar package documentation for examples of decoders.
   185  func OpenVariableV2(client *secretsmanagerv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   186  	return runtimevar.New(newWatcher(true, nil, client, name, decoder, opts)), nil
   187  }
   188  
   189  // state implements driver.State.
   190  type state struct {
   191  	val        interface{}
   192  	rawGetV1   *secretsmanager.GetSecretValueOutput
   193  	rawGetV2   *secretsmanagerv2.GetSecretValueOutput
   194  	rawDescV1  *secretsmanager.DescribeSecretOutput
   195  	rawDescV2  *secretsmanagerv2.DescribeSecretOutput
   196  	updateTime time.Time
   197  	versionID  string
   198  	err        error
   199  }
   200  
   201  // Value implements driver.State.Value.
   202  func (s *state) Value() (interface{}, error) {
   203  	return s.val, s.err
   204  }
   205  
   206  // UpdateTime implements driver.State.UpdateTime.
   207  func (s *state) UpdateTime() time.Time {
   208  	return s.updateTime
   209  }
   210  
   211  // As implements driver.State.As.
   212  func (s *state) As(i interface{}) bool {
   213  	switch p := i.(type) {
   214  	case **secretsmanager.GetSecretValueOutput:
   215  		*p = s.rawGetV1
   216  	case **secretsmanagerv2.GetSecretValueOutput:
   217  		*p = s.rawGetV2
   218  	case **secretsmanager.DescribeSecretOutput:
   219  		*p = s.rawDescV1
   220  	case **secretsmanagerv2.DescribeSecretOutput:
   221  		*p = s.rawDescV2
   222  	default:
   223  		return false
   224  	}
   225  	return true
   226  }
   227  
   228  // errorState returns a new State with err, unless prevS also represents
   229  // the same error, in which case it returns nil.
   230  func errorState(err error, prevS driver.State) driver.State {
   231  	// Map to the more standard context package error.
   232  	if strings.Contains(err.Error(), "context deadline exceeded") {
   233  		err = context.DeadlineExceeded
   234  	} else if getErrorCode(err) == request.CanceledErrorCode {
   235  		err = context.Canceled
   236  	}
   237  	s := &state{err: err}
   238  	if prevS == nil {
   239  		return s
   240  	}
   241  	prev := prevS.(*state)
   242  	if prev.err == nil {
   243  		// New error.
   244  		return s
   245  	}
   246  	if equivalentError(err, prev.err) {
   247  		// Same error, return nil to indicate no change.
   248  		return nil
   249  	}
   250  	return s
   251  }
   252  
   253  // equivalentError returns true iff err1 and err2 represent an equivalent error;
   254  // i.e., we don't want to return it to the user as a different error.
   255  func equivalentError(err1, err2 error) bool {
   256  	if err1 == err2 || err1.Error() == err2.Error() {
   257  		return true
   258  	}
   259  	code1 := getErrorCode(err1)
   260  	code2 := getErrorCode(err2)
   261  	return code1 != "" && code1 == code2
   262  }
   263  
   264  type watcher struct {
   265  	// useV2 indicates whether we're using clientV2.
   266  	useV2 bool
   267  	// sess is the AWS session to use to talk to AWS.
   268  	sess client.ConfigProvider
   269  	// clientV2 is the client to use when useV2 is true.
   270  	clientV2 *secretsmanagerv2.Client
   271  	// name is an ID of a secret to retrieve.
   272  	name string
   273  	// wait is the amount of time to wait between querying AWS.
   274  	wait time.Duration
   275  	// decoder is the decoder that unmarshalls the value in the param.
   276  	decoder *runtimevar.Decoder
   277  }
   278  
   279  func newWatcher(useV2 bool, sess client.ConfigProvider, clientV2 *secretsmanagerv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) *watcher {
   280  	if opts == nil {
   281  		opts = &Options{}
   282  	}
   283  	return &watcher{
   284  		useV2:    useV2,
   285  		sess:     sess,
   286  		clientV2: clientV2,
   287  		name:     name,
   288  		wait:     driver.WaitDuration(opts.WaitDuration),
   289  		decoder:  decoder,
   290  	}
   291  }
   292  
   293  func getSecretValue(ctx context.Context, svc *secretsmanager.SecretsManager, secretID string) (string, []byte, string, *secretsmanager.GetSecretValueOutput, error) {
   294  	getResp, err := svc.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{
   295  		SecretId: aws.String(secretID),
   296  	})
   297  	if err != nil {
   298  		return "", nil, "", nil, err
   299  	}
   300  	return aws.StringValue(getResp.VersionId), getResp.SecretBinary, aws.StringValue(getResp.SecretString), getResp, nil
   301  }
   302  
   303  func getSecretValueV2(ctx context.Context, client *secretsmanagerv2.Client, secretID string) (string, []byte, string, *secretsmanagerv2.GetSecretValueOutput, error) {
   304  	getResp, err := client.GetSecretValue(ctx, &secretsmanagerv2.GetSecretValueInput{
   305  		SecretId: awsv2.String(secretID),
   306  	})
   307  	if err != nil {
   308  		return "", nil, "", nil, err
   309  	}
   310  	return awsv2.ToString(getResp.VersionId), getResp.SecretBinary, awsv2.ToString(getResp.SecretString), getResp, nil
   311  }
   312  
   313  func describeSecret(ctx context.Context, svc *secretsmanager.SecretsManager, secretID string) (time.Time, *secretsmanager.DescribeSecretOutput, error) {
   314  	descResp, err := svc.DescribeSecretWithContext(ctx, &secretsmanager.DescribeSecretInput{
   315  		SecretId: aws.String(secretID),
   316  	})
   317  	if err != nil {
   318  		return time.Time{}, nil, err
   319  	}
   320  	return aws.TimeValue(descResp.LastChangedDate), descResp, nil
   321  }
   322  
   323  func describeSecretV2(ctx context.Context, client *secretsmanagerv2.Client, secretID string) (time.Time, *secretsmanagerv2.DescribeSecretOutput, error) {
   324  	descResp, err := client.DescribeSecret(ctx, &secretsmanagerv2.DescribeSecretInput{
   325  		SecretId: awsv2.String(secretID),
   326  	})
   327  	if err != nil {
   328  		return time.Time{}, nil, err
   329  	}
   330  	return aws.TimeValue(descResp.LastChangedDate), descResp, nil
   331  }
   332  
   333  // WatchVariable implements driver.WatchVariable.
   334  func (w *watcher) WatchVariable(ctx context.Context, prev driver.State) (driver.State, time.Duration) {
   335  	var lastVersionID string
   336  	if prev != nil {
   337  		lastVersionID = prev.(*state).versionID
   338  	}
   339  	var svc *secretsmanager.SecretsManager
   340  	if !w.useV2 {
   341  		svc = secretsmanager.New(w.sess)
   342  	}
   343  
   344  	// GetParameter from S3 to get the current value and version.
   345  	var newVersionID string
   346  	var newValBinary []byte
   347  	var newValString string
   348  	var rawGetV1 *secretsmanager.GetSecretValueOutput
   349  	var rawGetV2 *secretsmanagerv2.GetSecretValueOutput
   350  	var err error
   351  	if w.useV2 {
   352  		newVersionID, newValBinary, newValString, rawGetV2, err = getSecretValueV2(ctx, w.clientV2, w.name)
   353  	} else {
   354  		newVersionID, newValBinary, newValString, rawGetV1, err = getSecretValue(ctx, svc, w.name)
   355  	}
   356  	if err != nil {
   357  		return errorState(err, prev), w.wait
   358  	}
   359  	if newVersionID == lastVersionID {
   360  		// Version hasn't changed, so no change; return nil.
   361  		return nil, w.wait
   362  	}
   363  	// Both SecretBinary and SecretString fields are not empty
   364  	// which could indicate some internal Secrets Manager issues.
   365  	// Hence, return explicit error instead of choosing one field over another.
   366  	if len(newValBinary) > 0 && newValString != "" {
   367  		err = fmt.Errorf("invalid %q response: both SecretBinary and SecretString are not empty", w.name)
   368  		return errorState(err, prev), w.wait
   369  	}
   370  
   371  	data := newValBinary
   372  	if len(data) == 0 {
   373  		if newValString == "" {
   374  			err = fmt.Errorf("invalid %q response: both SecretBinary and SecretString are empty", w.name)
   375  			return errorState(err, prev), w.wait
   376  		}
   377  		// SecretBinary is empty so use SecretString
   378  		data = []byte(newValString)
   379  	}
   380  
   381  	// DescribeParameters from S3 to get the LastModified date.
   382  	var newLastModified time.Time
   383  	var rawDescV1 *secretsmanager.DescribeSecretOutput
   384  	var rawDescV2 *secretsmanagerv2.DescribeSecretOutput
   385  	if w.useV2 {
   386  		newLastModified, rawDescV2, err = describeSecretV2(ctx, w.clientV2, w.name)
   387  	} else {
   388  		newLastModified, rawDescV1, err = describeSecret(ctx, svc, w.name)
   389  	}
   390  	if err != nil {
   391  		return errorState(err, prev), w.wait
   392  	}
   393  
   394  	// New value (or at least, new version). Decode it.
   395  	val, err := w.decoder.Decode(ctx, data)
   396  	if err != nil {
   397  		return errorState(err, prev), w.wait
   398  	}
   399  
   400  	return &state{
   401  		val:        val,
   402  		rawGetV1:   rawGetV1,
   403  		rawGetV2:   rawGetV2,
   404  		rawDescV1:  rawDescV1,
   405  		rawDescV2:  rawDescV2,
   406  		updateTime: newLastModified,
   407  		versionID:  newVersionID,
   408  	}, w.wait
   409  }
   410  
   411  // Close implements driver.Close.
   412  func (w *watcher) Close() error {
   413  	return nil
   414  }
   415  
   416  // ErrorAs implements driver.ErrorAs.
   417  func (w *watcher) ErrorAs(err error, i interface{}) bool {
   418  	if w.useV2 {
   419  		return errors.As(err, i)
   420  	}
   421  	switch v := err.(type) {
   422  	case awserr.Error:
   423  		if p, ok := i.(*awserr.Error); ok {
   424  			*p = v
   425  			return true
   426  		}
   427  	}
   428  	return false
   429  }
   430  
   431  func getErrorCode(err error) string {
   432  	if awsErr, ok := err.(awserr.Error); ok {
   433  		return awsErr.Code()
   434  	}
   435  	var ae smithy.APIError
   436  	if errors.As(err, &ae) {
   437  		return ae.ErrorCode()
   438  	}
   439  	return ""
   440  }
   441  
   442  // ErrorCode implements driver.ErrorCode.
   443  func (w *watcher) ErrorCode(err error) gcerrors.ErrorCode {
   444  	code := getErrorCode(err)
   445  	switch code {
   446  	case secretsmanager.ErrCodeResourceNotFoundException:
   447  		return gcerrors.NotFound
   448  
   449  	case secretsmanager.ErrCodeInvalidParameterException,
   450  		secretsmanager.ErrCodeInvalidRequestException,
   451  		secretsmanager.ErrCodeInvalidNextTokenException:
   452  		return gcerrors.InvalidArgument
   453  
   454  	case secretsmanager.ErrCodeEncryptionFailure,
   455  		secretsmanager.ErrCodeDecryptionFailure,
   456  		secretsmanager.ErrCodeInternalServiceError:
   457  		return gcerrors.Internal
   458  
   459  	case secretsmanager.ErrCodeResourceExistsException:
   460  		return gcerrors.AlreadyExists
   461  
   462  	case secretsmanager.ErrCodePreconditionNotMetException,
   463  		secretsmanager.ErrCodeMalformedPolicyDocumentException:
   464  		return gcerrors.FailedPrecondition
   465  
   466  	case secretsmanager.ErrCodeLimitExceededException:
   467  		return gcerrors.ResourceExhausted
   468  	}
   469  	return gcerrors.Unknown
   470  }