github.com/crowdsecurity/crowdsec@v1.6.1/pkg/acquisition/modules/s3/s3_test.go (about)

     1  package s3acquisition
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/aws/aws-sdk-go/aws"
    12  	"github.com/aws/aws-sdk-go/aws/request"
    13  	"github.com/aws/aws-sdk-go/service/s3"
    14  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    15  	"github.com/aws/aws-sdk-go/service/sqs"
    16  	"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
    17  	"github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration"
    18  	"github.com/crowdsecurity/crowdsec/pkg/types"
    19  	log "github.com/sirupsen/logrus"
    20  	"github.com/stretchr/testify/assert"
    21  	"gopkg.in/tomb.v2"
    22  )
    23  
    24  func TestBadConfiguration(t *testing.T) {
    25  	tests := []struct {
    26  		name        string
    27  		config      string
    28  		expectedErr string
    29  	}{
    30  		{
    31  			name: "no bucket",
    32  			config: `
    33  source: s3
    34  `,
    35  			expectedErr: "bucket_name is required",
    36  		},
    37  		{
    38  			name: "invalid polling method",
    39  			config: `
    40  source: s3
    41  bucket_name: foobar
    42  polling_method: foobar
    43  `,
    44  			expectedErr: "invalid polling method foobar",
    45  		},
    46  		{
    47  			name: "no sqs name",
    48  			config: `
    49  source: s3
    50  bucket_name: foobar
    51  polling_method: sqs
    52  `,
    53  			expectedErr: "sqs_name is required when using sqs polling method",
    54  		},
    55  		{
    56  			name: "both bucket and sqs",
    57  			config: `
    58  source: s3
    59  bucket_name: foobar
    60  polling_method: sqs
    61  sqs_name: foobar
    62  `,
    63  			expectedErr: "bucket_name and sqs_name are mutually exclusive",
    64  		},
    65  	}
    66  
    67  	for _, test := range tests {
    68  		t.Run(test.name, func(t *testing.T) {
    69  			f := S3Source{}
    70  			err := f.Configure([]byte(test.config), nil, configuration.METRICS_NONE)
    71  			if err == nil {
    72  				t.Fatalf("expected error, got none")
    73  			}
    74  			if err.Error() != test.expectedErr {
    75  				t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error())
    76  			}
    77  		})
    78  	}
    79  }
    80  
    81  func TestGoodConfiguration(t *testing.T) {
    82  	tests := []struct {
    83  		name   string
    84  		config string
    85  	}{
    86  		{
    87  			name: "basic",
    88  			config: `
    89  source: s3
    90  bucket_name: foobar
    91  `,
    92  		},
    93  		{
    94  			name: "polling method",
    95  			config: `
    96  source: s3
    97  polling_method: sqs
    98  sqs_name: foobar
    99  `,
   100  		},
   101  		{
   102  			name: "list method",
   103  			config: `
   104  source: s3
   105  bucket_name: foobar
   106  polling_method: list
   107  `,
   108  		},
   109  	}
   110  
   111  	for _, test := range tests {
   112  		t.Run(test.name, func(t *testing.T) {
   113  			f := S3Source{}
   114  			logger := log.NewEntry(log.New())
   115  			err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE)
   116  			if err != nil {
   117  				t.Fatalf("unexpected error: %s", err.Error())
   118  			}
   119  		})
   120  	}
   121  }
   122  
   123  type mockS3Client struct {
   124  	s3iface.S3API
   125  }
   126  
   127  // We add one hour to trick the listing goroutine into thinking the files are new
   128  var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{
   129  	"bucket_no_prefix": {
   130  		{
   131  			Key:          aws.String("foo.log"),
   132  			LastModified: aws.Time(time.Now().Add(time.Hour)),
   133  		},
   134  	},
   135  	"bucket_with_prefix": {
   136  		{
   137  			Key:          aws.String("prefix/foo.log"),
   138  			LastModified: aws.Time(time.Now().Add(time.Hour)),
   139  		},
   140  		{
   141  			Key:          aws.String("prefix/bar.log"),
   142  			LastModified: aws.Time(time.Now().Add(time.Hour)),
   143  		},
   144  	},
   145  }
   146  
   147  func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) {
   148  	log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket])
   149  	return &s3.ListObjectsV2Output{
   150  		Contents: mockListOutput[*input.Bucket],
   151  	}, nil
   152  }
   153  
   154  func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) {
   155  	r := strings.NewReader("foo\nbar")
   156  	return &s3.GetObjectOutput{
   157  		Body: aws.ReadSeekCloser(r),
   158  	}, nil
   159  }
   160  
   161  type mockSQSClient struct {
   162  	sqsiface.SQSAPI
   163  	counter *int32
   164  }
   165  
   166  func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
   167  	if atomic.LoadInt32(msqs.counter) == 1 {
   168  		return &sqs.ReceiveMessageOutput{}, nil
   169  	}
   170  	atomic.AddInt32(msqs.counter, 1)
   171  	return &sqs.ReceiveMessageOutput{
   172  		Messages: []*sqs.Message{
   173  			{
   174  				Body: aws.String(`
   175  {"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`),
   176  			},
   177  		},
   178  	}, nil
   179  }
   180  
   181  func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
   182  	return &sqs.DeleteMessageOutput{}, nil
   183  }
   184  
   185  type mockSQSClientNotif struct {
   186  	sqsiface.SQSAPI
   187  	counter *int32
   188  }
   189  
   190  func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) {
   191  	if atomic.LoadInt32(msqs.counter) == 1 {
   192  		return &sqs.ReceiveMessageOutput{}, nil
   193  	}
   194  	atomic.AddInt32(msqs.counter, 1)
   195  	return &sqs.ReceiveMessageOutput{
   196  		Messages: []*sqs.Message{
   197  			{
   198  				Body: aws.String(`
   199  				{"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`),
   200  			},
   201  		},
   202  	}, nil
   203  }
   204  
   205  func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) {
   206  	return &sqs.DeleteMessageOutput{}, nil
   207  }
   208  
   209  func TestDSNAcquis(t *testing.T) {
   210  	tests := []struct {
   211  		name               string
   212  		dsn                string
   213  		expectedBucketName string
   214  		expectedPrefix     string
   215  		expectedCount      int
   216  	}{
   217  		{
   218  			name:               "basic",
   219  			dsn:                "s3://bucket_no_prefix/foo.log",
   220  			expectedBucketName: "bucket_no_prefix",
   221  			expectedPrefix:     "",
   222  			expectedCount:      2,
   223  		},
   224  		{
   225  			name:               "with prefix",
   226  			dsn:                "s3://bucket_with_prefix/prefix/",
   227  			expectedBucketName: "bucket_with_prefix",
   228  			expectedPrefix:     "prefix/",
   229  			expectedCount:      4,
   230  		},
   231  	}
   232  
   233  	for _, test := range tests {
   234  		t.Run(test.name, func(t *testing.T) {
   235  			linesRead := 0
   236  			f := S3Source{}
   237  			logger := log.NewEntry(log.New())
   238  			err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger, "")
   239  			if err != nil {
   240  				t.Fatalf("unexpected error: %s", err.Error())
   241  			}
   242  			assert.Equal(t, test.expectedBucketName, f.Config.BucketName)
   243  			assert.Equal(t, test.expectedPrefix, f.Config.Prefix)
   244  			out := make(chan types.Event)
   245  
   246  			done := make(chan bool)
   247  
   248  			go func() {
   249  				for {
   250  					select {
   251  					case s := <-out:
   252  						fmt.Printf("got line %s\n", s.Line.Raw)
   253  						linesRead++
   254  					case <-done:
   255  						return
   256  					}
   257  				}
   258  			}()
   259  
   260  			f.s3Client = mockS3Client{}
   261  			tmb := tomb.Tomb{}
   262  			err = f.OneShotAcquisition(out, &tmb)
   263  			if err != nil {
   264  				t.Fatalf("unexpected error: %s", err.Error())
   265  			}
   266  			time.Sleep(2 * time.Second)
   267  			done <- true
   268  			assert.Equal(t, test.expectedCount, linesRead)
   269  
   270  		})
   271  	}
   272  
   273  }
   274  
   275  func TestListPolling(t *testing.T) {
   276  	tests := []struct {
   277  		name          string
   278  		config        string
   279  		expectedCount int
   280  	}{
   281  		{
   282  			name: "basic",
   283  			config: `
   284  source: s3
   285  bucket_name: bucket_no_prefix
   286  polling_method: list
   287  polling_interval: 1
   288  `,
   289  			expectedCount: 2,
   290  		},
   291  		{
   292  			name: "with prefix",
   293  			config: `
   294  source: s3
   295  bucket_name: bucket_with_prefix
   296  polling_method: list
   297  polling_interval: 1
   298  prefix: foo/
   299  `,
   300  			expectedCount: 4,
   301  		},
   302  	}
   303  
   304  	for _, test := range tests {
   305  		t.Run(test.name, func(t *testing.T) {
   306  			linesRead := 0
   307  			f := S3Source{}
   308  			logger := log.NewEntry(log.New())
   309  			logger.Logger.SetLevel(log.TraceLevel)
   310  			err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE)
   311  			if err != nil {
   312  				t.Fatalf("unexpected error: %s", err.Error())
   313  			}
   314  			if f.Config.PollingMethod != PollMethodList {
   315  				t.Fatalf("expected list polling, got %s", f.Config.PollingMethod)
   316  			}
   317  
   318  			f.s3Client = mockS3Client{}
   319  
   320  			out := make(chan types.Event)
   321  			tb := tomb.Tomb{}
   322  
   323  			go func() {
   324  				for {
   325  					select {
   326  					case s := <-out:
   327  						fmt.Printf("got line %s\n", s.Line.Raw)
   328  						linesRead++
   329  					case <-tb.Dying():
   330  						return
   331  					}
   332  				}
   333  			}()
   334  
   335  			err = f.StreamingAcquisition(out, &tb)
   336  
   337  			if err != nil {
   338  				t.Fatalf("unexpected error: %s", err.Error())
   339  			}
   340  
   341  			time.Sleep(2 * time.Second)
   342  			tb.Kill(nil)
   343  			err = tb.Wait()
   344  			if err != nil {
   345  				t.Fatalf("unexpected error: %s", err.Error())
   346  			}
   347  			assert.Equal(t, test.expectedCount, linesRead)
   348  		})
   349  	}
   350  }
   351  
   352  func TestSQSPoll(t *testing.T) {
   353  	tests := []struct {
   354  		name          string
   355  		config        string
   356  		notifType     string
   357  		expectedCount int
   358  	}{
   359  		{
   360  			name: "eventbridge",
   361  			config: `
   362  source: s3
   363  polling_method: sqs
   364  sqs_name: test
   365  `,
   366  			expectedCount: 2,
   367  			notifType:     "eventbridge",
   368  		},
   369  		{
   370  			name: "notification",
   371  			config: `
   372  source: s3
   373  polling_method: sqs
   374  sqs_name: test
   375  `,
   376  			expectedCount: 2,
   377  			notifType:     "notification",
   378  		},
   379  	}
   380  	for _, test := range tests {
   381  		t.Run(test.name, func(t *testing.T) {
   382  			linesRead := 0
   383  			f := S3Source{}
   384  			logger := log.NewEntry(log.New())
   385  			err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE)
   386  			if err != nil {
   387  				t.Fatalf("unexpected error: %s", err.Error())
   388  			}
   389  			if f.Config.PollingMethod != PollMethodSQS {
   390  				t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod)
   391  			}
   392  
   393  			counter := int32(0)
   394  			f.s3Client = mockS3Client{}
   395  			if test.notifType == "eventbridge" {
   396  				f.sqsClient = mockSQSClient{counter: &counter}
   397  			} else {
   398  				f.sqsClient = mockSQSClientNotif{counter: &counter}
   399  			}
   400  
   401  			out := make(chan types.Event)
   402  			tb := tomb.Tomb{}
   403  
   404  			go func() {
   405  				for {
   406  					select {
   407  					case s := <-out:
   408  						fmt.Printf("got line %s\n", s.Line.Raw)
   409  						linesRead++
   410  					case <-tb.Dying():
   411  						return
   412  					}
   413  				}
   414  			}()
   415  
   416  			err = f.StreamingAcquisition(out, &tb)
   417  
   418  			if err != nil {
   419  				t.Fatalf("unexpected error: %s", err.Error())
   420  			}
   421  
   422  			time.Sleep(2 * time.Second)
   423  			tb.Kill(nil)
   424  			err = tb.Wait()
   425  			if err != nil {
   426  				t.Fatalf("unexpected error: %s", err.Error())
   427  			}
   428  			assert.Equal(t, test.expectedCount, linesRead)
   429  		})
   430  	}
   431  }