github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/cloud/awssession/provider.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package awssession
     6  
     7  import (
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/aws/aws-sdk-go/aws/credentials"
    12  	"github.com/Schaudge/grailbase/errors"
    13  	"github.com/Schaudge/grailbase/security/ticket"
    14  	"v.io/v23/context"
    15  )
    16  
    17  // Provider implements the aws/credentials.Provider interface using a GRAIL
    18  // ticket.
    19  type Provider struct {
    20  	// Ctx contains the Vanadium context used to make the call to ticket-server
    21  	// in response to calls to Retrieve(). Canceling the context will cause the
    22  	// calls to Retrieve() to fail and IsExpire() to always return true.
    23  	Ctx *context.T
    24  
    25  	// Timeout indicates what timeout to set for the Vanadium calls.
    26  	Timeout time.Duration
    27  
    28  	// Ticket contains the last GRAIL ticket retrieved by a call to Retrieve().
    29  	Ticket ticket.Ticket
    30  
    31  	// Expiration indicates when the AWS credentials will expire.
    32  	Expiration time.Time
    33  
    34  	// TicketPath indicates what Vanadium object name to use to retrieve the
    35  	// ticket.
    36  	TicketPath string
    37  
    38  	// ExpiryWindow allows triggering a refresh before the AWS credentials
    39  	// actually expire.
    40  	ExpiryWindow time.Duration
    41  
    42  	// Rationale indicates the reason for accessing a ticket
    43  	Rationale string
    44  }
    45  
    46  var _ credentials.Provider = (*Provider)(nil)
    47  
    48  // Retrieve implements the github.com/aws/aws-sdk-go/aws/credentials.Provider
    49  // interface.
    50  func (p *Provider) Retrieve() (credentials.Value, error) {
    51  	ctx := p.Ctx
    52  	if p.Timeout != 0 {
    53  		var cancel context.CancelFunc
    54  		ctx, cancel = context.WithTimeout(ctx, p.Timeout)
    55  		defer cancel()
    56  	}
    57  	var err error
    58  	if p.Rationale != "" {
    59  		p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).GetWithArgs(ctx, map[string]string{
    60  			ticket.ControlRationale.String(): p.Rationale,
    61  		})
    62  	} else {
    63  		p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).Get(ctx)
    64  	}
    65  	if err != nil {
    66  		return credentials.Value{}, err
    67  	}
    68  	return p.retrieve()
    69  }
    70  
    71  // retrieve implements some logic that would be harder to test if it's part of
    72  // the Retrieve() function.
    73  func (p *Provider) retrieve() (credentials.Value, error) {
    74  	awsTicket, ok := p.Ticket.(ticket.TicketAwsTicket)
    75  	if !ok {
    76  		return credentials.Value{}, fmt.Errorf("bad ticket type %T for %q, want %T", p.Ticket, p.TicketPath, awsTicket)
    77  	}
    78  
    79  	if awsTicket.Value.AwsCredentials.Expiration != "" {
    80  		var err error
    81  		p.Expiration, err = time.Parse(time.RFC3339, awsTicket.Value.AwsCredentials.Expiration)
    82  		if err != nil {
    83  			p.Ticket = nil
    84  			return credentials.Value{}, errors.E(err, fmt.Sprintf("%q: error parsing %q", p.TicketPath, awsTicket.Value.AwsCredentials.Expiration))
    85  		}
    86  	}
    87  	return credentials.Value{
    88  		AccessKeyID:     awsTicket.Value.AwsCredentials.AccessKeyId,
    89  		SecretAccessKey: awsTicket.Value.AwsCredentials.SecretAccessKey,
    90  		SessionToken:    awsTicket.Value.AwsCredentials.SessionToken,
    91  		ProviderName:    "ticket",
    92  	}, nil
    93  }
    94  
    95  // IsExpired implements the github.com/aws/aws-sdk-go/aws/credentials.Provider
    96  // interface.
    97  func (p *Provider) IsExpired() bool {
    98  	var r bool
    99  	if p.Ticket == nil {
   100  		r = true
   101  	} else if !p.Expiration.IsZero() {
   102  		r = time.Now().Add(p.ExpiryWindow).After(p.Expiration)
   103  	}
   104  	return r
   105  }