github.com/argoproj/argo-events@v1.9.1/eventsources/sources/awssqs/start.go (about)

     1  /*
     2  Copyright 2018 BlackRock, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  	http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package awssqs
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"fmt"
    23  	"time"
    24  
    25  	"github.com/aws/aws-sdk-go/aws"
    26  	"github.com/aws/aws-sdk-go/aws/awserr"
    27  	"github.com/aws/aws-sdk-go/aws/session"
    28  	sqslib "github.com/aws/aws-sdk-go/service/sqs"
    29  	"go.uber.org/zap"
    30  
    31  	"github.com/argoproj/argo-events/common/logging"
    32  	eventsourcecommon "github.com/argoproj/argo-events/eventsources/common"
    33  	awscommon "github.com/argoproj/argo-events/eventsources/common/aws"
    34  	"github.com/argoproj/argo-events/eventsources/sources"
    35  	metrics "github.com/argoproj/argo-events/metrics"
    36  	apicommon "github.com/argoproj/argo-events/pkg/apis/common"
    37  	"github.com/argoproj/argo-events/pkg/apis/events"
    38  	"github.com/argoproj/argo-events/pkg/apis/eventsource/v1alpha1"
    39  )
    40  
    41  // EventListener implements Eventing for aws sqs event source
    42  type EventListener struct {
    43  	EventSourceName string
    44  	EventName       string
    45  	SQSEventSource  v1alpha1.SQSEventSource
    46  	Metrics         *metrics.Metrics
    47  }
    48  
    49  // GetEventSourceName returns name of event source
    50  func (el *EventListener) GetEventSourceName() string {
    51  	return el.EventSourceName
    52  }
    53  
    54  // GetEventName returns name of event
    55  func (el *EventListener) GetEventName() string {
    56  	return el.EventName
    57  }
    58  
    59  // GetEventSourceType return type of event server
    60  func (el *EventListener) GetEventSourceType() apicommon.EventSourceType {
    61  	return apicommon.SQSEvent
    62  }
    63  
    64  // StartListening starts listening events
    65  func (el *EventListener) StartListening(ctx context.Context, dispatch func([]byte, ...eventsourcecommon.Option) error) error {
    66  	log := logging.FromContext(ctx).
    67  		With(logging.LabelEventSourceType, el.GetEventSourceType(), logging.LabelEventName, el.GetEventName())
    68  	log.Info("started processing the AWS SQS event source...")
    69  	defer sources.Recover(el.GetEventName())
    70  
    71  	sqsEventSource := &el.SQSEventSource
    72  	sqsClient, err := el.createSqsClient()
    73  	if err != nil {
    74  		return err
    75  	}
    76  
    77  	log.Info("fetching queue url...")
    78  	getQueueURLInput := &sqslib.GetQueueUrlInput{
    79  		QueueName: &sqsEventSource.Queue,
    80  	}
    81  	if sqsEventSource.QueueAccountID != "" {
    82  		getQueueURLInput = getQueueURLInput.SetQueueOwnerAWSAccountId(sqsEventSource.QueueAccountID)
    83  	}
    84  
    85  	queueURL, err := sqsClient.GetQueueUrl(getQueueURLInput)
    86  	if err != nil {
    87  		log.Errorw("Error getting SQS Queue URL", zap.Error(err))
    88  		return fmt.Errorf("failed to get the queue url for %s, %w", el.GetEventName(), err)
    89  	}
    90  
    91  	if sqsEventSource.JSONBody {
    92  		log.Info("assuming all events have a json body...")
    93  	}
    94  
    95  	log.Info("listening for messages on the queue...")
    96  	for {
    97  		select {
    98  		case <-ctx.Done():
    99  			log.Info("exiting SQS event listener...")
   100  			return nil
   101  		default:
   102  		}
   103  		messages, err := fetchMessages(ctx, sqsClient, *queueURL.QueueUrl, 10, sqsEventSource.WaitTimeSeconds)
   104  		if err != nil {
   105  			log.Errorw("failed to get messages from SQS", zap.Error(err))
   106  			awsError, ok := err.(awserr.Error)
   107  			if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
   108  				log.Info("credentials expired, reading credentials again")
   109  				newSqsClient, err := el.createSqsClient()
   110  				if err != nil {
   111  					log.Errorw("Error creating SQS client", zap.Error(err))
   112  				} else if newSqsClient != nil {
   113  					sqsClient = newSqsClient
   114  				}
   115  			}
   116  
   117  			time.Sleep(2 * time.Second)
   118  			continue
   119  		}
   120  		for _, m := range messages {
   121  			el.processMessage(m, dispatch, func() {
   122  				_, err = sqsClient.DeleteMessage(&sqslib.DeleteMessageInput{
   123  					QueueUrl:      queueURL.QueueUrl,
   124  					ReceiptHandle: m.ReceiptHandle,
   125  				})
   126  				if err != nil {
   127  					log.Errorw("Failed to delete message", zap.Error(err))
   128  					awsError, ok := err.(awserr.Error)
   129  					if ok && awsError.Code() == "ExpiredToken" && el.SQSEventSource.SessionToken != nil {
   130  						log.Info("credentials expired, reading credentials again")
   131  						newSqsClient, err := el.createSqsClient()
   132  						if err != nil {
   133  							log.Errorw("Error creating SQS client", zap.Error(err))
   134  						} else if newSqsClient != nil {
   135  							sqsClient = newSqsClient
   136  						}
   137  					}
   138  				}
   139  			}, log)
   140  		}
   141  	}
   142  }
   143  
   144  func (el *EventListener) processMessage(message *sqslib.Message, dispatch func([]byte, ...eventsourcecommon.Option) error, ack func(), log *zap.SugaredLogger) {
   145  	defer func(start time.Time) {
   146  		el.Metrics.EventProcessingDuration(el.GetEventSourceName(), el.GetEventName(), float64(time.Since(start)/time.Millisecond))
   147  	}(time.Now())
   148  
   149  	data := &events.SQSEventData{
   150  		MessageId:         *message.MessageId,
   151  		MessageAttributes: message.MessageAttributes,
   152  		Metadata:          el.SQSEventSource.Metadata,
   153  	}
   154  	if el.SQSEventSource.JSONBody {
   155  		body := []byte(*message.Body)
   156  		data.Body = (*json.RawMessage)(&body)
   157  	} else {
   158  		data.Body = []byte(*message.Body)
   159  	}
   160  	eventBytes, err := json.Marshal(data)
   161  	if err != nil {
   162  		log.Errorw("failed to marshal event data, will process next message...", zap.Error(err))
   163  		el.Metrics.EventProcessingFailed(el.GetEventSourceName(), el.GetEventName())
   164  		// Don't ack if a DLQ is configured to allow to forward the message to the DLQ
   165  		if !el.SQSEventSource.DLQ {
   166  			ack()
   167  		}
   168  		return
   169  	}
   170  	if err = dispatch(eventBytes); err != nil {
   171  		log.Errorw("failed to dispatch SQS event", zap.Error(err))
   172  		el.Metrics.EventProcessingFailed(el.GetEventSourceName(), el.GetEventName())
   173  	} else {
   174  		ack()
   175  	}
   176  }
   177  
   178  func fetchMessages(ctx context.Context, q *sqslib.SQS, url string, maxSize, waitSeconds int64) ([]*sqslib.Message, error) {
   179  	if waitSeconds == 0 {
   180  		// Defaults to 3 seconds
   181  		waitSeconds = 3
   182  	}
   183  	result, err := q.ReceiveMessageWithContext(ctx, &sqslib.ReceiveMessageInput{
   184  		AttributeNames: []*string{
   185  			aws.String(sqslib.MessageSystemAttributeNameSentTimestamp),
   186  		},
   187  		MessageAttributeNames: []*string{
   188  			aws.String(sqslib.QueueAttributeNameAll),
   189  		},
   190  		QueueUrl:            &url,
   191  		MaxNumberOfMessages: aws.Int64(maxSize),
   192  		VisibilityTimeout:   aws.Int64(120), // 120 seconds
   193  		WaitTimeSeconds:     aws.Int64(waitSeconds),
   194  	})
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	return result.Messages, nil
   199  }
   200  
   201  func (el *EventListener) createAWSSession() (*session.Session, error) {
   202  	sqsEventSource := &el.SQSEventSource
   203  	awsSession, err := awscommon.CreateAWSSessionWithCredsInVolume(sqsEventSource.Region, sqsEventSource.RoleARN, sqsEventSource.AccessKey, sqsEventSource.SecretKey, sqsEventSource.SessionToken)
   204  	if err != nil {
   205  		return nil, fmt.Errorf("failed to create aws session for %s, %w", el.GetEventName(), err)
   206  	}
   207  	return awsSession, nil
   208  }
   209  
   210  func (el *EventListener) createSqsClient() (*sqslib.SQS, error) {
   211  	awsSession, err := el.createAWSSession()
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  
   216  	var sqsClient *sqslib.SQS
   217  	if el.SQSEventSource.Endpoint == "" {
   218  		sqsClient = sqslib.New(awsSession)
   219  	} else {
   220  		sqsClient = sqslib.New(awsSession, &aws.Config{Endpoint: &el.SQSEventSource.Endpoint, Region: &el.SQSEventSource.Region})
   221  	}
   222  
   223  	return sqsClient, nil
   224  }