github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/internal/runner/sqs.go (about)

     1  // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License.
     2  
     3  package runner
     4  
     5  // This file contains the implementation of AWS SQS message queues
     6  // as they are used by studioML
     7  
     8  import (
     9  	"context"
    10  	"flag"
    11  	"fmt"
    12  	"net/url"
    13  	"regexp"
    14  	"runtime/debug"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/aws/aws-sdk-go/aws"
    19  	"github.com/aws/aws-sdk-go/aws/session"
    20  	"github.com/aws/aws-sdk-go/service/sqs"
    21  
    22  	runnerReports "github.com/leaf-ai/studio-go-runner/internal/gen/dev.cognizant_dev.ai/genproto/studio-go-runner/reports/v1"
    23  
    24  	"github.com/go-stack/stack"
    25  	"github.com/jjeffery/kv" // MIT License
    26  )
    27  
    28  var (
    29  	sqsTimeoutOpt = flag.Duration("sqs-timeout", time.Duration(15*time.Second), "the period of time for discrete SQS operations to use for timeouts")
    30  )
    31  
    32  // SQS encapsulates an AWS based SQS queue and associated it with a project
    33  //
    34  type SQS struct {
    35  	project string   // Fully qualified SQS queue reference
    36  	creds   *AWSCred // AWS credentials for access the queue
    37  	wrapper *Wrapper // Decryption information for messages with encrypted payloads
    38  }
    39  
    40  // NewSQS creates an SQS data structure using set set of credentials (creds) for
    41  // an sqs queue (sqs)
    42  //
    43  func NewSQS(project string, creds string, wrapper *Wrapper) (sqs *SQS, err kv.Error) {
    44  	// Use the creds directory to locate all of the credentials for AWS within
    45  	// a hierarchy of directories
    46  
    47  	awsCreds, err := AWSExtractCreds(strings.Split(creds, ","))
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	return &SQS{
    53  		project: project,
    54  		creds:   awsCreds,
    55  		wrapper: wrapper,
    56  	}, nil
    57  }
    58  
    59  // GetSQSProjects can be used to get a list of the SQS servers and the main URLs that are accessible to them
    60  func GetSQSProjects(credFiles []string) (urls map[string]struct{}, err kv.Error) {
    61  
    62  	sqs, err := NewSQS("aws_probe", strings.Join(credFiles, ","), nil)
    63  	if err != nil {
    64  		return urls, err
    65  	}
    66  	found, err := sqs.refresh(nil, nil)
    67  	if err != nil {
    68  		return urls, kv.Wrap(err, "failed to refresh sqs").With("stack", stack.Trace().TrimRuntime())
    69  	}
    70  
    71  	urls = make(map[string]struct{}, len(found))
    72  	for _, urlStr := range found {
    73  		qURL, err := url.Parse(urlStr)
    74  		if err != nil {
    75  			continue
    76  		}
    77  		segments := strings.Split(qURL.Path, "/")
    78  		qURL.Path = strings.Join(segments[:len(segments)-1], "/")
    79  		urls[qURL.String()] = struct{}{}
    80  	}
    81  
    82  	return urls, nil
    83  }
    84  
    85  func (sq *SQS) listQueues(qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (queues *sqs.ListQueuesOutput, err kv.Error) {
    86  
    87  	sess, errGo := session.NewSessionWithOptions(session.Options{
    88  		Config: aws.Config{
    89  			Region:                        aws.String(sq.creds.Region),
    90  			Credentials:                   sq.creds.Creds,
    91  			CredentialsChainVerboseErrors: aws.Bool(true),
    92  		},
    93  		Profile: "default",
    94  	})
    95  
    96  	if errGo != nil {
    97  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds)
    98  	}
    99  
   100  	// Create a SQS service client.
   101  	svc := sqs.New(sess)
   102  
   103  	ctx, cancel := context.WithTimeout(context.Background(), *sqsTimeoutOpt)
   104  	defer cancel()
   105  
   106  	listParam := &sqs.ListQueuesInput{}
   107  
   108  	qs, errGo := svc.ListQueuesWithContext(ctx, listParam)
   109  	if errGo != nil {
   110  		return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds)
   111  	}
   112  
   113  	queues = &sqs.ListQueuesOutput{
   114  		QueueUrls: []*string{},
   115  	}
   116  
   117  	for _, qURL := range qs.QueueUrls {
   118  		if qURL == nil {
   119  			continue
   120  		}
   121  		fullURL, errGo := url.Parse(*qURL)
   122  		if errGo != nil {
   123  			return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds)
   124  		}
   125  		paths := strings.Split(fullURL.Path, "/")
   126  		if qNameMismatch != nil {
   127  			if qNameMismatch.MatchString(paths[len(paths)-1]) {
   128  				fmt.Println("dropped", paths[len(paths)-1], qNameMismatch.String())
   129  				continue
   130  			}
   131  		}
   132  		if qNameMatch != nil {
   133  			if !qNameMatch.MatchString(paths[len(paths)-1]) {
   134  				fmt.Println("ignored", paths[len(paths)-1], qNameMatch.String())
   135  				continue
   136  			}
   137  		}
   138  		queues.QueueUrls = append(queues.QueueUrls, qURL)
   139  	}
   140  	return queues, nil
   141  }
   142  
   143  func (sq *SQS) refresh(qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (known []string, err kv.Error) {
   144  
   145  	known = []string{}
   146  
   147  	result, err := sq.listQueues(qNameMatch, qNameMismatch)
   148  	if err != nil {
   149  		return known, err
   150  	}
   151  
   152  	// As these are pointers, printing them out directly would not be useful.
   153  	for _, url := range result.QueueUrls {
   154  		// Avoid dereferencing a nil pointer.
   155  		if url == nil {
   156  			continue
   157  		}
   158  		known = append(known, *url)
   159  	}
   160  	return known, nil
   161  }
   162  
   163  // Refresh uses a regular expression to obtain matching queues from
   164  // the configured SQS server on AWS (sqs).
   165  //
   166  func (sq *SQS) Refresh(ctx context.Context, qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (known map[string]interface{}, err kv.Error) {
   167  
   168  	found, err := sq.refresh(qNameMatch, qNameMismatch)
   169  	if err != nil {
   170  		return known, err
   171  	}
   172  
   173  	known = make(map[string]interface{}, len(found))
   174  	for _, urlStr := range found {
   175  		qURL, err := url.Parse(urlStr)
   176  		if err != nil {
   177  			continue
   178  		}
   179  		segments := strings.Split(qURL.Path, "/")
   180  		known[sq.creds.Region+":"+segments[len(segments)-1]] = sq.creds
   181  	}
   182  
   183  	return known, nil
   184  }
   185  
   186  // Exists tests for the presence of a subscription, typically a queue name
   187  // on the configured sqs server.
   188  //
   189  func (sq *SQS) Exists(ctx context.Context, subscription string) (exists bool, err kv.Error) {
   190  
   191  	queues, err := sq.listQueues(nil, nil)
   192  	if err != nil {
   193  		return true, err
   194  	}
   195  
   196  	for _, q := range queues.QueueUrls {
   197  		if q != nil {
   198  			if strings.HasSuffix(subscription, *q) {
   199  				return true, nil
   200  			}
   201  		}
   202  	}
   203  	return false, nil
   204  }
   205  
   206  // Work is invoked by the queue handling software within the runner to get the
   207  // specific queue implementation to process potential work that could be
   208  // waiting inside the queue.
   209  func (sq *SQS) Work(ctx context.Context, qt *QueueTask) (msgProcessed bool, resource *Resource, err kv.Error) {
   210  
   211  	regionUrl := strings.SplitN(qt.Subscription, ":", 2)
   212  	url := sq.project + "/" + regionUrl[1]
   213  
   214  	sess, errGo := session.NewSessionWithOptions(session.Options{
   215  		Config: aws.Config{
   216  			Region:                        aws.String(sq.creds.Region),
   217  			Credentials:                   sq.creds.Creds,
   218  			CredentialsChainVerboseErrors: aws.Bool(true),
   219  		},
   220  		Profile: "default",
   221  	})
   222  
   223  	if errGo != nil {
   224  		return false, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds)
   225  	}
   226  
   227  	// Create a SQS service client.
   228  	svc := sqs.New(sess)
   229  
   230  	defer func() {
   231  		defer func() {
   232  			if r := recover(); r != nil {
   233  				fmt.Printf("panic in producer %#v, %s\n", r, string(debug.Stack()))
   234  			}
   235  		}()
   236  	}()
   237  
   238  	visTimeout := int64(30)
   239  	waitTimeout := int64(5)
   240  	msgs, errGo := svc.ReceiveMessageWithContext(ctx,
   241  		&sqs.ReceiveMessageInput{
   242  			QueueUrl:          &url,
   243  			VisibilityTimeout: &visTimeout,
   244  			WaitTimeSeconds:   &waitTimeout,
   245  		})
   246  	if errGo != nil {
   247  		return false, nil, kv.Wrap(errGo).With("credentials", sq.creds, "url", url).With("stack", stack.Trace().TrimRuntime())
   248  	}
   249  	if len(msgs.Messages) == 0 {
   250  		return false, nil, nil
   251  	}
   252  
   253  	// Make sure that the main ctx has not been Done with before continuing
   254  	select {
   255  	case <-ctx.Done():
   256  		return false, nil, kv.NewError("queue worker cancel received").With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds)
   257  	default:
   258  	}
   259  
   260  	// Start a visbility timeout extender that runs until the work is done
   261  	// Changing the timeout restarts the timer on the SQS side, for more information
   262  	// see http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html
   263  	//
   264  	quitC := make(chan struct{})
   265  	go func() {
   266  		timeout := time.Duration(int(visTimeout / 2))
   267  		for {
   268  			select {
   269  			case <-time.After(timeout * time.Second):
   270  				if _, err := svc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
   271  					QueueUrl:          &url,
   272  					ReceiptHandle:     msgs.Messages[0].ReceiptHandle,
   273  					VisibilityTimeout: &visTimeout,
   274  				}); err != nil {
   275  					// Once the 1/2 way mark is reached continue to try to change the
   276  					// visibility at decreasing intervals until we finish the job
   277  					if timeout.Seconds() > 5.0 {
   278  						timeout = time.Duration(timeout / 2)
   279  					}
   280  				}
   281  			case <-quitC:
   282  				return
   283  			}
   284  		}
   285  	}()
   286  
   287  	qt.Msg = nil
   288  	qt.Msg = []byte(*msgs.Messages[0].Body)
   289  
   290  	items := strings.Split(url, "/")
   291  	qt.ShortQName = items[len(items)-1]
   292  
   293  	rsc, ack, err := qt.Handler(ctx, qt)
   294  	close(quitC)
   295  
   296  	if ack {
   297  		// Delete the message
   298  		svc.DeleteMessage(&sqs.DeleteMessageInput{
   299  			QueueUrl:      &url,
   300  			ReceiptHandle: msgs.Messages[0].ReceiptHandle,
   301  		})
   302  		resource = rsc
   303  	} else {
   304  		// Set visibility timeout to 0, in otherwords Nack the message
   305  		visTimeout = 0
   306  		svc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
   307  			QueueUrl:          &url,
   308  			ReceiptHandle:     msgs.Messages[0].ReceiptHandle,
   309  			VisibilityTimeout: &visTimeout,
   310  		})
   311  	}
   312  
   313  	return true, resource, err
   314  }
   315  
   316  // HasWork will look at the SQS queue to see if there is any pending work.  The function
   317  // is called in an attempt to see if there is any point in processing new work without a
   318  // lot of overhead.  In the case of SQS at the moment we always assume there is work.
   319  //
   320  func (sq *SQS) HasWork(ctx context.Context, subscription string) (hasWork bool, err kv.Error) {
   321  	return true, nil
   322  }
   323  
   324  // Responder is used to open a connection to an existing response queue if
   325  // one was made available and also to provision a channel into which the
   326  // runner can place report messages
   327  func (sq *SQS) Responder(ctx context.Context, subscription string) (sender chan *runnerReports.Report, err kv.Error) {
   328  	sender = make(chan *runnerReports.Report, 1)
   329  	// Open the queue and if this cannot be done exit with the error
   330  	go func() {
   331  		for {
   332  			select {
   333  			case <-sender:
   334  				continue
   335  			case <-ctx.Done():
   336  				return
   337  			}
   338  		}
   339  	}()
   340  	return sender, err
   341  }