github.com/sentienttechnologies/studio-go-runner@v0.0.0-20201118202441-6d21f2ced8ee/internal/runner/sqs.go (about) 1 // Copyright 2018-2020 (c) Cognizant Digital Business, Evolutionary AI. All rights reserved. Issued under the Apache 2.0 License. 2 3 package runner 4 5 // This file contains the implementation of AWS SQS message queues 6 // as they are used by studioML 7 8 import ( 9 "context" 10 "flag" 11 "fmt" 12 "net/url" 13 "regexp" 14 "runtime/debug" 15 "strings" 16 "time" 17 18 "github.com/aws/aws-sdk-go/aws" 19 "github.com/aws/aws-sdk-go/aws/session" 20 "github.com/aws/aws-sdk-go/service/sqs" 21 22 runnerReports "github.com/leaf-ai/studio-go-runner/internal/gen/dev.cognizant_dev.ai/genproto/studio-go-runner/reports/v1" 23 24 "github.com/go-stack/stack" 25 "github.com/jjeffery/kv" // MIT License 26 ) 27 28 var ( 29 sqsTimeoutOpt = flag.Duration("sqs-timeout", time.Duration(15*time.Second), "the period of time for discrete SQS operations to use for timeouts") 30 ) 31 32 // SQS encapsulates an AWS based SQS queue and associated it with a project 33 // 34 type SQS struct { 35 project string // Fully qualified SQS queue reference 36 creds *AWSCred // AWS credentials for access the queue 37 wrapper *Wrapper // Decryption information for messages with encrypted payloads 38 } 39 40 // NewSQS creates an SQS data structure using set set of credentials (creds) for 41 // an sqs queue (sqs) 42 // 43 func NewSQS(project string, creds string, wrapper *Wrapper) (sqs *SQS, err kv.Error) { 44 // Use the creds directory to locate all of the credentials for AWS within 45 // a hierarchy of directories 46 47 awsCreds, err := AWSExtractCreds(strings.Split(creds, ",")) 48 if err != nil { 49 return nil, err 50 } 51 52 return &SQS{ 53 project: project, 54 creds: awsCreds, 55 wrapper: wrapper, 56 }, nil 57 } 58 59 // GetSQSProjects can be used to get a list of the SQS servers and the main URLs that are accessible to them 60 func GetSQSProjects(credFiles []string) (urls map[string]struct{}, err kv.Error) { 61 62 sqs, err := NewSQS("aws_probe", strings.Join(credFiles, ","), nil) 63 if err != nil { 64 return urls, err 65 } 66 found, err := sqs.refresh(nil, nil) 67 if err != nil { 68 return urls, kv.Wrap(err, "failed to refresh sqs").With("stack", stack.Trace().TrimRuntime()) 69 } 70 71 urls = make(map[string]struct{}, len(found)) 72 for _, urlStr := range found { 73 qURL, err := url.Parse(urlStr) 74 if err != nil { 75 continue 76 } 77 segments := strings.Split(qURL.Path, "/") 78 qURL.Path = strings.Join(segments[:len(segments)-1], "/") 79 urls[qURL.String()] = struct{}{} 80 } 81 82 return urls, nil 83 } 84 85 func (sq *SQS) listQueues(qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (queues *sqs.ListQueuesOutput, err kv.Error) { 86 87 sess, errGo := session.NewSessionWithOptions(session.Options{ 88 Config: aws.Config{ 89 Region: aws.String(sq.creds.Region), 90 Credentials: sq.creds.Creds, 91 CredentialsChainVerboseErrors: aws.Bool(true), 92 }, 93 Profile: "default", 94 }) 95 96 if errGo != nil { 97 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds) 98 } 99 100 // Create a SQS service client. 101 svc := sqs.New(sess) 102 103 ctx, cancel := context.WithTimeout(context.Background(), *sqsTimeoutOpt) 104 defer cancel() 105 106 listParam := &sqs.ListQueuesInput{} 107 108 qs, errGo := svc.ListQueuesWithContext(ctx, listParam) 109 if errGo != nil { 110 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds) 111 } 112 113 queues = &sqs.ListQueuesOutput{ 114 QueueUrls: []*string{}, 115 } 116 117 for _, qURL := range qs.QueueUrls { 118 if qURL == nil { 119 continue 120 } 121 fullURL, errGo := url.Parse(*qURL) 122 if errGo != nil { 123 return nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds) 124 } 125 paths := strings.Split(fullURL.Path, "/") 126 if qNameMismatch != nil { 127 if qNameMismatch.MatchString(paths[len(paths)-1]) { 128 fmt.Println("dropped", paths[len(paths)-1], qNameMismatch.String()) 129 continue 130 } 131 } 132 if qNameMatch != nil { 133 if !qNameMatch.MatchString(paths[len(paths)-1]) { 134 fmt.Println("ignored", paths[len(paths)-1], qNameMatch.String()) 135 continue 136 } 137 } 138 queues.QueueUrls = append(queues.QueueUrls, qURL) 139 } 140 return queues, nil 141 } 142 143 func (sq *SQS) refresh(qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (known []string, err kv.Error) { 144 145 known = []string{} 146 147 result, err := sq.listQueues(qNameMatch, qNameMismatch) 148 if err != nil { 149 return known, err 150 } 151 152 // As these are pointers, printing them out directly would not be useful. 153 for _, url := range result.QueueUrls { 154 // Avoid dereferencing a nil pointer. 155 if url == nil { 156 continue 157 } 158 known = append(known, *url) 159 } 160 return known, nil 161 } 162 163 // Refresh uses a regular expression to obtain matching queues from 164 // the configured SQS server on AWS (sqs). 165 // 166 func (sq *SQS) Refresh(ctx context.Context, qNameMatch *regexp.Regexp, qNameMismatch *regexp.Regexp) (known map[string]interface{}, err kv.Error) { 167 168 found, err := sq.refresh(qNameMatch, qNameMismatch) 169 if err != nil { 170 return known, err 171 } 172 173 known = make(map[string]interface{}, len(found)) 174 for _, urlStr := range found { 175 qURL, err := url.Parse(urlStr) 176 if err != nil { 177 continue 178 } 179 segments := strings.Split(qURL.Path, "/") 180 known[sq.creds.Region+":"+segments[len(segments)-1]] = sq.creds 181 } 182 183 return known, nil 184 } 185 186 // Exists tests for the presence of a subscription, typically a queue name 187 // on the configured sqs server. 188 // 189 func (sq *SQS) Exists(ctx context.Context, subscription string) (exists bool, err kv.Error) { 190 191 queues, err := sq.listQueues(nil, nil) 192 if err != nil { 193 return true, err 194 } 195 196 for _, q := range queues.QueueUrls { 197 if q != nil { 198 if strings.HasSuffix(subscription, *q) { 199 return true, nil 200 } 201 } 202 } 203 return false, nil 204 } 205 206 // Work is invoked by the queue handling software within the runner to get the 207 // specific queue implementation to process potential work that could be 208 // waiting inside the queue. 209 func (sq *SQS) Work(ctx context.Context, qt *QueueTask) (msgProcessed bool, resource *Resource, err kv.Error) { 210 211 regionUrl := strings.SplitN(qt.Subscription, ":", 2) 212 url := sq.project + "/" + regionUrl[1] 213 214 sess, errGo := session.NewSessionWithOptions(session.Options{ 215 Config: aws.Config{ 216 Region: aws.String(sq.creds.Region), 217 Credentials: sq.creds.Creds, 218 CredentialsChainVerboseErrors: aws.Bool(true), 219 }, 220 Profile: "default", 221 }) 222 223 if errGo != nil { 224 return false, nil, kv.Wrap(errGo).With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds) 225 } 226 227 // Create a SQS service client. 228 svc := sqs.New(sess) 229 230 defer func() { 231 defer func() { 232 if r := recover(); r != nil { 233 fmt.Printf("panic in producer %#v, %s\n", r, string(debug.Stack())) 234 } 235 }() 236 }() 237 238 visTimeout := int64(30) 239 waitTimeout := int64(5) 240 msgs, errGo := svc.ReceiveMessageWithContext(ctx, 241 &sqs.ReceiveMessageInput{ 242 QueueUrl: &url, 243 VisibilityTimeout: &visTimeout, 244 WaitTimeSeconds: &waitTimeout, 245 }) 246 if errGo != nil { 247 return false, nil, kv.Wrap(errGo).With("credentials", sq.creds, "url", url).With("stack", stack.Trace().TrimRuntime()) 248 } 249 if len(msgs.Messages) == 0 { 250 return false, nil, nil 251 } 252 253 // Make sure that the main ctx has not been Done with before continuing 254 select { 255 case <-ctx.Done(): 256 return false, nil, kv.NewError("queue worker cancel received").With("stack", stack.Trace().TrimRuntime()).With("credentials", sq.creds) 257 default: 258 } 259 260 // Start a visbility timeout extender that runs until the work is done 261 // Changing the timeout restarts the timer on the SQS side, for more information 262 // see http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html 263 // 264 quitC := make(chan struct{}) 265 go func() { 266 timeout := time.Duration(int(visTimeout / 2)) 267 for { 268 select { 269 case <-time.After(timeout * time.Second): 270 if _, err := svc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ 271 QueueUrl: &url, 272 ReceiptHandle: msgs.Messages[0].ReceiptHandle, 273 VisibilityTimeout: &visTimeout, 274 }); err != nil { 275 // Once the 1/2 way mark is reached continue to try to change the 276 // visibility at decreasing intervals until we finish the job 277 if timeout.Seconds() > 5.0 { 278 timeout = time.Duration(timeout / 2) 279 } 280 } 281 case <-quitC: 282 return 283 } 284 } 285 }() 286 287 qt.Msg = nil 288 qt.Msg = []byte(*msgs.Messages[0].Body) 289 290 items := strings.Split(url, "/") 291 qt.ShortQName = items[len(items)-1] 292 293 rsc, ack, err := qt.Handler(ctx, qt) 294 close(quitC) 295 296 if ack { 297 // Delete the message 298 svc.DeleteMessage(&sqs.DeleteMessageInput{ 299 QueueUrl: &url, 300 ReceiptHandle: msgs.Messages[0].ReceiptHandle, 301 }) 302 resource = rsc 303 } else { 304 // Set visibility timeout to 0, in otherwords Nack the message 305 visTimeout = 0 306 svc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ 307 QueueUrl: &url, 308 ReceiptHandle: msgs.Messages[0].ReceiptHandle, 309 VisibilityTimeout: &visTimeout, 310 }) 311 } 312 313 return true, resource, err 314 } 315 316 // HasWork will look at the SQS queue to see if there is any pending work. The function 317 // is called in an attempt to see if there is any point in processing new work without a 318 // lot of overhead. In the case of SQS at the moment we always assume there is work. 319 // 320 func (sq *SQS) HasWork(ctx context.Context, subscription string) (hasWork bool, err kv.Error) { 321 return true, nil 322 } 323 324 // Responder is used to open a connection to an existing response queue if 325 // one was made available and also to provision a channel into which the 326 // runner can place report messages 327 func (sq *SQS) Responder(ctx context.Context, subscription string) (sender chan *runnerReports.Report, err kv.Error) { 328 sender = make(chan *runnerReports.Report, 1) 329 // Open the queue and if this cannot be done exit with the error 330 go func() { 331 for { 332 select { 333 case <-sender: 334 continue 335 case <-ctx.Done(): 336 return 337 } 338 } 339 }() 340 return sender, err 341 }