github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/pubsub/awssnssqs/awssnssqs.go (about) 1 // Copyright 2018 The Go Cloud Development Kit Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package awssnssqs provides two implementations of pubsub.Topic, one that 16 // sends messages to AWS SNS (Simple Notification Service), and one that sends 17 // messages to SQS (Simple Queuing Service). It also provides an implementation 18 // of pubsub.Subscription that receives messages from SQS. 19 // 20 // # URLs 21 // 22 // For pubsub.OpenTopic, awssnssqs registers for the scheme "awssns" for 23 // an SNS topic, and "awssqs" for an SQS topic. For pubsub.OpenSubscription, 24 // it registers for the scheme "awssqs". 25 // 26 // The default URL opener will use an AWS session with the default credentials 27 // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ 28 // for more details. 29 // To customize the URL opener, or for more details on the URL format, 30 // see URLOpener. 31 // See https://gocloud.dev/concepts/urls/ for background information. 32 // 33 // # Message Delivery Semantics 34 // 35 // AWS SQS supports at-least-once semantics; applications must call Message.Ack 36 // after processing a message, or it will be redelivered. 37 // See https://godoc.org/gocloud.dev/pubsub#hdr-At_most_once_and_At_least_once_Delivery 38 // for more background. 39 // 40 // # Escaping 41 // 42 // Go CDK supports all UTF-8 strings; to make this work with services lacking 43 // full UTF-8 support, strings must be escaped (during writes) and unescaped 44 // (during reads). The following escapes are required for awssnssqs: 45 // - Metadata keys: Characters other than "a-zA-z0-9_-.", and additionally "." 46 // when it's at the start of the key or the previous character was ".", 47 // are escaped using "__0x<hex>__". These characters were determined by 48 // experimentation. 49 // - Metadata values: Escaped using URL encoding. 50 // - Message body: AWS SNS/SQS only supports UTF-8 strings. See the 51 // BodyBase64Encoding enum in TopicOptions for strategies on how to send 52 // non-UTF-8 message bodies. By default, non-UTF-8 message bodies are base64 53 // encoded. 54 // 55 // # As 56 // 57 // awssnssqs exposes the following types for As: 58 // - Topic: (V1) *sns.SNS for OpenSNSTopic, *sqs.SQS for OpenSQSTopic; (V2) *snsv2.Client for OpenSNSTopicV2, *sqsv2.Client for OpenSQSTopicV2 59 // - Subscription: (V1) *sqs.SQS; (V2) *sqsv2.Client 60 // - Message: (V1) *sqs.Message; (V2) sqstypesv2.Message 61 // - Message.BeforeSend: (V1) *sns.PublishInput for OpenSNSTopic, *sqs.SendMessageBatchRequestEntry or *sqs.SendMessageInput(deprecated) for OpenSQSTopic; (V2) *snsv2.PublishInput for OpenSNSTopicV2, sqstypesv2.SendMessageBatchRequestEntry for OpenSQSTopicV2 62 // - Message.AfterSend: (V1) *sns.PublishOutput for OpenSNSTopic, *sqs.SendMessageBatchResultEntry for OpenSQSTopic; (V2) *snsv2.PublishOutput for OpenSNSTopicV2, sqstypesv2.SendMessageBatchResultEntry for OpenSQSTopicV2 63 // - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError 64 package awssnssqs // import "gocloud.dev/pubsub/awssnssqs" 65 66 import ( 67 "context" 68 "encoding/base64" 69 "encoding/json" 70 "errors" 71 "fmt" 72 "net/url" 73 "path" 74 "strconv" 75 "strings" 76 "sync" 77 "time" 78 "unicode/utf8" 79 80 snsv2 "github.com/aws/aws-sdk-go-v2/service/sns" 81 snstypesv2 "github.com/aws/aws-sdk-go-v2/service/sns/types" 82 sqsv2 "github.com/aws/aws-sdk-go-v2/service/sqs" 83 sqstypesv2 "github.com/aws/aws-sdk-go-v2/service/sqs/types" 84 "github.com/aws/aws-sdk-go/aws" 85 "github.com/aws/aws-sdk-go/aws/awserr" 86 "github.com/aws/aws-sdk-go/aws/client" 87 "github.com/aws/aws-sdk-go/service/sns" 88 "github.com/aws/aws-sdk-go/service/sqs" 89 "github.com/aws/smithy-go" 90 "github.com/google/wire" 91 gcaws "gocloud.dev/aws" 92 "gocloud.dev/gcerrors" 93 "gocloud.dev/internal/escape" 94 "gocloud.dev/pubsub" 95 "gocloud.dev/pubsub/batcher" 96 "gocloud.dev/pubsub/driver" 97 ) 98 99 const ( 100 // base64EncodedKey is the Message Attribute key used to flag that the 101 // message body is base64 encoded. 102 base64EncodedKey = "base64encoded" 103 // How long ReceiveBatch should wait if no messages are available; controls 104 // the poll interval of requests to SQS. 105 noMessagesPollDuration = 250 * time.Millisecond 106 ) 107 108 var sendBatcherOptsSNS = &batcher.Options{ 109 MaxBatchSize: 1, // SNS SendBatch only supports one message at a time 110 MaxHandlers: 100, // max concurrency for sends 111 } 112 113 var sendBatcherOptsSQS = &batcher.Options{ 114 MaxBatchSize: 10, // SQS SendBatch supports 10 messages at a time 115 MaxHandlers: 100, // max concurrency for sends 116 } 117 118 var recvBatcherOpts = &batcher.Options{ 119 // SQS supports receiving at most 10 messages at a time: 120 // https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.ReceiveMessage 121 MaxBatchSize: 10, 122 MaxHandlers: 100, // max concurrency for receives 123 } 124 125 var ackBatcherOpts = &batcher.Options{ 126 // SQS supports deleting/updating at most 10 messages at a time: 127 // https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.DeleteMessageBatch 128 // https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.ChangeMessageVisibilityBatch 129 MaxBatchSize: 10, 130 MaxHandlers: 100, // max concurrency for acks 131 } 132 133 func init() { 134 lazy := new(lazySessionOpener) 135 pubsub.DefaultURLMux().RegisterTopic(SNSScheme, lazy) 136 pubsub.DefaultURLMux().RegisterTopic(SQSScheme, lazy) 137 pubsub.DefaultURLMux().RegisterSubscription(SQSScheme, lazy) 138 } 139 140 // Set holds Wire providers for this package. 141 var Set = wire.NewSet( 142 wire.Struct(new(URLOpener), "ConfigProvider"), 143 ) 144 145 // lazySessionOpener obtains the AWS session from the environment on the first 146 // call to OpenXXXURL. 147 type lazySessionOpener struct { 148 init sync.Once 149 opener *URLOpener 150 err error 151 } 152 153 func (o *lazySessionOpener) defaultOpener(u *url.URL) (*URLOpener, error) { 154 if gcaws.UseV2(u.Query()) { 155 return &URLOpener{UseV2: true}, nil 156 } 157 o.init.Do(func() { 158 sess, err := gcaws.NewDefaultSession() 159 if err != nil { 160 o.err = err 161 return 162 } 163 o.opener = &URLOpener{ 164 ConfigProvider: sess, 165 } 166 }) 167 return o.opener, o.err 168 } 169 170 func (o *lazySessionOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic, error) { 171 opener, err := o.defaultOpener(u) 172 if err != nil { 173 return nil, fmt.Errorf("open topic %v: failed to open default session: %v", u, err) 174 } 175 return opener.OpenTopicURL(ctx, u) 176 } 177 178 func (o *lazySessionOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) { 179 opener, err := o.defaultOpener(u) 180 if err != nil { 181 return nil, fmt.Errorf("open subscription %v: failed to open default session: %v", u, err) 182 } 183 return opener.OpenSubscriptionURL(ctx, u) 184 } 185 186 // SNSScheme is the URL scheme for pubsub.OpenTopic (for an SNS topic) that 187 // awssnssqs registers its URLOpeners under on pubsub.DefaultMux. 188 const SNSScheme = "awssns" 189 190 // SQSScheme is the URL scheme for pubsub.OpenTopic (for an SQS topic) and for 191 // pubsub.OpenSubscription that awssnssqs registers its URLOpeners under on 192 // pubsub.DefaultMux. 193 const SQSScheme = "awssqs" 194 195 // URLOpener opens AWS SNS/SQS URLs like "awssns:///sns-topic-arn" for 196 // SNS topics or "awssqs://sqs-queue-url" for SQS topics and subscriptions. 197 // 198 // For SNS topics, the URL's host+path is used as the topic Amazon Resource Name 199 // (ARN). Since ARNs have ":" in them, and ":" precedes a port in URL 200 // hostnames, leave the host blank and put the ARN in the path 201 // (e.g., "awssns:///arn:aws:service:region:accountid:resourceType/resourcePath"). 202 // 203 // For SQS topics and subscriptions, the URL's host+path is prefixed with 204 // "https://" to create the queue URL. 205 // 206 // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2, 207 // or anything else to accept the default. 208 // 209 // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters 210 // for overriding the aws.Session from the URL. 211 // For V2, see gocloud.dev/aws/V2ConfigFromURLParams. 212 // 213 // In addition, the following query parameters are supported: 214 // 215 // - raw (for "awssqs" Subscriptions only): sets SubscriberOptions.Raw. The 216 // value must be parseable by `strconv.ParseBool`. 217 // - waittime: sets SubscriberOptions.WaitTime, in time.ParseDuration formats. 218 // 219 // See gocloud.dev/aws/ConfigFromURLParams for other query parameters 220 // that affect the default AWS session. 221 type URLOpener struct { 222 // UseV2 indicates whether the AWS SDK V2 should be used. 223 UseV2 bool 224 225 // ConfigProvider configures the connection to AWS. 226 // It must be set to a non-nil value if UseV2 is false. 227 ConfigProvider client.ConfigProvider 228 229 // TopicOptions specifies the options to pass to OpenTopic. 230 TopicOptions TopicOptions 231 // SubscriptionOptions specifies the options to pass to OpenSubscription. 232 SubscriptionOptions SubscriptionOptions 233 } 234 235 // OpenTopicURL opens a pubsub.Topic based on u. 236 func (o *URLOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic, error) { 237 // Trim leading "/" if host is empty, so that 238 // awssns:///arn:aws:service:region:accountid:resourceType/resourcePath 239 // gives "arn:..." instead of "/arn:...". 240 topicARN := strings.TrimPrefix(path.Join(u.Host, u.Path), "/") 241 qURL := "https://" + path.Join(u.Host, u.Path) 242 if o.UseV2 { 243 cfg, err := gcaws.V2ConfigFromURLParams(ctx, u.Query()) 244 if err != nil { 245 return nil, fmt.Errorf("open topic %v: %v", u, err) 246 } 247 switch u.Scheme { 248 case SNSScheme: 249 return OpenSNSTopicV2(ctx, snsv2.NewFromConfig(cfg), topicARN, &o.TopicOptions), nil 250 case SQSScheme: 251 return OpenSQSTopicV2(ctx, sqsv2.NewFromConfig(cfg), qURL, &o.TopicOptions), nil 252 default: 253 return nil, fmt.Errorf("open topic %v: unsupported scheme", u) 254 } 255 } 256 configProvider := &gcaws.ConfigOverrider{ 257 Base: o.ConfigProvider, 258 } 259 overrideCfg, err := gcaws.ConfigFromURLParams(u.Query()) 260 if err != nil { 261 return nil, fmt.Errorf("open topic %v: %v", u, err) 262 } 263 configProvider.Configs = append(configProvider.Configs, overrideCfg) 264 switch u.Scheme { 265 case SNSScheme: 266 return OpenSNSTopic(ctx, configProvider, topicARN, &o.TopicOptions), nil 267 case SQSScheme: 268 return OpenSQSTopic(ctx, configProvider, qURL, &o.TopicOptions), nil 269 default: 270 return nil, fmt.Errorf("open topic %v: unsupported scheme", u) 271 } 272 } 273 274 // OpenSubscriptionURL opens a pubsub.Subscription based on u. 275 func (o *URLOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) { 276 // Clone the options since we might override Raw. 277 opts := o.SubscriptionOptions 278 q := u.Query() 279 if rawStr := q.Get("raw"); rawStr != "" { 280 var err error 281 opts.Raw, err = strconv.ParseBool(rawStr) 282 if err != nil { 283 return nil, fmt.Errorf("invalid value %q for raw: %v", rawStr, err) 284 } 285 q.Del("raw") 286 } 287 if waitTimeStr := q.Get("waittime"); waitTimeStr != "" { 288 var err error 289 opts.WaitTime, err = time.ParseDuration(waitTimeStr) 290 if err != nil { 291 return nil, fmt.Errorf("invalid value %q for waittime: %v", waitTimeStr, err) 292 } 293 q.Del("waittime") 294 } 295 qURL := "https://" + path.Join(u.Host, u.Path) 296 if o.UseV2 { 297 cfg, err := gcaws.V2ConfigFromURLParams(ctx, q) 298 if err != nil { 299 return nil, fmt.Errorf("open subscription %v: %v", u, err) 300 } 301 return OpenSubscriptionV2(ctx, sqsv2.NewFromConfig(cfg), qURL, &opts), nil 302 } 303 overrideCfg, err := gcaws.ConfigFromURLParams(q) 304 if err != nil { 305 return nil, fmt.Errorf("open subscription %v: %v", u, err) 306 } 307 configProvider := &gcaws.ConfigOverrider{ 308 Base: o.ConfigProvider, 309 } 310 configProvider.Configs = append(configProvider.Configs, overrideCfg) 311 return OpenSubscription(ctx, configProvider, qURL, &opts), nil 312 } 313 314 type snsTopic struct { 315 useV2 bool 316 client *sns.SNS 317 clientV2 *snsv2.Client 318 arn string 319 opts *TopicOptions 320 } 321 322 // BodyBase64Encoding is an enum of strategies for when to base64 message 323 // bodies. 324 type BodyBase64Encoding int 325 326 const ( 327 // NonUTF8Only means that message bodies that are valid UTF-8 encodings are 328 // sent as-is. Invalid UTF-8 message bodies are base64 encoded, and a 329 // MessageAttribute with key "base64encoded" is added to the message. 330 // When receiving messages, the "base64encoded" attribute is used to determine 331 // whether to base64 decode, and is then filtered out. 332 NonUTF8Only BodyBase64Encoding = 0 333 // Always means that all message bodies are base64 encoded. 334 // A MessageAttribute with key "base64encoded" is added to the message. 335 // When receiving messages, the "base64encoded" attribute is used to determine 336 // whether to base64 decode, and is then filtered out. 337 Always BodyBase64Encoding = 1 338 // Never means that message bodies are never base64 encoded. Non-UTF-8 339 // bytes in message bodies may be modified by SNS/SQS. 340 Never BodyBase64Encoding = 2 341 ) 342 343 func (e BodyBase64Encoding) wantEncode(b []byte) bool { 344 switch e { 345 case Always: 346 return true 347 case Never: 348 return false 349 case NonUTF8Only: 350 return !utf8.Valid(b) 351 } 352 panic("unreachable") 353 } 354 355 // TopicOptions contains configuration options for topics. 356 type TopicOptions struct { 357 // BodyBase64Encoding determines when message bodies are base64 encoded. 358 // The default is NonUTF8Only. 359 BodyBase64Encoding BodyBase64Encoding 360 361 // BatcherOptions adds constraints to the default batching done for sends. 362 BatcherOptions batcher.Options 363 } 364 365 // OpenTopic is a shortcut for OpenSNSTopic, provided for backwards compatibility. 366 func OpenTopic(ctx context.Context, sess client.ConfigProvider, topicARN string, opts *TopicOptions) *pubsub.Topic { 367 return OpenSNSTopic(ctx, sess, topicARN, opts) 368 } 369 370 // OpenSNSTopic opens a topic that sends to the SNS topic with the given Amazon 371 // Resource Name (ARN). 372 func OpenSNSTopic(ctx context.Context, sess client.ConfigProvider, topicARN string, opts *TopicOptions) *pubsub.Topic { 373 bo := sendBatcherOptsSNS.NewMergedOptions(&opts.BatcherOptions) 374 return pubsub.NewTopic(openSNSTopic(ctx, sns.New(sess), topicARN, opts), bo) 375 } 376 377 // OpenSNSTopicV2 opens a topic that sends to the SNS topic with the given Amazon 378 // Resource Name (ARN), using AWS SDK V2. 379 func OpenSNSTopicV2(ctx context.Context, client *snsv2.Client, topicARN string, opts *TopicOptions) *pubsub.Topic { 380 bo := sendBatcherOptsSNS.NewMergedOptions(&opts.BatcherOptions) 381 return pubsub.NewTopic(openSNSTopicV2(ctx, client, topicARN, opts), bo) 382 } 383 384 // openSNSTopic returns the driver for OpenSNSTopic. This function exists so the test 385 // harness can get the driver interface implementation if it needs to. 386 func openSNSTopic(ctx context.Context, client *sns.SNS, topicARN string, opts *TopicOptions) driver.Topic { 387 if opts == nil { 388 opts = &TopicOptions{} 389 } 390 return &snsTopic{ 391 useV2: false, 392 client: client, 393 arn: topicARN, 394 opts: opts, 395 } 396 } 397 398 // openSNSTopicV2 returns the driver for OpenSNSTopic. This function exists so the test 399 // harness can get the driver interface implementation if it needs to. 400 func openSNSTopicV2(ctx context.Context, client *snsv2.Client, topicARN string, opts *TopicOptions) driver.Topic { 401 if opts == nil { 402 opts = &TopicOptions{} 403 } 404 return &snsTopic{ 405 useV2: true, 406 clientV2: client, 407 arn: topicARN, 408 opts: opts, 409 } 410 } 411 412 var stringDataType = aws.String("String") 413 414 // encodeMetadata encodes the keys and values of md as needed. 415 func encodeMetadata(md map[string]string) map[string]string { 416 retval := map[string]string{} 417 for k, v := range md { 418 // See the package comments for more details on escaping of metadata 419 // keys & values. 420 k = escape.HexEscape(k, func(runes []rune, i int) bool { 421 c := runes[i] 422 switch { 423 case escape.IsASCIIAlphanumeric(c): 424 return false 425 case c == '_' || c == '-': 426 return false 427 case c == '.' && i != 0 && runes[i-1] != '.': 428 return false 429 } 430 return true 431 }) 432 retval[k] = escape.URLEscape(v) 433 } 434 return retval 435 } 436 437 // maybeEncodeBody decides whether body should base64-encoded based on opt, and 438 // returns the (possibly encoded) body as a string, along with a boolean 439 // indicating whether encoding occurred. 440 func maybeEncodeBody(body []byte, opt BodyBase64Encoding) (string, bool) { 441 if opt.wantEncode(body) { 442 return base64.StdEncoding.EncodeToString(body), true 443 } 444 return string(body), false 445 } 446 447 // SendBatch implements driver.Topic.SendBatch. 448 func (t *snsTopic) SendBatch(ctx context.Context, dms []*driver.Message) error { 449 if len(dms) != 1 { 450 panic("snsTopic.SendBatch should only get one message at a time") 451 } 452 dm := dms[0] 453 454 if t.useV2 { 455 attrs := map[string]snstypesv2.MessageAttributeValue{} 456 for k, v := range encodeMetadata(dm.Metadata) { 457 attrs[k] = snstypesv2.MessageAttributeValue{ 458 DataType: stringDataType, 459 StringValue: aws.String(v), 460 } 461 } 462 body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding) 463 if didEncode { 464 attrs[base64EncodedKey] = snstypesv2.MessageAttributeValue{ 465 DataType: stringDataType, 466 StringValue: aws.String("true"), 467 } 468 } 469 if len(attrs) == 0 { 470 attrs = nil 471 } 472 input := &snsv2.PublishInput{ 473 Message: aws.String(body), 474 MessageAttributes: attrs, 475 TopicArn: &t.arn, 476 } 477 if dm.BeforeSend != nil { 478 asFunc := func(i interface{}) bool { 479 if p, ok := i.(**snsv2.PublishInput); ok { 480 *p = input 481 return true 482 } 483 return false 484 } 485 if err := dm.BeforeSend(asFunc); err != nil { 486 return err 487 } 488 } 489 po, err := t.clientV2.Publish(ctx, input) 490 if err != nil { 491 return err 492 } 493 if dm.AfterSend != nil { 494 asFunc := func(i interface{}) bool { 495 if p, ok := i.(**snsv2.PublishOutput); ok { 496 *p = po 497 return true 498 } 499 return false 500 } 501 if err := dm.AfterSend(asFunc); err != nil { 502 return err 503 } 504 } 505 return nil 506 } 507 attrs := map[string]*sns.MessageAttributeValue{} 508 for k, v := range encodeMetadata(dm.Metadata) { 509 attrs[k] = &sns.MessageAttributeValue{ 510 DataType: stringDataType, 511 StringValue: aws.String(v), 512 } 513 } 514 body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding) 515 if didEncode { 516 attrs[base64EncodedKey] = &sns.MessageAttributeValue{ 517 DataType: stringDataType, 518 StringValue: aws.String("true"), 519 } 520 } 521 if len(attrs) == 0 { 522 attrs = nil 523 } 524 input := &sns.PublishInput{ 525 Message: aws.String(body), 526 MessageAttributes: attrs, 527 TopicArn: &t.arn, 528 } 529 if dm.BeforeSend != nil { 530 asFunc := func(i interface{}) bool { 531 if p, ok := i.(**sns.PublishInput); ok { 532 *p = input 533 return true 534 } 535 return false 536 } 537 if err := dm.BeforeSend(asFunc); err != nil { 538 return err 539 } 540 } 541 po, err := t.client.PublishWithContext(ctx, input) 542 if err != nil { 543 return err 544 } 545 if dm.AfterSend != nil { 546 asFunc := func(i interface{}) bool { 547 if p, ok := i.(**sns.PublishOutput); ok { 548 *p = po 549 return true 550 } 551 return false 552 } 553 if err := dm.AfterSend(asFunc); err != nil { 554 return err 555 } 556 } 557 return nil 558 } 559 560 // IsRetryable implements driver.Topic.IsRetryable. 561 func (t *snsTopic) IsRetryable(error) bool { 562 // The client handles retries. 563 return false 564 } 565 566 // As implements driver.Topic.As. 567 func (t *snsTopic) As(i interface{}) bool { 568 if t.useV2 { 569 c, ok := i.(**snsv2.Client) 570 if !ok { 571 return false 572 } 573 *c = t.clientV2 574 return true 575 } 576 c, ok := i.(**sns.SNS) 577 if !ok { 578 return false 579 } 580 *c = t.client 581 return true 582 } 583 584 // ErrorAs implements driver.Topic.ErrorAs. 585 func (t *snsTopic) ErrorAs(err error, i interface{}) bool { 586 return errorAs(err, t.useV2, i) 587 } 588 589 // ErrorCode implements driver.Topic.ErrorCode. 590 func (t *snsTopic) ErrorCode(err error) gcerrors.ErrorCode { 591 return errorCode(err) 592 } 593 594 // Close implements driver.Topic.Close. 595 func (*snsTopic) Close() error { return nil } 596 597 type sqsTopic struct { 598 useV2 bool 599 client *sqs.SQS 600 clientV2 *sqsv2.Client 601 qURL string 602 opts *TopicOptions 603 } 604 605 // OpenSQSTopic opens a topic that sends to the SQS topic with the given SQS 606 // queue URL. 607 func OpenSQSTopic(ctx context.Context, sess client.ConfigProvider, qURL string, opts *TopicOptions) *pubsub.Topic { 608 bo := sendBatcherOptsSQS.NewMergedOptions(&opts.BatcherOptions) 609 return pubsub.NewTopic(openSQSTopic(ctx, sqs.New(sess), qURL, opts), bo) 610 } 611 612 // OpenSQSTopicV2 opens a topic that sends to the SQS topic with the given SQS 613 // queue URL, using AWS SDK V2. 614 func OpenSQSTopicV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *TopicOptions) *pubsub.Topic { 615 bo := sendBatcherOptsSQS.NewMergedOptions(&opts.BatcherOptions) 616 return pubsub.NewTopic(openSQSTopicV2(ctx, client, qURL, opts), bo) 617 } 618 619 // openSQSTopic returns the driver for OpenSQSTopic. This function exists so the test 620 // harness can get the driver interface implementation if it needs to. 621 func openSQSTopic(ctx context.Context, client *sqs.SQS, qURL string, opts *TopicOptions) driver.Topic { 622 if opts == nil { 623 opts = &TopicOptions{} 624 } 625 return &sqsTopic{ 626 useV2: false, 627 client: client, 628 qURL: qURL, 629 opts: opts, 630 } 631 } 632 633 // openSQSTopicV2 returns the driver for OpenSQSTopic. This function exists so the test 634 // harness can get the driver interface implementation if it needs to. 635 func openSQSTopicV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *TopicOptions) driver.Topic { 636 if opts == nil { 637 opts = &TopicOptions{} 638 } 639 return &sqsTopic{ 640 useV2: true, 641 clientV2: client, 642 qURL: qURL, 643 opts: opts, 644 } 645 } 646 647 // SendBatch implements driver.Topic.SendBatch. 648 func (t *sqsTopic) SendBatch(ctx context.Context, dms []*driver.Message) error { 649 if t.useV2 { 650 req := &sqsv2.SendMessageBatchInput{ 651 QueueUrl: aws.String(t.qURL), 652 } 653 for _, dm := range dms { 654 attrs := map[string]sqstypesv2.MessageAttributeValue{} 655 for k, v := range encodeMetadata(dm.Metadata) { 656 attrs[k] = sqstypesv2.MessageAttributeValue{ 657 DataType: stringDataType, 658 StringValue: aws.String(v), 659 } 660 } 661 body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding) 662 if didEncode { 663 attrs[base64EncodedKey] = sqstypesv2.MessageAttributeValue{ 664 DataType: stringDataType, 665 StringValue: aws.String("true"), 666 } 667 } 668 if len(attrs) == 0 { 669 attrs = nil 670 } 671 entry := sqstypesv2.SendMessageBatchRequestEntry{ 672 Id: aws.String(strconv.Itoa(len(req.Entries))), 673 MessageAttributes: attrs, 674 MessageBody: aws.String(body), 675 } 676 req.Entries = append(req.Entries, entry) 677 if dm.BeforeSend != nil { 678 asFunc := func(i interface{}) bool { 679 if p, ok := i.(*sqstypesv2.SendMessageBatchRequestEntry); ok { 680 *p = entry 681 return true 682 } 683 return false 684 } 685 if err := dm.BeforeSend(asFunc); err != nil { 686 return err 687 } 688 } 689 } 690 resp, err := t.clientV2.SendMessageBatch(ctx, req) 691 if err != nil { 692 return err 693 } 694 if numFailed := len(resp.Failed); numFailed > 0 { 695 first := resp.Failed[0] 696 return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.SendMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil) 697 } 698 if len(resp.Successful) == len(dms) { 699 for n, dm := range dms { 700 if dm.AfterSend != nil { 701 asFunc := func(i interface{}) bool { 702 if p, ok := i.(*sqstypesv2.SendMessageBatchResultEntry); ok { 703 *p = resp.Successful[n] 704 return true 705 } 706 return false 707 } 708 if err := dm.AfterSend(asFunc); err != nil { 709 return err 710 } 711 } 712 } 713 } 714 return nil 715 } 716 req := &sqs.SendMessageBatchInput{ 717 QueueUrl: aws.String(t.qURL), 718 } 719 for _, dm := range dms { 720 attrs := map[string]*sqs.MessageAttributeValue{} 721 for k, v := range encodeMetadata(dm.Metadata) { 722 attrs[k] = &sqs.MessageAttributeValue{ 723 DataType: stringDataType, 724 StringValue: aws.String(v), 725 } 726 } 727 body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding) 728 if didEncode { 729 attrs[base64EncodedKey] = &sqs.MessageAttributeValue{ 730 DataType: stringDataType, 731 StringValue: aws.String("true"), 732 } 733 } 734 if len(attrs) == 0 { 735 attrs = nil 736 } 737 entry := &sqs.SendMessageBatchRequestEntry{ 738 Id: aws.String(strconv.Itoa(len(req.Entries))), 739 MessageAttributes: attrs, 740 MessageBody: aws.String(body), 741 } 742 req.Entries = append(req.Entries, entry) 743 if dm.BeforeSend != nil { 744 // A previous revision used the non-batch API SendMessage, which takes 745 // a *sqs.SendMessageInput. For backwards compatibility for As, continue 746 // to support that type. If it is requested, create a SendMessageInput 747 // with the fields from SendMessageBatchRequestEntry that were set, and 748 // then copy all of the matching fields back after calling dm.BeforeSend. 749 var smi *sqs.SendMessageInput 750 asFunc := func(i interface{}) bool { 751 if p, ok := i.(**sqs.SendMessageInput); ok { 752 smi = &sqs.SendMessageInput{ 753 // Id does not exist on SendMessageInput. 754 MessageAttributes: entry.MessageAttributes, 755 MessageBody: entry.MessageBody, 756 } 757 *p = smi 758 return true 759 } 760 if p, ok := i.(**sqs.SendMessageBatchRequestEntry); ok { 761 *p = entry 762 return true 763 } 764 return false 765 } 766 if err := dm.BeforeSend(asFunc); err != nil { 767 return err 768 } 769 if smi != nil { 770 // Copy all of the fields that may have been modified back to the entry. 771 entry.DelaySeconds = smi.DelaySeconds 772 entry.MessageAttributes = smi.MessageAttributes 773 entry.MessageBody = smi.MessageBody 774 entry.MessageDeduplicationId = smi.MessageDeduplicationId 775 entry.MessageGroupId = smi.MessageGroupId 776 } 777 } 778 } 779 resp, err := t.client.SendMessageBatchWithContext(ctx, req) 780 if err != nil { 781 return err 782 } 783 if numFailed := len(resp.Failed); numFailed > 0 { 784 first := resp.Failed[0] 785 return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.SendMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil) 786 } 787 if len(resp.Successful) == len(dms) { 788 for n, dm := range dms { 789 if dm.AfterSend != nil { 790 asFunc := func(i interface{}) bool { 791 if p, ok := i.(**sqs.SendMessageBatchResultEntry); ok { 792 *p = resp.Successful[n] 793 return true 794 } 795 return false 796 } 797 if err := dm.AfterSend(asFunc); err != nil { 798 return err 799 } 800 } 801 } 802 } 803 return nil 804 } 805 806 // IsRetryable implements driver.Topic.IsRetryable. 807 func (t *sqsTopic) IsRetryable(error) bool { 808 // The client handles retries. 809 return false 810 } 811 812 // As implements driver.Topic.As. 813 func (t *sqsTopic) As(i interface{}) bool { 814 if t.useV2 { 815 c, ok := i.(**sqsv2.Client) 816 if !ok { 817 return false 818 } 819 *c = t.clientV2 820 return true 821 } 822 c, ok := i.(**sqs.SQS) 823 if !ok { 824 return false 825 } 826 *c = t.client 827 return true 828 } 829 830 // ErrorAs implements driver.Topic.ErrorAs. 831 func (t *sqsTopic) ErrorAs(err error, i interface{}) bool { 832 return errorAs(err, t.useV2, i) 833 } 834 835 // ErrorCode implements driver.Topic.ErrorCode. 836 func (t *sqsTopic) ErrorCode(err error) gcerrors.ErrorCode { 837 return errorCode(err) 838 } 839 840 // Close implements driver.Topic.Close. 841 func (*sqsTopic) Close() error { return nil } 842 843 func errorCode(err error) gcerrors.ErrorCode { 844 var code string 845 var ae smithy.APIError 846 if errors.As(err, &ae) { 847 code = ae.ErrorCode() 848 } else if ae, ok := err.(awserr.Error); ok { 849 code = ae.Code() 850 } else { 851 return gcerrors.Unknown 852 } 853 ec, ok := errorCodeMap[code] 854 if !ok { 855 return gcerrors.Unknown 856 } 857 return ec 858 } 859 860 var errorCodeMap = map[string]gcerrors.ErrorCode{ 861 sns.ErrCodeAuthorizationErrorException: gcerrors.PermissionDenied, 862 sns.ErrCodeKMSAccessDeniedException: gcerrors.PermissionDenied, 863 sns.ErrCodeKMSDisabledException: gcerrors.FailedPrecondition, 864 sns.ErrCodeKMSInvalidStateException: gcerrors.FailedPrecondition, 865 sns.ErrCodeKMSOptInRequired: gcerrors.FailedPrecondition, 866 sqs.ErrCodeMessageNotInflight: gcerrors.FailedPrecondition, 867 sqs.ErrCodePurgeQueueInProgress: gcerrors.FailedPrecondition, 868 sqs.ErrCodeQueueDeletedRecently: gcerrors.FailedPrecondition, 869 sqs.ErrCodeQueueNameExists: gcerrors.FailedPrecondition, 870 sns.ErrCodeInternalErrorException: gcerrors.Internal, 871 sns.ErrCodeInvalidParameterException: gcerrors.InvalidArgument, 872 sns.ErrCodeInvalidParameterValueException: gcerrors.InvalidArgument, 873 sqs.ErrCodeBatchEntryIdsNotDistinct: gcerrors.InvalidArgument, 874 sqs.ErrCodeBatchRequestTooLong: gcerrors.InvalidArgument, 875 sqs.ErrCodeEmptyBatchRequest: gcerrors.InvalidArgument, 876 sqs.ErrCodeInvalidAttributeName: gcerrors.InvalidArgument, 877 sqs.ErrCodeInvalidBatchEntryId: gcerrors.InvalidArgument, 878 sqs.ErrCodeInvalidIdFormat: gcerrors.InvalidArgument, 879 sqs.ErrCodeInvalidMessageContents: gcerrors.InvalidArgument, 880 sqs.ErrCodeReceiptHandleIsInvalid: gcerrors.InvalidArgument, 881 sqs.ErrCodeTooManyEntriesInBatchRequest: gcerrors.InvalidArgument, 882 sqs.ErrCodeUnsupportedOperation: gcerrors.InvalidArgument, 883 sns.ErrCodeInvalidSecurityException: gcerrors.PermissionDenied, 884 sns.ErrCodeKMSNotFoundException: gcerrors.NotFound, 885 sns.ErrCodeNotFoundException: gcerrors.NotFound, 886 sqs.ErrCodeQueueDoesNotExist: gcerrors.NotFound, 887 sns.ErrCodeFilterPolicyLimitExceededException: gcerrors.ResourceExhausted, 888 sns.ErrCodeSubscriptionLimitExceededException: gcerrors.ResourceExhausted, 889 sns.ErrCodeTopicLimitExceededException: gcerrors.ResourceExhausted, 890 sqs.ErrCodeOverLimit: gcerrors.ResourceExhausted, 891 sns.ErrCodeKMSThrottlingException: gcerrors.ResourceExhausted, 892 sns.ErrCodeThrottledException: gcerrors.ResourceExhausted, 893 "RequestCanceled": gcerrors.Canceled, 894 sns.ErrCodeEndpointDisabledException: gcerrors.Unknown, 895 sns.ErrCodePlatformApplicationDisabledException: gcerrors.Unknown, 896 } 897 898 type subscription struct { 899 useV2 bool 900 client *sqs.SQS 901 clientV2 *sqsv2.Client 902 qURL string 903 opts *SubscriptionOptions 904 } 905 906 // SubscriptionOptions will contain configuration for subscriptions. 907 type SubscriptionOptions struct { 908 // Raw determines how the Subscription will process message bodies. 909 // 910 // If the subscription is expected to process messages sent directly to 911 // SQS, or messages from SNS topics configured to use "raw" delivery, 912 // set this to true. Message bodies will be passed through untouched. 913 // 914 // If false, the Subscription will use best-effort heuristics to 915 // identify whether message bodies are raw or SNS JSON; this may be 916 // inefficient for raw messages. 917 // 918 // See https://aws.amazon.com/sns/faqs/#Raw_message_delivery. 919 Raw bool 920 921 // WaitTime passed to ReceiveMessage to enable long polling. 922 // https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html#sqs-long-polling. 923 // Note that a non-zero WaitTime can delay delivery of messages 924 // by up to that duration. 925 WaitTime time.Duration 926 927 // ReceiveBatcherOptions adds constraints to the default batching done for receives. 928 ReceiveBatcherOptions batcher.Options 929 930 // AckBatcherOptions adds constraints to the default batching done for acks. 931 AckBatcherOptions batcher.Options 932 } 933 934 // OpenSubscription opens a subscription based on AWS SQS for the given SQS 935 // queue URL. The queue is assumed to be subscribed to some SNS topic, though 936 // there is no check for this. 937 func OpenSubscription(ctx context.Context, sess client.ConfigProvider, qURL string, opts *SubscriptionOptions) *pubsub.Subscription { 938 rbo := recvBatcherOpts.NewMergedOptions(&opts.ReceiveBatcherOptions) 939 abo := ackBatcherOpts.NewMergedOptions(&opts.AckBatcherOptions) 940 return pubsub.NewSubscription(openSubscription(ctx, sqs.New(sess), qURL, opts), rbo, abo) 941 } 942 943 // OpenSubscriptionV2 opens a subscription based on AWS SQS for the given SQS 944 // queue URL, using AWS SDK V2. The queue is assumed to be subscribed to some SNS topic, though 945 // there is no check for this. 946 func OpenSubscriptionV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *SubscriptionOptions) *pubsub.Subscription { 947 rbo := recvBatcherOpts.NewMergedOptions(&opts.ReceiveBatcherOptions) 948 abo := ackBatcherOpts.NewMergedOptions(&opts.AckBatcherOptions) 949 return pubsub.NewSubscription(openSubscriptionV2(ctx, client, qURL, opts), rbo, abo) 950 } 951 952 // openSubscription returns a driver.Subscription. 953 func openSubscription(ctx context.Context, client *sqs.SQS, qURL string, opts *SubscriptionOptions) driver.Subscription { 954 if opts == nil { 955 opts = &SubscriptionOptions{} 956 } 957 return &subscription{ 958 useV2: false, 959 client: client, 960 qURL: qURL, opts: opts, 961 } 962 } 963 964 // openSubscriptionV2 returns a driver.Subscription. 965 func openSubscriptionV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *SubscriptionOptions) driver.Subscription { 966 if opts == nil { 967 opts = &SubscriptionOptions{} 968 } 969 return &subscription{ 970 useV2: true, 971 clientV2: client, 972 qURL: qURL, opts: opts, 973 } 974 } 975 976 // ReceiveBatch implements driver.Subscription.ReceiveBatch. 977 func (s *subscription) ReceiveBatch(ctx context.Context, maxMessages int) ([]*driver.Message, error) { 978 var ms []*driver.Message 979 if s.useV2 { 980 req := &sqsv2.ReceiveMessageInput{ 981 QueueUrl: aws.String(s.qURL), 982 MaxNumberOfMessages: int32(maxMessages), 983 MessageAttributeNames: []string{"All"}, 984 AttributeNames: []sqstypesv2.QueueAttributeName{"All"}, 985 } 986 if s.opts.WaitTime != 0 { 987 req.WaitTimeSeconds = int32(s.opts.WaitTime.Seconds()) 988 } 989 output, err := s.clientV2.ReceiveMessage(ctx, req) 990 if err != nil { 991 return nil, err 992 } 993 for _, m := range output.Messages { 994 m := m 995 bodyStr := aws.StringValue(m.Body) 996 rawAttrs := map[string]string{} 997 for k, v := range m.MessageAttributes { 998 rawAttrs[k] = aws.StringValue(v.StringValue) 999 } 1000 bodyStr, rawAttrs = extractBody(bodyStr, rawAttrs, s.opts.Raw) 1001 1002 decodeIt := false 1003 attrs := map[string]string{} 1004 for k, v := range rawAttrs { 1005 // See BodyBase64Encoding for details on when we base64 decode message bodies. 1006 if k == base64EncodedKey { 1007 decodeIt = true 1008 continue 1009 } 1010 // See the package comments for more details on escaping of metadata 1011 // keys & values. 1012 attrs[escape.HexUnescape(k)] = escape.URLUnescape(v) 1013 } 1014 1015 var b []byte 1016 if decodeIt { 1017 var err error 1018 b, err = base64.StdEncoding.DecodeString(bodyStr) 1019 if err != nil { 1020 // Fall back to using the raw message. 1021 b = []byte(bodyStr) 1022 } 1023 } else { 1024 b = []byte(bodyStr) 1025 } 1026 1027 m2 := &driver.Message{ 1028 LoggableID: aws.StringValue(m.MessageId), 1029 Body: b, 1030 Metadata: attrs, 1031 AckID: m.ReceiptHandle, 1032 AsFunc: func(i interface{}) bool { 1033 p, ok := i.(*sqstypesv2.Message) 1034 if !ok { 1035 return false 1036 } 1037 *p = m 1038 return true 1039 }, 1040 } 1041 ms = append(ms, m2) 1042 } 1043 } else { 1044 req := &sqs.ReceiveMessageInput{ 1045 QueueUrl: aws.String(s.qURL), 1046 MaxNumberOfMessages: aws.Int64(int64(maxMessages)), 1047 MessageAttributeNames: []*string{aws.String("All")}, 1048 AttributeNames: []*string{aws.String("All")}, 1049 } 1050 if s.opts.WaitTime != 0 { 1051 req.WaitTimeSeconds = aws.Int64(int64(s.opts.WaitTime.Seconds())) 1052 } 1053 output, err := s.client.ReceiveMessageWithContext(ctx, req) 1054 if err != nil { 1055 return nil, err 1056 } 1057 for _, m := range output.Messages { 1058 m := m 1059 bodyStr := aws.StringValue(m.Body) 1060 rawAttrs := map[string]string{} 1061 for k, v := range m.MessageAttributes { 1062 rawAttrs[k] = aws.StringValue(v.StringValue) 1063 } 1064 bodyStr, rawAttrs = extractBody(bodyStr, rawAttrs, s.opts.Raw) 1065 1066 decodeIt := false 1067 attrs := map[string]string{} 1068 for k, v := range rawAttrs { 1069 // See BodyBase64Encoding for details on when we base64 decode message bodies. 1070 if k == base64EncodedKey { 1071 decodeIt = true 1072 continue 1073 } 1074 // See the package comments for more details on escaping of metadata 1075 // keys & values. 1076 attrs[escape.HexUnescape(k)] = escape.URLUnescape(v) 1077 } 1078 1079 var b []byte 1080 if decodeIt { 1081 var err error 1082 b, err = base64.StdEncoding.DecodeString(bodyStr) 1083 if err != nil { 1084 // Fall back to using the raw message. 1085 b = []byte(bodyStr) 1086 } 1087 } else { 1088 b = []byte(bodyStr) 1089 } 1090 1091 m2 := &driver.Message{ 1092 LoggableID: aws.StringValue(m.MessageId), 1093 Body: b, 1094 Metadata: attrs, 1095 AckID: m.ReceiptHandle, 1096 AsFunc: func(i interface{}) bool { 1097 p, ok := i.(**sqs.Message) 1098 if !ok { 1099 return false 1100 } 1101 *p = m 1102 return true 1103 }, 1104 } 1105 ms = append(ms, m2) 1106 } 1107 } 1108 if len(ms) == 0 { 1109 // When we return no messages and no error, the portable type will call 1110 // ReceiveBatch again immediately. Sleep for a bit to avoid hammering SQS 1111 // with RPCs. 1112 time.Sleep(noMessagesPollDuration) 1113 } 1114 return ms, nil 1115 } 1116 1117 func extractBody(bodyStr string, rawAttrs map[string]string, raw bool) (body string, attributes map[string]string) { 1118 // If the user told us that message bodies are raw, or if there are 1119 // top-level MessageAttributes, then it's raw. 1120 // (SNS JSON message can have attributes, but they are encoded in 1121 // the JSON instead of being at the top level). 1122 raw = raw || len(rawAttrs) > 0 1123 if raw { 1124 // For raw messages, the attributes are at the top level 1125 // and we leave bodyStr alone. 1126 return bodyStr, rawAttrs 1127 } 1128 1129 // It might be SNS JSON; try to parse the raw body as such. 1130 // https://aws.amazon.com/sns/faqs/#Raw_message_delivery 1131 // If it parses as JSON and has a TopicArn field, assume it's SNS JSON. 1132 var bodyJSON struct { 1133 TopicArn string 1134 Message string 1135 MessageAttributes map[string]struct{ Value string } 1136 } 1137 if err := json.Unmarshal([]byte(bodyStr), &bodyJSON); err == nil && bodyJSON.TopicArn != "" { 1138 // It looks like SNS JSON. Get attributes from the decoded struct, 1139 // and update the body to be the JSON Message field. 1140 for k, v := range bodyJSON.MessageAttributes { 1141 rawAttrs[k] = v.Value 1142 } 1143 return bodyJSON.Message, rawAttrs 1144 } 1145 // It doesn't look like SNS JSON, either because it 1146 // isn't JSON or because the JSON doesn't have a TopicArn 1147 // field. Treat it as raw. 1148 // 1149 // As above in the other "raw" case, we leave bodyStr 1150 // alone. There can't be any top-level attributes (because 1151 // then we would have known it was raw earlier). 1152 return bodyStr, rawAttrs 1153 } 1154 1155 // SendAcks implements driver.Subscription.SendAcks. 1156 func (s *subscription) SendAcks(ctx context.Context, ids []driver.AckID) error { 1157 if s.useV2 { 1158 req := &sqsv2.DeleteMessageBatchInput{QueueUrl: aws.String(s.qURL)} 1159 for _, id := range ids { 1160 req.Entries = append(req.Entries, sqstypesv2.DeleteMessageBatchRequestEntry{ 1161 Id: aws.String(strconv.Itoa(len(req.Entries))), 1162 ReceiptHandle: id.(*string), 1163 }) 1164 } 1165 resp, err := s.clientV2.DeleteMessageBatch(ctx, req) 1166 if err != nil { 1167 return err 1168 } 1169 // Note: DeleteMessageBatch doesn't return failures when you try 1170 // to Delete an id that isn't found. 1171 if numFailed := len(resp.Failed); numFailed > 0 { 1172 first := resp.Failed[0] 1173 return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.DeleteMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil) 1174 } 1175 return nil 1176 } 1177 req := &sqs.DeleteMessageBatchInput{QueueUrl: aws.String(s.qURL)} 1178 for _, id := range ids { 1179 req.Entries = append(req.Entries, &sqs.DeleteMessageBatchRequestEntry{ 1180 Id: aws.String(strconv.Itoa(len(req.Entries))), 1181 ReceiptHandle: id.(*string), 1182 }) 1183 } 1184 resp, err := s.client.DeleteMessageBatchWithContext(ctx, req) 1185 if err != nil { 1186 return err 1187 } 1188 // Note: DeleteMessageBatch doesn't return failures when you try 1189 // to Delete an id that isn't found. 1190 if numFailed := len(resp.Failed); numFailed > 0 { 1191 first := resp.Failed[0] 1192 return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.DeleteMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil) 1193 } 1194 return nil 1195 } 1196 1197 // CanNack implements driver.CanNack. 1198 func (s *subscription) CanNack() bool { return true } 1199 1200 // SendNacks implements driver.Subscription.SendNacks. 1201 func (s *subscription) SendNacks(ctx context.Context, ids []driver.AckID) error { 1202 if s.useV2 { 1203 req := &sqsv2.ChangeMessageVisibilityBatchInput{QueueUrl: aws.String(s.qURL)} 1204 for _, id := range ids { 1205 req.Entries = append(req.Entries, sqstypesv2.ChangeMessageVisibilityBatchRequestEntry{ 1206 Id: aws.String(strconv.Itoa(len(req.Entries))), 1207 ReceiptHandle: id.(*string), 1208 VisibilityTimeout: 1, 1209 }) 1210 } 1211 resp, err := s.clientV2.ChangeMessageVisibilityBatch(ctx, req) 1212 if err != nil { 1213 return err 1214 } 1215 // Note: ChangeMessageVisibilityBatch returns failures when you try to 1216 // modify an id that isn't found; drop those. 1217 var firstFail sqstypesv2.BatchResultErrorEntry 1218 numFailed := 0 1219 for _, fail := range resp.Failed { 1220 if aws.StringValue(fail.Code) == sqs.ErrCodeReceiptHandleIsInvalid { 1221 continue 1222 } 1223 if numFailed == 0 { 1224 firstFail = fail 1225 } 1226 numFailed++ 1227 } 1228 if numFailed > 0 { 1229 return awserr.New(aws.StringValue(firstFail.Code), fmt.Sprintf("sqs.ChangeMessageVisibilityBatch failed for %d message(s): %s", numFailed, aws.StringValue(firstFail.Message)), nil) 1230 } 1231 return nil 1232 } 1233 req := &sqs.ChangeMessageVisibilityBatchInput{QueueUrl: aws.String(s.qURL)} 1234 for _, id := range ids { 1235 req.Entries = append(req.Entries, &sqs.ChangeMessageVisibilityBatchRequestEntry{ 1236 Id: aws.String(strconv.Itoa(len(req.Entries))), 1237 ReceiptHandle: id.(*string), 1238 VisibilityTimeout: aws.Int64(0), 1239 }) 1240 } 1241 resp, err := s.client.ChangeMessageVisibilityBatchWithContext(ctx, req) 1242 if err != nil { 1243 return err 1244 } 1245 // Note: ChangeMessageVisibilityBatch returns failures when you try to 1246 // modify an id that isn't found; drop those. 1247 var firstFail *sqs.BatchResultErrorEntry 1248 numFailed := 0 1249 for _, fail := range resp.Failed { 1250 if aws.StringValue(fail.Code) == sqs.ErrCodeReceiptHandleIsInvalid { 1251 continue 1252 } 1253 if numFailed == 0 { 1254 firstFail = fail 1255 } 1256 numFailed++ 1257 } 1258 if numFailed > 0 { 1259 return awserr.New(aws.StringValue(firstFail.Code), fmt.Sprintf("sqs.ChangeMessageVisibilityBatch failed for %d message(s): %s", numFailed, aws.StringValue(firstFail.Message)), nil) 1260 } 1261 return nil 1262 } 1263 1264 // IsRetryable implements driver.Subscription.IsRetryable. 1265 func (*subscription) IsRetryable(error) bool { 1266 // The client handles retries. 1267 return false 1268 } 1269 1270 // As implements driver.Subscription.As. 1271 func (s *subscription) As(i interface{}) bool { 1272 if s.useV2 { 1273 c, ok := i.(**sqsv2.Client) 1274 if !ok { 1275 return false 1276 } 1277 *c = s.clientV2 1278 return true 1279 } 1280 c, ok := i.(**sqs.SQS) 1281 if !ok { 1282 return false 1283 } 1284 *c = s.client 1285 return true 1286 } 1287 1288 // ErrorAs implements driver.Subscription.ErrorAs. 1289 func (s *subscription) ErrorAs(err error, i interface{}) bool { 1290 return errorAs(err, s.useV2, i) 1291 } 1292 1293 // ErrorCode implements driver.Subscription.ErrorCode. 1294 func (s *subscription) ErrorCode(err error) gcerrors.ErrorCode { 1295 return errorCode(err) 1296 } 1297 1298 func errorAs(err error, useV2 bool, i interface{}) bool { 1299 if useV2 { 1300 return errors.As(err, i) 1301 } 1302 e, ok := err.(awserr.Error) 1303 if !ok { 1304 return false 1305 } 1306 p, ok := i.(*awserr.Error) 1307 if !ok { 1308 return false 1309 } 1310 *p = e 1311 return true 1312 } 1313 1314 // Close implements driver.Subscription.Close. 1315 func (*subscription) Close() error { return nil }