github.com/defang-io/defang/src@v0.0.0-20240505002154-bdf411911834/pkg/clouds/aws/ecs/logs.go (about)

     1  package ecs
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
    13  	"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
    14  	"github.com/aws/aws-sdk-go-v2/service/ecs"
    15  	ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types"
    16  	"github.com/aws/smithy-go/ptr"
    17  	"github.com/defang-io/defang/src/pkg/clouds/aws"
    18  	"github.com/defang-io/defang/src/pkg/clouds/aws/region"
    19  )
    20  
    21  // Task ARN						arn:aws:ecs:us-west-2:123456789012:task/CLUSTER_NAME/2cba912d5eb14ffd926f6992b054f3bf
    22  // Cluster ARN					arn:aws:ecs:us-west-2:123456789012:cluster/CLUSTER_NAME
    23  // LogGroup ARN					arn:aws:logs:us-west-2:123456789012:log-group:/LOG/GROUP/NAME:*
    24  // LogGroup ID					arn:aws:logs:us-west-2:123456789012:log-group:/LOG/GROUP/NAME
    25  // LogStream ("awslogs")		PREFIX/CONTAINER/2cba912d5eb14ffd926f6992b054f3bf
    26  // LogStream ("awsfirelens")	PREFIX/CONTAINER-firelens-2cba912d5eb14ffd926f6992b054f3bf
    27  
    28  type LogStreamInfo struct {
    29  	Prefix    string
    30  	Container string
    31  	Firelens  bool
    32  	TaskID    string
    33  }
    34  
    35  func GetLogStreamInfo(logStream string) *LogStreamInfo {
    36  	parts := strings.Split(logStream, "/")
    37  	switch len(parts) {
    38  	case 3:
    39  		return &LogStreamInfo{
    40  			Prefix:    parts[0],
    41  			Container: parts[1],
    42  			Firelens:  false,
    43  			TaskID:    parts[2],
    44  		}
    45  	case 2:
    46  		firelensParts := strings.Split(parts[1], "-")
    47  		if len(firelensParts) != 3 || firelensParts[1] != "firelens" {
    48  			return nil
    49  		}
    50  		return &LogStreamInfo{
    51  			Prefix:    parts[0],
    52  			Container: firelensParts[0],
    53  			Firelens:  true,
    54  			TaskID:    firelensParts[2],
    55  		}
    56  	default:
    57  		return nil
    58  	}
    59  }
    60  
    61  func getLogGroupIdentifier(arnOrId string) string {
    62  	return strings.TrimSuffix(arnOrId, ":*")
    63  }
    64  
    65  func TailLogGroups(ctx context.Context, since time.Time, logGroups ...LogGroupInput) (EventStream, error) {
    66  	child, cancel := context.WithCancel(ctx)
    67  	var cs = collectionStream{
    68  		cancel: cancel,
    69  		ch:     make(chan types.StartLiveTailResponseStream, 10), // max number of loggroups to query
    70  		ctx:    child,
    71  		errCh:  make(chan error, 1),
    72  	}
    73  
    74  	type pair struct {
    75  		es  EventStream
    76  		lgi LogGroupInput
    77  	}
    78  	var pairs []pair
    79  	var pendingGroups []LogGroupInput
    80  
    81  	sincePending := since
    82  	if sincePending.IsZero() {
    83  		sincePending = time.Now()
    84  	}
    85  	for _, lgi := range logGroups {
    86  		es, err := TailLogGroup(ctx, lgi)
    87  		if err == nil {
    88  			pairs = append(pairs, pair{es, lgi})
    89  			continue
    90  		}
    91  
    92  		var resourceNotFound *types.ResourceNotFoundException
    93  		if !errors.As(err, &resourceNotFound) {
    94  			return nil, err
    95  		}
    96  		pendingGroups = append(pendingGroups, lgi)
    97  	}
    98  
    99  	// Start goroutines to wait for the log group to be created for the resource not found log groups
   100  	for _, lgi := range pendingGroups {
   101  		cs.wg.Add(1)
   102  		go func(lgi LogGroupInput) {
   103  			defer cs.wg.Done()
   104  			ticker := time.NewTicker(time.Second)
   105  			defer ticker.Stop()
   106  
   107  			for {
   108  				select {
   109  				case <-cs.ctx.Done():
   110  					return
   111  				case <-ticker.C:
   112  					es, err := TailLogGroup(cs.ctx, lgi)
   113  					if err == nil {
   114  						cs.addAndStart(es, sincePending, lgi)
   115  						return
   116  					}
   117  					var resourceNotFound *types.ResourceNotFoundException
   118  					if !errors.As(err, &resourceNotFound) {
   119  						cs.errCh <- err
   120  						return
   121  					}
   122  				}
   123  			}
   124  		}(lgi)
   125  	}
   126  
   127  	// Only add and start watching the streams if there were no errors, prevent lingering goroutines
   128  	for _, s := range pairs {
   129  		cs.addAndStart(s.es, since, s.lgi)
   130  	}
   131  
   132  	return &cs, nil
   133  }
   134  
   135  // LogGroupInput is like cloudwatchlogs.StartLiveTailInput but with only one loggroup and one logstream prefix.
   136  type LogGroupInput struct {
   137  	LogGroupARN           string
   138  	LogStreamNames        []string
   139  	LogStreamNamePrefix   string
   140  	LogEventFilterPattern string
   141  }
   142  
   143  func TailLogGroup(ctx context.Context, input LogGroupInput) (EventStream, error) {
   144  	var pattern *string
   145  	if input.LogEventFilterPattern != "" {
   146  		pattern = &input.LogEventFilterPattern
   147  	}
   148  	var prefixes []string
   149  	if input.LogStreamNamePrefix != "" {
   150  		prefixes = []string{input.LogStreamNamePrefix}
   151  	}
   152  	return startTail(ctx, &cloudwatchlogs.StartLiveTailInput{
   153  		LogGroupIdentifiers:   []string{getLogGroupIdentifier(input.LogGroupARN)},
   154  		LogStreamNames:        input.LogStreamNames,
   155  		LogStreamNamePrefixes: prefixes,
   156  		LogEventFilterPattern: pattern,
   157  	})
   158  }
   159  
   160  func Query(ctx context.Context, lgi LogGroupInput, start time.Time, end time.Time) ([]LogEvent, error) {
   161  	region := region.FromArn(lgi.LogGroupARN)
   162  	cfg, err := aws.LoadDefaultConfig(ctx, region)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	logGroupIdentifier := getLogGroupIdentifier(lgi.LogGroupARN)
   168  	var prefix *string
   169  	if lgi.LogStreamNamePrefix != "" {
   170  		prefix = &lgi.LogStreamNamePrefix
   171  	}
   172  	cw := cloudwatchlogs.NewFromConfig(cfg)
   173  	fleo, err := cw.FilterLogEvents(ctx, &cloudwatchlogs.FilterLogEventsInput{
   174  		StartTime:           ptr.Int64(start.UnixMilli()),
   175  		EndTime:             ptr.Int64(end.UnixMilli()),
   176  		LogGroupIdentifier:  &logGroupIdentifier,
   177  		LogStreamNamePrefix: prefix,
   178  		LogStreamNames:      lgi.LogStreamNames,
   179  	})
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	events := make([]LogEvent, len(fleo.Events))
   184  	for i, e := range fleo.Events {
   185  		events[i] = LogEvent{
   186  			IngestionTime:      e.IngestionTime,
   187  			LogGroupIdentifier: &logGroupIdentifier,
   188  			Message:            e.Message,
   189  			Timestamp:          e.Timestamp,
   190  			LogStreamName:      e.LogStreamName,
   191  		}
   192  	}
   193  	// TODO: handle pagination using NextToken
   194  	return events, nil
   195  }
   196  
   197  func startTail(ctx context.Context, slti *cloudwatchlogs.StartLiveTailInput) (EventStream, error) {
   198  	region := region.FromArn(slti.LogGroupIdentifiers[0]) // must have at least one log group
   199  	cfg, err := aws.LoadDefaultConfig(ctx, region)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	cw := cloudwatchlogs.NewFromConfig(cfg)
   205  	slto, err := cw.StartLiveTail(ctx, slti)
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	// if !since.IsZero() {
   211  	// 	if events, err := Query(ctx, slti.LogGroupIdentifiers[0], since, time.Now()); err == nil {
   212  	// 		slto.Events <- &types.StartLiveTailResponseStreamMemberSessionUpdate{
   213  	// 			Value: types.LiveTailSessionUpdate{
   214  	// 				SessionResults: events,
   215  	// 			},
   216  	// 		}
   217  	// 	}
   218  	// }
   219  
   220  	return slto.GetStream(), nil
   221  }
   222  
   223  func GetTaskStatus(ctx context.Context, taskArn TaskArn) error {
   224  	region := region.FromArn(*taskArn)
   225  	cluster, taskID := SplitClusterTask(taskArn)
   226  	return getTaskStatus(ctx, region, cluster, taskID)
   227  }
   228  
   229  func isTaskTerminalState(status string) bool {
   230  	// From https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-lifecycle-explanation.html
   231  	switch status {
   232  	case "DELETED", "STOPPED", "DEPROVISIONING":
   233  		return true
   234  	default:
   235  		return false // we might still get logs
   236  	}
   237  }
   238  
   239  func getTaskStatus(ctx context.Context, region aws.Region, cluster, taskId string) error {
   240  	cfg, err := aws.LoadDefaultConfig(ctx, region)
   241  	if err != nil {
   242  		return err
   243  	}
   244  	ecsClient := ecs.NewFromConfig(cfg)
   245  
   246  	// Use DescribeTasks API to check if the task is still running (same as ecs.NewTasksStoppedWaiter)
   247  	ti, _ := ecsClient.DescribeTasks(ctx, &ecs.DescribeTasksInput{
   248  		Cluster: &cluster,
   249  		Tasks:   []string{taskId},
   250  	})
   251  	if ti == nil || len(ti.Tasks) == 0 {
   252  		return nil // task doesn't exist yet; TODO: check the actual error from DescribeTasks
   253  	}
   254  	task := ti.Tasks[0]
   255  	if task.LastStatus == nil || !isTaskTerminalState(*task.LastStatus) {
   256  		return nil // still running
   257  	}
   258  
   259  	switch task.StopCode {
   260  	default:
   261  		return taskFailure{string(task.StopCode), *task.StoppedReason}
   262  	case ecsTypes.TaskStopCodeEssentialContainerExited:
   263  		for _, c := range task.Containers {
   264  			if c.ExitCode != nil && *c.ExitCode != 0 {
   265  				reason := fmt.Sprintf("%s with code %d", *task.StoppedReason, *c.ExitCode)
   266  				return taskFailure{string(task.StopCode), reason}
   267  			}
   268  		}
   269  		fallthrough
   270  	case "": // TODO: shouldn't happen
   271  		return io.EOF // Success
   272  	}
   273  }
   274  
   275  func SplitClusterTask(taskArn TaskArn) (string, string) {
   276  	if !strings.HasPrefix(*taskArn, "arn:aws:ecs:") {
   277  		panic("invalid ECS ARN")
   278  	}
   279  	parts := strings.Split(*taskArn, "/")
   280  	if len(parts) != 3 || !strings.HasSuffix(parts[0], ":task") {
   281  		panic("invalid task ARN")
   282  	}
   283  	return parts[1], parts[2]
   284  }
   285  
   286  type LogEvent = types.LiveTailSessionLogEvent
   287  
   288  // EventStream is an interface that represents a stream of events from a call to StartLiveTail
   289  type EventStream interface {
   290  	Close() error
   291  	Events() <-chan types.StartLiveTailResponseStream
   292  }
   293  
   294  type collectionStream struct {
   295  	cancel  context.CancelFunc
   296  	ch      chan types.StartLiveTailResponseStream
   297  	ctx     context.Context // derived from the context passed to TailLogGroups
   298  	errCh   chan error
   299  	streams []EventStream
   300  
   301  	lock sync.Mutex
   302  	wg   sync.WaitGroup
   303  }
   304  
   305  func (c *collectionStream) addAndStart(s EventStream, since time.Time, lgi LogGroupInput) {
   306  	c.lock.Lock()
   307  	defer c.lock.Unlock()
   308  	c.streams = append(c.streams, s)
   309  	c.wg.Add(1)
   310  	go func() {
   311  		defer c.wg.Done()
   312  		if !since.IsZero() {
   313  			// Query the logs between the start time and now
   314  			if events, err := Query(c.ctx, lgi, since, time.Now()); err != nil {
   315  				c.errCh <- err // the caller will likely cancel the context
   316  			} else {
   317  				c.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{
   318  					Value: types.LiveTailSessionUpdate{SessionResults: events},
   319  				}
   320  			}
   321  		}
   322  		for {
   323  			// Double select to make sure context cancellation is not blocked by either the receive or send
   324  			// See: https://stackoverflow.com/questions/60030756/what-does-it-mean-when-one-channel-uses-two-arrows-to-write-to-another-channel
   325  			select {
   326  			case e := <-s.Events(): // blocking
   327  				select {
   328  				case c.ch <- e:
   329  				case <-c.ctx.Done():
   330  					return
   331  				}
   332  			case <-c.ctx.Done(): // blocking
   333  				return
   334  			}
   335  		}
   336  	}()
   337  }
   338  
   339  func (c *collectionStream) Close() error {
   340  	c.cancel()
   341  	c.wg.Wait() // Only close the channels after all goroutines have exited
   342  	close(c.ch)
   343  	close(c.errCh)
   344  
   345  	var errs []error
   346  	for _, s := range c.streams {
   347  		err := s.Close()
   348  		if err != nil {
   349  			errs = append(errs, err)
   350  		}
   351  	}
   352  	return errors.Join(errs...) // nil if no errors
   353  }
   354  
   355  func (c *collectionStream) Events() <-chan types.StartLiveTailResponseStream {
   356  	return c.ch
   357  }
   358  
   359  func (c *collectionStream) Errs() <-chan error {
   360  	return c.errCh
   361  }
   362  
   363  func GetLogEvents(e types.StartLiveTailResponseStream) ([]LogEvent, error) {
   364  	switch ev := e.(type) {
   365  	case *types.StartLiveTailResponseStreamMemberSessionStart:
   366  		// fmt.Println("session start:", ev.Value.SessionId)
   367  		return nil, nil // ignore start message
   368  	case *types.StartLiveTailResponseStreamMemberSessionUpdate:
   369  		// fmt.Println("session update:", len(ev.Value.SessionResults))
   370  		return ev.Value.SessionResults, nil
   371  	case nil:
   372  		return nil, io.EOF
   373  	default:
   374  		return nil, fmt.Errorf("unexpected event: %T", ev)
   375  	}
   376  }
   377  
   378  func WaitForTask(ctx context.Context, taskArn TaskArn, poll time.Duration) error {
   379  	if taskArn == nil {
   380  		panic("taskArn is nil")
   381  	}
   382  	ticker := time.NewTicker(poll)
   383  	defer ticker.Stop()
   384  	for {
   385  		select {
   386  		case <-ctx.Done():
   387  			// Handle cancellation
   388  			return ctx.Err()
   389  		case <-ticker.C:
   390  			if err := GetTaskStatus(ctx, taskArn); err != nil {
   391  				return err
   392  			}
   393  		}
   394  	}
   395  }