github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/security/ticket/aws.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package ticket 6 7 import ( 8 "errors" 9 "strings" 10 "time" 11 12 "github.com/aws/aws-sdk-go/aws" 13 "github.com/aws/aws-sdk-go/aws/client" 14 "github.com/aws/aws-sdk-go/aws/credentials" 15 "github.com/aws/aws-sdk-go/aws/credentials/stscreds" 16 "github.com/aws/aws-sdk-go/aws/session" 17 "github.com/aws/aws-sdk-go/service/ec2" 18 "github.com/aws/aws-sdk-go/service/ecr" 19 "github.com/aws/aws-sdk-go/service/sts" 20 "github.com/Schaudge/grailbase/cloud/ec2util" 21 "github.com/Schaudge/grailbase/common/log" 22 "github.com/Schaudge/grailbase/ttlcache" 23 ) 24 25 type cacheKey struct { 26 region string 27 role string 28 session string 29 } 30 31 // cacheTTL is how long the entries in cache will be considered valid. 32 const cacheTTL = time.Minute 33 34 var cache = ttlcache.New(cacheTTL) 35 36 func (b *AwsAssumeRoleBuilder) newAwsTicket(ctx *TicketContext) (TicketAwsTicket, error) { 37 awsCredentials, err := b.genAwsCredentials(ctx) 38 39 if err != nil { 40 return TicketAwsTicket{}, err 41 } 42 43 return TicketAwsTicket{ 44 Value: AwsTicket{ 45 AwsCredentials: awsCredentials, 46 }, 47 }, nil 48 } 49 50 func (b *AwsAssumeRoleBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, error) { 51 awsCredentials, err := b.genAwsCredentials(ctx) 52 53 if err != nil { 54 return TicketS3Ticket{}, err 55 } 56 57 return TicketS3Ticket{ 58 Value: S3Ticket{ 59 AwsCredentials: awsCredentials, 60 }, 61 }, nil 62 } 63 64 func (b *AwsAssumeRoleBuilder) newEcrTicket(ctx *TicketContext) (TicketEcrTicket, error) { 65 log.Debug(ctx.ctx, "generating ECR ticket", "AwsAssumeRoleBuilder", b) 66 awsCredentials, err := b.genAwsCredentials(ctx) 67 68 if err != nil { 69 return TicketEcrTicket{}, err 70 } 71 return TicketEcrTicket{ 72 Value: newEcrTicket(ctx, awsCredentials), 73 }, nil 74 } 75 76 func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredentials, error) { 77 log.Debug(ctx.ctx, "generating AWS credentials", "AwsAssumeRoleBuilder", b) 78 empty := AwsCredentials{} 79 80 sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1) 81 // AWS session names must be 64 characters or less 82 if runes := []rune(sessionName); len(runes) > 64 { 83 // Some risk with simple truncation - two large IAM role's would overlap 84 // for example. This is mitigated by the format which includes instance id 85 // as the last component. Ability to determine exactly which instance made 86 // the call will be difficult, but likelihood of 2 instances sharing a prefix 87 // is low. 88 sessionName = string(runes[0:64]) 89 } 90 key := cacheKey{b.Region, b.Role, sessionName} 91 if v, ok := cache.Get(key); ok { 92 log.Debug(ctx.ctx, "AWS credentials lookup cache hit", "key", key) 93 return v.(AwsCredentials), nil 94 } 95 log.Debug(ctx.ctx, "AWS credentials lookup cache miss", "key", key) 96 97 s := ctx.session 98 if aws.StringValue(s.Config.Region) != b.Region { 99 // This mismatch should be very rare. 100 var err error 101 s, err = session.NewSession(s.Config.WithRegion(b.Region)) 102 if err != nil { 103 log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) 104 return empty, err 105 } 106 } 107 108 client := sts.New(s) 109 assumeRoleInput := &sts.AssumeRoleInput{ 110 RoleArn: aws.String(b.Role), 111 // TODO(razvanm): the role session name is a string of characters consisting 112 // of upper- and lower-case alphanumeric characters with no spaces that can 113 // include '=,.@-'. Notably, a blessing can include ':' which is not allowed 114 // in here. 115 // 116 // Reference: http://docs.aws.amazon.com/cli/latest/reference/sts/assume-role.html 117 RoleSessionName: aws.String(sessionName), 118 DurationSeconds: aws.Int64(int64(b.TtlSec)), 119 } 120 121 assumeRoleOutput, err := client.AssumeRole(assumeRoleInput) 122 if err != nil { 123 log.Error(ctx.ctx, "error in AssumeRole API call", "key", key) 124 return empty, err 125 } 126 127 result := AwsCredentials{ 128 Region: b.Region, 129 AccessKeyId: aws.StringValue(assumeRoleOutput.Credentials.AccessKeyId), 130 SecretAccessKey: aws.StringValue(assumeRoleOutput.Credentials.SecretAccessKey), 131 SessionToken: aws.StringValue(assumeRoleOutput.Credentials.SessionToken), 132 Expiration: assumeRoleOutput.Credentials.Expiration.Format(time.RFC3339Nano), 133 } 134 135 log.Debug(ctx.ctx, "adding AWS credentials to cache", "key", key) 136 cache.Set(key, result) 137 138 return result, nil 139 } 140 141 func (b *AwsSessionBuilder) newAwsTicket(ctx *TicketContext) (TicketAwsTicket, error) { 142 awsCredentials, err := b.genAwsSession(ctx) 143 144 if err != nil { 145 return TicketAwsTicket{}, err 146 } 147 148 return TicketAwsTicket{ 149 Value: AwsTicket{ 150 AwsCredentials: awsCredentials, 151 }, 152 }, nil 153 } 154 155 func (b *AwsSessionBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, error) { 156 awsCredentials, err := b.genAwsSession(ctx) 157 158 if err != nil { 159 return TicketS3Ticket{}, err 160 } 161 162 return TicketS3Ticket{ 163 Value: S3Ticket{ 164 AwsCredentials: awsCredentials, 165 }, 166 }, nil 167 } 168 169 func (b *AwsSessionBuilder) genAwsSession(ctx *TicketContext) (AwsCredentials, error) { 170 log.Debug(ctx.ctx, "enerating AWS session", "AwsAssumeRoleBuilder", b.AwsCredentials.AccessKeyId) 171 empty := AwsCredentials{} 172 awsCredentials := b.AwsCredentials 173 174 sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1) 175 // AWS session names must be 64 characters or less 176 if runes := []rune(sessionName); len(runes) > 64 { 177 // Some risk with simple truncation - two large IAM role's would overlap 178 // for example. This is mitigated by the format which includes instance id 179 // as the last component. Ability to determine exactly which instance made 180 // the call will be difficult, but likelihood of 2 instances sharing a prefix 181 // is low. 182 sessionName = string(runes[0:64]) 183 } 184 key := cacheKey{awsCredentials.Region, awsCredentials.AccessKeyId, sessionName} 185 if v, ok := cache.Get(key); ok { 186 log.Debug(ctx.ctx, "AWS session lookup cache hit", "key", key) 187 return v.(AwsCredentials), nil 188 } 189 log.Debug(ctx.ctx, "AWS session lookup cache miss", "key", key) 190 s, err := session.NewSession(&aws.Config{ 191 Region: aws.String(awsCredentials.Region), 192 Credentials: credentials.NewStaticCredentials( 193 awsCredentials.AccessKeyId, 194 awsCredentials.SecretAccessKey, 195 awsCredentials.SessionToken), 196 }) 197 if err != nil { 198 return empty, err 199 } 200 201 sessionTokenInput := &sts.GetSessionTokenInput{ 202 DurationSeconds: aws.Int64(int64(b.TtlSec)), 203 } 204 205 client := sts.New(s) 206 sessionTokenOutput, err := client.GetSessionToken(sessionTokenInput) 207 if err != nil { 208 return empty, err 209 } 210 211 result := AwsCredentials{ 212 Region: awsCredentials.Region, 213 AccessKeyId: aws.StringValue(sessionTokenOutput.Credentials.AccessKeyId), 214 SecretAccessKey: aws.StringValue(sessionTokenOutput.Credentials.SecretAccessKey), 215 SessionToken: aws.StringValue(sessionTokenOutput.Credentials.SessionToken), 216 Expiration: sessionTokenOutput.Credentials.Expiration.Format(time.RFC3339Nano), 217 } 218 219 log.Debug(ctx.ctx, "Adding AWS session to cache", "key", key) 220 cache.Set(key, result) 221 222 return result, nil 223 } 224 225 func newEcrTicket(ctx *TicketContext, awsCredentials AwsCredentials) EcrTicket { 226 empty := EcrTicket{} 227 s, err := session.NewSession(&aws.Config{ 228 Region: aws.String(awsCredentials.Region), 229 Credentials: credentials.NewStaticCredentials( 230 awsCredentials.AccessKeyId, 231 awsCredentials.SecretAccessKey, 232 awsCredentials.SessionToken), 233 }) 234 if err != nil { 235 log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) 236 return empty 237 } 238 r, err := ecr.New(s).GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{}) 239 if err != nil { 240 log.Error(ctx.ctx, "error fetching ECR authorization token", "err", err.Error()) 241 return empty 242 } 243 if len(r.AuthorizationData) == 0 { 244 log.Error(ctx.ctx, "no authorization data from ECR") 245 return empty 246 } 247 auth := r.AuthorizationData[0] 248 if auth.AuthorizationToken == nil || auth.ProxyEndpoint == nil || auth.ExpiresAt == nil { 249 log.Error(ctx.ctx, "bad authorization data from ECR") 250 return empty 251 } 252 return EcrTicket{ 253 AuthorizationToken: *auth.AuthorizationToken, 254 Expiration: aws.TimeValue(auth.ExpiresAt).Format(time.RFC3339Nano), 255 Endpoint: *auth.ProxyEndpoint, 256 } 257 } 258 259 // Returns a list of Compute Instances that match the filter 260 func AwsEc2InstanceLookup(ctx *TicketContext, builder *AwsComputeInstancesBuilder) ([]ComputeInstance, error) { 261 var instances []ComputeInstance 262 263 if len(builder.InstanceFilters) == 0 { 264 return instances, errors.New("An instance filters is required") 265 } 266 267 // Create the STS session with the provided lookup role 268 config := aws.Config{ 269 Region: aws.String(builder.Region), 270 Credentials: stscreds.NewCredentials(ctx.session, builder.AwsAccountLookupRole), 271 Retryer: client.DefaultRetryer{ 272 NumMaxRetries: 100, 273 }, 274 } 275 276 s, err := session.NewSession(&config) 277 if err != nil { 278 log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) 279 return instances, err 280 } 281 282 var filters []*ec2.Filter 283 filters = append(filters, 284 &ec2.Filter{ 285 Name: aws.String("instance-state-name"), 286 Values: []*string{ 287 aws.String("running"), 288 }, 289 }, 290 ) 291 292 for _, f := range builder.InstanceFilters { 293 filters = append(filters, 294 &ec2.Filter{ 295 Name: aws.String(f.Key), 296 Values: []*string{ 297 aws.String(f.Value), 298 }, 299 }, 300 ) 301 } 302 303 output, err := ec2.New(s, &config).DescribeInstances(&ec2.DescribeInstancesInput{ 304 Filters: filters, 305 }) 306 if err != nil { 307 log.Error(ctx.ctx, "error describing EC2 instance", "err", err.Error()) 308 return instances, err 309 } 310 311 for _, reservations := range output.Reservations { 312 for _, instance := range reservations.Instances { 313 var params []Parameter 314 publicIp, err := ec2util.GetPublicIPAddress(instance) 315 if err != nil { 316 log.Error(ctx.ctx, "error fetching EC2 public IP address. Continuing anyways.", "err", err.Error()) 317 continue // parse error skip 318 } 319 320 privateIp, err := ec2util.GetPrivateIPAddress(instance) 321 if err != nil { 322 log.Error(ctx.ctx, "error fetching EC2 private IP address. Continuing anyways.", "err", err.Error()) 323 continue // parse error skip 324 } 325 326 ec2Tags, err := ec2util.GetTags(instance) 327 if err != nil { 328 log.Error(ctx.ctx, "error fetching EC2 tags. Continuing anyways.", "err", err.Error()) 329 continue // parse error skip 330 } 331 for _, tag := range ec2Tags { 332 params = append(params, 333 Parameter{ 334 Key: *tag.Key, 335 Value: *tag.Value, 336 }) 337 } 338 339 instanceId, err := ec2util.GetInstanceId(instance) 340 if err != nil { 341 log.Error(ctx.ctx, "error fetching EC2 instance ID. Continuing anyways.", "err", err.Error()) 342 continue // parse error skip 343 } 344 345 instances = append(instances, 346 ComputeInstance{ 347 PublicIp: publicIp, 348 PrivateIp: privateIp, 349 InstanceId: instanceId, 350 Tags: params, 351 }) 352 } 353 } 354 355 log.Debug(ctx.ctx, "AWS EC2 instances", "instances", instances) 356 return instances, nil 357 }