github.com/defang-io/defang/src@v0.0.0-20240505002154-bdf411911834/pkg/clouds/aws/ecs/run.go (about)

     1  package ecs
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"os"
     7  	"time"
     8  
     9  	"github.com/aws/aws-sdk-go-v2/aws"
    10  	"github.com/aws/aws-sdk-go-v2/service/ec2"
    11  	ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
    12  	"github.com/aws/aws-sdk-go-v2/service/ecs"
    13  	"github.com/aws/aws-sdk-go-v2/service/ecs/types"
    14  	"github.com/aws/smithy-go/ptr"
    15  )
    16  
    17  const taskCount = 1
    18  
    19  func (a *AwsEcs) PopulateVPCandSubnetID(ctx context.Context, vpcID, subnetID string) error {
    20  	cfg, err := a.LoadConfig(ctx)
    21  	if err != nil {
    22  		return err
    23  	}
    24  
    25  	if vpcID != "" && subnetID == "" {
    26  		subnetID, err = getPublicSubnetId(ctx, cfg, vpcID)
    27  	} else if vpcID == "" && subnetID != "" {
    28  		vpcID, err = getSubnetVPCId(ctx, cfg, subnetID)
    29  	}
    30  
    31  	a.VpcID = vpcID
    32  	a.SubNetID = subnetID
    33  	return err
    34  }
    35  
    36  func (a *AwsEcs) Run(ctx context.Context, env map[string]string, cmd ...string) (TaskArn, error) {
    37  	// a.Refresh(ctx)
    38  
    39  	cfg, err := a.LoadConfig(ctx)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	var pairs []types.KeyValuePair
    45  	for k, v := range env {
    46  		pairs = append(pairs, types.KeyValuePair{
    47  			Name:  ptr.String(k),
    48  			Value: ptr.String(v),
    49  		})
    50  	}
    51  
    52  	// stsClient := sts.NewFromConfig(cfg)
    53  	// cred, err := stsClient.GetCallerIdentity(ctx, nil)
    54  	// if err != nil {
    55  	// 	return nil, err
    56  	// }
    57  
    58  	securityGroups := []string{a.SecurityGroupID} // TODO: only if ports are mapped
    59  	rti := ecs.RunTaskInput{
    60  		Count:          ptr.Int32(taskCount),
    61  		LaunchType:     types.LaunchTypeFargate,
    62  		TaskDefinition: ptr.String(a.TaskDefARN),
    63  		PropagateTags:  types.PropagateTagsTaskDefinition,
    64  		Cluster:        ptr.String(a.ClusterName),
    65  		NetworkConfiguration: &types.NetworkConfiguration{
    66  			AwsvpcConfiguration: &types.AwsVpcConfiguration{
    67  				AssignPublicIp: types.AssignPublicIpEnabled, // only works with public subnets
    68  				Subnets:        []string{a.SubNetID},        // TODO: make configurable; must this match the VPC of the SecGroup?
    69  				SecurityGroups: securityGroups,
    70  			},
    71  		},
    72  		Overrides: &types.TaskOverride{
    73  			// Cpu:   ptr.String("256"),
    74  			// Memory: ptr.String("512"),
    75  			// TaskRoleArn: cred.Arn; TODO: default to caller identity; needs trust + iam:PassRole
    76  			ContainerOverrides: []types.ContainerOverride{
    77  				{
    78  					Name:        ptr.String(ContainerName),
    79  					Command:     cmd,
    80  					Environment: pairs,
    81  					// ResourceRequirements:; TODO: make configurable, support GPUs
    82  					// EnvironmentFiles: ,
    83  				},
    84  			},
    85  		},
    86  		Tags: []types.Tag{ //TODO: add tags to the task
    87  			{
    88  				Key:   ptr.String("StartedAt"),
    89  				Value: ptr.String(time.Now().Format(time.RFC3339)),
    90  			},
    91  			{
    92  				Key:   ptr.String("StartedBy"),
    93  				Value: ptr.String(os.Getenv("USER")),
    94  			},
    95  		},
    96  	}
    97  
    98  	ecsOutput, err := ecs.NewFromConfig(cfg).RunTask(ctx, &rti)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	failures := make([]error, len(ecsOutput.Failures))
   103  	for i, f := range ecsOutput.Failures {
   104  		failures[i] = taskFailure{*f.Reason, *f.Detail}
   105  	}
   106  	if err := errors.Join(failures...); err != nil || len(ecsOutput.Tasks) == 0 {
   107  		return nil, err
   108  	}
   109  	// bytes, _ := json.MarshalIndent(ecsOutput.Tasks, "", "  ")
   110  	// println(string(bytes))
   111  	return TaskArn(ecsOutput.Tasks[0].TaskArn), nil
   112  }
   113  
   114  func getPublicSubnetId(ctx context.Context, cfg aws.Config, vpcId string) (string, error) {
   115  	subnetsOutput, err := ec2.NewFromConfig(cfg).DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
   116  		Filters: []ec2types.Filter{
   117  			{
   118  				Name:   ptr.String("vpc-id"),
   119  				Values: []string{vpcId},
   120  			},
   121  			{
   122  				Name:   ptr.String("map-public-ip-on-launch"),
   123  				Values: []string{"true"},
   124  			},
   125  		},
   126  	})
   127  	if err != nil {
   128  		return "", err
   129  	}
   130  	return *subnetsOutput.Subnets[0].SubnetId, nil // TODO: make configurable/deterministic
   131  }
   132  
   133  func getSubnetVPCId(ctx context.Context, cfg aws.Config, subnetId string) (string, error) {
   134  	subnetsOutput, err := ec2.NewFromConfig(cfg).DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{
   135  		SubnetIds: []string{subnetId},
   136  	})
   137  	if err != nil {
   138  		return "", err
   139  	}
   140  	return *subnetsOutput.Subnets[0].VpcId, nil // TODO: make configurable/deterministic
   141  }
   142  
   143  type taskFailure struct {
   144  	Reason string
   145  	Detail string
   146  }
   147  
   148  func (t taskFailure) Error() string {
   149  	return t.Reason + ": " + t.Detail
   150  }
   151  
   152  /*
   153  func getAwsEnv() awsEnv {
   154  	creds := getEcsCreds()
   155  	return map[string]string{
   156  		"AWS_ACCESS_KEY_ID":     creds.AccessKeyId,
   157  		"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
   158  		"AWS_SESSION_TOKEN":     creds.Token,
   159  		// "AWS_REGION": "us-west-2", should not be needed because it's in the stack config and/or env
   160  	}
   161  }
   162  
   163  var (
   164  	ecsCredsUrl = "http://169.254.170.2" + os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
   165  )
   166  
   167  type ecsCreds struct {
   168  	AccessKeyId     string
   169  	Expiration      string
   170  	RoleArn         string
   171  	SecretAccessKey string
   172  	Token           string
   173  }
   174  
   175  func getEcsCreds() (creds ecsCreds) {
   176  	// Grab the ECS credentials from the metadata service at AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
   177  	res, err := http.Get(ecsCredsUrl)
   178  	if err != nil {
   179  		log.Panicln(err)
   180  	}
   181  	defer res.Body.Close()
   182  	if err := json.NewDecoder(res.Body).Decode(&creds); err != nil {
   183  		log.Panicln(err)
   184  	}
   185  	return creds
   186  }
   187  */