github.com/crowdsecurity/crowdsec@v1.6.1/pkg/acquisition/modules/s3/s3_test.go (about) 1 package s3acquisition 2 3 import ( 4 "context" 5 "fmt" 6 "strings" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/aws/aws-sdk-go/aws" 12 "github.com/aws/aws-sdk-go/aws/request" 13 "github.com/aws/aws-sdk-go/service/s3" 14 "github.com/aws/aws-sdk-go/service/s3/s3iface" 15 "github.com/aws/aws-sdk-go/service/sqs" 16 "github.com/aws/aws-sdk-go/service/sqs/sqsiface" 17 "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" 18 "github.com/crowdsecurity/crowdsec/pkg/types" 19 log "github.com/sirupsen/logrus" 20 "github.com/stretchr/testify/assert" 21 "gopkg.in/tomb.v2" 22 ) 23 24 func TestBadConfiguration(t *testing.T) { 25 tests := []struct { 26 name string 27 config string 28 expectedErr string 29 }{ 30 { 31 name: "no bucket", 32 config: ` 33 source: s3 34 `, 35 expectedErr: "bucket_name is required", 36 }, 37 { 38 name: "invalid polling method", 39 config: ` 40 source: s3 41 bucket_name: foobar 42 polling_method: foobar 43 `, 44 expectedErr: "invalid polling method foobar", 45 }, 46 { 47 name: "no sqs name", 48 config: ` 49 source: s3 50 bucket_name: foobar 51 polling_method: sqs 52 `, 53 expectedErr: "sqs_name is required when using sqs polling method", 54 }, 55 { 56 name: "both bucket and sqs", 57 config: ` 58 source: s3 59 bucket_name: foobar 60 polling_method: sqs 61 sqs_name: foobar 62 `, 63 expectedErr: "bucket_name and sqs_name are mutually exclusive", 64 }, 65 } 66 67 for _, test := range tests { 68 t.Run(test.name, func(t *testing.T) { 69 f := S3Source{} 70 err := f.Configure([]byte(test.config), nil, configuration.METRICS_NONE) 71 if err == nil { 72 t.Fatalf("expected error, got none") 73 } 74 if err.Error() != test.expectedErr { 75 t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error()) 76 } 77 }) 78 } 79 } 80 81 func TestGoodConfiguration(t *testing.T) { 82 tests := []struct { 83 name string 84 config string 85 }{ 86 { 87 name: "basic", 88 config: ` 89 source: s3 90 bucket_name: foobar 91 `, 92 }, 93 { 94 name: "polling method", 95 config: ` 96 source: s3 97 polling_method: sqs 98 sqs_name: foobar 99 `, 100 }, 101 { 102 name: "list method", 103 config: ` 104 source: s3 105 bucket_name: foobar 106 polling_method: list 107 `, 108 }, 109 } 110 111 for _, test := range tests { 112 t.Run(test.name, func(t *testing.T) { 113 f := S3Source{} 114 logger := log.NewEntry(log.New()) 115 err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) 116 if err != nil { 117 t.Fatalf("unexpected error: %s", err.Error()) 118 } 119 }) 120 } 121 } 122 123 type mockS3Client struct { 124 s3iface.S3API 125 } 126 127 // We add one hour to trick the listing goroutine into thinking the files are new 128 var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{ 129 "bucket_no_prefix": { 130 { 131 Key: aws.String("foo.log"), 132 LastModified: aws.Time(time.Now().Add(time.Hour)), 133 }, 134 }, 135 "bucket_with_prefix": { 136 { 137 Key: aws.String("prefix/foo.log"), 138 LastModified: aws.Time(time.Now().Add(time.Hour)), 139 }, 140 { 141 Key: aws.String("prefix/bar.log"), 142 LastModified: aws.Time(time.Now().Add(time.Hour)), 143 }, 144 }, 145 } 146 147 func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) { 148 log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket]) 149 return &s3.ListObjectsV2Output{ 150 Contents: mockListOutput[*input.Bucket], 151 }, nil 152 } 153 154 func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) { 155 r := strings.NewReader("foo\nbar") 156 return &s3.GetObjectOutput{ 157 Body: aws.ReadSeekCloser(r), 158 }, nil 159 } 160 161 type mockSQSClient struct { 162 sqsiface.SQSAPI 163 counter *int32 164 } 165 166 func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { 167 if atomic.LoadInt32(msqs.counter) == 1 { 168 return &sqs.ReceiveMessageOutput{}, nil 169 } 170 atomic.AddInt32(msqs.counter, 1) 171 return &sqs.ReceiveMessageOutput{ 172 Messages: []*sqs.Message{ 173 { 174 Body: aws.String(` 175 {"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`), 176 }, 177 }, 178 }, nil 179 } 180 181 func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { 182 return &sqs.DeleteMessageOutput{}, nil 183 } 184 185 type mockSQSClientNotif struct { 186 sqsiface.SQSAPI 187 counter *int32 188 } 189 190 func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { 191 if atomic.LoadInt32(msqs.counter) == 1 { 192 return &sqs.ReceiveMessageOutput{}, nil 193 } 194 atomic.AddInt32(msqs.counter, 1) 195 return &sqs.ReceiveMessageOutput{ 196 Messages: []*sqs.Message{ 197 { 198 Body: aws.String(` 199 {"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`), 200 }, 201 }, 202 }, nil 203 } 204 205 func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { 206 return &sqs.DeleteMessageOutput{}, nil 207 } 208 209 func TestDSNAcquis(t *testing.T) { 210 tests := []struct { 211 name string 212 dsn string 213 expectedBucketName string 214 expectedPrefix string 215 expectedCount int 216 }{ 217 { 218 name: "basic", 219 dsn: "s3://bucket_no_prefix/foo.log", 220 expectedBucketName: "bucket_no_prefix", 221 expectedPrefix: "", 222 expectedCount: 2, 223 }, 224 { 225 name: "with prefix", 226 dsn: "s3://bucket_with_prefix/prefix/", 227 expectedBucketName: "bucket_with_prefix", 228 expectedPrefix: "prefix/", 229 expectedCount: 4, 230 }, 231 } 232 233 for _, test := range tests { 234 t.Run(test.name, func(t *testing.T) { 235 linesRead := 0 236 f := S3Source{} 237 logger := log.NewEntry(log.New()) 238 err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger, "") 239 if err != nil { 240 t.Fatalf("unexpected error: %s", err.Error()) 241 } 242 assert.Equal(t, test.expectedBucketName, f.Config.BucketName) 243 assert.Equal(t, test.expectedPrefix, f.Config.Prefix) 244 out := make(chan types.Event) 245 246 done := make(chan bool) 247 248 go func() { 249 for { 250 select { 251 case s := <-out: 252 fmt.Printf("got line %s\n", s.Line.Raw) 253 linesRead++ 254 case <-done: 255 return 256 } 257 } 258 }() 259 260 f.s3Client = mockS3Client{} 261 tmb := tomb.Tomb{} 262 err = f.OneShotAcquisition(out, &tmb) 263 if err != nil { 264 t.Fatalf("unexpected error: %s", err.Error()) 265 } 266 time.Sleep(2 * time.Second) 267 done <- true 268 assert.Equal(t, test.expectedCount, linesRead) 269 270 }) 271 } 272 273 } 274 275 func TestListPolling(t *testing.T) { 276 tests := []struct { 277 name string 278 config string 279 expectedCount int 280 }{ 281 { 282 name: "basic", 283 config: ` 284 source: s3 285 bucket_name: bucket_no_prefix 286 polling_method: list 287 polling_interval: 1 288 `, 289 expectedCount: 2, 290 }, 291 { 292 name: "with prefix", 293 config: ` 294 source: s3 295 bucket_name: bucket_with_prefix 296 polling_method: list 297 polling_interval: 1 298 prefix: foo/ 299 `, 300 expectedCount: 4, 301 }, 302 } 303 304 for _, test := range tests { 305 t.Run(test.name, func(t *testing.T) { 306 linesRead := 0 307 f := S3Source{} 308 logger := log.NewEntry(log.New()) 309 logger.Logger.SetLevel(log.TraceLevel) 310 err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) 311 if err != nil { 312 t.Fatalf("unexpected error: %s", err.Error()) 313 } 314 if f.Config.PollingMethod != PollMethodList { 315 t.Fatalf("expected list polling, got %s", f.Config.PollingMethod) 316 } 317 318 f.s3Client = mockS3Client{} 319 320 out := make(chan types.Event) 321 tb := tomb.Tomb{} 322 323 go func() { 324 for { 325 select { 326 case s := <-out: 327 fmt.Printf("got line %s\n", s.Line.Raw) 328 linesRead++ 329 case <-tb.Dying(): 330 return 331 } 332 } 333 }() 334 335 err = f.StreamingAcquisition(out, &tb) 336 337 if err != nil { 338 t.Fatalf("unexpected error: %s", err.Error()) 339 } 340 341 time.Sleep(2 * time.Second) 342 tb.Kill(nil) 343 err = tb.Wait() 344 if err != nil { 345 t.Fatalf("unexpected error: %s", err.Error()) 346 } 347 assert.Equal(t, test.expectedCount, linesRead) 348 }) 349 } 350 } 351 352 func TestSQSPoll(t *testing.T) { 353 tests := []struct { 354 name string 355 config string 356 notifType string 357 expectedCount int 358 }{ 359 { 360 name: "eventbridge", 361 config: ` 362 source: s3 363 polling_method: sqs 364 sqs_name: test 365 `, 366 expectedCount: 2, 367 notifType: "eventbridge", 368 }, 369 { 370 name: "notification", 371 config: ` 372 source: s3 373 polling_method: sqs 374 sqs_name: test 375 `, 376 expectedCount: 2, 377 notifType: "notification", 378 }, 379 } 380 for _, test := range tests { 381 t.Run(test.name, func(t *testing.T) { 382 linesRead := 0 383 f := S3Source{} 384 logger := log.NewEntry(log.New()) 385 err := f.Configure([]byte(test.config), logger, configuration.METRICS_NONE) 386 if err != nil { 387 t.Fatalf("unexpected error: %s", err.Error()) 388 } 389 if f.Config.PollingMethod != PollMethodSQS { 390 t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod) 391 } 392 393 counter := int32(0) 394 f.s3Client = mockS3Client{} 395 if test.notifType == "eventbridge" { 396 f.sqsClient = mockSQSClient{counter: &counter} 397 } else { 398 f.sqsClient = mockSQSClientNotif{counter: &counter} 399 } 400 401 out := make(chan types.Event) 402 tb := tomb.Tomb{} 403 404 go func() { 405 for { 406 select { 407 case s := <-out: 408 fmt.Printf("got line %s\n", s.Line.Raw) 409 linesRead++ 410 case <-tb.Dying(): 411 return 412 } 413 } 414 }() 415 416 err = f.StreamingAcquisition(out, &tb) 417 418 if err != nil { 419 t.Fatalf("unexpected error: %s", err.Error()) 420 } 421 422 time.Sleep(2 * time.Second) 423 tb.Kill(nil) 424 err = tb.Wait() 425 if err != nil { 426 t.Fatalf("unexpected error: %s", err.Error()) 427 } 428 assert.Equal(t, test.expectedCount, linesRead) 429 }) 430 } 431 }