github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/security/ticket/aws.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package ticket
     6  
     7  import (
     8  	"errors"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/aws/aws-sdk-go/aws"
    13  	"github.com/aws/aws-sdk-go/aws/client"
    14  	"github.com/aws/aws-sdk-go/aws/credentials"
    15  	"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
    16  	"github.com/aws/aws-sdk-go/aws/session"
    17  	"github.com/aws/aws-sdk-go/service/ec2"
    18  	"github.com/aws/aws-sdk-go/service/ecr"
    19  	"github.com/aws/aws-sdk-go/service/sts"
    20  	"github.com/Schaudge/grailbase/cloud/ec2util"
    21  	"github.com/Schaudge/grailbase/common/log"
    22  	"github.com/Schaudge/grailbase/ttlcache"
    23  )
    24  
    25  type cacheKey struct {
    26  	region  string
    27  	role    string
    28  	session string
    29  }
    30  
    31  // cacheTTL is how long the entries in cache will be considered valid.
    32  const cacheTTL = time.Minute
    33  
    34  var cache = ttlcache.New(cacheTTL)
    35  
    36  func (b *AwsAssumeRoleBuilder) newAwsTicket(ctx *TicketContext) (TicketAwsTicket, error) {
    37  	awsCredentials, err := b.genAwsCredentials(ctx)
    38  
    39  	if err != nil {
    40  		return TicketAwsTicket{}, err
    41  	}
    42  
    43  	return TicketAwsTicket{
    44  		Value: AwsTicket{
    45  			AwsCredentials: awsCredentials,
    46  		},
    47  	}, nil
    48  }
    49  
    50  func (b *AwsAssumeRoleBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, error) {
    51  	awsCredentials, err := b.genAwsCredentials(ctx)
    52  
    53  	if err != nil {
    54  		return TicketS3Ticket{}, err
    55  	}
    56  
    57  	return TicketS3Ticket{
    58  		Value: S3Ticket{
    59  			AwsCredentials: awsCredentials,
    60  		},
    61  	}, nil
    62  }
    63  
    64  func (b *AwsAssumeRoleBuilder) newEcrTicket(ctx *TicketContext) (TicketEcrTicket, error) {
    65  	log.Debug(ctx.ctx, "generating ECR ticket", "AwsAssumeRoleBuilder", b)
    66  	awsCredentials, err := b.genAwsCredentials(ctx)
    67  
    68  	if err != nil {
    69  		return TicketEcrTicket{}, err
    70  	}
    71  	return TicketEcrTicket{
    72  		Value: newEcrTicket(ctx, awsCredentials),
    73  	}, nil
    74  }
    75  
    76  func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredentials, error) {
    77  	log.Debug(ctx.ctx, "generating AWS credentials", "AwsAssumeRoleBuilder", b)
    78  	empty := AwsCredentials{}
    79  
    80  	sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1)
    81  	// AWS session names must be 64 characters or less
    82  	if runes := []rune(sessionName); len(runes) > 64 {
    83  		// Some risk with simple truncation - two large IAM role's would overlap
    84  		// for example. This is mitigated by the format which includes instance id
    85  		// as the last component. Ability to determine exactly which instance made
    86  		// the call will be difficult, but likelihood of 2 instances sharing a prefix
    87  		// is low.
    88  		sessionName = string(runes[0:64])
    89  	}
    90  	key := cacheKey{b.Region, b.Role, sessionName}
    91  	if v, ok := cache.Get(key); ok {
    92  		log.Debug(ctx.ctx, "AWS credentials lookup cache hit", "key", key)
    93  		return v.(AwsCredentials), nil
    94  	}
    95  	log.Debug(ctx.ctx, "AWS credentials lookup cache miss", "key", key)
    96  
    97  	s := ctx.session
    98  	if aws.StringValue(s.Config.Region) != b.Region {
    99  		// This mismatch should be very rare.
   100  		var err error
   101  		s, err = session.NewSession(s.Config.WithRegion(b.Region))
   102  		if err != nil {
   103  			log.Error(ctx.ctx, "error creating AWS session", "err", err.Error())
   104  			return empty, err
   105  		}
   106  	}
   107  
   108  	client := sts.New(s)
   109  	assumeRoleInput := &sts.AssumeRoleInput{
   110  		RoleArn: aws.String(b.Role),
   111  		// TODO(razvanm): the role session name is a string of characters consisting
   112  		// of upper- and lower-case alphanumeric characters with no spaces that can
   113  		// include '=,.@-'. Notably, a blessing can include ':' which is not allowed
   114  		// in here.
   115  		//
   116  		// Reference: http://docs.aws.amazon.com/cli/latest/reference/sts/assume-role.html
   117  		RoleSessionName: aws.String(sessionName),
   118  		DurationSeconds: aws.Int64(int64(b.TtlSec)),
   119  	}
   120  
   121  	assumeRoleOutput, err := client.AssumeRole(assumeRoleInput)
   122  	if err != nil {
   123  		log.Error(ctx.ctx, "error in AssumeRole API call", "key", key)
   124  		return empty, err
   125  	}
   126  
   127  	result := AwsCredentials{
   128  		Region:          b.Region,
   129  		AccessKeyId:     aws.StringValue(assumeRoleOutput.Credentials.AccessKeyId),
   130  		SecretAccessKey: aws.StringValue(assumeRoleOutput.Credentials.SecretAccessKey),
   131  		SessionToken:    aws.StringValue(assumeRoleOutput.Credentials.SessionToken),
   132  		Expiration:      assumeRoleOutput.Credentials.Expiration.Format(time.RFC3339Nano),
   133  	}
   134  
   135  	log.Debug(ctx.ctx, "adding AWS credentials to cache", "key", key)
   136  	cache.Set(key, result)
   137  
   138  	return result, nil
   139  }
   140  
   141  func (b *AwsSessionBuilder) newAwsTicket(ctx *TicketContext) (TicketAwsTicket, error) {
   142  	awsCredentials, err := b.genAwsSession(ctx)
   143  
   144  	if err != nil {
   145  		return TicketAwsTicket{}, err
   146  	}
   147  
   148  	return TicketAwsTicket{
   149  		Value: AwsTicket{
   150  			AwsCredentials: awsCredentials,
   151  		},
   152  	}, nil
   153  }
   154  
   155  func (b *AwsSessionBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, error) {
   156  	awsCredentials, err := b.genAwsSession(ctx)
   157  
   158  	if err != nil {
   159  		return TicketS3Ticket{}, err
   160  	}
   161  
   162  	return TicketS3Ticket{
   163  		Value: S3Ticket{
   164  			AwsCredentials: awsCredentials,
   165  		},
   166  	}, nil
   167  }
   168  
   169  func (b *AwsSessionBuilder) genAwsSession(ctx *TicketContext) (AwsCredentials, error) {
   170  	log.Debug(ctx.ctx, "enerating AWS session", "AwsAssumeRoleBuilder", b.AwsCredentials.AccessKeyId)
   171  	empty := AwsCredentials{}
   172  	awsCredentials := b.AwsCredentials
   173  
   174  	sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1)
   175  	// AWS session names must be 64 characters or less
   176  	if runes := []rune(sessionName); len(runes) > 64 {
   177  		// Some risk with simple truncation - two large IAM role's would overlap
   178  		// for example. This is mitigated by the format which includes instance id
   179  		// as the last component. Ability to determine exactly which instance made
   180  		// the call will be difficult, but likelihood of 2 instances sharing a prefix
   181  		// is low.
   182  		sessionName = string(runes[0:64])
   183  	}
   184  	key := cacheKey{awsCredentials.Region, awsCredentials.AccessKeyId, sessionName}
   185  	if v, ok := cache.Get(key); ok {
   186  		log.Debug(ctx.ctx, "AWS session lookup cache hit", "key", key)
   187  		return v.(AwsCredentials), nil
   188  	}
   189  	log.Debug(ctx.ctx, "AWS session lookup cache miss", "key", key)
   190  	s, err := session.NewSession(&aws.Config{
   191  		Region: aws.String(awsCredentials.Region),
   192  		Credentials: credentials.NewStaticCredentials(
   193  			awsCredentials.AccessKeyId,
   194  			awsCredentials.SecretAccessKey,
   195  			awsCredentials.SessionToken),
   196  	})
   197  	if err != nil {
   198  		return empty, err
   199  	}
   200  
   201  	sessionTokenInput := &sts.GetSessionTokenInput{
   202  		DurationSeconds: aws.Int64(int64(b.TtlSec)),
   203  	}
   204  
   205  	client := sts.New(s)
   206  	sessionTokenOutput, err := client.GetSessionToken(sessionTokenInput)
   207  	if err != nil {
   208  		return empty, err
   209  	}
   210  
   211  	result := AwsCredentials{
   212  		Region:          awsCredentials.Region,
   213  		AccessKeyId:     aws.StringValue(sessionTokenOutput.Credentials.AccessKeyId),
   214  		SecretAccessKey: aws.StringValue(sessionTokenOutput.Credentials.SecretAccessKey),
   215  		SessionToken:    aws.StringValue(sessionTokenOutput.Credentials.SessionToken),
   216  		Expiration:      sessionTokenOutput.Credentials.Expiration.Format(time.RFC3339Nano),
   217  	}
   218  
   219  	log.Debug(ctx.ctx, "Adding AWS session to cache", "key", key)
   220  	cache.Set(key, result)
   221  
   222  	return result, nil
   223  }
   224  
   225  func newEcrTicket(ctx *TicketContext, awsCredentials AwsCredentials) EcrTicket {
   226  	empty := EcrTicket{}
   227  	s, err := session.NewSession(&aws.Config{
   228  		Region: aws.String(awsCredentials.Region),
   229  		Credentials: credentials.NewStaticCredentials(
   230  			awsCredentials.AccessKeyId,
   231  			awsCredentials.SecretAccessKey,
   232  			awsCredentials.SessionToken),
   233  	})
   234  	if err != nil {
   235  		log.Error(ctx.ctx, "error creating AWS session", "err", err.Error())
   236  		return empty
   237  	}
   238  	r, err := ecr.New(s).GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{})
   239  	if err != nil {
   240  		log.Error(ctx.ctx, "error fetching ECR authorization token", "err", err.Error())
   241  		return empty
   242  	}
   243  	if len(r.AuthorizationData) == 0 {
   244  		log.Error(ctx.ctx, "no authorization data from ECR")
   245  		return empty
   246  	}
   247  	auth := r.AuthorizationData[0]
   248  	if auth.AuthorizationToken == nil || auth.ProxyEndpoint == nil || auth.ExpiresAt == nil {
   249  		log.Error(ctx.ctx, "bad authorization data from ECR")
   250  		return empty
   251  	}
   252  	return EcrTicket{
   253  		AuthorizationToken: *auth.AuthorizationToken,
   254  		Expiration:         aws.TimeValue(auth.ExpiresAt).Format(time.RFC3339Nano),
   255  		Endpoint:           *auth.ProxyEndpoint,
   256  	}
   257  }
   258  
   259  // Returns a list of Compute Instances that match the filter
   260  func AwsEc2InstanceLookup(ctx *TicketContext, builder *AwsComputeInstancesBuilder) ([]ComputeInstance, error) {
   261  	var instances []ComputeInstance
   262  
   263  	if len(builder.InstanceFilters) == 0 {
   264  		return instances, errors.New("An instance filters is required")
   265  	}
   266  
   267  	// Create the STS session with the provided lookup role
   268  	config := aws.Config{
   269  		Region:      aws.String(builder.Region),
   270  		Credentials: stscreds.NewCredentials(ctx.session, builder.AwsAccountLookupRole),
   271  		Retryer: client.DefaultRetryer{
   272  			NumMaxRetries: 100,
   273  		},
   274  	}
   275  
   276  	s, err := session.NewSession(&config)
   277  	if err != nil {
   278  		log.Error(ctx.ctx, "error creating AWS session", "err", err.Error())
   279  		return instances, err
   280  	}
   281  
   282  	var filters []*ec2.Filter
   283  	filters = append(filters,
   284  		&ec2.Filter{
   285  			Name: aws.String("instance-state-name"),
   286  			Values: []*string{
   287  				aws.String("running"),
   288  			},
   289  		},
   290  	)
   291  
   292  	for _, f := range builder.InstanceFilters {
   293  		filters = append(filters,
   294  			&ec2.Filter{
   295  				Name: aws.String(f.Key),
   296  				Values: []*string{
   297  					aws.String(f.Value),
   298  				},
   299  			},
   300  		)
   301  	}
   302  
   303  	output, err := ec2.New(s, &config).DescribeInstances(&ec2.DescribeInstancesInput{
   304  		Filters: filters,
   305  	})
   306  	if err != nil {
   307  		log.Error(ctx.ctx, "error describing EC2 instance", "err", err.Error())
   308  		return instances, err
   309  	}
   310  
   311  	for _, reservations := range output.Reservations {
   312  		for _, instance := range reservations.Instances {
   313  			var params []Parameter
   314  			publicIp, err := ec2util.GetPublicIPAddress(instance)
   315  			if err != nil {
   316  				log.Error(ctx.ctx, "error fetching EC2 public IP address. Continuing anyways.", "err", err.Error())
   317  				continue // parse error skip
   318  			}
   319  
   320  			privateIp, err := ec2util.GetPrivateIPAddress(instance)
   321  			if err != nil {
   322  				log.Error(ctx.ctx, "error fetching EC2 private IP address. Continuing anyways.", "err", err.Error())
   323  				continue // parse error skip
   324  			}
   325  
   326  			ec2Tags, err := ec2util.GetTags(instance)
   327  			if err != nil {
   328  				log.Error(ctx.ctx, "error fetching EC2 tags. Continuing anyways.", "err", err.Error())
   329  				continue // parse error skip
   330  			}
   331  			for _, tag := range ec2Tags {
   332  				params = append(params,
   333  					Parameter{
   334  						Key:   *tag.Key,
   335  						Value: *tag.Value,
   336  					})
   337  			}
   338  
   339  			instanceId, err := ec2util.GetInstanceId(instance)
   340  			if err != nil {
   341  				log.Error(ctx.ctx, "error fetching EC2 instance ID. Continuing anyways.", "err", err.Error())
   342  				continue // parse error skip
   343  			}
   344  
   345  			instances = append(instances,
   346  				ComputeInstance{
   347  					PublicIp:   publicIp,
   348  					PrivateIp:  privateIp,
   349  					InstanceId: instanceId,
   350  					Tags:       params,
   351  				})
   352  		}
   353  	}
   354  
   355  	log.Debug(ctx.ctx, "AWS EC2 instances", "instances", instances)
   356  	return instances, nil
   357  }