github.com/defang-io/defang/src@v0.0.0-20240505002154-bdf411911834/pkg/clouds/aws/ecs/logs.go (about) 1 package ecs 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" 13 "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" 14 "github.com/aws/aws-sdk-go-v2/service/ecs" 15 ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" 16 "github.com/aws/smithy-go/ptr" 17 "github.com/defang-io/defang/src/pkg/clouds/aws" 18 "github.com/defang-io/defang/src/pkg/clouds/aws/region" 19 ) 20 21 // Task ARN arn:aws:ecs:us-west-2:123456789012:task/CLUSTER_NAME/2cba912d5eb14ffd926f6992b054f3bf 22 // Cluster ARN arn:aws:ecs:us-west-2:123456789012:cluster/CLUSTER_NAME 23 // LogGroup ARN arn:aws:logs:us-west-2:123456789012:log-group:/LOG/GROUP/NAME:* 24 // LogGroup ID arn:aws:logs:us-west-2:123456789012:log-group:/LOG/GROUP/NAME 25 // LogStream ("awslogs") PREFIX/CONTAINER/2cba912d5eb14ffd926f6992b054f3bf 26 // LogStream ("awsfirelens") PREFIX/CONTAINER-firelens-2cba912d5eb14ffd926f6992b054f3bf 27 28 type LogStreamInfo struct { 29 Prefix string 30 Container string 31 Firelens bool 32 TaskID string 33 } 34 35 func GetLogStreamInfo(logStream string) *LogStreamInfo { 36 parts := strings.Split(logStream, "/") 37 switch len(parts) { 38 case 3: 39 return &LogStreamInfo{ 40 Prefix: parts[0], 41 Container: parts[1], 42 Firelens: false, 43 TaskID: parts[2], 44 } 45 case 2: 46 firelensParts := strings.Split(parts[1], "-") 47 if len(firelensParts) != 3 || firelensParts[1] != "firelens" { 48 return nil 49 } 50 return &LogStreamInfo{ 51 Prefix: parts[0], 52 Container: firelensParts[0], 53 Firelens: true, 54 TaskID: firelensParts[2], 55 } 56 default: 57 return nil 58 } 59 } 60 61 func getLogGroupIdentifier(arnOrId string) string { 62 return strings.TrimSuffix(arnOrId, ":*") 63 } 64 65 func TailLogGroups(ctx context.Context, since time.Time, logGroups ...LogGroupInput) (EventStream, error) { 66 child, cancel := context.WithCancel(ctx) 67 var cs = collectionStream{ 68 cancel: cancel, 69 ch: make(chan types.StartLiveTailResponseStream, 10), // max number of loggroups to query 70 ctx: child, 71 errCh: make(chan error, 1), 72 } 73 74 type pair struct { 75 es EventStream 76 lgi LogGroupInput 77 } 78 var pairs []pair 79 var pendingGroups []LogGroupInput 80 81 sincePending := since 82 if sincePending.IsZero() { 83 sincePending = time.Now() 84 } 85 for _, lgi := range logGroups { 86 es, err := TailLogGroup(ctx, lgi) 87 if err == nil { 88 pairs = append(pairs, pair{es, lgi}) 89 continue 90 } 91 92 var resourceNotFound *types.ResourceNotFoundException 93 if !errors.As(err, &resourceNotFound) { 94 return nil, err 95 } 96 pendingGroups = append(pendingGroups, lgi) 97 } 98 99 // Start goroutines to wait for the log group to be created for the resource not found log groups 100 for _, lgi := range pendingGroups { 101 cs.wg.Add(1) 102 go func(lgi LogGroupInput) { 103 defer cs.wg.Done() 104 ticker := time.NewTicker(time.Second) 105 defer ticker.Stop() 106 107 for { 108 select { 109 case <-cs.ctx.Done(): 110 return 111 case <-ticker.C: 112 es, err := TailLogGroup(cs.ctx, lgi) 113 if err == nil { 114 cs.addAndStart(es, sincePending, lgi) 115 return 116 } 117 var resourceNotFound *types.ResourceNotFoundException 118 if !errors.As(err, &resourceNotFound) { 119 cs.errCh <- err 120 return 121 } 122 } 123 } 124 }(lgi) 125 } 126 127 // Only add and start watching the streams if there were no errors, prevent lingering goroutines 128 for _, s := range pairs { 129 cs.addAndStart(s.es, since, s.lgi) 130 } 131 132 return &cs, nil 133 } 134 135 // LogGroupInput is like cloudwatchlogs.StartLiveTailInput but with only one loggroup and one logstream prefix. 136 type LogGroupInput struct { 137 LogGroupARN string 138 LogStreamNames []string 139 LogStreamNamePrefix string 140 LogEventFilterPattern string 141 } 142 143 func TailLogGroup(ctx context.Context, input LogGroupInput) (EventStream, error) { 144 var pattern *string 145 if input.LogEventFilterPattern != "" { 146 pattern = &input.LogEventFilterPattern 147 } 148 var prefixes []string 149 if input.LogStreamNamePrefix != "" { 150 prefixes = []string{input.LogStreamNamePrefix} 151 } 152 return startTail(ctx, &cloudwatchlogs.StartLiveTailInput{ 153 LogGroupIdentifiers: []string{getLogGroupIdentifier(input.LogGroupARN)}, 154 LogStreamNames: input.LogStreamNames, 155 LogStreamNamePrefixes: prefixes, 156 LogEventFilterPattern: pattern, 157 }) 158 } 159 160 func Query(ctx context.Context, lgi LogGroupInput, start time.Time, end time.Time) ([]LogEvent, error) { 161 region := region.FromArn(lgi.LogGroupARN) 162 cfg, err := aws.LoadDefaultConfig(ctx, region) 163 if err != nil { 164 return nil, err 165 } 166 167 logGroupIdentifier := getLogGroupIdentifier(lgi.LogGroupARN) 168 var prefix *string 169 if lgi.LogStreamNamePrefix != "" { 170 prefix = &lgi.LogStreamNamePrefix 171 } 172 cw := cloudwatchlogs.NewFromConfig(cfg) 173 fleo, err := cw.FilterLogEvents(ctx, &cloudwatchlogs.FilterLogEventsInput{ 174 StartTime: ptr.Int64(start.UnixMilli()), 175 EndTime: ptr.Int64(end.UnixMilli()), 176 LogGroupIdentifier: &logGroupIdentifier, 177 LogStreamNamePrefix: prefix, 178 LogStreamNames: lgi.LogStreamNames, 179 }) 180 if err != nil { 181 return nil, err 182 } 183 events := make([]LogEvent, len(fleo.Events)) 184 for i, e := range fleo.Events { 185 events[i] = LogEvent{ 186 IngestionTime: e.IngestionTime, 187 LogGroupIdentifier: &logGroupIdentifier, 188 Message: e.Message, 189 Timestamp: e.Timestamp, 190 LogStreamName: e.LogStreamName, 191 } 192 } 193 // TODO: handle pagination using NextToken 194 return events, nil 195 } 196 197 func startTail(ctx context.Context, slti *cloudwatchlogs.StartLiveTailInput) (EventStream, error) { 198 region := region.FromArn(slti.LogGroupIdentifiers[0]) // must have at least one log group 199 cfg, err := aws.LoadDefaultConfig(ctx, region) 200 if err != nil { 201 return nil, err 202 } 203 204 cw := cloudwatchlogs.NewFromConfig(cfg) 205 slto, err := cw.StartLiveTail(ctx, slti) 206 if err != nil { 207 return nil, err 208 } 209 210 // if !since.IsZero() { 211 // if events, err := Query(ctx, slti.LogGroupIdentifiers[0], since, time.Now()); err == nil { 212 // slto.Events <- &types.StartLiveTailResponseStreamMemberSessionUpdate{ 213 // Value: types.LiveTailSessionUpdate{ 214 // SessionResults: events, 215 // }, 216 // } 217 // } 218 // } 219 220 return slto.GetStream(), nil 221 } 222 223 func GetTaskStatus(ctx context.Context, taskArn TaskArn) error { 224 region := region.FromArn(*taskArn) 225 cluster, taskID := SplitClusterTask(taskArn) 226 return getTaskStatus(ctx, region, cluster, taskID) 227 } 228 229 func isTaskTerminalState(status string) bool { 230 // From https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-lifecycle-explanation.html 231 switch status { 232 case "DELETED", "STOPPED", "DEPROVISIONING": 233 return true 234 default: 235 return false // we might still get logs 236 } 237 } 238 239 func getTaskStatus(ctx context.Context, region aws.Region, cluster, taskId string) error { 240 cfg, err := aws.LoadDefaultConfig(ctx, region) 241 if err != nil { 242 return err 243 } 244 ecsClient := ecs.NewFromConfig(cfg) 245 246 // Use DescribeTasks API to check if the task is still running (same as ecs.NewTasksStoppedWaiter) 247 ti, _ := ecsClient.DescribeTasks(ctx, &ecs.DescribeTasksInput{ 248 Cluster: &cluster, 249 Tasks: []string{taskId}, 250 }) 251 if ti == nil || len(ti.Tasks) == 0 { 252 return nil // task doesn't exist yet; TODO: check the actual error from DescribeTasks 253 } 254 task := ti.Tasks[0] 255 if task.LastStatus == nil || !isTaskTerminalState(*task.LastStatus) { 256 return nil // still running 257 } 258 259 switch task.StopCode { 260 default: 261 return taskFailure{string(task.StopCode), *task.StoppedReason} 262 case ecsTypes.TaskStopCodeEssentialContainerExited: 263 for _, c := range task.Containers { 264 if c.ExitCode != nil && *c.ExitCode != 0 { 265 reason := fmt.Sprintf("%s with code %d", *task.StoppedReason, *c.ExitCode) 266 return taskFailure{string(task.StopCode), reason} 267 } 268 } 269 fallthrough 270 case "": // TODO: shouldn't happen 271 return io.EOF // Success 272 } 273 } 274 275 func SplitClusterTask(taskArn TaskArn) (string, string) { 276 if !strings.HasPrefix(*taskArn, "arn:aws:ecs:") { 277 panic("invalid ECS ARN") 278 } 279 parts := strings.Split(*taskArn, "/") 280 if len(parts) != 3 || !strings.HasSuffix(parts[0], ":task") { 281 panic("invalid task ARN") 282 } 283 return parts[1], parts[2] 284 } 285 286 type LogEvent = types.LiveTailSessionLogEvent 287 288 // EventStream is an interface that represents a stream of events from a call to StartLiveTail 289 type EventStream interface { 290 Close() error 291 Events() <-chan types.StartLiveTailResponseStream 292 } 293 294 type collectionStream struct { 295 cancel context.CancelFunc 296 ch chan types.StartLiveTailResponseStream 297 ctx context.Context // derived from the context passed to TailLogGroups 298 errCh chan error 299 streams []EventStream 300 301 lock sync.Mutex 302 wg sync.WaitGroup 303 } 304 305 func (c *collectionStream) addAndStart(s EventStream, since time.Time, lgi LogGroupInput) { 306 c.lock.Lock() 307 defer c.lock.Unlock() 308 c.streams = append(c.streams, s) 309 c.wg.Add(1) 310 go func() { 311 defer c.wg.Done() 312 if !since.IsZero() { 313 // Query the logs between the start time and now 314 if events, err := Query(c.ctx, lgi, since, time.Now()); err != nil { 315 c.errCh <- err // the caller will likely cancel the context 316 } else { 317 c.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{ 318 Value: types.LiveTailSessionUpdate{SessionResults: events}, 319 } 320 } 321 } 322 for { 323 // Double select to make sure context cancellation is not blocked by either the receive or send 324 // See: https://stackoverflow.com/questions/60030756/what-does-it-mean-when-one-channel-uses-two-arrows-to-write-to-another-channel 325 select { 326 case e := <-s.Events(): // blocking 327 select { 328 case c.ch <- e: 329 case <-c.ctx.Done(): 330 return 331 } 332 case <-c.ctx.Done(): // blocking 333 return 334 } 335 } 336 }() 337 } 338 339 func (c *collectionStream) Close() error { 340 c.cancel() 341 c.wg.Wait() // Only close the channels after all goroutines have exited 342 close(c.ch) 343 close(c.errCh) 344 345 var errs []error 346 for _, s := range c.streams { 347 err := s.Close() 348 if err != nil { 349 errs = append(errs, err) 350 } 351 } 352 return errors.Join(errs...) // nil if no errors 353 } 354 355 func (c *collectionStream) Events() <-chan types.StartLiveTailResponseStream { 356 return c.ch 357 } 358 359 func (c *collectionStream) Errs() <-chan error { 360 return c.errCh 361 } 362 363 func GetLogEvents(e types.StartLiveTailResponseStream) ([]LogEvent, error) { 364 switch ev := e.(type) { 365 case *types.StartLiveTailResponseStreamMemberSessionStart: 366 // fmt.Println("session start:", ev.Value.SessionId) 367 return nil, nil // ignore start message 368 case *types.StartLiveTailResponseStreamMemberSessionUpdate: 369 // fmt.Println("session update:", len(ev.Value.SessionResults)) 370 return ev.Value.SessionResults, nil 371 case nil: 372 return nil, io.EOF 373 default: 374 return nil, fmt.Errorf("unexpected event: %T", ev) 375 } 376 } 377 378 func WaitForTask(ctx context.Context, taskArn TaskArn, poll time.Duration) error { 379 if taskArn == nil { 380 panic("taskArn is nil") 381 } 382 ticker := time.NewTicker(poll) 383 defer ticker.Stop() 384 for { 385 select { 386 case <-ctx.Done(): 387 // Handle cancellation 388 return ctx.Err() 389 case <-ticker.C: 390 if err := GetTaskStatus(ctx, taskArn); err != nil { 391 return err 392 } 393 } 394 } 395 }