github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/pubsub/awssnssqs/awssnssqs.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awssnssqs provides two implementations of pubsub.Topic, one that
    16  // sends messages to AWS SNS (Simple Notification Service), and one that sends
    17  // messages to SQS (Simple Queuing Service). It also provides an implementation
    18  // of pubsub.Subscription that receives messages from SQS.
    19  //
    20  // # URLs
    21  //
    22  // For pubsub.OpenTopic, awssnssqs registers for the scheme "awssns" for
    23  // an SNS topic, and "awssqs" for an SQS topic. For pubsub.OpenSubscription,
    24  // it registers for the scheme "awssqs".
    25  //
    26  // The default URL opener will use an AWS session with the default credentials
    27  // and configuration; see https://docs.aws.amazon.com/sdk-for-go/api/aws/session/
    28  // for more details.
    29  // To customize the URL opener, or for more details on the URL format,
    30  // see URLOpener.
    31  // See https://gocloud.dev/concepts/urls/ for background information.
    32  //
    33  // # Message Delivery Semantics
    34  //
    35  // AWS SQS supports at-least-once semantics; applications must call Message.Ack
    36  // after processing a message, or it will be redelivered.
    37  // See https://godoc.org/gocloud.dev/pubsub#hdr-At_most_once_and_At_least_once_Delivery
    38  // for more background.
    39  //
    40  // # Escaping
    41  //
    42  // Go CDK supports all UTF-8 strings; to make this work with services lacking
    43  // full UTF-8 support, strings must be escaped (during writes) and unescaped
    44  // (during reads). The following escapes are required for awssnssqs:
    45  //   - Metadata keys: Characters other than "a-zA-z0-9_-.", and additionally "."
    46  //     when it's at the start of the key or the previous character was ".",
    47  //     are escaped using "__0x<hex>__". These characters were determined by
    48  //     experimentation.
    49  //   - Metadata values: Escaped using URL encoding.
    50  //   - Message body: AWS SNS/SQS only supports UTF-8 strings. See the
    51  //     BodyBase64Encoding enum in TopicOptions for strategies on how to send
    52  //     non-UTF-8 message bodies. By default, non-UTF-8 message bodies are base64
    53  //     encoded.
    54  //
    55  // # As
    56  //
    57  // awssnssqs exposes the following types for As:
    58  //   - Topic: (V1) *sns.SNS for OpenSNSTopic, *sqs.SQS for OpenSQSTopic; (V2) *snsv2.Client for OpenSNSTopicV2, *sqsv2.Client for OpenSQSTopicV2
    59  //   - Subscription: (V1) *sqs.SQS; (V2) *sqsv2.Client
    60  //   - Message: (V1) *sqs.Message; (V2) sqstypesv2.Message
    61  //   - Message.BeforeSend: (V1) *sns.PublishInput for OpenSNSTopic, *sqs.SendMessageBatchRequestEntry or *sqs.SendMessageInput(deprecated) for OpenSQSTopic; (V2) *snsv2.PublishInput for OpenSNSTopicV2, sqstypesv2.SendMessageBatchRequestEntry for OpenSQSTopicV2
    62  //   - Message.AfterSend: (V1) *sns.PublishOutput for OpenSNSTopic, *sqs.SendMessageBatchResultEntry for OpenSQSTopic; (V2) *snsv2.PublishOutput for OpenSNSTopicV2, sqstypesv2.SendMessageBatchResultEntry for OpenSQSTopicV2
    63  //   - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError
    64  package awssnssqs // import "gocloud.dev/pubsub/awssnssqs"
    65  
    66  import (
    67  	"context"
    68  	"encoding/base64"
    69  	"encoding/json"
    70  	"errors"
    71  	"fmt"
    72  	"net/url"
    73  	"path"
    74  	"strconv"
    75  	"strings"
    76  	"sync"
    77  	"time"
    78  	"unicode/utf8"
    79  
    80  	snsv2 "github.com/aws/aws-sdk-go-v2/service/sns"
    81  	snstypesv2 "github.com/aws/aws-sdk-go-v2/service/sns/types"
    82  	sqsv2 "github.com/aws/aws-sdk-go-v2/service/sqs"
    83  	sqstypesv2 "github.com/aws/aws-sdk-go-v2/service/sqs/types"
    84  	"github.com/aws/aws-sdk-go/aws"
    85  	"github.com/aws/aws-sdk-go/aws/awserr"
    86  	"github.com/aws/aws-sdk-go/aws/client"
    87  	"github.com/aws/aws-sdk-go/service/sns"
    88  	"github.com/aws/aws-sdk-go/service/sqs"
    89  	"github.com/aws/smithy-go"
    90  	"github.com/google/wire"
    91  	gcaws "gocloud.dev/aws"
    92  	"gocloud.dev/gcerrors"
    93  	"gocloud.dev/internal/escape"
    94  	"gocloud.dev/pubsub"
    95  	"gocloud.dev/pubsub/batcher"
    96  	"gocloud.dev/pubsub/driver"
    97  )
    98  
    99  const (
   100  	// base64EncodedKey is the Message Attribute key used to flag that the
   101  	// message body is base64 encoded.
   102  	base64EncodedKey = "base64encoded"
   103  	// How long ReceiveBatch should wait if no messages are available; controls
   104  	// the poll interval of requests to SQS.
   105  	noMessagesPollDuration = 250 * time.Millisecond
   106  )
   107  
   108  var sendBatcherOptsSNS = &batcher.Options{
   109  	MaxBatchSize: 1,   // SNS SendBatch only supports one message at a time
   110  	MaxHandlers:  100, // max concurrency for sends
   111  }
   112  
   113  var sendBatcherOptsSQS = &batcher.Options{
   114  	MaxBatchSize: 10,  // SQS SendBatch supports 10 messages at a time
   115  	MaxHandlers:  100, // max concurrency for sends
   116  }
   117  
   118  var recvBatcherOpts = &batcher.Options{
   119  	// SQS supports receiving at most 10 messages at a time:
   120  	// https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.ReceiveMessage
   121  	MaxBatchSize: 10,
   122  	MaxHandlers:  100, // max concurrency for receives
   123  }
   124  
   125  var ackBatcherOpts = &batcher.Options{
   126  	// SQS supports deleting/updating at most 10 messages at a time:
   127  	// https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.DeleteMessageBatch
   128  	// https://godoc.org/github.com/aws/aws-sdk-go/service/sqs#SQS.ChangeMessageVisibilityBatch
   129  	MaxBatchSize: 10,
   130  	MaxHandlers:  100, // max concurrency for acks
   131  }
   132  
   133  func init() {
   134  	lazy := new(lazySessionOpener)
   135  	pubsub.DefaultURLMux().RegisterTopic(SNSScheme, lazy)
   136  	pubsub.DefaultURLMux().RegisterTopic(SQSScheme, lazy)
   137  	pubsub.DefaultURLMux().RegisterSubscription(SQSScheme, lazy)
   138  }
   139  
   140  // Set holds Wire providers for this package.
   141  var Set = wire.NewSet(
   142  	wire.Struct(new(URLOpener), "ConfigProvider"),
   143  )
   144  
   145  // lazySessionOpener obtains the AWS session from the environment on the first
   146  // call to OpenXXXURL.
   147  type lazySessionOpener struct {
   148  	init   sync.Once
   149  	opener *URLOpener
   150  	err    error
   151  }
   152  
   153  func (o *lazySessionOpener) defaultOpener(u *url.URL) (*URLOpener, error) {
   154  	if gcaws.UseV2(u.Query()) {
   155  		return &URLOpener{UseV2: true}, nil
   156  	}
   157  	o.init.Do(func() {
   158  		sess, err := gcaws.NewDefaultSession()
   159  		if err != nil {
   160  			o.err = err
   161  			return
   162  		}
   163  		o.opener = &URLOpener{
   164  			ConfigProvider: sess,
   165  		}
   166  	})
   167  	return o.opener, o.err
   168  }
   169  
   170  func (o *lazySessionOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic, error) {
   171  	opener, err := o.defaultOpener(u)
   172  	if err != nil {
   173  		return nil, fmt.Errorf("open topic %v: failed to open default session: %v", u, err)
   174  	}
   175  	return opener.OpenTopicURL(ctx, u)
   176  }
   177  
   178  func (o *lazySessionOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) {
   179  	opener, err := o.defaultOpener(u)
   180  	if err != nil {
   181  		return nil, fmt.Errorf("open subscription %v: failed to open default session: %v", u, err)
   182  	}
   183  	return opener.OpenSubscriptionURL(ctx, u)
   184  }
   185  
   186  // SNSScheme is the URL scheme for pubsub.OpenTopic (for an SNS topic) that
   187  // awssnssqs registers its URLOpeners under on pubsub.DefaultMux.
   188  const SNSScheme = "awssns"
   189  
   190  // SQSScheme is the URL scheme for pubsub.OpenTopic (for an SQS topic) and for
   191  // pubsub.OpenSubscription that awssnssqs registers its URLOpeners under on
   192  // pubsub.DefaultMux.
   193  const SQSScheme = "awssqs"
   194  
   195  // URLOpener opens AWS SNS/SQS URLs like "awssns:///sns-topic-arn" for
   196  // SNS topics or "awssqs://sqs-queue-url" for SQS topics and subscriptions.
   197  //
   198  // For SNS topics, the URL's host+path is used as the topic Amazon Resource Name
   199  // (ARN). Since ARNs have ":" in them, and ":" precedes a port in URL
   200  // hostnames, leave the host blank and put the ARN in the path
   201  // (e.g., "awssns:///arn:aws:service:region:accountid:resourceType/resourcePath").
   202  //
   203  // For SQS topics and subscriptions, the URL's host+path is prefixed with
   204  // "https://" to create the queue URL.
   205  //
   206  // Use "awssdk=v1" to force using AWS SDK v1, "awssdk=v2" to force using AWS SDK v2,
   207  // or anything else to accept the default.
   208  //
   209  // For V1, see gocloud.dev/aws/ConfigFromURLParams for supported query parameters
   210  // for overriding the aws.Session from the URL.
   211  // For V2, see gocloud.dev/aws/V2ConfigFromURLParams.
   212  //
   213  // In addition, the following query parameters are supported:
   214  //
   215  //   - raw (for "awssqs" Subscriptions only): sets SubscriberOptions.Raw. The
   216  //     value must be parseable by `strconv.ParseBool`.
   217  //   - waittime: sets SubscriberOptions.WaitTime, in time.ParseDuration formats.
   218  //
   219  // See gocloud.dev/aws/ConfigFromURLParams for other query parameters
   220  // that affect the default AWS session.
   221  type URLOpener struct {
   222  	// UseV2 indicates whether the AWS SDK V2 should be used.
   223  	UseV2 bool
   224  
   225  	// ConfigProvider configures the connection to AWS.
   226  	// It must be set to a non-nil value if UseV2 is false.
   227  	ConfigProvider client.ConfigProvider
   228  
   229  	// TopicOptions specifies the options to pass to OpenTopic.
   230  	TopicOptions TopicOptions
   231  	// SubscriptionOptions specifies the options to pass to OpenSubscription.
   232  	SubscriptionOptions SubscriptionOptions
   233  }
   234  
   235  // OpenTopicURL opens a pubsub.Topic based on u.
   236  func (o *URLOpener) OpenTopicURL(ctx context.Context, u *url.URL) (*pubsub.Topic, error) {
   237  	// Trim leading "/" if host is empty, so that
   238  	// awssns:///arn:aws:service:region:accountid:resourceType/resourcePath
   239  	// gives "arn:..." instead of "/arn:...".
   240  	topicARN := strings.TrimPrefix(path.Join(u.Host, u.Path), "/")
   241  	qURL := "https://" + path.Join(u.Host, u.Path)
   242  	if o.UseV2 {
   243  		cfg, err := gcaws.V2ConfigFromURLParams(ctx, u.Query())
   244  		if err != nil {
   245  			return nil, fmt.Errorf("open topic %v: %v", u, err)
   246  		}
   247  		switch u.Scheme {
   248  		case SNSScheme:
   249  			return OpenSNSTopicV2(ctx, snsv2.NewFromConfig(cfg), topicARN, &o.TopicOptions), nil
   250  		case SQSScheme:
   251  			return OpenSQSTopicV2(ctx, sqsv2.NewFromConfig(cfg), qURL, &o.TopicOptions), nil
   252  		default:
   253  			return nil, fmt.Errorf("open topic %v: unsupported scheme", u)
   254  		}
   255  	}
   256  	configProvider := &gcaws.ConfigOverrider{
   257  		Base: o.ConfigProvider,
   258  	}
   259  	overrideCfg, err := gcaws.ConfigFromURLParams(u.Query())
   260  	if err != nil {
   261  		return nil, fmt.Errorf("open topic %v: %v", u, err)
   262  	}
   263  	configProvider.Configs = append(configProvider.Configs, overrideCfg)
   264  	switch u.Scheme {
   265  	case SNSScheme:
   266  		return OpenSNSTopic(ctx, configProvider, topicARN, &o.TopicOptions), nil
   267  	case SQSScheme:
   268  		return OpenSQSTopic(ctx, configProvider, qURL, &o.TopicOptions), nil
   269  	default:
   270  		return nil, fmt.Errorf("open topic %v: unsupported scheme", u)
   271  	}
   272  }
   273  
   274  // OpenSubscriptionURL opens a pubsub.Subscription based on u.
   275  func (o *URLOpener) OpenSubscriptionURL(ctx context.Context, u *url.URL) (*pubsub.Subscription, error) {
   276  	// Clone the options since we might override Raw.
   277  	opts := o.SubscriptionOptions
   278  	q := u.Query()
   279  	if rawStr := q.Get("raw"); rawStr != "" {
   280  		var err error
   281  		opts.Raw, err = strconv.ParseBool(rawStr)
   282  		if err != nil {
   283  			return nil, fmt.Errorf("invalid value %q for raw: %v", rawStr, err)
   284  		}
   285  		q.Del("raw")
   286  	}
   287  	if waitTimeStr := q.Get("waittime"); waitTimeStr != "" {
   288  		var err error
   289  		opts.WaitTime, err = time.ParseDuration(waitTimeStr)
   290  		if err != nil {
   291  			return nil, fmt.Errorf("invalid value %q for waittime: %v", waitTimeStr, err)
   292  		}
   293  		q.Del("waittime")
   294  	}
   295  	qURL := "https://" + path.Join(u.Host, u.Path)
   296  	if o.UseV2 {
   297  		cfg, err := gcaws.V2ConfigFromURLParams(ctx, q)
   298  		if err != nil {
   299  			return nil, fmt.Errorf("open subscription %v: %v", u, err)
   300  		}
   301  		return OpenSubscriptionV2(ctx, sqsv2.NewFromConfig(cfg), qURL, &opts), nil
   302  	}
   303  	overrideCfg, err := gcaws.ConfigFromURLParams(q)
   304  	if err != nil {
   305  		return nil, fmt.Errorf("open subscription %v: %v", u, err)
   306  	}
   307  	configProvider := &gcaws.ConfigOverrider{
   308  		Base: o.ConfigProvider,
   309  	}
   310  	configProvider.Configs = append(configProvider.Configs, overrideCfg)
   311  	return OpenSubscription(ctx, configProvider, qURL, &opts), nil
   312  }
   313  
   314  type snsTopic struct {
   315  	useV2    bool
   316  	client   *sns.SNS
   317  	clientV2 *snsv2.Client
   318  	arn      string
   319  	opts     *TopicOptions
   320  }
   321  
   322  // BodyBase64Encoding is an enum of strategies for when to base64 message
   323  // bodies.
   324  type BodyBase64Encoding int
   325  
   326  const (
   327  	// NonUTF8Only means that message bodies that are valid UTF-8 encodings are
   328  	// sent as-is. Invalid UTF-8 message bodies are base64 encoded, and a
   329  	// MessageAttribute with key "base64encoded" is added to the message.
   330  	// When receiving messages, the "base64encoded" attribute is used to determine
   331  	// whether to base64 decode, and is then filtered out.
   332  	NonUTF8Only BodyBase64Encoding = 0
   333  	// Always means that all message bodies are base64 encoded.
   334  	// A MessageAttribute with key "base64encoded" is added to the message.
   335  	// When receiving messages, the "base64encoded" attribute is used to determine
   336  	// whether to base64 decode, and is then filtered out.
   337  	Always BodyBase64Encoding = 1
   338  	// Never means that message bodies are never base64 encoded. Non-UTF-8
   339  	// bytes in message bodies may be modified by SNS/SQS.
   340  	Never BodyBase64Encoding = 2
   341  )
   342  
   343  func (e BodyBase64Encoding) wantEncode(b []byte) bool {
   344  	switch e {
   345  	case Always:
   346  		return true
   347  	case Never:
   348  		return false
   349  	case NonUTF8Only:
   350  		return !utf8.Valid(b)
   351  	}
   352  	panic("unreachable")
   353  }
   354  
   355  // TopicOptions contains configuration options for topics.
   356  type TopicOptions struct {
   357  	// BodyBase64Encoding determines when message bodies are base64 encoded.
   358  	// The default is NonUTF8Only.
   359  	BodyBase64Encoding BodyBase64Encoding
   360  
   361  	// BatcherOptions adds constraints to the default batching done for sends.
   362  	BatcherOptions batcher.Options
   363  }
   364  
   365  // OpenTopic is a shortcut for OpenSNSTopic, provided for backwards compatibility.
   366  func OpenTopic(ctx context.Context, sess client.ConfigProvider, topicARN string, opts *TopicOptions) *pubsub.Topic {
   367  	return OpenSNSTopic(ctx, sess, topicARN, opts)
   368  }
   369  
   370  // OpenSNSTopic opens a topic that sends to the SNS topic with the given Amazon
   371  // Resource Name (ARN).
   372  func OpenSNSTopic(ctx context.Context, sess client.ConfigProvider, topicARN string, opts *TopicOptions) *pubsub.Topic {
   373  	bo := sendBatcherOptsSNS.NewMergedOptions(&opts.BatcherOptions)
   374  	return pubsub.NewTopic(openSNSTopic(ctx, sns.New(sess), topicARN, opts), bo)
   375  }
   376  
   377  // OpenSNSTopicV2 opens a topic that sends to the SNS topic with the given Amazon
   378  // Resource Name (ARN), using AWS SDK V2.
   379  func OpenSNSTopicV2(ctx context.Context, client *snsv2.Client, topicARN string, opts *TopicOptions) *pubsub.Topic {
   380  	bo := sendBatcherOptsSNS.NewMergedOptions(&opts.BatcherOptions)
   381  	return pubsub.NewTopic(openSNSTopicV2(ctx, client, topicARN, opts), bo)
   382  }
   383  
   384  // openSNSTopic returns the driver for OpenSNSTopic. This function exists so the test
   385  // harness can get the driver interface implementation if it needs to.
   386  func openSNSTopic(ctx context.Context, client *sns.SNS, topicARN string, opts *TopicOptions) driver.Topic {
   387  	if opts == nil {
   388  		opts = &TopicOptions{}
   389  	}
   390  	return &snsTopic{
   391  		useV2:  false,
   392  		client: client,
   393  		arn:    topicARN,
   394  		opts:   opts,
   395  	}
   396  }
   397  
   398  // openSNSTopicV2 returns the driver for OpenSNSTopic. This function exists so the test
   399  // harness can get the driver interface implementation if it needs to.
   400  func openSNSTopicV2(ctx context.Context, client *snsv2.Client, topicARN string, opts *TopicOptions) driver.Topic {
   401  	if opts == nil {
   402  		opts = &TopicOptions{}
   403  	}
   404  	return &snsTopic{
   405  		useV2:    true,
   406  		clientV2: client,
   407  		arn:      topicARN,
   408  		opts:     opts,
   409  	}
   410  }
   411  
   412  var stringDataType = aws.String("String")
   413  
   414  // encodeMetadata encodes the keys and values of md as needed.
   415  func encodeMetadata(md map[string]string) map[string]string {
   416  	retval := map[string]string{}
   417  	for k, v := range md {
   418  		// See the package comments for more details on escaping of metadata
   419  		// keys & values.
   420  		k = escape.HexEscape(k, func(runes []rune, i int) bool {
   421  			c := runes[i]
   422  			switch {
   423  			case escape.IsASCIIAlphanumeric(c):
   424  				return false
   425  			case c == '_' || c == '-':
   426  				return false
   427  			case c == '.' && i != 0 && runes[i-1] != '.':
   428  				return false
   429  			}
   430  			return true
   431  		})
   432  		retval[k] = escape.URLEscape(v)
   433  	}
   434  	return retval
   435  }
   436  
   437  // maybeEncodeBody decides whether body should base64-encoded based on opt, and
   438  // returns the (possibly encoded) body as a string, along with a boolean
   439  // indicating whether encoding occurred.
   440  func maybeEncodeBody(body []byte, opt BodyBase64Encoding) (string, bool) {
   441  	if opt.wantEncode(body) {
   442  		return base64.StdEncoding.EncodeToString(body), true
   443  	}
   444  	return string(body), false
   445  }
   446  
   447  // SendBatch implements driver.Topic.SendBatch.
   448  func (t *snsTopic) SendBatch(ctx context.Context, dms []*driver.Message) error {
   449  	if len(dms) != 1 {
   450  		panic("snsTopic.SendBatch should only get one message at a time")
   451  	}
   452  	dm := dms[0]
   453  
   454  	if t.useV2 {
   455  		attrs := map[string]snstypesv2.MessageAttributeValue{}
   456  		for k, v := range encodeMetadata(dm.Metadata) {
   457  			attrs[k] = snstypesv2.MessageAttributeValue{
   458  				DataType:    stringDataType,
   459  				StringValue: aws.String(v),
   460  			}
   461  		}
   462  		body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
   463  		if didEncode {
   464  			attrs[base64EncodedKey] = snstypesv2.MessageAttributeValue{
   465  				DataType:    stringDataType,
   466  				StringValue: aws.String("true"),
   467  			}
   468  		}
   469  		if len(attrs) == 0 {
   470  			attrs = nil
   471  		}
   472  		input := &snsv2.PublishInput{
   473  			Message:           aws.String(body),
   474  			MessageAttributes: attrs,
   475  			TopicArn:          &t.arn,
   476  		}
   477  		if dm.BeforeSend != nil {
   478  			asFunc := func(i interface{}) bool {
   479  				if p, ok := i.(**snsv2.PublishInput); ok {
   480  					*p = input
   481  					return true
   482  				}
   483  				return false
   484  			}
   485  			if err := dm.BeforeSend(asFunc); err != nil {
   486  				return err
   487  			}
   488  		}
   489  		po, err := t.clientV2.Publish(ctx, input)
   490  		if err != nil {
   491  			return err
   492  		}
   493  		if dm.AfterSend != nil {
   494  			asFunc := func(i interface{}) bool {
   495  				if p, ok := i.(**snsv2.PublishOutput); ok {
   496  					*p = po
   497  					return true
   498  				}
   499  				return false
   500  			}
   501  			if err := dm.AfterSend(asFunc); err != nil {
   502  				return err
   503  			}
   504  		}
   505  		return nil
   506  	}
   507  	attrs := map[string]*sns.MessageAttributeValue{}
   508  	for k, v := range encodeMetadata(dm.Metadata) {
   509  		attrs[k] = &sns.MessageAttributeValue{
   510  			DataType:    stringDataType,
   511  			StringValue: aws.String(v),
   512  		}
   513  	}
   514  	body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
   515  	if didEncode {
   516  		attrs[base64EncodedKey] = &sns.MessageAttributeValue{
   517  			DataType:    stringDataType,
   518  			StringValue: aws.String("true"),
   519  		}
   520  	}
   521  	if len(attrs) == 0 {
   522  		attrs = nil
   523  	}
   524  	input := &sns.PublishInput{
   525  		Message:           aws.String(body),
   526  		MessageAttributes: attrs,
   527  		TopicArn:          &t.arn,
   528  	}
   529  	if dm.BeforeSend != nil {
   530  		asFunc := func(i interface{}) bool {
   531  			if p, ok := i.(**sns.PublishInput); ok {
   532  				*p = input
   533  				return true
   534  			}
   535  			return false
   536  		}
   537  		if err := dm.BeforeSend(asFunc); err != nil {
   538  			return err
   539  		}
   540  	}
   541  	po, err := t.client.PublishWithContext(ctx, input)
   542  	if err != nil {
   543  		return err
   544  	}
   545  	if dm.AfterSend != nil {
   546  		asFunc := func(i interface{}) bool {
   547  			if p, ok := i.(**sns.PublishOutput); ok {
   548  				*p = po
   549  				return true
   550  			}
   551  			return false
   552  		}
   553  		if err := dm.AfterSend(asFunc); err != nil {
   554  			return err
   555  		}
   556  	}
   557  	return nil
   558  }
   559  
   560  // IsRetryable implements driver.Topic.IsRetryable.
   561  func (t *snsTopic) IsRetryable(error) bool {
   562  	// The client handles retries.
   563  	return false
   564  }
   565  
   566  // As implements driver.Topic.As.
   567  func (t *snsTopic) As(i interface{}) bool {
   568  	if t.useV2 {
   569  		c, ok := i.(**snsv2.Client)
   570  		if !ok {
   571  			return false
   572  		}
   573  		*c = t.clientV2
   574  		return true
   575  	}
   576  	c, ok := i.(**sns.SNS)
   577  	if !ok {
   578  		return false
   579  	}
   580  	*c = t.client
   581  	return true
   582  }
   583  
   584  // ErrorAs implements driver.Topic.ErrorAs.
   585  func (t *snsTopic) ErrorAs(err error, i interface{}) bool {
   586  	return errorAs(err, t.useV2, i)
   587  }
   588  
   589  // ErrorCode implements driver.Topic.ErrorCode.
   590  func (t *snsTopic) ErrorCode(err error) gcerrors.ErrorCode {
   591  	return errorCode(err)
   592  }
   593  
   594  // Close implements driver.Topic.Close.
   595  func (*snsTopic) Close() error { return nil }
   596  
   597  type sqsTopic struct {
   598  	useV2    bool
   599  	client   *sqs.SQS
   600  	clientV2 *sqsv2.Client
   601  	qURL     string
   602  	opts     *TopicOptions
   603  }
   604  
   605  // OpenSQSTopic opens a topic that sends to the SQS topic with the given SQS
   606  // queue URL.
   607  func OpenSQSTopic(ctx context.Context, sess client.ConfigProvider, qURL string, opts *TopicOptions) *pubsub.Topic {
   608  	bo := sendBatcherOptsSQS.NewMergedOptions(&opts.BatcherOptions)
   609  	return pubsub.NewTopic(openSQSTopic(ctx, sqs.New(sess), qURL, opts), bo)
   610  }
   611  
   612  // OpenSQSTopicV2 opens a topic that sends to the SQS topic with the given SQS
   613  // queue URL, using AWS SDK V2.
   614  func OpenSQSTopicV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *TopicOptions) *pubsub.Topic {
   615  	bo := sendBatcherOptsSQS.NewMergedOptions(&opts.BatcherOptions)
   616  	return pubsub.NewTopic(openSQSTopicV2(ctx, client, qURL, opts), bo)
   617  }
   618  
   619  // openSQSTopic returns the driver for OpenSQSTopic. This function exists so the test
   620  // harness can get the driver interface implementation if it needs to.
   621  func openSQSTopic(ctx context.Context, client *sqs.SQS, qURL string, opts *TopicOptions) driver.Topic {
   622  	if opts == nil {
   623  		opts = &TopicOptions{}
   624  	}
   625  	return &sqsTopic{
   626  		useV2:  false,
   627  		client: client,
   628  		qURL:   qURL,
   629  		opts:   opts,
   630  	}
   631  }
   632  
   633  // openSQSTopicV2 returns the driver for OpenSQSTopic. This function exists so the test
   634  // harness can get the driver interface implementation if it needs to.
   635  func openSQSTopicV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *TopicOptions) driver.Topic {
   636  	if opts == nil {
   637  		opts = &TopicOptions{}
   638  	}
   639  	return &sqsTopic{
   640  		useV2:    true,
   641  		clientV2: client,
   642  		qURL:     qURL,
   643  		opts:     opts,
   644  	}
   645  }
   646  
   647  // SendBatch implements driver.Topic.SendBatch.
   648  func (t *sqsTopic) SendBatch(ctx context.Context, dms []*driver.Message) error {
   649  	if t.useV2 {
   650  		req := &sqsv2.SendMessageBatchInput{
   651  			QueueUrl: aws.String(t.qURL),
   652  		}
   653  		for _, dm := range dms {
   654  			attrs := map[string]sqstypesv2.MessageAttributeValue{}
   655  			for k, v := range encodeMetadata(dm.Metadata) {
   656  				attrs[k] = sqstypesv2.MessageAttributeValue{
   657  					DataType:    stringDataType,
   658  					StringValue: aws.String(v),
   659  				}
   660  			}
   661  			body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
   662  			if didEncode {
   663  				attrs[base64EncodedKey] = sqstypesv2.MessageAttributeValue{
   664  					DataType:    stringDataType,
   665  					StringValue: aws.String("true"),
   666  				}
   667  			}
   668  			if len(attrs) == 0 {
   669  				attrs = nil
   670  			}
   671  			entry := sqstypesv2.SendMessageBatchRequestEntry{
   672  				Id:                aws.String(strconv.Itoa(len(req.Entries))),
   673  				MessageAttributes: attrs,
   674  				MessageBody:       aws.String(body),
   675  			}
   676  			req.Entries = append(req.Entries, entry)
   677  			if dm.BeforeSend != nil {
   678  				asFunc := func(i interface{}) bool {
   679  					if p, ok := i.(*sqstypesv2.SendMessageBatchRequestEntry); ok {
   680  						*p = entry
   681  						return true
   682  					}
   683  					return false
   684  				}
   685  				if err := dm.BeforeSend(asFunc); err != nil {
   686  					return err
   687  				}
   688  			}
   689  		}
   690  		resp, err := t.clientV2.SendMessageBatch(ctx, req)
   691  		if err != nil {
   692  			return err
   693  		}
   694  		if numFailed := len(resp.Failed); numFailed > 0 {
   695  			first := resp.Failed[0]
   696  			return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.SendMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
   697  		}
   698  		if len(resp.Successful) == len(dms) {
   699  			for n, dm := range dms {
   700  				if dm.AfterSend != nil {
   701  					asFunc := func(i interface{}) bool {
   702  						if p, ok := i.(*sqstypesv2.SendMessageBatchResultEntry); ok {
   703  							*p = resp.Successful[n]
   704  							return true
   705  						}
   706  						return false
   707  					}
   708  					if err := dm.AfterSend(asFunc); err != nil {
   709  						return err
   710  					}
   711  				}
   712  			}
   713  		}
   714  		return nil
   715  	}
   716  	req := &sqs.SendMessageBatchInput{
   717  		QueueUrl: aws.String(t.qURL),
   718  	}
   719  	for _, dm := range dms {
   720  		attrs := map[string]*sqs.MessageAttributeValue{}
   721  		for k, v := range encodeMetadata(dm.Metadata) {
   722  			attrs[k] = &sqs.MessageAttributeValue{
   723  				DataType:    stringDataType,
   724  				StringValue: aws.String(v),
   725  			}
   726  		}
   727  		body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
   728  		if didEncode {
   729  			attrs[base64EncodedKey] = &sqs.MessageAttributeValue{
   730  				DataType:    stringDataType,
   731  				StringValue: aws.String("true"),
   732  			}
   733  		}
   734  		if len(attrs) == 0 {
   735  			attrs = nil
   736  		}
   737  		entry := &sqs.SendMessageBatchRequestEntry{
   738  			Id:                aws.String(strconv.Itoa(len(req.Entries))),
   739  			MessageAttributes: attrs,
   740  			MessageBody:       aws.String(body),
   741  		}
   742  		req.Entries = append(req.Entries, entry)
   743  		if dm.BeforeSend != nil {
   744  			// A previous revision used the non-batch API SendMessage, which takes
   745  			// a *sqs.SendMessageInput. For backwards compatibility for As, continue
   746  			// to support that type. If it is requested, create a SendMessageInput
   747  			// with the fields from SendMessageBatchRequestEntry that were set, and
   748  			// then copy all of the matching fields back after calling dm.BeforeSend.
   749  			var smi *sqs.SendMessageInput
   750  			asFunc := func(i interface{}) bool {
   751  				if p, ok := i.(**sqs.SendMessageInput); ok {
   752  					smi = &sqs.SendMessageInput{
   753  						// Id does not exist on SendMessageInput.
   754  						MessageAttributes: entry.MessageAttributes,
   755  						MessageBody:       entry.MessageBody,
   756  					}
   757  					*p = smi
   758  					return true
   759  				}
   760  				if p, ok := i.(**sqs.SendMessageBatchRequestEntry); ok {
   761  					*p = entry
   762  					return true
   763  				}
   764  				return false
   765  			}
   766  			if err := dm.BeforeSend(asFunc); err != nil {
   767  				return err
   768  			}
   769  			if smi != nil {
   770  				// Copy all of the fields that may have been modified back to the entry.
   771  				entry.DelaySeconds = smi.DelaySeconds
   772  				entry.MessageAttributes = smi.MessageAttributes
   773  				entry.MessageBody = smi.MessageBody
   774  				entry.MessageDeduplicationId = smi.MessageDeduplicationId
   775  				entry.MessageGroupId = smi.MessageGroupId
   776  			}
   777  		}
   778  	}
   779  	resp, err := t.client.SendMessageBatchWithContext(ctx, req)
   780  	if err != nil {
   781  		return err
   782  	}
   783  	if numFailed := len(resp.Failed); numFailed > 0 {
   784  		first := resp.Failed[0]
   785  		return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.SendMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
   786  	}
   787  	if len(resp.Successful) == len(dms) {
   788  		for n, dm := range dms {
   789  			if dm.AfterSend != nil {
   790  				asFunc := func(i interface{}) bool {
   791  					if p, ok := i.(**sqs.SendMessageBatchResultEntry); ok {
   792  						*p = resp.Successful[n]
   793  						return true
   794  					}
   795  					return false
   796  				}
   797  				if err := dm.AfterSend(asFunc); err != nil {
   798  					return err
   799  				}
   800  			}
   801  		}
   802  	}
   803  	return nil
   804  }
   805  
   806  // IsRetryable implements driver.Topic.IsRetryable.
   807  func (t *sqsTopic) IsRetryable(error) bool {
   808  	// The client handles retries.
   809  	return false
   810  }
   811  
   812  // As implements driver.Topic.As.
   813  func (t *sqsTopic) As(i interface{}) bool {
   814  	if t.useV2 {
   815  		c, ok := i.(**sqsv2.Client)
   816  		if !ok {
   817  			return false
   818  		}
   819  		*c = t.clientV2
   820  		return true
   821  	}
   822  	c, ok := i.(**sqs.SQS)
   823  	if !ok {
   824  		return false
   825  	}
   826  	*c = t.client
   827  	return true
   828  }
   829  
   830  // ErrorAs implements driver.Topic.ErrorAs.
   831  func (t *sqsTopic) ErrorAs(err error, i interface{}) bool {
   832  	return errorAs(err, t.useV2, i)
   833  }
   834  
   835  // ErrorCode implements driver.Topic.ErrorCode.
   836  func (t *sqsTopic) ErrorCode(err error) gcerrors.ErrorCode {
   837  	return errorCode(err)
   838  }
   839  
   840  // Close implements driver.Topic.Close.
   841  func (*sqsTopic) Close() error { return nil }
   842  
   843  func errorCode(err error) gcerrors.ErrorCode {
   844  	var code string
   845  	var ae smithy.APIError
   846  	if errors.As(err, &ae) {
   847  		code = ae.ErrorCode()
   848  	} else if ae, ok := err.(awserr.Error); ok {
   849  		code = ae.Code()
   850  	} else {
   851  		return gcerrors.Unknown
   852  	}
   853  	ec, ok := errorCodeMap[code]
   854  	if !ok {
   855  		return gcerrors.Unknown
   856  	}
   857  	return ec
   858  }
   859  
   860  var errorCodeMap = map[string]gcerrors.ErrorCode{
   861  	sns.ErrCodeAuthorizationErrorException:          gcerrors.PermissionDenied,
   862  	sns.ErrCodeKMSAccessDeniedException:             gcerrors.PermissionDenied,
   863  	sns.ErrCodeKMSDisabledException:                 gcerrors.FailedPrecondition,
   864  	sns.ErrCodeKMSInvalidStateException:             gcerrors.FailedPrecondition,
   865  	sns.ErrCodeKMSOptInRequired:                     gcerrors.FailedPrecondition,
   866  	sqs.ErrCodeMessageNotInflight:                   gcerrors.FailedPrecondition,
   867  	sqs.ErrCodePurgeQueueInProgress:                 gcerrors.FailedPrecondition,
   868  	sqs.ErrCodeQueueDeletedRecently:                 gcerrors.FailedPrecondition,
   869  	sqs.ErrCodeQueueNameExists:                      gcerrors.FailedPrecondition,
   870  	sns.ErrCodeInternalErrorException:               gcerrors.Internal,
   871  	sns.ErrCodeInvalidParameterException:            gcerrors.InvalidArgument,
   872  	sns.ErrCodeInvalidParameterValueException:       gcerrors.InvalidArgument,
   873  	sqs.ErrCodeBatchEntryIdsNotDistinct:             gcerrors.InvalidArgument,
   874  	sqs.ErrCodeBatchRequestTooLong:                  gcerrors.InvalidArgument,
   875  	sqs.ErrCodeEmptyBatchRequest:                    gcerrors.InvalidArgument,
   876  	sqs.ErrCodeInvalidAttributeName:                 gcerrors.InvalidArgument,
   877  	sqs.ErrCodeInvalidBatchEntryId:                  gcerrors.InvalidArgument,
   878  	sqs.ErrCodeInvalidIdFormat:                      gcerrors.InvalidArgument,
   879  	sqs.ErrCodeInvalidMessageContents:               gcerrors.InvalidArgument,
   880  	sqs.ErrCodeReceiptHandleIsInvalid:               gcerrors.InvalidArgument,
   881  	sqs.ErrCodeTooManyEntriesInBatchRequest:         gcerrors.InvalidArgument,
   882  	sqs.ErrCodeUnsupportedOperation:                 gcerrors.InvalidArgument,
   883  	sns.ErrCodeInvalidSecurityException:             gcerrors.PermissionDenied,
   884  	sns.ErrCodeKMSNotFoundException:                 gcerrors.NotFound,
   885  	sns.ErrCodeNotFoundException:                    gcerrors.NotFound,
   886  	sqs.ErrCodeQueueDoesNotExist:                    gcerrors.NotFound,
   887  	sns.ErrCodeFilterPolicyLimitExceededException:   gcerrors.ResourceExhausted,
   888  	sns.ErrCodeSubscriptionLimitExceededException:   gcerrors.ResourceExhausted,
   889  	sns.ErrCodeTopicLimitExceededException:          gcerrors.ResourceExhausted,
   890  	sqs.ErrCodeOverLimit:                            gcerrors.ResourceExhausted,
   891  	sns.ErrCodeKMSThrottlingException:               gcerrors.ResourceExhausted,
   892  	sns.ErrCodeThrottledException:                   gcerrors.ResourceExhausted,
   893  	"RequestCanceled":                               gcerrors.Canceled,
   894  	sns.ErrCodeEndpointDisabledException:            gcerrors.Unknown,
   895  	sns.ErrCodePlatformApplicationDisabledException: gcerrors.Unknown,
   896  }
   897  
   898  type subscription struct {
   899  	useV2    bool
   900  	client   *sqs.SQS
   901  	clientV2 *sqsv2.Client
   902  	qURL     string
   903  	opts     *SubscriptionOptions
   904  }
   905  
   906  // SubscriptionOptions will contain configuration for subscriptions.
   907  type SubscriptionOptions struct {
   908  	// Raw determines how the Subscription will process message bodies.
   909  	//
   910  	// If the subscription is expected to process messages sent directly to
   911  	// SQS, or messages from SNS topics configured to use "raw" delivery,
   912  	// set this to true. Message bodies will be passed through untouched.
   913  	//
   914  	// If false, the Subscription will use best-effort heuristics to
   915  	// identify whether message bodies are raw or SNS JSON; this may be
   916  	// inefficient for raw messages.
   917  	//
   918  	// See https://aws.amazon.com/sns/faqs/#Raw_message_delivery.
   919  	Raw bool
   920  
   921  	// WaitTime passed to ReceiveMessage to enable long polling.
   922  	// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html#sqs-long-polling.
   923  	// Note that a non-zero WaitTime can delay delivery of messages
   924  	// by up to that duration.
   925  	WaitTime time.Duration
   926  
   927  	// ReceiveBatcherOptions adds constraints to the default batching done for receives.
   928  	ReceiveBatcherOptions batcher.Options
   929  
   930  	// AckBatcherOptions adds constraints to the default batching done for acks.
   931  	AckBatcherOptions batcher.Options
   932  }
   933  
   934  // OpenSubscription opens a subscription based on AWS SQS for the given SQS
   935  // queue URL. The queue is assumed to be subscribed to some SNS topic, though
   936  // there is no check for this.
   937  func OpenSubscription(ctx context.Context, sess client.ConfigProvider, qURL string, opts *SubscriptionOptions) *pubsub.Subscription {
   938  	rbo := recvBatcherOpts.NewMergedOptions(&opts.ReceiveBatcherOptions)
   939  	abo := ackBatcherOpts.NewMergedOptions(&opts.AckBatcherOptions)
   940  	return pubsub.NewSubscription(openSubscription(ctx, sqs.New(sess), qURL, opts), rbo, abo)
   941  }
   942  
   943  // OpenSubscriptionV2 opens a subscription based on AWS SQS for the given SQS
   944  // queue URL, using AWS SDK V2. The queue is assumed to be subscribed to some SNS topic, though
   945  // there is no check for this.
   946  func OpenSubscriptionV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *SubscriptionOptions) *pubsub.Subscription {
   947  	rbo := recvBatcherOpts.NewMergedOptions(&opts.ReceiveBatcherOptions)
   948  	abo := ackBatcherOpts.NewMergedOptions(&opts.AckBatcherOptions)
   949  	return pubsub.NewSubscription(openSubscriptionV2(ctx, client, qURL, opts), rbo, abo)
   950  }
   951  
   952  // openSubscription returns a driver.Subscription.
   953  func openSubscription(ctx context.Context, client *sqs.SQS, qURL string, opts *SubscriptionOptions) driver.Subscription {
   954  	if opts == nil {
   955  		opts = &SubscriptionOptions{}
   956  	}
   957  	return &subscription{
   958  		useV2:  false,
   959  		client: client,
   960  		qURL:   qURL, opts: opts,
   961  	}
   962  }
   963  
   964  // openSubscriptionV2 returns a driver.Subscription.
   965  func openSubscriptionV2(ctx context.Context, client *sqsv2.Client, qURL string, opts *SubscriptionOptions) driver.Subscription {
   966  	if opts == nil {
   967  		opts = &SubscriptionOptions{}
   968  	}
   969  	return &subscription{
   970  		useV2:    true,
   971  		clientV2: client,
   972  		qURL:     qURL, opts: opts,
   973  	}
   974  }
   975  
   976  // ReceiveBatch implements driver.Subscription.ReceiveBatch.
   977  func (s *subscription) ReceiveBatch(ctx context.Context, maxMessages int) ([]*driver.Message, error) {
   978  	var ms []*driver.Message
   979  	if s.useV2 {
   980  		req := &sqsv2.ReceiveMessageInput{
   981  			QueueUrl:              aws.String(s.qURL),
   982  			MaxNumberOfMessages:   int32(maxMessages),
   983  			MessageAttributeNames: []string{"All"},
   984  			AttributeNames:        []sqstypesv2.QueueAttributeName{"All"},
   985  		}
   986  		if s.opts.WaitTime != 0 {
   987  			req.WaitTimeSeconds = int32(s.opts.WaitTime.Seconds())
   988  		}
   989  		output, err := s.clientV2.ReceiveMessage(ctx, req)
   990  		if err != nil {
   991  			return nil, err
   992  		}
   993  		for _, m := range output.Messages {
   994  			m := m
   995  			bodyStr := aws.StringValue(m.Body)
   996  			rawAttrs := map[string]string{}
   997  			for k, v := range m.MessageAttributes {
   998  				rawAttrs[k] = aws.StringValue(v.StringValue)
   999  			}
  1000  			bodyStr, rawAttrs = extractBody(bodyStr, rawAttrs, s.opts.Raw)
  1001  
  1002  			decodeIt := false
  1003  			attrs := map[string]string{}
  1004  			for k, v := range rawAttrs {
  1005  				// See BodyBase64Encoding for details on when we base64 decode message bodies.
  1006  				if k == base64EncodedKey {
  1007  					decodeIt = true
  1008  					continue
  1009  				}
  1010  				// See the package comments for more details on escaping of metadata
  1011  				// keys & values.
  1012  				attrs[escape.HexUnescape(k)] = escape.URLUnescape(v)
  1013  			}
  1014  
  1015  			var b []byte
  1016  			if decodeIt {
  1017  				var err error
  1018  				b, err = base64.StdEncoding.DecodeString(bodyStr)
  1019  				if err != nil {
  1020  					// Fall back to using the raw message.
  1021  					b = []byte(bodyStr)
  1022  				}
  1023  			} else {
  1024  				b = []byte(bodyStr)
  1025  			}
  1026  
  1027  			m2 := &driver.Message{
  1028  				LoggableID: aws.StringValue(m.MessageId),
  1029  				Body:       b,
  1030  				Metadata:   attrs,
  1031  				AckID:      m.ReceiptHandle,
  1032  				AsFunc: func(i interface{}) bool {
  1033  					p, ok := i.(*sqstypesv2.Message)
  1034  					if !ok {
  1035  						return false
  1036  					}
  1037  					*p = m
  1038  					return true
  1039  				},
  1040  			}
  1041  			ms = append(ms, m2)
  1042  		}
  1043  	} else {
  1044  		req := &sqs.ReceiveMessageInput{
  1045  			QueueUrl:              aws.String(s.qURL),
  1046  			MaxNumberOfMessages:   aws.Int64(int64(maxMessages)),
  1047  			MessageAttributeNames: []*string{aws.String("All")},
  1048  			AttributeNames:        []*string{aws.String("All")},
  1049  		}
  1050  		if s.opts.WaitTime != 0 {
  1051  			req.WaitTimeSeconds = aws.Int64(int64(s.opts.WaitTime.Seconds()))
  1052  		}
  1053  		output, err := s.client.ReceiveMessageWithContext(ctx, req)
  1054  		if err != nil {
  1055  			return nil, err
  1056  		}
  1057  		for _, m := range output.Messages {
  1058  			m := m
  1059  			bodyStr := aws.StringValue(m.Body)
  1060  			rawAttrs := map[string]string{}
  1061  			for k, v := range m.MessageAttributes {
  1062  				rawAttrs[k] = aws.StringValue(v.StringValue)
  1063  			}
  1064  			bodyStr, rawAttrs = extractBody(bodyStr, rawAttrs, s.opts.Raw)
  1065  
  1066  			decodeIt := false
  1067  			attrs := map[string]string{}
  1068  			for k, v := range rawAttrs {
  1069  				// See BodyBase64Encoding for details on when we base64 decode message bodies.
  1070  				if k == base64EncodedKey {
  1071  					decodeIt = true
  1072  					continue
  1073  				}
  1074  				// See the package comments for more details on escaping of metadata
  1075  				// keys & values.
  1076  				attrs[escape.HexUnescape(k)] = escape.URLUnescape(v)
  1077  			}
  1078  
  1079  			var b []byte
  1080  			if decodeIt {
  1081  				var err error
  1082  				b, err = base64.StdEncoding.DecodeString(bodyStr)
  1083  				if err != nil {
  1084  					// Fall back to using the raw message.
  1085  					b = []byte(bodyStr)
  1086  				}
  1087  			} else {
  1088  				b = []byte(bodyStr)
  1089  			}
  1090  
  1091  			m2 := &driver.Message{
  1092  				LoggableID: aws.StringValue(m.MessageId),
  1093  				Body:       b,
  1094  				Metadata:   attrs,
  1095  				AckID:      m.ReceiptHandle,
  1096  				AsFunc: func(i interface{}) bool {
  1097  					p, ok := i.(**sqs.Message)
  1098  					if !ok {
  1099  						return false
  1100  					}
  1101  					*p = m
  1102  					return true
  1103  				},
  1104  			}
  1105  			ms = append(ms, m2)
  1106  		}
  1107  	}
  1108  	if len(ms) == 0 {
  1109  		// When we return no messages and no error, the portable type will call
  1110  		// ReceiveBatch again immediately. Sleep for a bit to avoid hammering SQS
  1111  		// with RPCs.
  1112  		time.Sleep(noMessagesPollDuration)
  1113  	}
  1114  	return ms, nil
  1115  }
  1116  
  1117  func extractBody(bodyStr string, rawAttrs map[string]string, raw bool) (body string, attributes map[string]string) {
  1118  	// If the user told us that message bodies are raw, or if there are
  1119  	// top-level MessageAttributes, then it's raw.
  1120  	// (SNS JSON message can have attributes, but they are encoded in
  1121  	// the JSON instead of being at the top level).
  1122  	raw = raw || len(rawAttrs) > 0
  1123  	if raw {
  1124  		// For raw messages, the attributes are at the top level
  1125  		// and we leave bodyStr alone.
  1126  		return bodyStr, rawAttrs
  1127  	}
  1128  
  1129  	// It might be SNS JSON; try to parse the raw body as such.
  1130  	// https://aws.amazon.com/sns/faqs/#Raw_message_delivery
  1131  	// If it parses as JSON and has a TopicArn field, assume it's SNS JSON.
  1132  	var bodyJSON struct {
  1133  		TopicArn          string
  1134  		Message           string
  1135  		MessageAttributes map[string]struct{ Value string }
  1136  	}
  1137  	if err := json.Unmarshal([]byte(bodyStr), &bodyJSON); err == nil && bodyJSON.TopicArn != "" {
  1138  		// It looks like SNS JSON. Get attributes from the decoded struct,
  1139  		// and update the body to be the JSON Message field.
  1140  		for k, v := range bodyJSON.MessageAttributes {
  1141  			rawAttrs[k] = v.Value
  1142  		}
  1143  		return bodyJSON.Message, rawAttrs
  1144  	}
  1145  	// It doesn't look like SNS JSON, either because it
  1146  	// isn't JSON or because the JSON doesn't have a TopicArn
  1147  	// field. Treat it as raw.
  1148  	//
  1149  	// As above in the other "raw" case, we leave bodyStr
  1150  	// alone. There can't be any top-level attributes (because
  1151  	// then we would have known it was raw earlier).
  1152  	return bodyStr, rawAttrs
  1153  }
  1154  
  1155  // SendAcks implements driver.Subscription.SendAcks.
  1156  func (s *subscription) SendAcks(ctx context.Context, ids []driver.AckID) error {
  1157  	if s.useV2 {
  1158  		req := &sqsv2.DeleteMessageBatchInput{QueueUrl: aws.String(s.qURL)}
  1159  		for _, id := range ids {
  1160  			req.Entries = append(req.Entries, sqstypesv2.DeleteMessageBatchRequestEntry{
  1161  				Id:            aws.String(strconv.Itoa(len(req.Entries))),
  1162  				ReceiptHandle: id.(*string),
  1163  			})
  1164  		}
  1165  		resp, err := s.clientV2.DeleteMessageBatch(ctx, req)
  1166  		if err != nil {
  1167  			return err
  1168  		}
  1169  		// Note: DeleteMessageBatch doesn't return failures when you try
  1170  		// to Delete an id that isn't found.
  1171  		if numFailed := len(resp.Failed); numFailed > 0 {
  1172  			first := resp.Failed[0]
  1173  			return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.DeleteMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
  1174  		}
  1175  		return nil
  1176  	}
  1177  	req := &sqs.DeleteMessageBatchInput{QueueUrl: aws.String(s.qURL)}
  1178  	for _, id := range ids {
  1179  		req.Entries = append(req.Entries, &sqs.DeleteMessageBatchRequestEntry{
  1180  			Id:            aws.String(strconv.Itoa(len(req.Entries))),
  1181  			ReceiptHandle: id.(*string),
  1182  		})
  1183  	}
  1184  	resp, err := s.client.DeleteMessageBatchWithContext(ctx, req)
  1185  	if err != nil {
  1186  		return err
  1187  	}
  1188  	// Note: DeleteMessageBatch doesn't return failures when you try
  1189  	// to Delete an id that isn't found.
  1190  	if numFailed := len(resp.Failed); numFailed > 0 {
  1191  		first := resp.Failed[0]
  1192  		return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sqs.DeleteMessageBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
  1193  	}
  1194  	return nil
  1195  }
  1196  
  1197  // CanNack implements driver.CanNack.
  1198  func (s *subscription) CanNack() bool { return true }
  1199  
  1200  // SendNacks implements driver.Subscription.SendNacks.
  1201  func (s *subscription) SendNacks(ctx context.Context, ids []driver.AckID) error {
  1202  	if s.useV2 {
  1203  		req := &sqsv2.ChangeMessageVisibilityBatchInput{QueueUrl: aws.String(s.qURL)}
  1204  		for _, id := range ids {
  1205  			req.Entries = append(req.Entries, sqstypesv2.ChangeMessageVisibilityBatchRequestEntry{
  1206  				Id:                aws.String(strconv.Itoa(len(req.Entries))),
  1207  				ReceiptHandle:     id.(*string),
  1208  				VisibilityTimeout: 1,
  1209  			})
  1210  		}
  1211  		resp, err := s.clientV2.ChangeMessageVisibilityBatch(ctx, req)
  1212  		if err != nil {
  1213  			return err
  1214  		}
  1215  		// Note: ChangeMessageVisibilityBatch returns failures when you try to
  1216  		// modify an id that isn't found; drop those.
  1217  		var firstFail sqstypesv2.BatchResultErrorEntry
  1218  		numFailed := 0
  1219  		for _, fail := range resp.Failed {
  1220  			if aws.StringValue(fail.Code) == sqs.ErrCodeReceiptHandleIsInvalid {
  1221  				continue
  1222  			}
  1223  			if numFailed == 0 {
  1224  				firstFail = fail
  1225  			}
  1226  			numFailed++
  1227  		}
  1228  		if numFailed > 0 {
  1229  			return awserr.New(aws.StringValue(firstFail.Code), fmt.Sprintf("sqs.ChangeMessageVisibilityBatch failed for %d message(s): %s", numFailed, aws.StringValue(firstFail.Message)), nil)
  1230  		}
  1231  		return nil
  1232  	}
  1233  	req := &sqs.ChangeMessageVisibilityBatchInput{QueueUrl: aws.String(s.qURL)}
  1234  	for _, id := range ids {
  1235  		req.Entries = append(req.Entries, &sqs.ChangeMessageVisibilityBatchRequestEntry{
  1236  			Id:                aws.String(strconv.Itoa(len(req.Entries))),
  1237  			ReceiptHandle:     id.(*string),
  1238  			VisibilityTimeout: aws.Int64(0),
  1239  		})
  1240  	}
  1241  	resp, err := s.client.ChangeMessageVisibilityBatchWithContext(ctx, req)
  1242  	if err != nil {
  1243  		return err
  1244  	}
  1245  	// Note: ChangeMessageVisibilityBatch returns failures when you try to
  1246  	// modify an id that isn't found; drop those.
  1247  	var firstFail *sqs.BatchResultErrorEntry
  1248  	numFailed := 0
  1249  	for _, fail := range resp.Failed {
  1250  		if aws.StringValue(fail.Code) == sqs.ErrCodeReceiptHandleIsInvalid {
  1251  			continue
  1252  		}
  1253  		if numFailed == 0 {
  1254  			firstFail = fail
  1255  		}
  1256  		numFailed++
  1257  	}
  1258  	if numFailed > 0 {
  1259  		return awserr.New(aws.StringValue(firstFail.Code), fmt.Sprintf("sqs.ChangeMessageVisibilityBatch failed for %d message(s): %s", numFailed, aws.StringValue(firstFail.Message)), nil)
  1260  	}
  1261  	return nil
  1262  }
  1263  
  1264  // IsRetryable implements driver.Subscription.IsRetryable.
  1265  func (*subscription) IsRetryable(error) bool {
  1266  	// The client handles retries.
  1267  	return false
  1268  }
  1269  
  1270  // As implements driver.Subscription.As.
  1271  func (s *subscription) As(i interface{}) bool {
  1272  	if s.useV2 {
  1273  		c, ok := i.(**sqsv2.Client)
  1274  		if !ok {
  1275  			return false
  1276  		}
  1277  		*c = s.clientV2
  1278  		return true
  1279  	}
  1280  	c, ok := i.(**sqs.SQS)
  1281  	if !ok {
  1282  		return false
  1283  	}
  1284  	*c = s.client
  1285  	return true
  1286  }
  1287  
  1288  // ErrorAs implements driver.Subscription.ErrorAs.
  1289  func (s *subscription) ErrorAs(err error, i interface{}) bool {
  1290  	return errorAs(err, s.useV2, i)
  1291  }
  1292  
  1293  // ErrorCode implements driver.Subscription.ErrorCode.
  1294  func (s *subscription) ErrorCode(err error) gcerrors.ErrorCode {
  1295  	return errorCode(err)
  1296  }
  1297  
  1298  func errorAs(err error, useV2 bool, i interface{}) bool {
  1299  	if useV2 {
  1300  		return errors.As(err, i)
  1301  	}
  1302  	e, ok := err.(awserr.Error)
  1303  	if !ok {
  1304  		return false
  1305  	}
  1306  	p, ok := i.(*awserr.Error)
  1307  	if !ok {
  1308  		return false
  1309  	}
  1310  	*p = e
  1311  	return true
  1312  }
  1313  
  1314  // Close implements driver.Subscription.Close.
  1315  func (*subscription) Close() error { return nil }