github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/runtimevar/awssecretsmanager/awssecretsmanager.go (about)

     1  // Copyright 2020 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awssecretsmanager provides a runtimevar implementation with variables
    16  // read from AWS Secrets Manager (https://aws.amazon.com/secrets-manager)
    17  // Use OpenVariable to construct a *runtimevar.Variable.
    18  //
    19  // # URLs
    20  //
    21  // For runtimevar.OpenVariable, awssecretsmanager registers for the scheme "awssecretsmanager".
    22  // The default URL opener will use an AWS session with the default credentials
    23  // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
    24  // for more details.
    25  // To customize the URL opener, or for more details on the URL format,
    26  // see URLOpener.
    27  // See https://gocloud.dev/concepts/urls/ for background information.
    28  //
    29  // # As
    30  //
    31  // awssecretsmanager exposes the following types for As:
    32  //   - Snapshot: (V1) *secretsmanager.GetSecretValueOutput, *secretsmanager.DescribeSecretOutput, (V2) *secretsmanagerv2.GetSecretValueOutput, *secretsmanagerv2.DescribeSecretOutput
    33  //   - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError
    34  package awssecretsmanager // import "gocloud.dev/runtimevar/awssecretsmanager"
    35  
    36  import (
    37  	"context"
    38  	"errors"
    39  	"fmt"
    40  	"net/url"
    41  	"path"
    42  	"strings"
    43  	"sync"
    44  	"time"
    45  
    46  	awsv2 "github.com/aws/aws-sdk-go-v2/aws"
    47  	secretsmanagerv2 "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
    48  	"github.com/aws/aws-sdk-go/aws"
    49  	"github.com/aws/aws-sdk-go/aws/awserr"
    50  	"github.com/aws/aws-sdk-go/aws/client"
    51  	"github.com/aws/aws-sdk-go/aws/request"
    52  	"github.com/aws/aws-sdk-go/service/secretsmanager"
    53  	"github.com/aws/smithy-go"
    54  	"github.com/google/wire"
    55  	gcaws "gocloud.dev/aws"
    56  	"gocloud.dev/gcerrors"
    57  	"gocloud.dev/runtimevar"
    58  	"gocloud.dev/runtimevar/driver"
    59  )
    60  
    61  func init() {
    62  	runtimevar.DefaultURLMux().RegisterVariable(Scheme, new(lazySessionOpener))
    63  }
    64  
    65  // Set holds Wire providers for this package.
    66  var Set = wire.NewSet(
    67  	wire.Struct(new(URLOpener), "ConfigProvider"),
    68  )
    69  
    70  // URLOpener opens AWS Secrets Manager URLs like "awssecretsmanager://my-secret-var-name".
    71  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
    72  //
    73  // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
    74  // or anything else to accept the default.
    75  //
    76  // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
    77  // for overriding the aws.Session from the URL.
    78  // For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
    79  //
    80  // In addition, the following URL parameters are supported:
    81  //   - decoder: The decoder to use. Defaults to URLOpener.Decoder, or
    82  //     runtimevar.BytesDecoder if URLOpener.Decoder is nil.
    83  //     See runtimevar.DecoderByName for supported values.
    84  //   - wait: The poll interval, in time.ParseDuration formats.
    85  //     Defaults to 30s.
    86  type URLOpener struct {
    87  	// UseV2 indicates whether the AWS SDK V2 should be used.
    88  	UseV2 bool
    89  
    90  	// ConfigProvider must be set to a non-nil value if UseV2 is false.
    91  	ConfigProvider client.ConfigProvider
    92  
    93  	// Decoder specifies the decoder to use if one is not specified in the URL.
    94  	// Defaults to runtimevar.BytesDecoder.
    95  	Decoder *runtimevar.Decoder
    96  
    97  	// Options specifies the options to pass to New.
    98  	Options Options
    99  }
   100  
   101  // lazySessionOpener obtains the AWS session from the environment on the first
   102  // call to OpenVariableURL.
   103  type lazySessionOpener struct {
   104  	init   sync.Once
   105  	opener *URLOpener
   106  	err    error
   107  }
   108  
   109  func (o *lazySessionOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   110  	if gcaws.UseV2(u.Query()) {
   111  		opener := &URLOpener{UseV2: true}
   112  		return opener.OpenVariableURL(ctx, u)
   113  	}
   114  	o.init.Do(func() {
   115  		sess, err := gcaws.NewDefaultSession()
   116  		if err != nil {
   117  			o.err = err
   118  			return
   119  		}
   120  		o.opener = &URLOpener{
   121  			ConfigProvider: sess,
   122  		}
   123  	})
   124  	if o.err != nil {
   125  		return nil, fmt.Errorf("open variable %v: %v", u, o.err)
   126  	}
   127  	return o.opener.OpenVariableURL(ctx, u)
   128  }
   129  
   130  // Scheme is the URL scheme awssecretsmanager registers its URLOpener under on runtimevar.DefaultMux.
   131  const Scheme = "awssecretsmanager"
   132  
   133  // OpenVariableURL opens the variable at the URL's path. See the package doc
   134  // for more details.
   135  func (o *URLOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   136  	q := u.Query()
   137  
   138  	decoderName := q.Get("decoder")
   139  	q.Del("decoder")
   140  	decoder, err := runtimevar.DecoderByName(ctx, decoderName, o.Decoder)
   141  	if err != nil {
   142  		return nil, fmt.Errorf("open variable %v: invalid decoder: %v", u, err)
   143  	}
   144  	opts := o.Options
   145  	if s := q.Get("wait"); s != "" {
   146  		q.Del("wait")
   147  		d, err := time.ParseDuration(s)
   148  		if err != nil {
   149  			return nil, fmt.Errorf("open variable %v: invalid wait %q: %v", u, s, err)
   150  		}
   151  		opts.WaitDuration = d
   152  	}
   153  	if o.UseV2 {
   154  		cfg, err := gcaws.V2ConfigFromURLParams(ctx, q)
   155  		if err != nil {
   156  			return nil, fmt.Errorf("open variable %v: %v", u, err)
   157  		}
   158  		return OpenVariableV2(secretsmanagerv2.NewFromConfig(cfg), path.Join(u.Host, u.Path), decoder, &opts)
   159  	}
   160  	configProvider := &gcaws.ConfigOverrider{
   161  		Base: o.ConfigProvider,
   162  	}
   163  	overrideCfg, err := gcaws.ConfigFromURLParams(q)
   164  	if err != nil {
   165  		return nil, fmt.Errorf("open variable %v: %v", u, err)
   166  	}
   167  
   168  	configProvider.Configs = append(configProvider.Configs, overrideCfg)
   169  
   170  	return OpenVariable(configProvider, path.Join(u.Host, u.Path), decoder, &opts)
   171  }
   172  
   173  // Options sets options.
   174  type Options struct {
   175  	// WaitDuration controls the rate at which AWS Secrets Manager is polled.
   176  	// Defaults to 30 seconds.
   177  	WaitDuration time.Duration
   178  }
   179  
   180  // OpenVariable constructs a *runtimevar.Variable backed by the variable name in AWS Secrets Manager.
   181  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
   182  // Secrets Manager returns raw bytes; provide a decoder to decode the raw bytes
   183  // into the appropriate type for runtimevar.Snapshot.Value.
   184  // See the runtimevar package documentation for examples of decoders.
   185  func OpenVariable(sess client.ConfigProvider, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   186  	return runtimevar.New(newWatcher(false, sess, nil, name, decoder, opts)), nil
   187  }
   188  
   189  // OpenVariableV2 constructs a *runtimevar.Variable backed by the variable name in AWS Secrets Manager,
   190  // using AWS SDK V2.
   191  // A friendly name of the secret must be specified. You can NOT specify the Amazon Resource Name (ARN).
   192  // Secrets Manager returns raw bytes; provide a decoder to decode the raw bytes
   193  // into the appropriate type for runtimevar.Snapshot.Value.
   194  // See the runtimevar package documentation for examples of decoders.
   195  func OpenVariableV2(client *secretsmanagerv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   196  	return runtimevar.New(newWatcher(true, nil, client, name, decoder, opts)), nil
   197  }
   198  
   199  // state implements driver.State.
   200  type state struct {
   201  	val        interface{}
   202  	rawGetV1   *secretsmanager.GetSecretValueOutput
   203  	rawGetV2   *secretsmanagerv2.GetSecretValueOutput
   204  	rawDescV1  *secretsmanager.DescribeSecretOutput
   205  	rawDescV2  *secretsmanagerv2.DescribeSecretOutput
   206  	updateTime time.Time
   207  	versionID  string
   208  	err        error
   209  }
   210  
   211  // Value implements driver.State.Value.
   212  func (s *state) Value() (interface{}, error) {
   213  	return s.val, s.err
   214  }
   215  
   216  // UpdateTime implements driver.State.UpdateTime.
   217  func (s *state) UpdateTime() time.Time {
   218  	return s.updateTime
   219  }
   220  
   221  // As implements driver.State.As.
   222  func (s *state) As(i interface{}) bool {
   223  	switch p := i.(type) {
   224  	case **secretsmanager.GetSecretValueOutput:
   225  		*p = s.rawGetV1
   226  	case **secretsmanagerv2.GetSecretValueOutput:
   227  		*p = s.rawGetV2
   228  	case **secretsmanager.DescribeSecretOutput:
   229  		*p = s.rawDescV1
   230  	case **secretsmanagerv2.DescribeSecretOutput:
   231  		*p = s.rawDescV2
   232  	default:
   233  		return false
   234  	}
   235  	return true
   236  }
   237  
   238  // errorState returns a new State with err, unless prevS also represents
   239  // the same error, in which case it returns nil.
   240  func errorState(err error, prevS driver.State) driver.State {
   241  	// Map to the more standard context package error.
   242  	if strings.Contains(err.Error(), "context deadline exceeded") {
   243  		err = context.DeadlineExceeded
   244  	} else if getErrorCode(err) == request.CanceledErrorCode {
   245  		err = context.Canceled
   246  	}
   247  	s := &state{err: err}
   248  	if prevS == nil {
   249  		return s
   250  	}
   251  	prev := prevS.(*state)
   252  	if prev.err == nil {
   253  		// New error.
   254  		return s
   255  	}
   256  	if equivalentError(err, prev.err) {
   257  		// Same error, return nil to indicate no change.
   258  		return nil
   259  	}
   260  	return s
   261  }
   262  
   263  // equivalentError returns true iff err1 and err2 represent an equivalent error;
   264  // i.e., we don't want to return it to the user as a different error.
   265  func equivalentError(err1, err2 error) bool {
   266  	if err1 == err2 || err1.Error() == err2.Error() {
   267  		return true
   268  	}
   269  	code1 := getErrorCode(err1)
   270  	code2 := getErrorCode(err2)
   271  	return code1 != "" && code1 == code2
   272  }
   273  
   274  type watcher struct {
   275  	// useV2 indicates whether we're using clientV2.
   276  	useV2 bool
   277  	// sess is the AWS session to use to talk to AWS.
   278  	sess client.ConfigProvider
   279  	// clientV2 is the client to use when useV2 is true.
   280  	clientV2 *secretsmanagerv2.Client
   281  	// name is an ID of a secret to retrieve.
   282  	name string
   283  	// wait is the amount of time to wait between querying AWS.
   284  	wait time.Duration
   285  	// decoder is the decoder that unmarshalls the value in the param.
   286  	decoder *runtimevar.Decoder
   287  }
   288  
   289  func newWatcher(useV2 bool, sess client.ConfigProvider, clientV2 *secretsmanagerv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) *watcher {
   290  	if opts == nil {
   291  		opts = &Options{}
   292  	}
   293  	return &watcher{
   294  		useV2:    useV2,
   295  		sess:     sess,
   296  		clientV2: clientV2,
   297  		name:     name,
   298  		wait:     driver.WaitDuration(opts.WaitDuration),
   299  		decoder:  decoder,
   300  	}
   301  }
   302  
   303  func getSecretValue(ctx context.Context, svc *secretsmanager.SecretsManager, secretID string) (string, []byte, string, *secretsmanager.GetSecretValueOutput, error) {
   304  	getResp, err := svc.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{
   305  		SecretId: aws.String(secretID),
   306  	})
   307  	if err != nil {
   308  		return "", nil, "", nil, err
   309  	}
   310  	return aws.StringValue(getResp.VersionId), getResp.SecretBinary, aws.StringValue(getResp.SecretString), getResp, nil
   311  }
   312  
   313  func getSecretValueV2(ctx context.Context, client *secretsmanagerv2.Client, secretID string) (string, []byte, string, *secretsmanagerv2.GetSecretValueOutput, error) {
   314  	getResp, err := client.GetSecretValue(ctx, &secretsmanagerv2.GetSecretValueInput{
   315  		SecretId: awsv2.String(secretID),
   316  	})
   317  	if err != nil {
   318  		return "", nil, "", nil, err
   319  	}
   320  	return awsv2.ToString(getResp.VersionId), getResp.SecretBinary, awsv2.ToString(getResp.SecretString), getResp, nil
   321  }
   322  
   323  func describeSecret(ctx context.Context, svc *secretsmanager.SecretsManager, secretID string) (time.Time, *secretsmanager.DescribeSecretOutput, error) {
   324  	descResp, err := svc.DescribeSecretWithContext(ctx, &secretsmanager.DescribeSecretInput{
   325  		SecretId: aws.String(secretID),
   326  	})
   327  	if err != nil {
   328  		return time.Time{}, nil, err
   329  	}
   330  	return aws.TimeValue(descResp.LastChangedDate), descResp, nil
   331  }
   332  
   333  func describeSecretV2(ctx context.Context, client *secretsmanagerv2.Client, secretID string) (time.Time, *secretsmanagerv2.DescribeSecretOutput, error) {
   334  	descResp, err := client.DescribeSecret(ctx, &secretsmanagerv2.DescribeSecretInput{
   335  		SecretId: awsv2.String(secretID),
   336  	})
   337  	if err != nil {
   338  		return time.Time{}, nil, err
   339  	}
   340  	return aws.TimeValue(descResp.LastChangedDate), descResp, nil
   341  }
   342  
   343  // WatchVariable implements driver.WatchVariable.
   344  func (w *watcher) WatchVariable(ctx context.Context, prev driver.State) (driver.State, time.Duration) {
   345  	var lastVersionID string
   346  	if prev != nil {
   347  		lastVersionID = prev.(*state).versionID
   348  	}
   349  	var svc *secretsmanager.SecretsManager
   350  	if !w.useV2 {
   351  		svc = secretsmanager.New(w.sess)
   352  	}
   353  
   354  	// GetParameter from S3 to get the current value and version.
   355  	var newVersionID string
   356  	var newValBinary []byte
   357  	var newValString string
   358  	var rawGetV1 *secretsmanager.GetSecretValueOutput
   359  	var rawGetV2 *secretsmanagerv2.GetSecretValueOutput
   360  	var err error
   361  	if w.useV2 {
   362  		newVersionID, newValBinary, newValString, rawGetV2, err = getSecretValueV2(ctx, w.clientV2, w.name)
   363  	} else {
   364  		newVersionID, newValBinary, newValString, rawGetV1, err = getSecretValue(ctx, svc, w.name)
   365  	}
   366  	if err != nil {
   367  		return errorState(err, prev), w.wait
   368  	}
   369  	if newVersionID == lastVersionID {
   370  		// Version hasn't changed, so no change; return nil.
   371  		return nil, w.wait
   372  	}
   373  	// Both SecretBinary and SecretString fields are not empty
   374  	// which could indicate some internal Secrets Manager issues.
   375  	// Hence, return explicit error instead of choosing one field over another.
   376  	if len(newValBinary) > 0 && newValString != "" {
   377  		err = fmt.Errorf("invalid %q response: both SecretBinary and SecretString are not empty", w.name)
   378  		return errorState(err, prev), w.wait
   379  	}
   380  
   381  	data := newValBinary
   382  	if len(data) == 0 {
   383  		if newValString == "" {
   384  			err = fmt.Errorf("invalid %q response: both SecretBinary and SecretString are empty", w.name)
   385  			return errorState(err, prev), w.wait
   386  		}
   387  		// SecretBinary is empty so use SecretString
   388  		data = []byte(newValString)
   389  	}
   390  
   391  	// DescribeParameters from S3 to get the LastModified date.
   392  	var newLastModified time.Time
   393  	var rawDescV1 *secretsmanager.DescribeSecretOutput
   394  	var rawDescV2 *secretsmanagerv2.DescribeSecretOutput
   395  	if w.useV2 {
   396  		newLastModified, rawDescV2, err = describeSecretV2(ctx, w.clientV2, w.name)
   397  	} else {
   398  		newLastModified, rawDescV1, err = describeSecret(ctx, svc, w.name)
   399  	}
   400  	if err != nil {
   401  		return errorState(err, prev), w.wait
   402  	}
   403  
   404  	// New value (or at least, new version). Decode it.
   405  	val, err := w.decoder.Decode(ctx, data)
   406  	if err != nil {
   407  		return errorState(err, prev), w.wait
   408  	}
   409  
   410  	return &state{
   411  		val:        val,
   412  		rawGetV1:   rawGetV1,
   413  		rawGetV2:   rawGetV2,
   414  		rawDescV1:  rawDescV1,
   415  		rawDescV2:  rawDescV2,
   416  		updateTime: newLastModified,
   417  		versionID:  newVersionID,
   418  	}, w.wait
   419  }
   420  
   421  // Close implements driver.Close.
   422  func (w *watcher) Close() error {
   423  	return nil
   424  }
   425  
   426  // ErrorAs implements driver.ErrorAs.
   427  func (w *watcher) ErrorAs(err error, i interface{}) bool {
   428  	if w.useV2 {
   429  		return errors.As(err, i)
   430  	}
   431  	switch v := err.(type) {
   432  	case awserr.Error:
   433  		if p, ok := i.(*awserr.Error); ok {
   434  			*p = v
   435  			return true
   436  		}
   437  	}
   438  	return false
   439  }
   440  
   441  func getErrorCode(err error) string {
   442  	if awsErr, ok := err.(awserr.Error); ok {
   443  		return awsErr.Code()
   444  	}
   445  	var ae smithy.APIError
   446  	if errors.As(err, &ae) {
   447  		return ae.ErrorCode()
   448  	}
   449  	return ""
   450  }
   451  
   452  // ErrorCode implements driver.ErrorCode.
   453  func (w *watcher) ErrorCode(err error) gcerrors.ErrorCode {
   454  	code := getErrorCode(err)
   455  	switch code {
   456  	case secretsmanager.ErrCodeResourceNotFoundException:
   457  		return gcerrors.NotFound
   458  
   459  	case secretsmanager.ErrCodeInvalidParameterException,
   460  		secretsmanager.ErrCodeInvalidRequestException,
   461  		secretsmanager.ErrCodeInvalidNextTokenException:
   462  		return gcerrors.InvalidArgument
   463  
   464  	case secretsmanager.ErrCodeEncryptionFailure,
   465  		secretsmanager.ErrCodeDecryptionFailure,
   466  		secretsmanager.ErrCodeInternalServiceError:
   467  		return gcerrors.Internal
   468  
   469  	case secretsmanager.ErrCodeResourceExistsException:
   470  		return gcerrors.AlreadyExists
   471  
   472  	case secretsmanager.ErrCodePreconditionNotMetException,
   473  		secretsmanager.ErrCodeMalformedPolicyDocumentException:
   474  		return gcerrors.FailedPrecondition
   475  
   476  	case secretsmanager.ErrCodeLimitExceededException:
   477  		return gcerrors.ResourceExhausted
   478  	}
   479  	return gcerrors.Unknown
   480  }