github.com/aavshr/aws-sdk-go@v1.41.3/service/kinesis/cust_integ_shared_test.go (about)

     1  //go:build integration && go1.15
     2  // +build integration,go1.15
     3  
     4  package kinesis_test
     5  
     6  import (
     7  	crand "crypto/rand"
     8  	"crypto/tls"
     9  	"flag"
    10  	"fmt"
    11  	"io"
    12  	"math/rand"
    13  	"net/http"
    14  	"os"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/aavshr/aws-sdk-go/aws"
    19  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    20  	"github.com/aavshr/aws-sdk-go/awstesting/integration"
    21  	"github.com/aavshr/aws-sdk-go/service/kinesis"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  var (
    26  	skipTLSVerify    bool
    27  	hUsage           string
    28  	endpoint         string
    29  	streamName       string
    30  	consumerName     string
    31  	numRecords       int
    32  	recordSize       int
    33  	debugEventStream bool
    34  	mode             string
    35  
    36  	svc     *kinesis.Kinesis
    37  	records []*kinesis.PutRecordsRequestEntry
    38  
    39  	startingTimestamp time.Time
    40  )
    41  
    42  func init() {
    43  	flag.StringVar(
    44  		&mode, "mode", "all",
    45  		"Sets the mode to run in, (test,create,cleanup,all).",
    46  	)
    47  	flag.BoolVar(
    48  		&skipTLSVerify, "skip-verify", false,
    49  		"Skips verification of TLS certificate.",
    50  	)
    51  	flag.StringVar(
    52  		&hUsage, "http", "default",
    53  		"The HTTP `version` to use for the connection. (default,1,2)",
    54  	)
    55  	flag.StringVar(
    56  		&endpoint, "endpoint", "",
    57  		"Overrides SDK `URL` endpoint for tests.",
    58  	)
    59  	flag.StringVar(
    60  		&streamName, "stream", fmt.Sprintf("awsdkgo-s%v", UniqueID()),
    61  		"The `name` of the stream to test against.",
    62  	)
    63  	flag.StringVar(
    64  		&consumerName, "consumer", fmt.Sprintf("awsdkgo-c%v", UniqueID()),
    65  		"The `name` of the stream to test against.",
    66  	)
    67  	flag.IntVar(
    68  		&numRecords, "records", 20,
    69  		"The `number` of records per PutRecords to test with.",
    70  	)
    71  	flag.IntVar(
    72  		&recordSize, "record-size", 500,
    73  		"The size in `bytes` of each record.",
    74  	)
    75  	flag.BoolVar(
    76  		&debugEventStream, "debug-eventstream", false,
    77  		"Enables debugging of the EventStream messages",
    78  	)
    79  }
    80  
    81  func TestMain(m *testing.M) {
    82  	flag.Parse()
    83  
    84  	svc = createClient()
    85  
    86  	startingTimestamp = time.Now().Add(-time.Minute)
    87  
    88  	switch mode {
    89  	case "create", "all":
    90  		if err := createStream(streamName); err != nil {
    91  			panic(err)
    92  		}
    93  		if err := createStreamConsumer(streamName, consumerName); err != nil {
    94  			panic(err)
    95  		}
    96  		fmt.Println("Stream Ready:", streamName, consumerName)
    97  
    98  		if mode != "all" {
    99  			break
   100  		}
   101  		fallthrough
   102  	case "test":
   103  		records = createRecords(numRecords, recordSize)
   104  		if err := putRecords(streamName, records, svc); err != nil {
   105  			panic(err)
   106  		}
   107  		time.Sleep(time.Second)
   108  
   109  		var exitCode int
   110  		defer func() {
   111  			os.Exit(exitCode)
   112  		}()
   113  
   114  		exitCode = m.Run()
   115  
   116  		if mode != "all" {
   117  			break
   118  		}
   119  		fallthrough
   120  	case "cleanup":
   121  		if err := cleanupStreamConsumer(streamName, consumerName); err != nil {
   122  			panic(err)
   123  		}
   124  		if err := cleanupStream(streamName); err != nil {
   125  			panic(err)
   126  		}
   127  	default:
   128  		fmt.Fprintf(os.Stderr, "unknown mode, %v", mode)
   129  		os.Exit(1)
   130  	}
   131  }
   132  
   133  func createClient() *kinesis.Kinesis {
   134  	ts := &http.Transport{}
   135  
   136  	if skipTLSVerify {
   137  		ts.TLSClientConfig = &tls.Config{
   138  			InsecureSkipVerify: true,
   139  		}
   140  	}
   141  
   142  	http2.ConfigureTransport(ts)
   143  	switch hUsage {
   144  	case "default":
   145  		// Restore H2 optional support since the Transport/TLSConfig was
   146  		// modified.
   147  		http2.ConfigureTransport(ts)
   148  	case "1":
   149  		// Do nothing. Without usign ConfigureTransport h2 won't be available.
   150  		ts.TLSClientConfig.NextProtos = []string{"http/1.1"}
   151  	case "2":
   152  		// Force the TLS ALPN (NextProto) to H2 only.
   153  		ts.TLSClientConfig.NextProtos = []string{http2.NextProtoTLS}
   154  	default:
   155  		panic("unknown h usage, " + hUsage)
   156  	}
   157  
   158  	sess := integration.SessionWithDefaultRegion("us-west-2")
   159  	cfg := &aws.Config{
   160  		HTTPClient: &http.Client{
   161  			Transport: ts,
   162  		},
   163  	}
   164  	if debugEventStream {
   165  		cfg.LogLevel = aws.LogLevel(
   166  			sess.Config.LogLevel.Value() | aws.LogDebugWithEventStreamBody)
   167  	}
   168  
   169  	return kinesis.New(sess, cfg)
   170  }
   171  
   172  func createStream(name string) error {
   173  	descParams := &kinesis.DescribeStreamInput{
   174  		StreamName: &name,
   175  	}
   176  
   177  	_, err := svc.DescribeStream(descParams)
   178  	if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException {
   179  		_, err := svc.CreateStream(&kinesis.CreateStreamInput{
   180  			ShardCount: aws.Int64(100),
   181  			StreamName: &name,
   182  		})
   183  		if err != nil {
   184  			return fmt.Errorf("failed to create stream, %v", err)
   185  		}
   186  	} else if err != nil {
   187  		return fmt.Errorf("failed to describe stream, %v", err)
   188  	}
   189  
   190  	if err := svc.WaitUntilStreamExists(descParams); err != nil {
   191  		return fmt.Errorf("failed to wait for stream to exist, %v", err)
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  func cleanupStream(name string) error {
   198  	_, err := svc.DeleteStream(&kinesis.DeleteStreamInput{
   199  		StreamName:              &name,
   200  		EnforceConsumerDeletion: aws.Bool(true),
   201  	})
   202  	if err != nil {
   203  		return fmt.Errorf("failed to delete stream, %v", err)
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  func createStreamConsumer(streamName, consumerName string) error {
   210  	desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{
   211  		StreamName: &streamName,
   212  	})
   213  	if err != nil {
   214  		return fmt.Errorf("failed to describe stream, %s, %v", streamName, err)
   215  	}
   216  
   217  	descParams := &kinesis.DescribeStreamConsumerInput{
   218  		StreamARN:    desc.StreamDescription.StreamARN,
   219  		ConsumerName: &consumerName,
   220  	}
   221  	_, err = svc.DescribeStreamConsumer(descParams)
   222  	if aerr, ok := err.(awserr.Error); ok && aerr.Code() == kinesis.ErrCodeResourceNotFoundException {
   223  		_, err := svc.RegisterStreamConsumer(
   224  			&kinesis.RegisterStreamConsumerInput{
   225  				ConsumerName: aws.String(consumerName),
   226  				StreamARN:    desc.StreamDescription.StreamARN,
   227  			},
   228  		)
   229  		if err != nil {
   230  			return fmt.Errorf("failed to create stream consumer %s, %v",
   231  				consumerName, err)
   232  		}
   233  	} else if err != nil {
   234  		return fmt.Errorf("failed to describe stream consumer %s, %v",
   235  			consumerName, err)
   236  	}
   237  
   238  	for i := 0; i < 10; i++ {
   239  		resp, err := svc.DescribeStreamConsumer(descParams)
   240  		if err != nil || aws.StringValue(resp.ConsumerDescription.ConsumerStatus) != kinesis.ConsumerStatusActive {
   241  			time.Sleep(time.Second * 30)
   242  			continue
   243  		}
   244  		return nil
   245  	}
   246  
   247  	return fmt.Errorf("failed to wait for consumer to exist, %v, %v",
   248  		*descParams.StreamARN, *descParams.ConsumerName)
   249  }
   250  
   251  func cleanupStreamConsumer(streamName, consumerName string) error {
   252  	desc, err := svc.DescribeStream(&kinesis.DescribeStreamInput{
   253  		StreamName: &streamName,
   254  	})
   255  	if err != nil {
   256  		return fmt.Errorf("failed to describe stream, %s, %v",
   257  			streamName, err)
   258  	}
   259  
   260  	descCons, err := svc.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{
   261  		StreamARN:    desc.StreamDescription.StreamARN,
   262  		ConsumerName: &consumerName,
   263  	})
   264  	if err != nil {
   265  		return fmt.Errorf("failed to describe stream consumer, %s, %v",
   266  			consumerName, err)
   267  	}
   268  
   269  	_, err = svc.DeregisterStreamConsumer(
   270  		&kinesis.DeregisterStreamConsumerInput{
   271  			ConsumerName: descCons.ConsumerDescription.ConsumerName,
   272  			ConsumerARN:  descCons.ConsumerDescription.ConsumerARN,
   273  			StreamARN:    desc.StreamDescription.StreamARN,
   274  		},
   275  	)
   276  	if err != nil {
   277  		return fmt.Errorf("failed to delete stream consumer, %s %v",
   278  			consumerName, err)
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  func createRecords(num, size int) []*kinesis.PutRecordsRequestEntry {
   285  	var err error
   286  	data, err := loadRandomData(num, size)
   287  	if err != nil {
   288  		fmt.Fprintf(os.Stderr, "unable to read random data, %v", err)
   289  		os.Exit(1)
   290  	}
   291  
   292  	records := make([]*kinesis.PutRecordsRequestEntry, len(data))
   293  	for i, td := range data {
   294  		records[i] = &kinesis.PutRecordsRequestEntry{
   295  			Data:         td,
   296  			PartitionKey: aws.String(UniqueID()),
   297  		}
   298  	}
   299  
   300  	return records
   301  }
   302  
   303  func putRecords(stream string, records []*kinesis.PutRecordsRequestEntry, svc *kinesis.Kinesis) error {
   304  	resp, err := svc.PutRecords(&kinesis.PutRecordsInput{
   305  		StreamName: &stream,
   306  		Records:    records,
   307  	})
   308  	if err != nil {
   309  		return fmt.Errorf("failed to put records to stream %s, %v", stream, err)
   310  	}
   311  
   312  	if v := aws.Int64Value(resp.FailedRecordCount); v != 0 {
   313  		return fmt.Errorf("failed to put records to stream %s, %d failed",
   314  			stream, v)
   315  	}
   316  
   317  	return nil
   318  }
   319  
   320  func loadRandomData(m, n int) ([][]byte, error) {
   321  	data := make([]byte, m*n)
   322  
   323  	_, err := rand.Read(data)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	parts := make([][]byte, m)
   329  
   330  	for i := 0; i < m; i++ {
   331  		mod := (i % m)
   332  		parts[i] = data[mod*n : (mod+1)*n]
   333  	}
   334  
   335  	return parts, nil
   336  }
   337  
   338  // UniqueID returns a unique UUID-like identifier for use in generating
   339  // resources for integration tests.
   340  func UniqueID() string {
   341  	uuid := make([]byte, 16)
   342  	io.ReadFull(crand.Reader, uuid)
   343  	return fmt.Sprintf("%x", uuid)
   344  }