github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/runtimevar/awsparamstore/awsparamstore.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awsparamstore provides a runtimevar implementation with variables
    16  // read from AWS Systems Manager Parameter Store
    17  // (https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-paramstore.html)
    18  // Use OpenVariable to construct a *runtimevar.Variable.
    19  //
    20  // # URLs
    21  //
    22  // For runtimevar.OpenVariable, awsparamstore registers for the scheme "awsparamstore".
    23  // The default URL opener will use an AWS session with the default credentials
    24  // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
    25  // for more details.
    26  // To customize the URL opener, or for more details on the URL format,
    27  // see URLOpener.
    28  // See https://gocloud.dev/concepts/urls/ for background information.
    29  //
    30  // # As
    31  //
    32  // awsparamstore exposes the following types for As:
    33  //   - Snapshot: (V1) *ssm.GetParameterOutput, (V2) *ssmv2.GetParameterOutput
    34  //   - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError
    35  package awsparamstore // import "gocloud.dev/runtimevar/awsparamstore"
    36  
    37  import (
    38  	"context"
    39  	"errors"
    40  	"fmt"
    41  	"net/url"
    42  	"path"
    43  	"strings"
    44  	"sync"
    45  	"time"
    46  
    47  	awsv2 "github.com/aws/aws-sdk-go-v2/aws"
    48  	ssmv2 "github.com/aws/aws-sdk-go-v2/service/ssm"
    49  	"github.com/aws/aws-sdk-go/aws"
    50  	"github.com/aws/aws-sdk-go/aws/awserr"
    51  	"github.com/aws/aws-sdk-go/aws/client"
    52  	"github.com/aws/aws-sdk-go/aws/request"
    53  	"github.com/aws/aws-sdk-go/service/ssm"
    54  	"github.com/aws/smithy-go"
    55  	"github.com/google/wire"
    56  	gcaws "gocloud.dev/aws"
    57  	"gocloud.dev/gcerrors"
    58  	"gocloud.dev/runtimevar"
    59  	"gocloud.dev/runtimevar/driver"
    60  )
    61  
    62  func init() {
    63  	runtimevar.DefaultURLMux().RegisterVariable(Scheme, new(lazySessionOpener))
    64  }
    65  
    66  // Set holds Wire providers for this package.
    67  var Set = wire.NewSet(
    68  	wire.Struct(new(URLOpener), "ConfigProvider"),
    69  )
    70  
    71  // URLOpener opens AWS Paramstore URLs like "awsparamstore://myvar".
    72  //
    73  // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
    74  // or anything else to accept the default.
    75  //
    76  // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
    77  // for overriding the aws.Session from the URL.
    78  // For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
    79  //
    80  // In addition, the following URL parameters are supported:
    81  //   - decoder: The decoder to use. Defaults to URLOpener.Decoder, or
    82  //     runtimevar.BytesDecoder if URLOpener.Decoder is nil.
    83  //     See runtimevar.DecoderByName for supported values.
    84  //   - wait: The poll interval, in time.ParseDuration formats.
    85  //     Defaults to 30s.
    86  type URLOpener struct {
    87  	// UseV2 indicates whether the AWS SDK V2 should be used.
    88  	UseV2 bool
    89  
    90  	// ConfigProvider must be set to a non-nil value if UseV2 is false.
    91  	ConfigProvider client.ConfigProvider
    92  
    93  	// Decoder specifies the decoder to use if one is not specified in the URL.
    94  	// Defaults to runtimevar.BytesDecoder.
    95  	Decoder *runtimevar.Decoder
    96  
    97  	// Options specifies the options to pass to New.
    98  	Options Options
    99  }
   100  
   101  // lazySessionOpener obtains the AWS session from the environment on the first
   102  // call to OpenVariableURL.
   103  type lazySessionOpener struct {
   104  	init   sync.Once
   105  	opener *URLOpener
   106  	err    error
   107  }
   108  
   109  func (o *lazySessionOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   110  	if gcaws.UseV2(u.Query()) {
   111  		opener := &URLOpener{UseV2: true}
   112  		return opener.OpenVariableURL(ctx, u)
   113  	}
   114  	o.init.Do(func() {
   115  		sess, err := gcaws.NewDefaultSession()
   116  		if err != nil {
   117  			o.err = err
   118  			return
   119  		}
   120  		o.opener = &URLOpener{ConfigProvider: sess}
   121  	})
   122  	if o.err != nil {
   123  		return nil, fmt.Errorf("open variable %v: %v", u, o.err)
   124  	}
   125  	return o.opener.OpenVariableURL(ctx, u)
   126  }
   127  
   128  // Scheme is the URL scheme awsparamstore registers its URLOpener under on runtimevar.DefaultMux.
   129  const Scheme = "awsparamstore"
   130  
   131  // OpenVariableURL opens the variable at the URL's path. See the package doc
   132  // for more details.
   133  func (o *URLOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) {
   134  	q := u.Query()
   135  
   136  	decoderName := q.Get("decoder")
   137  	q.Del("decoder")
   138  	decoder, err := runtimevar.DecoderByName(ctx, decoderName, o.Decoder)
   139  	if err != nil {
   140  		return nil, fmt.Errorf("open variable %v: invalid decoder: %v", u, err)
   141  	}
   142  	opts := o.Options
   143  	if s := q.Get("wait"); s != "" {
   144  		q.Del("wait")
   145  		d, err := time.ParseDuration(s)
   146  		if err != nil {
   147  			return nil, fmt.Errorf("open variable %v: invalid wait %q: %v", u, s, err)
   148  		}
   149  		opts.WaitDuration = d
   150  	}
   151  
   152  	if o.UseV2 {
   153  		cfg, err := gcaws.V2ConfigFromURLParams(ctx, q)
   154  		if err != nil {
   155  			return nil, fmt.Errorf("open variable %v: %v", u, err)
   156  		}
   157  		return OpenVariableV2(ssmv2.NewFromConfig(cfg), path.Join(u.Host, u.Path), decoder, &opts)
   158  	}
   159  	configProvider := &gcaws.ConfigOverrider{
   160  		Base: o.ConfigProvider,
   161  	}
   162  	overrideCfg, err := gcaws.ConfigFromURLParams(q)
   163  	if err != nil {
   164  		return nil, fmt.Errorf("open variable %v: %v", u, err)
   165  	}
   166  	configProvider.Configs = append(configProvider.Configs, overrideCfg)
   167  	return OpenVariable(configProvider, path.Join(u.Host, u.Path), decoder, &opts)
   168  }
   169  
   170  // Options sets options.
   171  type Options struct {
   172  	// WaitDuration controls the rate at which Parameter Store is polled.
   173  	// Defaults to 30 seconds.
   174  	WaitDuration time.Duration
   175  }
   176  
   177  // OpenVariable constructs a *runtimevar.Variable backed by the variable name in
   178  // AWS Systems Manager Parameter Store.
   179  // Parameter Store returns raw bytes; provide a decoder to decode the raw bytes
   180  // into the appropriate type for runtimevar.Snapshot.Value.
   181  // See the runtimevar package documentation for examples of decoders.
   182  func OpenVariable(sess client.ConfigProvider, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   183  	return runtimevar.New(newWatcher(false, sess, nil, name, decoder, opts)), nil
   184  }
   185  
   186  // OpenVariableV2 constructs a *runtimevar.Variable backed by the variable name in
   187  // AWS Systems Manager Parameter Store, using AWS SDK V2.
   188  // Parameter Store returns raw bytes; provide a decoder to decode the raw bytes
   189  // into the appropriate type for runtimevar.Snapshot.Value.
   190  // See the runtimevar package documentation for examples of decoders.
   191  func OpenVariableV2(client *ssmv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) {
   192  	return runtimevar.New(newWatcher(true, nil, client, name, decoder, opts)), nil
   193  }
   194  
   195  func newWatcher(useV2 bool, sess client.ConfigProvider, clientV2 *ssmv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) *watcher {
   196  	if opts == nil {
   197  		opts = &Options{}
   198  	}
   199  	return &watcher{
   200  		useV2:    useV2,
   201  		sess:     sess,
   202  		clientV2: clientV2,
   203  		name:     name,
   204  		wait:     driver.WaitDuration(opts.WaitDuration),
   205  		decoder:  decoder,
   206  	}
   207  }
   208  
   209  // state implements driver.State.
   210  type state struct {
   211  	val        interface{}
   212  	rawGetV1   *ssm.GetParameterOutput
   213  	rawGetV2   *ssmv2.GetParameterOutput
   214  	updateTime time.Time
   215  	version    int64
   216  	err        error
   217  }
   218  
   219  // Value implements driver.State.Value.
   220  func (s *state) Value() (interface{}, error) {
   221  	return s.val, s.err
   222  }
   223  
   224  // UpdateTime implements driver.State.UpdateTime.
   225  func (s *state) UpdateTime() time.Time {
   226  	return s.updateTime
   227  }
   228  
   229  // As implements driver.State.As.
   230  func (s *state) As(i interface{}) bool {
   231  	switch p := i.(type) {
   232  	case **ssm.GetParameterOutput:
   233  		*p = s.rawGetV1
   234  	case **ssmv2.GetParameterOutput:
   235  		*p = s.rawGetV2
   236  	default:
   237  		return false
   238  	}
   239  	return true
   240  }
   241  
   242  // errorState returns a new State with err, unless prevS also represents
   243  // the same error, in which case it returns nil.
   244  func errorState(err error, prevS driver.State) driver.State {
   245  	// Map aws.RequestCanceled to the more standard context package errors.
   246  	if getErrorCode(err) == request.CanceledErrorCode {
   247  		msg := err.Error()
   248  		if strings.Contains(msg, "context deadline exceeded") {
   249  			err = context.DeadlineExceeded
   250  		} else {
   251  			err = context.Canceled
   252  		}
   253  	}
   254  	s := &state{err: err}
   255  	if prevS == nil {
   256  		return s
   257  	}
   258  	prev := prevS.(*state)
   259  	if prev.err == nil {
   260  		// New error.
   261  		return s
   262  	}
   263  	if equivalentError(err, prev.err) {
   264  		// Same error, return nil to indicate no change.
   265  		return nil
   266  	}
   267  	return s
   268  }
   269  
   270  // equivalentError returns true iff err1 and err2 represent an equivalent error;
   271  // i.e., we don't want to return it to the user as a different error.
   272  func equivalentError(err1, err2 error) bool {
   273  	if err1 == err2 || err1.Error() == err2.Error() {
   274  		return true
   275  	}
   276  	code1 := getErrorCode(err1)
   277  	code2 := getErrorCode(err2)
   278  	return code1 != "" && code1 == code2
   279  }
   280  
   281  type watcher struct {
   282  	// useV2 indicates whether we're using clientV2.
   283  	useV2 bool
   284  	// sess is the AWS session to use to talk to AWS.
   285  	sess client.ConfigProvider
   286  	// clientV2 is the client to use when useV2 is true.
   287  	clientV2 *ssmv2.Client
   288  	// name is the parameter to retrieve.
   289  	name string
   290  	// wait is the amount of time to wait between querying AWS.
   291  	wait time.Duration
   292  	// decoder is the decoder that unmarshals the value in the param.
   293  	decoder *runtimevar.Decoder
   294  }
   295  
   296  func getParameter(svc *ssm.SSM, name string) (int64, []byte, time.Time, *ssm.GetParameterOutput, error) {
   297  	getResp, err := svc.GetParameter(&ssm.GetParameterInput{
   298  		Name: aws.String(name),
   299  		// Ignored if the parameter is not encrypted.
   300  		WithDecryption: aws.Bool(true),
   301  	})
   302  	if err != nil {
   303  		return 0, nil, time.Time{}, nil, err
   304  	}
   305  	if getResp.Parameter == nil {
   306  		return 0, nil, time.Time{}, getResp, fmt.Errorf("unable to get %q parameter", name)
   307  	}
   308  	return aws.Int64Value(getResp.Parameter.Version), []byte(aws.StringValue(getResp.Parameter.Value)), aws.TimeValue(getResp.Parameter.LastModifiedDate), getResp, nil
   309  }
   310  
   311  func getParameterV2(ctx context.Context, client *ssmv2.Client, name string) (int64, []byte, time.Time, *ssmv2.GetParameterOutput, error) {
   312  	getResp, err := client.GetParameter(ctx, &ssmv2.GetParameterInput{
   313  		Name: aws.String(name),
   314  		// Ignored if the parameter is not encrypted.
   315  		WithDecryption: true,
   316  	})
   317  	if err != nil {
   318  		return 0, nil, time.Time{}, nil, err
   319  	}
   320  	if getResp.Parameter == nil {
   321  		return 0, nil, time.Time{}, getResp, fmt.Errorf("unable to get %q parameter", name)
   322  	}
   323  	return getResp.Parameter.Version, []byte(awsv2.ToString(getResp.Parameter.Value)), awsv2.ToTime(getResp.Parameter.LastModifiedDate), getResp, nil
   324  }
   325  
   326  func (w *watcher) WatchVariable(ctx context.Context, prev driver.State) (driver.State, time.Duration) {
   327  	lastVersion := int64(-1)
   328  	if prev != nil {
   329  		lastVersion = prev.(*state).version
   330  	}
   331  	var svc *ssm.SSM
   332  	if !w.useV2 {
   333  		svc = ssm.New(w.sess)
   334  	}
   335  
   336  	// GetParameter from S3 to get the current value and version.
   337  	var newVersion int64
   338  	var newVal []byte
   339  	var newLastModified time.Time
   340  	var rawGetV1 *ssm.GetParameterOutput
   341  	var rawGetV2 *ssmv2.GetParameterOutput
   342  	var err error
   343  	if w.useV2 {
   344  		newVersion, newVal, newLastModified, rawGetV2, err = getParameterV2(ctx, w.clientV2, w.name)
   345  	} else {
   346  		newVersion, newVal, newLastModified, rawGetV1, err = getParameter(svc, w.name)
   347  	}
   348  	if err != nil {
   349  		return errorState(err, prev), w.wait
   350  	}
   351  	if newVersion == lastVersion {
   352  		// Version hasn't changed, so no change; return nil.
   353  		return nil, w.wait
   354  	}
   355  
   356  	// New value (or at least, new version). Decode it.
   357  	val, err := w.decoder.Decode(ctx, newVal)
   358  	if err != nil {
   359  		return errorState(err, prev), w.wait
   360  	}
   361  	return &state{
   362  		val:        val,
   363  		rawGetV1:   rawGetV1,
   364  		rawGetV2:   rawGetV2,
   365  		updateTime: newLastModified,
   366  		version:    newVersion,
   367  	}, w.wait
   368  }
   369  
   370  // Close implements driver.Close.
   371  func (w *watcher) Close() error {
   372  	return nil
   373  }
   374  
   375  // ErrorAs implements driver.ErrorAs.
   376  func (w *watcher) ErrorAs(err error, i interface{}) bool {
   377  	if w.useV2 {
   378  		return errors.As(err, i)
   379  	}
   380  	switch v := err.(type) {
   381  	case awserr.Error:
   382  		if p, ok := i.(*awserr.Error); ok {
   383  			*p = v
   384  			return true
   385  		}
   386  	}
   387  	return false
   388  }
   389  
   390  func getErrorCode(err error) string {
   391  	if awsErr, ok := err.(awserr.Error); ok {
   392  		return awsErr.Code()
   393  	}
   394  	var ae smithy.APIError
   395  	if errors.As(err, &ae) {
   396  		return ae.ErrorCode()
   397  	}
   398  	return ""
   399  }
   400  
   401  // ErrorCode implements driver.ErrorCode.
   402  func (w *watcher) ErrorCode(err error) gcerrors.ErrorCode {
   403  	code := getErrorCode(err)
   404  	if code == "ParameterNotFound" {
   405  		return gcerrors.NotFound
   406  	}
   407  	return gcerrors.Unknown
   408  }