github.com/Jeffail/benthos/v3@v3.65.0/lib/output/writer/kinesis_test.go (about)

     1  package writer
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/Jeffail/benthos/v3/internal/bloblang"
     9  	"github.com/Jeffail/benthos/v3/lib/log"
    10  	"github.com/Jeffail/benthos/v3/lib/message"
    11  	"github.com/Jeffail/benthos/v3/lib/metrics"
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/credentials"
    14  	"github.com/aws/aws-sdk-go/aws/session"
    15  	"github.com/aws/aws-sdk-go/service/kinesis"
    16  	"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
    17  	"github.com/cenkalti/backoff/v4"
    18  )
    19  
    20  var (
    21  	mockStats = metrics.Noop()
    22  )
    23  
    24  var (
    25  	mThrottled       = mockStats.GetCounter("send.throttled")
    26  	mThrottledF      = mockStats.GetCounter("send.throttled")
    27  	mPartsThrottled  = mockStats.GetCounter("parts.send.throttled")
    28  	mPartsThrottledF = mockStats.GetCounter("parts.send.throttled")
    29  )
    30  
    31  type mockKinesis struct {
    32  	kinesisiface.KinesisAPI
    33  	fn func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error)
    34  }
    35  
    36  func (m *mockKinesis) PutRecords(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
    37  	return m.fn(input)
    38  }
    39  
    40  func TestKinesisWriteSinglePartMessage(t *testing.T) {
    41  	k := Kinesis{
    42  		backoffCtor: func() backoff.BackOff {
    43  			return backoff.NewExponentialBackOff()
    44  		},
    45  		session: session.Must(session.NewSession(&aws.Config{
    46  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
    47  		})),
    48  		kinesis: &mockKinesis{
    49  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
    50  				if exp, act := 1, len(input.Records); exp != act {
    51  					return nil, fmt.Errorf("expected input to have records with length %d, got %d", exp, act)
    52  				}
    53  				if exp, act := "123", input.Records[0].PartitionKey; exp != *act {
    54  					return nil, fmt.Errorf("expected record to have partition key %s, got %s", exp, *act)
    55  				}
    56  				return &kinesis.PutRecordsOutput{}, nil
    57  			},
    58  		},
    59  		log: log.Noop(),
    60  	}
    61  
    62  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
    63  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
    64  
    65  	msg := message.New(nil)
    66  	part := message.NewPart([]byte(`{"foo":"bar","id":123}`))
    67  	msg.Append(part)
    68  
    69  	if err := k.Write(msg); err != nil {
    70  		t.Error(err)
    71  	}
    72  }
    73  
    74  func TestKinesisWriteMultiPartMessage(t *testing.T) {
    75  	parts := []struct {
    76  		data []byte
    77  		key  string
    78  	}{
    79  		{[]byte(`{"foo":"bar","id":123}`), "123"},
    80  		{[]byte(`{"foo":"baz","id":456}`), "456"},
    81  	}
    82  	k := Kinesis{
    83  		backoffCtor: func() backoff.BackOff {
    84  			return backoff.NewExponentialBackOff()
    85  		},
    86  		session: session.Must(session.NewSession(&aws.Config{
    87  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
    88  		})),
    89  		kinesis: &mockKinesis{
    90  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
    91  				if exp, act := len(parts), len(input.Records); exp != act {
    92  					return nil, fmt.Errorf("expected input to have records with length %d, got %d", exp, act)
    93  				}
    94  				for i, p := range parts {
    95  					if exp, act := p.key, input.Records[i].PartitionKey; exp != *act {
    96  						return nil, fmt.Errorf("expected record %d to have partition key %s, got %s", i, exp, *act)
    97  					}
    98  				}
    99  				return &kinesis.PutRecordsOutput{}, nil
   100  			},
   101  		},
   102  		log: log.Noop(),
   103  	}
   104  
   105  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   106  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   107  
   108  	msg := message.New(nil)
   109  	for _, p := range parts {
   110  		part := message.NewPart(p.data)
   111  		msg.Append(part)
   112  	}
   113  
   114  	if err := k.Write(msg); err != nil {
   115  		t.Error(err)
   116  	}
   117  }
   118  
   119  func TestKinesisWriteChunk(t *testing.T) {
   120  	batchLengths := []int{}
   121  	n := 1200
   122  	k := Kinesis{
   123  		backoffCtor: func() backoff.BackOff {
   124  			return backoff.NewExponentialBackOff()
   125  		},
   126  		session: session.Must(session.NewSession(&aws.Config{
   127  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
   128  		})),
   129  		kinesis: &mockKinesis{
   130  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
   131  				batchLengths = append(batchLengths, len(input.Records))
   132  				return &kinesis.PutRecordsOutput{}, nil
   133  			},
   134  		},
   135  		log: log.Noop(),
   136  	}
   137  
   138  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   139  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   140  
   141  	msg := message.New(nil)
   142  	for i := 0; i < n; i++ {
   143  		part := message.NewPart([]byte(`{"foo":"bar","id":123}`))
   144  		msg.Append(part)
   145  	}
   146  
   147  	if err := k.Write(msg); err != nil {
   148  		t.Error(err)
   149  	}
   150  	if exp, act := n/kinesisMaxRecordsCount+1, len(batchLengths); act != exp {
   151  		t.Errorf("Expected kinesis PutRecords to have call count %d, got %d", exp, act)
   152  	}
   153  	for i, act := range batchLengths {
   154  		exp := n
   155  		if exp > kinesisMaxRecordsCount {
   156  			exp = kinesisMaxRecordsCount
   157  			n -= kinesisMaxRecordsCount
   158  		}
   159  		if act != exp {
   160  			t.Errorf("Expected kinesis PutRecords call %d to have batch size %d, got %d", i, exp, act)
   161  		}
   162  	}
   163  }
   164  
   165  func TestKinesisWriteChunkWithThrottling(t *testing.T) {
   166  	t.Parallel()
   167  	batchLengths := []int{}
   168  	n := 1200
   169  	k := Kinesis{
   170  		backoffCtor: func() backoff.BackOff {
   171  			return backoff.NewExponentialBackOff()
   172  		},
   173  		session: session.Must(session.NewSession(&aws.Config{
   174  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
   175  		})),
   176  		kinesis: &mockKinesis{
   177  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
   178  				count := len(input.Records)
   179  				batchLengths = append(batchLengths, count)
   180  				var failed int64
   181  				output := kinesis.PutRecordsOutput{
   182  					Records: make([]*kinesis.PutRecordsResultEntry, count),
   183  				}
   184  				for i := 0; i < count; i++ {
   185  					var entry kinesis.PutRecordsResultEntry
   186  					if i >= 300 {
   187  						failed++
   188  						entry.SetErrorCode(kinesis.ErrCodeProvisionedThroughputExceededException)
   189  					}
   190  					output.Records[i] = &entry
   191  				}
   192  				output.SetFailedRecordCount(failed)
   193  				return &output, nil
   194  			},
   195  		},
   196  		mThrottled:       mThrottled,
   197  		mThrottledF:      mThrottledF,
   198  		mPartsThrottled:  mPartsThrottled,
   199  		mPartsThrottledF: mPartsThrottledF,
   200  		log:              log.Noop(),
   201  	}
   202  
   203  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   204  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   205  
   206  	msg := message.New(nil)
   207  	for i := 0; i < n; i++ {
   208  		part := message.NewPart([]byte(`{"foo":"bar","id":123}`))
   209  		msg.Append(part)
   210  	}
   211  
   212  	expectedLengths := []int{
   213  		500, 500, 500, 300,
   214  	}
   215  
   216  	if err := k.Write(msg); err != nil {
   217  		t.Error(err)
   218  	}
   219  	if exp, act := len(expectedLengths), len(batchLengths); act != exp {
   220  		t.Errorf("Expected kinesis PutRecords to have call count %d, got %d", exp, act)
   221  	}
   222  	for i, act := range batchLengths {
   223  		if exp := expectedLengths[i]; act != exp {
   224  			t.Errorf("Expected kinesis PutRecords call %d to have batch size %d, got %d", i, exp, act)
   225  		}
   226  	}
   227  }
   228  
   229  func TestKinesisWriteError(t *testing.T) {
   230  	t.Parallel()
   231  	var calls int
   232  	k := Kinesis{
   233  		backoffCtor: func() backoff.BackOff {
   234  			return backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 2)
   235  		},
   236  		session: session.Must(session.NewSession(&aws.Config{
   237  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
   238  		})),
   239  		kinesis: &mockKinesis{
   240  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
   241  				calls++
   242  				return nil, errors.New("blah")
   243  			},
   244  		},
   245  		log: log.Noop(),
   246  	}
   247  
   248  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   249  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   250  
   251  	msg := message.New(nil)
   252  	msg.Append(message.NewPart([]byte(`{"foo":"bar"}`)))
   253  
   254  	if exp, err := "blah", k.Write(msg); err.Error() != exp {
   255  		t.Errorf("Expected err to equal %s, got %v", exp, err)
   256  	}
   257  	if exp, act := 3, calls; act != exp {
   258  		t.Errorf("Expected kinesis.PutRecords to have call count %d, got %d", exp, act)
   259  	}
   260  }
   261  
   262  func TestKinesisWriteMessageThrottling(t *testing.T) {
   263  	t.Parallel()
   264  	var calls [][]*kinesis.PutRecordsRequestEntry
   265  	k := Kinesis{
   266  		backoffCtor: func() backoff.BackOff {
   267  			return backoff.NewExponentialBackOff()
   268  		},
   269  		session: session.Must(session.NewSession(&aws.Config{
   270  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
   271  		})),
   272  		kinesis: &mockKinesis{
   273  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
   274  				records := make([]*kinesis.PutRecordsRequestEntry, len(input.Records))
   275  				copy(records, input.Records)
   276  				calls = append(calls, records)
   277  				var failed int64
   278  				var output kinesis.PutRecordsOutput
   279  				for i := 0; i < len(input.Records); i++ {
   280  					entry := kinesis.PutRecordsResultEntry{}
   281  					if i > 0 {
   282  						failed++
   283  						entry.SetErrorCode(kinesis.ErrCodeProvisionedThroughputExceededException)
   284  					}
   285  					output.Records = append(output.Records, &entry)
   286  				}
   287  				output.SetFailedRecordCount(failed)
   288  				return &output, nil
   289  			},
   290  		},
   291  		mThrottled:       mThrottled,
   292  		mThrottledF:      mThrottledF,
   293  		mPartsThrottled:  mPartsThrottled,
   294  		mPartsThrottledF: mPartsThrottledF,
   295  		log:              log.Noop(),
   296  	}
   297  
   298  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   299  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   300  
   301  	msg := message.New(nil)
   302  	msg.Append(message.NewPart([]byte(`{"foo":"bar","id":123}`)))
   303  	msg.Append(message.NewPart([]byte(`{"foo":"baz","id":456}`)))
   304  	msg.Append(message.NewPart([]byte(`{"foo":"qux","id":789}`)))
   305  
   306  	if err := k.Write(msg); err != nil {
   307  		t.Error(err)
   308  	}
   309  	if exp, act := msg.Len(), len(calls); act != exp {
   310  		t.Errorf("Expected kinesis.PutRecords to have call count %d, got %d", exp, act)
   311  	}
   312  	for i, c := range calls {
   313  		if exp, act := msg.Len()-i, len(c); act != exp {
   314  			t.Errorf("Expected kinesis.PutRecords call %d input to have Records with length %d, got %d", i, exp, act)
   315  		}
   316  	}
   317  }
   318  
   319  func TestKinesisWriteBackoffMaxRetriesExceeded(t *testing.T) {
   320  	t.Parallel()
   321  	var calls int
   322  	k := Kinesis{
   323  		backoffCtor: func() backoff.BackOff {
   324  			return backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 2)
   325  		},
   326  		session: session.Must(session.NewSession(&aws.Config{
   327  			Credentials: credentials.NewStaticCredentials("xxxxx", "xxxxx", "xxxxx"),
   328  		})),
   329  		kinesis: &mockKinesis{
   330  			fn: func(input *kinesis.PutRecordsInput) (*kinesis.PutRecordsOutput, error) {
   331  				calls++
   332  				var output kinesis.PutRecordsOutput
   333  				output.FailedRecordCount = aws.Int64(1)
   334  				output.Records = append(output.Records, &kinesis.PutRecordsResultEntry{
   335  					ErrorCode: aws.String(kinesis.ErrCodeProvisionedThroughputExceededException),
   336  				})
   337  				return &output, nil
   338  			},
   339  		},
   340  		mThrottled:       mThrottled,
   341  		mThrottledF:      mThrottledF,
   342  		mPartsThrottled:  mPartsThrottled,
   343  		mPartsThrottledF: mPartsThrottledF,
   344  		log:              log.Noop(),
   345  	}
   346  
   347  	k.partitionKey, _ = bloblang.GlobalEnvironment().NewField("${!json(\"id\")}")
   348  	k.hashKey, _ = bloblang.GlobalEnvironment().NewField("")
   349  
   350  	msg := message.New(nil)
   351  	msg.Append(message.NewPart([]byte(`{"foo":"bar","id":123}`)))
   352  
   353  	if err := k.Write(msg); err == nil {
   354  		t.Error(errors.New("expected kinesis.Write to error"))
   355  	}
   356  	if exp := 3; calls != exp {
   357  		t.Errorf("Expected kinesis.PutRecords to have call count %d, got %d", exp, calls)
   358  	}
   359  }