github.com/crowdsecurity/crowdsec@v1.6.1/pkg/acquisition/modules/kinesis/kinesis_test.go (about)

     1  package kinesisacquisition
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"runtime"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/crowdsecurity/go-cs-lib/cstest"
    16  
    17  	"github.com/aws/aws-sdk-go/aws"
    18  	"github.com/aws/aws-sdk-go/aws/session"
    19  	"github.com/aws/aws-sdk-go/service/kinesis"
    20  	"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
    21  	"github.com/crowdsecurity/crowdsec/pkg/types"
    22  	log "github.com/sirupsen/logrus"
    23  	"github.com/stretchr/testify/assert"
    24  	"gopkg.in/tomb.v2"
    25  )
    26  
    27  func getLocalStackEndpoint() (string, error) {
    28  	endpoint := "http://localhost:4566"
    29  	if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" {
    30  		v = strings.TrimPrefix(v, "http://")
    31  		_, err := net.Dial("tcp", v)
    32  		if err != nil {
    33  			return "", fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err)
    34  		}
    35  	}
    36  	return endpoint, nil
    37  }
    38  
    39  func GenSubObject(i int) []byte {
    40  	r := CloudWatchSubscriptionRecord{
    41  		MessageType:         "subscription",
    42  		Owner:               "test",
    43  		LogGroup:            "test",
    44  		LogStream:           "test",
    45  		SubscriptionFilters: []string{"filter1"},
    46  		LogEvents: []CloudwatchSubscriptionLogEvent{
    47  			{
    48  				ID:        "testid",
    49  				Message:   fmt.Sprintf("%d", i),
    50  				Timestamp: time.Now().UTC().Unix(),
    51  			},
    52  		},
    53  	}
    54  	body, err := json.Marshal(r)
    55  	if err != nil {
    56  		log.Fatal(err)
    57  	}
    58  	var b bytes.Buffer
    59  	gz := gzip.NewWriter(&b)
    60  	gz.Write(body)
    61  	gz.Close()
    62  	//AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point
    63  	//localstack does not do it, so let's just write a raw gzipped stream
    64  	return b.Bytes()
    65  }
    66  
    67  func WriteToStream(streamName string, count int, shards int, sub bool) {
    68  	endpoint, err := getLocalStackEndpoint()
    69  	if err != nil {
    70  		log.Fatal(err)
    71  	}
    72  	sess := session.Must(session.NewSession())
    73  	kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1"))
    74  	for i := 0; i < count; i++ {
    75  		partition := "partition"
    76  		if shards != 1 {
    77  			partition = fmt.Sprintf("partition-%d", i%shards)
    78  		}
    79  		var data []byte
    80  		if sub {
    81  			data = GenSubObject(i)
    82  		} else {
    83  			data = []byte(fmt.Sprintf("%d", i))
    84  		}
    85  		_, err = kinesisClient.PutRecord(&kinesis.PutRecordInput{
    86  			Data:         data,
    87  			PartitionKey: aws.String(partition),
    88  			StreamName:   aws.String(streamName),
    89  		})
    90  		if err != nil {
    91  			fmt.Printf("Error writing to stream: %s\n", err)
    92  			log.Fatal(err)
    93  		}
    94  	}
    95  }
    96  
    97  func TestMain(m *testing.M) {
    98  	os.Setenv("AWS_ACCESS_KEY_ID", "foobar")
    99  	os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar")
   100  
   101  	//delete_streams()
   102  	//create_streams()
   103  	code := m.Run()
   104  	//delete_streams()
   105  	os.Exit(code)
   106  }
   107  
   108  func TestBadConfiguration(t *testing.T) {
   109  	if runtime.GOOS == "windows" {
   110  		t.Skip("Skipping test on windows")
   111  	}
   112  	tests := []struct {
   113  		config      string
   114  		expectedErr string
   115  	}{
   116  		{
   117  			config:      `source: kinesis`,
   118  			expectedErr: "stream_name is mandatory when use_enhanced_fanout is false",
   119  		},
   120  		{
   121  			config: `
   122  source: kinesis
   123  use_enhanced_fanout: true`,
   124  			expectedErr: "stream_arn is mandatory when use_enhanced_fanout is true",
   125  		},
   126  		{
   127  			config: `
   128  source: kinesis
   129  use_enhanced_fanout: true
   130  stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`,
   131  			expectedErr: "consumer_name is mandatory when use_enhanced_fanout is true",
   132  		},
   133  		{
   134  			config: `
   135  source: kinesis
   136  stream_name: foobar
   137  stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`,
   138  			expectedErr: "stream_arn and stream_name are mutually exclusive",
   139  		},
   140  	}
   141  
   142  	subLogger := log.WithFields(log.Fields{
   143  		"type": "kinesis",
   144  	})
   145  	for _, test := range tests {
   146  		f := KinesisSource{}
   147  		err := f.Configure([]byte(test.config), subLogger, configuration.METRICS_NONE)
   148  		cstest.AssertErrorContains(t, err, test.expectedErr)
   149  	}
   150  }
   151  
   152  func TestReadFromStream(t *testing.T) {
   153  	if runtime.GOOS == "windows" {
   154  		t.Skip("Skipping test on windows")
   155  	}
   156  	tests := []struct {
   157  		config string
   158  		count  int
   159  		shards int
   160  	}{
   161  		{
   162  			config: `source: kinesis
   163  aws_endpoint: %s
   164  aws_region: us-east-1
   165  stream_name: stream-1-shard`,
   166  			count:  10,
   167  			shards: 1,
   168  		},
   169  	}
   170  	endpoint, _ := getLocalStackEndpoint()
   171  	for _, test := range tests {
   172  		f := KinesisSource{}
   173  		config := fmt.Sprintf(test.config, endpoint)
   174  		err := f.Configure([]byte(config), log.WithFields(log.Fields{
   175  			"type": "kinesis",
   176  		}), configuration.METRICS_NONE)
   177  		if err != nil {
   178  			t.Fatalf("Error configuring source: %s", err)
   179  		}
   180  		tomb := &tomb.Tomb{}
   181  		out := make(chan types.Event)
   182  		err = f.StreamingAcquisition(out, tomb)
   183  		if err != nil {
   184  			t.Fatalf("Error starting source: %s", err)
   185  		}
   186  		//Allow the datasource to start listening to the stream
   187  		time.Sleep(4 * time.Second)
   188  		WriteToStream(f.Config.StreamName, test.count, test.shards, false)
   189  		for i := 0; i < test.count; i++ {
   190  			e := <-out
   191  			assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw)
   192  		}
   193  		tomb.Kill(nil)
   194  		tomb.Wait()
   195  	}
   196  }
   197  
   198  func TestReadFromMultipleShards(t *testing.T) {
   199  	if runtime.GOOS == "windows" {
   200  		t.Skip("Skipping test on windows")
   201  	}
   202  	tests := []struct {
   203  		config string
   204  		count  int
   205  		shards int
   206  	}{
   207  		{
   208  			config: `source: kinesis
   209  aws_endpoint: %s
   210  aws_region: us-east-1
   211  stream_name: stream-2-shards`,
   212  			count:  10,
   213  			shards: 2,
   214  		},
   215  	}
   216  	endpoint, _ := getLocalStackEndpoint()
   217  	for _, test := range tests {
   218  		f := KinesisSource{}
   219  		config := fmt.Sprintf(test.config, endpoint)
   220  		err := f.Configure([]byte(config), log.WithFields(log.Fields{
   221  			"type": "kinesis",
   222  		}), configuration.METRICS_NONE)
   223  		if err != nil {
   224  			t.Fatalf("Error configuring source: %s", err)
   225  		}
   226  		tomb := &tomb.Tomb{}
   227  		out := make(chan types.Event)
   228  		err = f.StreamingAcquisition(out, tomb)
   229  		if err != nil {
   230  			t.Fatalf("Error starting source: %s", err)
   231  		}
   232  		//Allow the datasource to start listening to the stream
   233  		time.Sleep(4 * time.Second)
   234  		WriteToStream(f.Config.StreamName, test.count, test.shards, false)
   235  		c := 0
   236  		for i := 0; i < test.count; i++ {
   237  			<-out
   238  			c += 1
   239  		}
   240  		assert.Equal(t, test.count, c)
   241  		tomb.Kill(nil)
   242  		tomb.Wait()
   243  	}
   244  }
   245  
   246  func TestFromSubscription(t *testing.T) {
   247  	if runtime.GOOS == "windows" {
   248  		t.Skip("Skipping test on windows")
   249  	}
   250  	tests := []struct {
   251  		config string
   252  		count  int
   253  		shards int
   254  	}{
   255  		{
   256  			config: `source: kinesis
   257  aws_endpoint: %s
   258  aws_region: us-east-1
   259  stream_name: stream-1-shard
   260  from_subscription: true`,
   261  			count:  10,
   262  			shards: 1,
   263  		},
   264  	}
   265  	endpoint, _ := getLocalStackEndpoint()
   266  	for _, test := range tests {
   267  		f := KinesisSource{}
   268  		config := fmt.Sprintf(test.config, endpoint)
   269  		err := f.Configure([]byte(config), log.WithFields(log.Fields{
   270  			"type": "kinesis",
   271  		}), configuration.METRICS_NONE)
   272  		if err != nil {
   273  			t.Fatalf("Error configuring source: %s", err)
   274  		}
   275  		tomb := &tomb.Tomb{}
   276  		out := make(chan types.Event)
   277  		err = f.StreamingAcquisition(out, tomb)
   278  		if err != nil {
   279  			t.Fatalf("Error starting source: %s", err)
   280  		}
   281  		//Allow the datasource to start listening to the stream
   282  		time.Sleep(4 * time.Second)
   283  		WriteToStream(f.Config.StreamName, test.count, test.shards, true)
   284  		for i := 0; i < test.count; i++ {
   285  			e := <-out
   286  			assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw)
   287  		}
   288  		tomb.Kill(nil)
   289  		tomb.Wait()
   290  	}
   291  }
   292  
   293  /*
   294  func TestSubscribeToStream(t *testing.T) {
   295  	tests := []struct {
   296  		config string
   297  		count  int
   298  		shards int
   299  	}{
   300  		{
   301  			config: `source: kinesis
   302  aws_endpoint: %s
   303  aws_region: us-east-1
   304  stream_arn: arn:aws:kinesis:us-east-1:000000000000:stream/stream-1-shard
   305  consumer_name: consumer-1
   306  use_enhanced_fanout: true`,
   307  			count:  10,
   308  			shards: 1,
   309  		},
   310  	}
   311  	endpoint, _ := getLocalStackEndpoint()
   312  	for _, test := range tests {
   313  		f := KinesisSource{}
   314  		config := fmt.Sprintf(test.config, endpoint)
   315  		err := f.Configure([]byte(config), log.WithFields(log.Fields{
   316  			"type": "kinesis",
   317  		}))
   318  		if err != nil {
   319  			t.Fatalf("Error configuring source: %s", err)
   320  		}
   321  		tomb := &tomb.Tomb{}
   322  		out := make(chan types.Event)
   323  		err = f.StreamingAcquisition(out, tomb)
   324  		if err != nil {
   325  			t.Fatalf("Error starting source: %s", err)
   326  		}
   327  		//Allow the datasource to start listening to the stream
   328  		time.Sleep(10 * time.Second)
   329  		WriteToStream("stream-1-shard", test.count, test.shards)
   330  		for i := 0; i < test.count; i++ {
   331  			e := <-out
   332  			assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw)
   333  		}
   334  	}
   335  }
   336  */