github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/cloud/awssession/provider.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package awssession 6 7 import ( 8 "fmt" 9 "time" 10 11 "github.com/aws/aws-sdk-go/aws/credentials" 12 "github.com/Schaudge/grailbase/errors" 13 "github.com/Schaudge/grailbase/security/ticket" 14 "v.io/v23/context" 15 ) 16 17 // Provider implements the aws/credentials.Provider interface using a GRAIL 18 // ticket. 19 type Provider struct { 20 // Ctx contains the Vanadium context used to make the call to ticket-server 21 // in response to calls to Retrieve(). Canceling the context will cause the 22 // calls to Retrieve() to fail and IsExpire() to always return true. 23 Ctx *context.T 24 25 // Timeout indicates what timeout to set for the Vanadium calls. 26 Timeout time.Duration 27 28 // Ticket contains the last GRAIL ticket retrieved by a call to Retrieve(). 29 Ticket ticket.Ticket 30 31 // Expiration indicates when the AWS credentials will expire. 32 Expiration time.Time 33 34 // TicketPath indicates what Vanadium object name to use to retrieve the 35 // ticket. 36 TicketPath string 37 38 // ExpiryWindow allows triggering a refresh before the AWS credentials 39 // actually expire. 40 ExpiryWindow time.Duration 41 42 // Rationale indicates the reason for accessing a ticket 43 Rationale string 44 } 45 46 var _ credentials.Provider = (*Provider)(nil) 47 48 // Retrieve implements the github.com/aws/aws-sdk-go/aws/credentials.Provider 49 // interface. 50 func (p *Provider) Retrieve() (credentials.Value, error) { 51 ctx := p.Ctx 52 if p.Timeout != 0 { 53 var cancel context.CancelFunc 54 ctx, cancel = context.WithTimeout(ctx, p.Timeout) 55 defer cancel() 56 } 57 var err error 58 if p.Rationale != "" { 59 p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).GetWithArgs(ctx, map[string]string{ 60 ticket.ControlRationale.String(): p.Rationale, 61 }) 62 } else { 63 p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).Get(ctx) 64 } 65 if err != nil { 66 return credentials.Value{}, err 67 } 68 return p.retrieve() 69 } 70 71 // retrieve implements some logic that would be harder to test if it's part of 72 // the Retrieve() function. 73 func (p *Provider) retrieve() (credentials.Value, error) { 74 awsTicket, ok := p.Ticket.(ticket.TicketAwsTicket) 75 if !ok { 76 return credentials.Value{}, fmt.Errorf("bad ticket type %T for %q, want %T", p.Ticket, p.TicketPath, awsTicket) 77 } 78 79 if awsTicket.Value.AwsCredentials.Expiration != "" { 80 var err error 81 p.Expiration, err = time.Parse(time.RFC3339, awsTicket.Value.AwsCredentials.Expiration) 82 if err != nil { 83 p.Ticket = nil 84 return credentials.Value{}, errors.E(err, fmt.Sprintf("%q: error parsing %q", p.TicketPath, awsTicket.Value.AwsCredentials.Expiration)) 85 } 86 } 87 return credentials.Value{ 88 AccessKeyID: awsTicket.Value.AwsCredentials.AccessKeyId, 89 SecretAccessKey: awsTicket.Value.AwsCredentials.SecretAccessKey, 90 SessionToken: awsTicket.Value.AwsCredentials.SessionToken, 91 ProviderName: "ticket", 92 }, nil 93 } 94 95 // IsExpired implements the github.com/aws/aws-sdk-go/aws/credentials.Provider 96 // interface. 97 func (p *Provider) IsExpired() bool { 98 var r bool 99 if p.Ticket == nil { 100 r = true 101 } else if !p.Expiration.IsZero() { 102 r = time.Now().Add(p.ExpiryWindow).After(p.Expiration) 103 } 104 return r 105 }