github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/runtimevar/awsparamstore/awsparamstore.go (about) 1 // Copyright 2018 The Go Cloud Development Kit Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package awsparamstore provides a runtimevar implementation with variables 16 // read from AWS Systems Manager Parameter Store 17 // (https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-paramstore.html) 18 // Use OpenVariable to construct a *runtimevar.Variable. 19 // 20 // # URLs 21 // 22 // For runtimevar.OpenVariable, awsparamstore registers for the scheme "awsparamstore". 23 // The default URL opener will use an AWS session with the default credentials 24 // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ 25 // for more details. 26 // To customize the URL opener, or for more details on the URL format, 27 // see URLOpener. 28 // See https://gocloud.dev/concepts/urls/ for background information. 29 // 30 // # As 31 // 32 // awsparamstore exposes the following types for As: 33 // - Snapshot: (V1) *ssm.GetParameterOutput, (V2) *ssmv2.GetParameterOutput 34 // - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError 35 package awsparamstore // import "gocloud.dev/runtimevar/awsparamstore" 36 37 import ( 38 "context" 39 "errors" 40 "fmt" 41 "net/url" 42 "path" 43 "strings" 44 "sync" 45 "time" 46 47 awsv2 "github.com/aws/aws-sdk-go-v2/aws" 48 ssmv2 "github.com/aws/aws-sdk-go-v2/service/ssm" 49 "github.com/aws/aws-sdk-go/aws" 50 "github.com/aws/aws-sdk-go/aws/awserr" 51 "github.com/aws/aws-sdk-go/aws/client" 52 "github.com/aws/aws-sdk-go/aws/request" 53 "github.com/aws/aws-sdk-go/service/ssm" 54 "github.com/aws/smithy-go" 55 "github.com/google/wire" 56 gcaws "gocloud.dev/aws" 57 "gocloud.dev/gcerrors" 58 "gocloud.dev/runtimevar" 59 "gocloud.dev/runtimevar/driver" 60 ) 61 62 func init() { 63 runtimevar.DefaultURLMux().RegisterVariable(Scheme, new(lazySessionOpener)) 64 } 65 66 // Set holds Wire providers for this package. 67 var Set = wire.NewSet( 68 wire.Struct(new(URLOpener), "ConfigProvider"), 69 ) 70 71 // URLOpener opens AWS Paramstore URLs like "awsparamstore://myvar". 72 // 73 // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2, 74 // or anything else to accept the default. 75 // 76 // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters 77 // for overriding the aws.Session from the URL. 78 // For V2, see gocloud.dev/aws/V2ConfigFromURLParams. 79 // 80 // In addition, the following URL parameters are supported: 81 // - decoder: The decoder to use. Defaults to URLOpener.Decoder, or 82 // runtimevar.BytesDecoder if URLOpener.Decoder is nil. 83 // See runtimevar.DecoderByName for supported values. 84 // - wait: The poll interval, in time.ParseDuration formats. 85 // Defaults to 30s. 86 type URLOpener struct { 87 // UseV2 indicates whether the AWS SDK V2 should be used. 88 UseV2 bool 89 90 // ConfigProvider must be set to a non-nil value if UseV2 is false. 91 ConfigProvider client.ConfigProvider 92 93 // Decoder specifies the decoder to use if one is not specified in the URL. 94 // Defaults to runtimevar.BytesDecoder. 95 Decoder *runtimevar.Decoder 96 97 // Options specifies the options to pass to New. 98 Options Options 99 } 100 101 // lazySessionOpener obtains the AWS session from the environment on the first 102 // call to OpenVariableURL. 103 type lazySessionOpener struct { 104 init sync.Once 105 opener *URLOpener 106 err error 107 } 108 109 func (o *lazySessionOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) { 110 if gcaws.UseV2(u.Query()) { 111 opener := &URLOpener{UseV2: true} 112 return opener.OpenVariableURL(ctx, u) 113 } 114 o.init.Do(func() { 115 sess, err := gcaws.NewDefaultSession() 116 if err != nil { 117 o.err = err 118 return 119 } 120 o.opener = &URLOpener{ConfigProvider: sess} 121 }) 122 if o.err != nil { 123 return nil, fmt.Errorf("open variable %v: %v", u, o.err) 124 } 125 return o.opener.OpenVariableURL(ctx, u) 126 } 127 128 // Scheme is the URL scheme awsparamstore registers its URLOpener under on runtimevar.DefaultMux. 129 const Scheme = "awsparamstore" 130 131 // OpenVariableURL opens the variable at the URL's path. See the package doc 132 // for more details. 133 func (o *URLOpener) OpenVariableURL(ctx context.Context, u *url.URL) (*runtimevar.Variable, error) { 134 q := u.Query() 135 136 decoderName := q.Get("decoder") 137 q.Del("decoder") 138 decoder, err := runtimevar.DecoderByName(ctx, decoderName, o.Decoder) 139 if err != nil { 140 return nil, fmt.Errorf("open variable %v: invalid decoder: %v", u, err) 141 } 142 opts := o.Options 143 if s := q.Get("wait"); s != "" { 144 q.Del("wait") 145 d, err := time.ParseDuration(s) 146 if err != nil { 147 return nil, fmt.Errorf("open variable %v: invalid wait %q: %v", u, s, err) 148 } 149 opts.WaitDuration = d 150 } 151 152 if o.UseV2 { 153 cfg, err := gcaws.V2ConfigFromURLParams(ctx, q) 154 if err != nil { 155 return nil, fmt.Errorf("open variable %v: %v", u, err) 156 } 157 return OpenVariableV2(ssmv2.NewFromConfig(cfg), path.Join(u.Host, u.Path), decoder, &opts) 158 } 159 configProvider := &gcaws.ConfigOverrider{ 160 Base: o.ConfigProvider, 161 } 162 overrideCfg, err := gcaws.ConfigFromURLParams(q) 163 if err != nil { 164 return nil, fmt.Errorf("open variable %v: %v", u, err) 165 } 166 configProvider.Configs = append(configProvider.Configs, overrideCfg) 167 return OpenVariable(configProvider, path.Join(u.Host, u.Path), decoder, &opts) 168 } 169 170 // Options sets options. 171 type Options struct { 172 // WaitDuration controls the rate at which Parameter Store is polled. 173 // Defaults to 30 seconds. 174 WaitDuration time.Duration 175 } 176 177 // OpenVariable constructs a *runtimevar.Variable backed by the variable name in 178 // AWS Systems Manager Parameter Store. 179 // Parameter Store returns raw bytes; provide a decoder to decode the raw bytes 180 // into the appropriate type for runtimevar.Snapshot.Value. 181 // See the runtimevar package documentation for examples of decoders. 182 func OpenVariable(sess client.ConfigProvider, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) { 183 return runtimevar.New(newWatcher(false, sess, nil, name, decoder, opts)), nil 184 } 185 186 // OpenVariableV2 constructs a *runtimevar.Variable backed by the variable name in 187 // AWS Systems Manager Parameter Store, using AWS SDK V2. 188 // Parameter Store returns raw bytes; provide a decoder to decode the raw bytes 189 // into the appropriate type for runtimevar.Snapshot.Value. 190 // See the runtimevar package documentation for examples of decoders. 191 func OpenVariableV2(client *ssmv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) (*runtimevar.Variable, error) { 192 return runtimevar.New(newWatcher(true, nil, client, name, decoder, opts)), nil 193 } 194 195 func newWatcher(useV2 bool, sess client.ConfigProvider, clientV2 *ssmv2.Client, name string, decoder *runtimevar.Decoder, opts *Options) *watcher { 196 if opts == nil { 197 opts = &Options{} 198 } 199 return &watcher{ 200 useV2: useV2, 201 sess: sess, 202 clientV2: clientV2, 203 name: name, 204 wait: driver.WaitDuration(opts.WaitDuration), 205 decoder: decoder, 206 } 207 } 208 209 // state implements driver.State. 210 type state struct { 211 val interface{} 212 rawGetV1 *ssm.GetParameterOutput 213 rawGetV2 *ssmv2.GetParameterOutput 214 updateTime time.Time 215 version int64 216 err error 217 } 218 219 // Value implements driver.State.Value. 220 func (s *state) Value() (interface{}, error) { 221 return s.val, s.err 222 } 223 224 // UpdateTime implements driver.State.UpdateTime. 225 func (s *state) UpdateTime() time.Time { 226 return s.updateTime 227 } 228 229 // As implements driver.State.As. 230 func (s *state) As(i interface{}) bool { 231 switch p := i.(type) { 232 case **ssm.GetParameterOutput: 233 *p = s.rawGetV1 234 case **ssmv2.GetParameterOutput: 235 *p = s.rawGetV2 236 default: 237 return false 238 } 239 return true 240 } 241 242 // errorState returns a new State with err, unless prevS also represents 243 // the same error, in which case it returns nil. 244 func errorState(err error, prevS driver.State) driver.State { 245 // Map aws.RequestCanceled to the more standard context package errors. 246 if getErrorCode(err) == request.CanceledErrorCode { 247 msg := err.Error() 248 if strings.Contains(msg, "context deadline exceeded") { 249 err = context.DeadlineExceeded 250 } else { 251 err = context.Canceled 252 } 253 } 254 s := &state{err: err} 255 if prevS == nil { 256 return s 257 } 258 prev := prevS.(*state) 259 if prev.err == nil { 260 // New error. 261 return s 262 } 263 if equivalentError(err, prev.err) { 264 // Same error, return nil to indicate no change. 265 return nil 266 } 267 return s 268 } 269 270 // equivalentError returns true iff err1 and err2 represent an equivalent error; 271 // i.e., we don't want to return it to the user as a different error. 272 func equivalentError(err1, err2 error) bool { 273 if err1 == err2 || err1.Error() == err2.Error() { 274 return true 275 } 276 code1 := getErrorCode(err1) 277 code2 := getErrorCode(err2) 278 return code1 != "" && code1 == code2 279 } 280 281 type watcher struct { 282 // useV2 indicates whether we're using clientV2. 283 useV2 bool 284 // sess is the AWS session to use to talk to AWS. 285 sess client.ConfigProvider 286 // clientV2 is the client to use when useV2 is true. 287 clientV2 *ssmv2.Client 288 // name is the parameter to retrieve. 289 name string 290 // wait is the amount of time to wait between querying AWS. 291 wait time.Duration 292 // decoder is the decoder that unmarshals the value in the param. 293 decoder *runtimevar.Decoder 294 } 295 296 func getParameter(svc *ssm.SSM, name string) (int64, []byte, time.Time, *ssm.GetParameterOutput, error) { 297 getResp, err := svc.GetParameter(&ssm.GetParameterInput{ 298 Name: aws.String(name), 299 // Ignored if the parameter is not encrypted. 300 WithDecryption: aws.Bool(true), 301 }) 302 if err != nil { 303 return 0, nil, time.Time{}, nil, err 304 } 305 if getResp.Parameter == nil { 306 return 0, nil, time.Time{}, getResp, fmt.Errorf("unable to get %q parameter", name) 307 } 308 return aws.Int64Value(getResp.Parameter.Version), []byte(aws.StringValue(getResp.Parameter.Value)), aws.TimeValue(getResp.Parameter.LastModifiedDate), getResp, nil 309 } 310 311 func getParameterV2(ctx context.Context, client *ssmv2.Client, name string) (int64, []byte, time.Time, *ssmv2.GetParameterOutput, error) { 312 getResp, err := client.GetParameter(ctx, &ssmv2.GetParameterInput{ 313 Name: aws.String(name), 314 // Ignored if the parameter is not encrypted. 315 WithDecryption: true, 316 }) 317 if err != nil { 318 return 0, nil, time.Time{}, nil, err 319 } 320 if getResp.Parameter == nil { 321 return 0, nil, time.Time{}, getResp, fmt.Errorf("unable to get %q parameter", name) 322 } 323 return getResp.Parameter.Version, []byte(awsv2.ToString(getResp.Parameter.Value)), awsv2.ToTime(getResp.Parameter.LastModifiedDate), getResp, nil 324 } 325 326 func (w *watcher) WatchVariable(ctx context.Context, prev driver.State) (driver.State, time.Duration) { 327 lastVersion := int64(-1) 328 if prev != nil { 329 lastVersion = prev.(*state).version 330 } 331 var svc *ssm.SSM 332 if !w.useV2 { 333 svc = ssm.New(w.sess) 334 } 335 336 // GetParameter from S3 to get the current value and version. 337 var newVersion int64 338 var newVal []byte 339 var newLastModified time.Time 340 var rawGetV1 *ssm.GetParameterOutput 341 var rawGetV2 *ssmv2.GetParameterOutput 342 var err error 343 if w.useV2 { 344 newVersion, newVal, newLastModified, rawGetV2, err = getParameterV2(ctx, w.clientV2, w.name) 345 } else { 346 newVersion, newVal, newLastModified, rawGetV1, err = getParameter(svc, w.name) 347 } 348 if err != nil { 349 return errorState(err, prev), w.wait 350 } 351 if newVersion == lastVersion { 352 // Version hasn't changed, so no change; return nil. 353 return nil, w.wait 354 } 355 356 // New value (or at least, new version). Decode it. 357 val, err := w.decoder.Decode(ctx, newVal) 358 if err != nil { 359 return errorState(err, prev), w.wait 360 } 361 return &state{ 362 val: val, 363 rawGetV1: rawGetV1, 364 rawGetV2: rawGetV2, 365 updateTime: newLastModified, 366 version: newVersion, 367 }, w.wait 368 } 369 370 // Close implements driver.Close. 371 func (w *watcher) Close() error { 372 return nil 373 } 374 375 // ErrorAs implements driver.ErrorAs. 376 func (w *watcher) ErrorAs(err error, i interface{}) bool { 377 if w.useV2 { 378 return errors.As(err, i) 379 } 380 switch v := err.(type) { 381 case awserr.Error: 382 if p, ok := i.(*awserr.Error); ok { 383 *p = v 384 return true 385 } 386 } 387 return false 388 } 389 390 func getErrorCode(err error) string { 391 if awsErr, ok := err.(awserr.Error); ok { 392 return awsErr.Code() 393 } 394 var ae smithy.APIError 395 if errors.As(err, &ae) { 396 return ae.ErrorCode() 397 } 398 return "" 399 } 400 401 // ErrorCode implements driver.ErrorCode. 402 func (w *watcher) ErrorCode(err error) gcerrors.ErrorCode { 403 code := getErrorCode(err) 404 if code == "ParameterNotFound" { 405 return gcerrors.NotFound 406 } 407 return gcerrors.Unknown 408 }