github.com/thiagoyeds/go-cloud@v0.26.0/pubsub/awssnssqs/awssnssqs_test.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package awssnssqs
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"net/http"
    22  	"strings"
    23  	"sync/atomic"
    24  	"testing"
    25  
    26  	snsv2 "github.com/aws/aws-sdk-go-v2/service/sns"
    27  	sqsv2 "github.com/aws/aws-sdk-go-v2/service/sqs"
    28  	sqstypesv2 "github.com/aws/aws-sdk-go-v2/service/sqs/types"
    29  	"github.com/aws/aws-sdk-go/aws"
    30  	"github.com/aws/aws-sdk-go/aws/awserr"
    31  	"github.com/aws/aws-sdk-go/aws/session"
    32  	"github.com/aws/aws-sdk-go/service/sns"
    33  	"github.com/aws/aws-sdk-go/service/sqs"
    34  	"github.com/aws/smithy-go"
    35  	"gocloud.dev/internal/testing/setup"
    36  	"gocloud.dev/pubsub"
    37  	"gocloud.dev/pubsub/driver"
    38  	"gocloud.dev/pubsub/drivertest"
    39  )
    40  
    41  const (
    42  	region        = "us-east-2"
    43  	accountNumber = "462380225722"
    44  )
    45  
    46  // We run conformance tests against multiple kinds of topics; this enum
    47  // represents which one we're doing.
    48  type topicKind string
    49  
    50  const (
    51  	topicKindSNS    = topicKind("SNS")    // send through an SNS topic
    52  	topicKindSNSRaw = topicKind("SNSRaw") // send through an SNS topic using RawMessageDelivery=true
    53  	topicKindSQS    = topicKind("SQS")    // send directly to an SQS queue
    54  )
    55  
    56  func newSession() (*session.Session, error) {
    57  	return session.NewSession(&aws.Config{
    58  		HTTPClient: &http.Client{},
    59  		Region:     aws.String(region),
    60  		MaxRetries: aws.Int(0),
    61  	})
    62  }
    63  
    64  type harness struct {
    65  	useV2       bool
    66  	sess        *session.Session
    67  	snsClientV2 *snsv2.Client
    68  	sqsClientV2 *sqsv2.Client
    69  	topicKind   topicKind
    70  	rt          http.RoundTripper
    71  	closer      func()
    72  	numTopics   uint32
    73  	numSubs     uint32
    74  }
    75  
    76  func newHarness(ctx context.Context, t *testing.T, topicKind topicKind) (drivertest.Harness, error) {
    77  	sess, rt, closer, _ := setup.NewAWSSession(ctx, t, region)
    78  	return &harness{useV2: false, sess: sess, rt: rt, topicKind: topicKind, closer: closer}, nil
    79  }
    80  
    81  func newHarnessV2(ctx context.Context, t *testing.T, topicKind topicKind) (drivertest.Harness, error) {
    82  	cfg, rt, closer, _ := setup.NewAWSv2Config(context.Background(), t, region)
    83  	return &harness{useV2: true, snsClientV2: snsv2.NewFromConfig(cfg), sqsClientV2: sqsv2.NewFromConfig(cfg), rt: rt, topicKind: topicKind, closer: closer}, nil
    84  }
    85  
    86  func (h *harness) CreateTopic(ctx context.Context, testName string) (dt driver.Topic, cleanup func(), err error) {
    87  	topicName := sanitize(fmt.Sprintf("%s-top-%d", testName, atomic.AddUint32(&h.numTopics, 1)))
    88  	return createTopic(ctx, topicName, h.useV2, h.sess, h.snsClientV2, h.sqsClientV2, h.topicKind)
    89  }
    90  
    91  func createTopic(ctx context.Context, topicName string, useV2 bool, sess *session.Session, snsClientV2 *snsv2.Client, sqsClientV2 *sqsv2.Client, topicKind topicKind) (dt driver.Topic, cleanup func(), err error) {
    92  	switch topicKind {
    93  	case topicKindSNS, topicKindSNSRaw:
    94  		// Create an SNS topic.
    95  		if useV2 {
    96  			out, err := snsClientV2.CreateTopic(ctx, &snsv2.CreateTopicInput{Name: aws.String(topicName)})
    97  			if err != nil {
    98  				return nil, nil, fmt.Errorf("creating SNS topic %q: %v", topicName, err)
    99  			}
   100  			dt = openSNSTopicV2(ctx, snsClientV2, *out.TopicArn, nil)
   101  			cleanup = func() {
   102  				snsClientV2.DeleteTopic(ctx, &snsv2.DeleteTopicInput{TopicArn: out.TopicArn})
   103  			}
   104  		} else {
   105  			client := sns.New(sess)
   106  			out, err := client.CreateTopicWithContext(ctx, &sns.CreateTopicInput{Name: aws.String(topicName)})
   107  			if err != nil {
   108  				return nil, nil, fmt.Errorf("creating SNS topic %q: %v", topicName, err)
   109  			}
   110  			dt = openSNSTopic(ctx, client, *out.TopicArn, nil)
   111  			cleanup = func() {
   112  				client.DeleteTopicWithContext(ctx, &sns.DeleteTopicInput{TopicArn: out.TopicArn})
   113  			}
   114  		}
   115  		return dt, cleanup, nil
   116  	case topicKindSQS:
   117  		// Create an SQS queue.
   118  		if useV2 {
   119  			qURL, _, err := createSQSQueue(ctx, true, nil, sqsClientV2, topicName)
   120  			if err != nil {
   121  				return nil, nil, fmt.Errorf("creating SQS queue %q: %v", topicName, err)
   122  			}
   123  			dt = openSQSTopicV2(ctx, sqsClientV2, qURL, nil)
   124  			cleanup = func() {
   125  				sqsClientV2.DeleteQueue(ctx, &sqsv2.DeleteQueueInput{QueueUrl: aws.String(qURL)})
   126  			}
   127  		} else {
   128  			sqsClient := sqs.New(sess)
   129  			qURL, _, err := createSQSQueue(ctx, false, sqsClient, nil, topicName)
   130  			if err != nil {
   131  				return nil, nil, fmt.Errorf("creating SQS queue %q: %v", topicName, err)
   132  			}
   133  			dt = openSQSTopic(ctx, sqsClient, qURL, nil)
   134  			cleanup = func() {
   135  				sqsClient.DeleteQueueWithContext(ctx, &sqs.DeleteQueueInput{QueueUrl: aws.String(qURL)})
   136  			}
   137  		}
   138  		return dt, cleanup, nil
   139  	default:
   140  		panic("unreachable")
   141  	}
   142  }
   143  
   144  func (h *harness) MakeNonexistentTopic(ctx context.Context) (driver.Topic, error) {
   145  	switch h.topicKind {
   146  	case topicKindSNS, topicKindSNSRaw:
   147  		const fakeTopicARN = "arn:aws:sns:" + region + ":" + accountNumber + ":nonexistenttopic"
   148  		if h.useV2 {
   149  			return openSNSTopicV2(ctx, h.snsClientV2, fakeTopicARN, nil), nil
   150  		} else {
   151  		}
   152  		return openSNSTopic(ctx, sns.New(h.sess), fakeTopicARN, nil), nil
   153  	case topicKindSQS:
   154  		const fakeQueueURL = "https://" + region + ".amazonaws.com/" + accountNumber + "/nonexistent-queue"
   155  		if h.useV2 {
   156  			return openSQSTopicV2(ctx, h.sqsClientV2, fakeQueueURL, nil), nil
   157  		} else {
   158  		}
   159  		return openSQSTopic(ctx, sqs.New(h.sess), fakeQueueURL, nil), nil
   160  	default:
   161  		panic("unreachable")
   162  	}
   163  }
   164  
   165  func (h *harness) CreateSubscription(ctx context.Context, dt driver.Topic, testName string) (ds driver.Subscription, cleanup func(), err error) {
   166  	subName := sanitize(fmt.Sprintf("%s-sub-%d", testName, atomic.AddUint32(&h.numSubs, 1)))
   167  	return createSubscription(ctx, dt, subName, h.useV2, h.sess, h.snsClientV2, h.sqsClientV2, h.topicKind)
   168  }
   169  
   170  func createSubscription(ctx context.Context, dt driver.Topic, subName string, useV2 bool, sess *session.Session, snsClientV2 *snsv2.Client, sqsClientV2 *sqsv2.Client, topicKind topicKind) (ds driver.Subscription, cleanup func(), err error) {
   171  	switch topicKind {
   172  	case topicKindSNS, topicKindSNSRaw:
   173  		// Create an SQS queue, and subscribe it to the SNS topic.
   174  		var qURL, qARN string
   175  		var err error
   176  		if useV2 {
   177  			qURL, qARN, err = createSQSQueue(ctx, true, nil, sqsClientV2, subName)
   178  			if err != nil {
   179  				return nil, nil, fmt.Errorf("creating SQS queue %q: %v", subName, err)
   180  			}
   181  			ds = openSubscriptionV2(ctx, sqsClientV2, qURL, nil)
   182  		} else {
   183  			sqsClient := sqs.New(sess)
   184  			qURL, qARN, err = createSQSQueue(ctx, false, sqsClient, nil, subName)
   185  			if err != nil {
   186  				return nil, nil, fmt.Errorf("creating SQS queue %q: %v", subName, err)
   187  			}
   188  			ds = openSubscription(ctx, sqsClient, qURL, nil)
   189  		}
   190  
   191  		snsTopicARN := dt.(*snsTopic).arn
   192  		var cleanup func()
   193  		if useV2 {
   194  			req := &snsv2.SubscribeInput{
   195  				TopicArn: aws.String(snsTopicARN),
   196  				Endpoint: aws.String(qARN),
   197  				Protocol: aws.String("sqs"),
   198  			}
   199  			// Enable RawMessageDelivery on the subscription if needed.
   200  			if topicKind == topicKindSNSRaw {
   201  				req.Attributes = map[string]string{"RawMessageDelivery": "true"}
   202  			}
   203  			out, err := snsClientV2.Subscribe(ctx, req)
   204  			if err != nil {
   205  				return nil, nil, fmt.Errorf("subscribing: %v", err)
   206  			}
   207  			cleanup = func() {
   208  				snsClientV2.Unsubscribe(ctx, &snsv2.UnsubscribeInput{SubscriptionArn: out.SubscriptionArn})
   209  				sqsClientV2.DeleteQueue(ctx, &sqsv2.DeleteQueueInput{QueueUrl: aws.String(qURL)})
   210  			}
   211  		} else {
   212  			snsClient := sns.New(sess)
   213  			req := &sns.SubscribeInput{
   214  				TopicArn: aws.String(snsTopicARN),
   215  				Endpoint: aws.String(qARN),
   216  				Protocol: aws.String("sqs"),
   217  			}
   218  			// Enable RawMessageDelivery on the subscription if needed.
   219  			if topicKind == topicKindSNSRaw {
   220  				req.Attributes = map[string]*string{"RawMessageDelivery": aws.String("true")}
   221  			}
   222  			out, err := snsClient.SubscribeWithContext(ctx, req)
   223  			if err != nil {
   224  				return nil, nil, fmt.Errorf("subscribing: %v", err)
   225  			}
   226  			cleanup = func() {
   227  				sqsClient := sqs.New(sess)
   228  				snsClient.UnsubscribeWithContext(ctx, &sns.UnsubscribeInput{SubscriptionArn: out.SubscriptionArn})
   229  				sqsClient.DeleteQueueWithContext(ctx, &sqs.DeleteQueueInput{QueueUrl: aws.String(qURL)})
   230  			}
   231  		}
   232  		return ds, cleanup, nil
   233  	case topicKindSQS:
   234  		// The SQS queue already exists; we created it for the topic. Re-use it
   235  		// for the subscription.
   236  		qURL := dt.(*sqsTopic).qURL
   237  		if useV2 {
   238  			return openSubscriptionV2(ctx, sqsClientV2, qURL, nil), func() {}, nil
   239  		} else {
   240  			return openSubscription(ctx, sqs.New(sess), qURL, nil), func() {}, nil
   241  		}
   242  	default:
   243  		panic("unreachable")
   244  	}
   245  }
   246  
   247  func createSQSQueue(ctx context.Context, useV2 bool, sqsClient *sqs.SQS, sqsClientV2 *sqsv2.Client, topicName string) (string, string, error) {
   248  	var qURL string
   249  	if useV2 {
   250  		out, err := sqsClientV2.CreateQueue(ctx, &sqsv2.CreateQueueInput{QueueName: aws.String(topicName)})
   251  		if err != nil {
   252  			return "", "", fmt.Errorf("creating SQS queue %q: %v", topicName, err)
   253  		}
   254  		qURL = aws.StringValue(out.QueueUrl)
   255  	} else {
   256  		out, err := sqsClient.CreateQueueWithContext(ctx, &sqs.CreateQueueInput{QueueName: aws.String(topicName)})
   257  		if err != nil {
   258  			return "", "", fmt.Errorf("creating SQS queue %q: %v", topicName, err)
   259  		}
   260  		qURL = aws.StringValue(out.QueueUrl)
   261  	}
   262  
   263  	// Get the ARN.
   264  	var qARN string
   265  	if useV2 {
   266  		out2, err := sqsClientV2.GetQueueAttributes(ctx, &sqsv2.GetQueueAttributesInput{
   267  			QueueUrl:       aws.String(qURL),
   268  			AttributeNames: []sqstypesv2.QueueAttributeName{"QueueArn"},
   269  		})
   270  		if err != nil {
   271  			return "", "", fmt.Errorf("getting queue ARN for %s: %v", qURL, err)
   272  		}
   273  		qARN = out2.Attributes["QueueArn"]
   274  	} else {
   275  		out2, err := sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{
   276  			QueueUrl:       aws.String(qURL),
   277  			AttributeNames: []*string{aws.String("QueueArn")},
   278  		})
   279  		if err != nil {
   280  			return "", "", fmt.Errorf("getting queue ARN for %s: %v", qURL, err)
   281  		}
   282  		qARN = aws.StringValue(out2.Attributes["QueueArn"])
   283  	}
   284  
   285  	queuePolicy := `{
   286  		"Version": "2012-10-17",
   287  		"Id": "AllowQueue",
   288  		"Statement": [
   289  		{
   290  		"Sid": "MySQSPolicy001",
   291  		"Effect": "Allow",
   292  		"Principal": {
   293  		"AWS": "*"
   294  		},
   295  		"Action": "sqs:SendMessage",
   296  		"Resource": "` + qARN + `"
   297  		}
   298  		]
   299  		}`
   300  	var err error
   301  	if useV2 {
   302  		_, err = sqsClientV2.SetQueueAttributes(ctx, &sqsv2.SetQueueAttributesInput{
   303  			Attributes: map[string]string{"Policy": queuePolicy},
   304  			QueueUrl:   aws.String(qURL),
   305  		})
   306  	} else {
   307  		_, err = sqsClient.SetQueueAttributesWithContext(ctx, &sqs.SetQueueAttributesInput{
   308  			Attributes: map[string]*string{"Policy": &queuePolicy},
   309  			QueueUrl:   aws.String(qURL),
   310  		})
   311  	}
   312  	if err != nil {
   313  		return "", "", fmt.Errorf("setting policy: %v", err)
   314  	}
   315  	return qURL, qARN, nil
   316  }
   317  
   318  func (h *harness) MakeNonexistentSubscription(ctx context.Context) (driver.Subscription, func(), error) {
   319  	const fakeSubscriptionQueueURL = "https://" + region + ".amazonaws.com/" + accountNumber + "/nonexistent-subscription"
   320  	if h.useV2 {
   321  		return openSubscriptionV2(ctx, h.sqsClientV2, fakeSubscriptionQueueURL, nil), func() {}, nil
   322  	} else {
   323  		return openSubscription(ctx, sqs.New(h.sess), fakeSubscriptionQueueURL, nil), func() {}, nil
   324  	}
   325  }
   326  
   327  func (h *harness) Close() {
   328  	h.closer()
   329  }
   330  
   331  func (h *harness) MaxBatchSizes() (int, int) {
   332  	if h.topicKind == topicKindSQS {
   333  		return sendBatcherOptsSQS.MaxBatchSize, ackBatcherOpts.MaxBatchSize
   334  	}
   335  	return sendBatcherOptsSNS.MaxBatchSize, ackBatcherOpts.MaxBatchSize
   336  }
   337  
   338  func (h *harness) SupportsMultipleSubscriptions() bool {
   339  	// If we're publishing to an SQS topic, we're reading from the same topic,
   340  	// so there's no way to get multiple subscriptions.
   341  	return h.topicKind != topicKindSQS
   342  }
   343  
   344  func TestConformanceSNSTopic(t *testing.T) {
   345  	asTests := []drivertest.AsTest{awsAsTest{useV2: false, topicKind: topicKindSNS}}
   346  	newSNSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   347  		return newHarness(ctx, t, topicKindSNS)
   348  	}
   349  	drivertest.RunConformanceTests(t, newSNSHarness, asTests)
   350  }
   351  
   352  func TestConformanceSNSTopicV2(t *testing.T) {
   353  	asTests := []drivertest.AsTest{awsAsTest{useV2: true, topicKind: topicKindSNS}}
   354  	newSNSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   355  		return newHarnessV2(ctx, t, topicKindSNS)
   356  	}
   357  	drivertest.RunConformanceTests(t, newSNSHarness, asTests)
   358  }
   359  
   360  func TestConformanceSNSTopicRaw(t *testing.T) {
   361  	asTests := []drivertest.AsTest{awsAsTest{useV2: false, topicKind: topicKindSNSRaw}}
   362  	newSNSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   363  		return newHarness(ctx, t, topicKindSNSRaw)
   364  	}
   365  	drivertest.RunConformanceTests(t, newSNSHarness, asTests)
   366  }
   367  
   368  func TestConformanceSNSTopicRawV2(t *testing.T) {
   369  	asTests := []drivertest.AsTest{awsAsTest{useV2: true, topicKind: topicKindSNSRaw}}
   370  	newSNSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   371  		return newHarnessV2(ctx, t, topicKindSNSRaw)
   372  	}
   373  	drivertest.RunConformanceTests(t, newSNSHarness, asTests)
   374  }
   375  
   376  func TestConformanceSQSTopic(t *testing.T) {
   377  	asTests := []drivertest.AsTest{awsAsTest{useV2: false, topicKind: topicKindSQS}}
   378  	newSQSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   379  		return newHarness(ctx, t, topicKindSQS)
   380  	}
   381  	drivertest.RunConformanceTests(t, newSQSHarness, asTests)
   382  }
   383  
   384  func TestConformanceSQSTopicV2(t *testing.T) {
   385  	asTests := []drivertest.AsTest{awsAsTest{useV2: true, topicKind: topicKindSQS}}
   386  	newSQSHarness := func(ctx context.Context, t *testing.T) (drivertest.Harness, error) {
   387  		return newHarnessV2(ctx, t, topicKindSQS)
   388  	}
   389  	drivertest.RunConformanceTests(t, newSQSHarness, asTests)
   390  }
   391  
   392  type awsAsTest struct {
   393  	useV2     bool
   394  	topicKind topicKind
   395  }
   396  
   397  func (awsAsTest) Name() string {
   398  	return "aws test"
   399  }
   400  
   401  func (t awsAsTest) TopicCheck(topic *pubsub.Topic) error {
   402  	switch t.topicKind {
   403  	case topicKindSNS, topicKindSNSRaw:
   404  		if t.useV2 {
   405  			var s *snsv2.Client
   406  			if !topic.As(&s) {
   407  				return fmt.Errorf("cast failed for %T", s)
   408  			}
   409  		} else {
   410  			var s *sns.SNS
   411  			if !topic.As(&s) {
   412  				return fmt.Errorf("cast failed for %T", s)
   413  			}
   414  		}
   415  	case topicKindSQS:
   416  		if t.useV2 {
   417  			var s *sqsv2.Client
   418  			if !topic.As(&s) {
   419  				return fmt.Errorf("cast failed for %T", s)
   420  			}
   421  		} else {
   422  			var s *sqs.SQS
   423  			if !topic.As(&s) {
   424  				return fmt.Errorf("cast failed for %T", s)
   425  			}
   426  		}
   427  	default:
   428  		panic("unreachable")
   429  	}
   430  	return nil
   431  }
   432  
   433  func (t awsAsTest) SubscriptionCheck(sub *pubsub.Subscription) error {
   434  	if t.useV2 {
   435  		var s *sqsv2.Client
   436  		if !sub.As(&s) {
   437  			return fmt.Errorf("cast failed for %T", s)
   438  		}
   439  	} else {
   440  		var s *sqs.SQS
   441  		if !sub.As(&s) {
   442  			return fmt.Errorf("cast failed for %T", s)
   443  		}
   444  	}
   445  	return nil
   446  }
   447  
   448  func (t awsAsTest) TopicErrorCheck(topic *pubsub.Topic, err error) error {
   449  	if t.useV2 {
   450  		var e smithy.APIError
   451  		if !topic.ErrorAs(err, &e) {
   452  			return errors.New("Topic.ErrorAs failed")
   453  		}
   454  		switch t.topicKind {
   455  		case topicKindSNS, topicKindSNSRaw:
   456  			if got, want := e.ErrorCode(), sns.ErrCodeNotFoundException; got != want {
   457  				return fmt.Errorf("got %q, want %q", got, want)
   458  			}
   459  		case topicKindSQS:
   460  			if got, want := e.ErrorCode(), sqs.ErrCodeQueueDoesNotExist; got != want {
   461  				return fmt.Errorf("got %q, want %q", got, want)
   462  			}
   463  		default:
   464  			panic("unreachable")
   465  		}
   466  		return nil
   467  	}
   468  	var ae awserr.Error
   469  	if !topic.ErrorAs(err, &ae) {
   470  		return fmt.Errorf("failed to convert %v (%T) to an awserr.Error", err, err)
   471  	}
   472  	switch t.topicKind {
   473  	case topicKindSNS, topicKindSNSRaw:
   474  		if got, want := ae.Code(), sns.ErrCodeNotFoundException; got != want {
   475  			return fmt.Errorf("got %q, want %q", got, want)
   476  		}
   477  	case topicKindSQS:
   478  		if got, want := ae.Code(), sqs.ErrCodeQueueDoesNotExist; got != want {
   479  			return fmt.Errorf("got %q, want %q", got, want)
   480  		}
   481  	default:
   482  		panic("unreachable")
   483  	}
   484  	return nil
   485  }
   486  
   487  func (t awsAsTest) SubscriptionErrorCheck(s *pubsub.Subscription, err error) error {
   488  	if t.useV2 {
   489  		var e smithy.APIError
   490  		if !s.ErrorAs(err, &e) {
   491  			return errors.New("Subscription.ErrorAs failed")
   492  		}
   493  		if got, want := e.ErrorCode(), sqs.ErrCodeQueueDoesNotExist; got != want {
   494  			return fmt.Errorf("got %q, want %q", got, want)
   495  		}
   496  		return nil
   497  	}
   498  	var ae awserr.Error
   499  	if !s.ErrorAs(err, &ae) {
   500  		return fmt.Errorf("failed to convert %v (%T) to an awserr.Error", err, err)
   501  	}
   502  	if got, want := ae.Code(), sqs.ErrCodeQueueDoesNotExist; got != want {
   503  		return fmt.Errorf("got %q, want %q", got, want)
   504  	}
   505  	return nil
   506  }
   507  
   508  func (t awsAsTest) MessageCheck(m *pubsub.Message) error {
   509  	if t.useV2 {
   510  		var sm sqstypesv2.Message
   511  		if !m.As(&sm) {
   512  			return fmt.Errorf("cast failed for %T", &sm)
   513  		}
   514  	} else {
   515  		var sm sqs.Message
   516  		if m.As(&sm) {
   517  			return fmt.Errorf("cast succeeded for %T, want failure", &sm)
   518  		}
   519  		var psm *sqs.Message
   520  		if !m.As(&psm) {
   521  			return fmt.Errorf("cast failed for %T", &psm)
   522  		}
   523  	}
   524  	return nil
   525  }
   526  
   527  func (t awsAsTest) BeforeSend(as func(interface{}) bool) error {
   528  	switch t.topicKind {
   529  	case topicKindSNS, topicKindSNSRaw:
   530  		if t.useV2 {
   531  			var pub *snsv2.PublishInput
   532  			if !as(&pub) {
   533  				return fmt.Errorf("cast failed for %T", &pub)
   534  			}
   535  		} else {
   536  			var pub *sns.PublishInput
   537  			if !as(&pub) {
   538  				return fmt.Errorf("cast failed for %T", &pub)
   539  			}
   540  		}
   541  	case topicKindSQS:
   542  		if t.useV2 {
   543  			var entry sqstypesv2.SendMessageBatchRequestEntry
   544  			if !as(&entry) {
   545  				return fmt.Errorf("cast failed for %T", &entry)
   546  			}
   547  		} else {
   548  			var smi *sqs.SendMessageInput
   549  			if !as(&smi) {
   550  				return fmt.Errorf("cast failed for %T", &smi)
   551  			}
   552  			var entry *sqs.SendMessageBatchRequestEntry
   553  			if !as(&entry) {
   554  				return fmt.Errorf("cast failed for %T", &entry)
   555  			}
   556  		}
   557  	default:
   558  		panic("unreachable")
   559  	}
   560  	return nil
   561  }
   562  
   563  func (t awsAsTest) AfterSend(as func(interface{}) bool) error {
   564  	switch t.topicKind {
   565  	case topicKindSNS, topicKindSNSRaw:
   566  		if t.useV2 {
   567  			var pub *snsv2.PublishOutput
   568  			if !as(&pub) {
   569  				return fmt.Errorf("cast failed for %T", &pub)
   570  			}
   571  		} else {
   572  			var pub *sns.PublishOutput
   573  			if !as(&pub) {
   574  				return fmt.Errorf("cast failed for %T", &pub)
   575  			}
   576  		}
   577  	case topicKindSQS:
   578  		if t.useV2 {
   579  			var entry sqstypesv2.SendMessageBatchResultEntry
   580  			if !as(&entry) {
   581  				return fmt.Errorf("cast failed for %T", &entry)
   582  			}
   583  		} else {
   584  			var entry *sqs.SendMessageBatchResultEntry
   585  			if !as(&entry) {
   586  				return fmt.Errorf("cast failed for %T", &entry)
   587  			}
   588  		}
   589  	default:
   590  		panic("unreachable")
   591  	}
   592  	return nil
   593  }
   594  
   595  func sanitize(s string) string {
   596  	// AWS doesn't like names that are too long; trim some not-so-useful stuff.
   597  	const maxNameLen = 80
   598  	s = strings.Replace(s, "TestConformance", "", 1)
   599  	s = strings.Replace(s, "/Test", "", 1)
   600  	s = strings.Replace(s, "/", "_", -1)
   601  	if len(s) > maxNameLen {
   602  		// Drop prefix, not suffix, because suffix includes something to make
   603  		// entities unique within a test.
   604  		s = s[len(s)-maxNameLen:]
   605  	}
   606  	return s
   607  }
   608  func BenchmarkSNSSQS(b *testing.B) {
   609  	benchmark(b, topicKindSNS)
   610  }
   611  
   612  func BenchmarkSQS(b *testing.B) {
   613  	benchmark(b, topicKindSQS)
   614  }
   615  
   616  func benchmark(b *testing.B, topicKind topicKind) {
   617  	ctx := context.Background()
   618  	sess, err := session.NewSession(&aws.Config{
   619  		HTTPClient: &http.Client{},
   620  		Region:     aws.String(region),
   621  		MaxRetries: aws.Int(0),
   622  	})
   623  	if err != nil {
   624  		b.Fatal(err)
   625  	}
   626  	topicName := fmt.Sprintf("%s-topic", b.Name())
   627  	dt, cleanup1, err := createTopic(ctx, topicName, false, sess, nil, nil, topicKind)
   628  	if err != nil {
   629  		b.Fatal(err)
   630  	}
   631  	defer cleanup1()
   632  	sendBatcherOpts := sendBatcherOptsSNS
   633  	if topicKind == topicKindSQS {
   634  		sendBatcherOpts = sendBatcherOptsSQS
   635  	}
   636  	topic := pubsub.NewTopic(dt, sendBatcherOpts)
   637  	defer topic.Shutdown(ctx)
   638  	subName := fmt.Sprintf("%s-subscription", b.Name())
   639  	ds, cleanup2, err := createSubscription(ctx, dt, subName, false, sess, nil, nil, topicKind)
   640  	if err != nil {
   641  		b.Fatal(err)
   642  	}
   643  	defer cleanup2()
   644  	sub := pubsub.NewSubscription(ds, recvBatcherOpts, ackBatcherOpts)
   645  	defer sub.Shutdown(ctx)
   646  	drivertest.RunBenchmarks(b, topic, sub)
   647  }
   648  
   649  func TestOpenTopicFromURL(t *testing.T) {
   650  	tests := []struct {
   651  		URL     string
   652  		WantErr bool
   653  	}{
   654  		// SNS...
   655  
   656  		// OK.
   657  		{"awssns:///arn:aws:service:region:accountid:resourceType/resourcePath", false},
   658  		// OK, setting region.
   659  		{"awssns:///arn:aws:service:region:accountid:resourceType/resourcePath?region=us-east-2", false},
   660  		// OK, setting usev2.
   661  		{"awssns:///arn:aws:service:region:accountid:resourceType/resourcePath?awssdk=v2", false},
   662  		// Invalid parameter.
   663  		{"awssns:///arn:aws:service:region:accountid:resourceType/resourcePath?param=value", true},
   664  
   665  		// SQS...
   666  		// OK.
   667  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue", false},
   668  		// OK, setting region.
   669  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?region=us-east-2", false},
   670  		// OK, setting usev2.
   671  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?awssdk=v2", false},
   672  		// Invalid parameter.
   673  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?param=value", true},
   674  	}
   675  
   676  	ctx := context.Background()
   677  	for _, test := range tests {
   678  		topic, err := pubsub.OpenTopic(ctx, test.URL)
   679  		if (err != nil) != test.WantErr {
   680  			t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
   681  		}
   682  		if topic != nil {
   683  			topic.Shutdown(ctx)
   684  		}
   685  	}
   686  }
   687  
   688  func TestOpenSubscriptionFromURL(t *testing.T) {
   689  	tests := []struct {
   690  		URL     string
   691  		WantErr bool
   692  	}{
   693  		// OK.
   694  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue", false},
   695  		// OK, setting region.
   696  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?region=us-east-2", false},
   697  		// OK, setting raw.
   698  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?raw=true", false},
   699  		// OK, setting raw.
   700  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?raw=1", false},
   701  		// Invalid raw.
   702  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?raw=foo", true},
   703  		// OK, setting waittime.
   704  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?waittime=5s", false},
   705  		// OK, setting usev2.
   706  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?awssdk=v2", false},
   707  		// Invalid waittime.
   708  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?waittime=foo", true},
   709  		// Invalid parameter.
   710  		{"awssqs://sqs.us-east-2.amazonaws.com/99999/my-queue?param=value", true},
   711  	}
   712  
   713  	ctx := context.Background()
   714  	for _, test := range tests {
   715  		sub, err := pubsub.OpenSubscription(ctx, test.URL)
   716  		if (err != nil) != test.WantErr {
   717  			t.Errorf("%s: got error %v, want error %v", test.URL, err, test.WantErr)
   718  		}
   719  		if sub != nil {
   720  			sub.Shutdown(ctx)
   721  		}
   722  	}
   723  }