github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/stscreds/web_identity_provider.go (about)

     1  package stscreds
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/aavshr/aws-sdk-go/aws"
    10  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    11  	"github.com/aavshr/aws-sdk-go/aws/client"
    12  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    13  	"github.com/aavshr/aws-sdk-go/service/sts"
    14  	"github.com/aavshr/aws-sdk-go/service/sts/stsiface"
    15  )
    16  
    17  const (
    18  	// ErrCodeWebIdentity will be used as an error code when constructing
    19  	// a new error to be returned during session creation or retrieval.
    20  	ErrCodeWebIdentity = "WebIdentityErr"
    21  
    22  	// WebIdentityProviderName is the web identity provider name
    23  	WebIdentityProviderName = "WebIdentityCredentials"
    24  )
    25  
    26  // now is used to return a time.Time object representing
    27  // the current time. This can be used to easily test and
    28  // compare test values.
    29  var now = time.Now
    30  
    31  // TokenFetcher shuold return WebIdentity token bytes or an error
    32  type TokenFetcher interface {
    33  	FetchToken(credentials.Context) ([]byte, error)
    34  }
    35  
    36  // FetchTokenPath is a path to a WebIdentity token file
    37  type FetchTokenPath string
    38  
    39  // FetchToken returns a token by reading from the filesystem
    40  func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
    41  	data, err := ioutil.ReadFile(string(f))
    42  	if err != nil {
    43  		errMsg := fmt.Sprintf("unable to read file at %s", f)
    44  		return nil, awserr.New(ErrCodeWebIdentity, errMsg, err)
    45  	}
    46  	return data, nil
    47  }
    48  
    49  // WebIdentityRoleProvider is used to retrieve credentials using
    50  // an OIDC token.
    51  type WebIdentityRoleProvider struct {
    52  	credentials.Expiry
    53  	PolicyArns []*sts.PolicyDescriptorType
    54  
    55  	// Duration the STS credentials will be valid for. Truncated to seconds.
    56  	// If unset, the assumed role will use AssumeRoleWithWebIdentity's default
    57  	// expiry duration. See
    58  	// https://docs.aws.amazon.com/sdk-for-go/api/service/sts/#STS.AssumeRoleWithWebIdentity
    59  	// for more information.
    60  	Duration time.Duration
    61  
    62  	// The amount of time the credentials will be refreshed before they expire.
    63  	// This is useful refresh credentials before they expire to reduce risk of
    64  	// using credentials as they expire. If unset, will default to no expiry
    65  	// window.
    66  	ExpiryWindow time.Duration
    67  
    68  	client stsiface.STSAPI
    69  
    70  	tokenFetcher    TokenFetcher
    71  	roleARN         string
    72  	roleSessionName string
    73  }
    74  
    75  // NewWebIdentityCredentials will return a new set of credentials with a given
    76  // configuration, role arn, and token file path.
    77  func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
    78  	svc := sts.New(c)
    79  	p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
    80  	return credentials.NewCredentials(p)
    81  }
    82  
    83  // NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
    84  // provided stsiface.STSAPI
    85  func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
    86  	return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path))
    87  }
    88  
    89  // NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
    90  // provided stsiface.STSAPI and a TokenFetcher
    91  func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
    92  	return &WebIdentityRoleProvider{
    93  		client:          svc,
    94  		tokenFetcher:    tokenFetcher,
    95  		roleARN:         roleARN,
    96  		roleSessionName: roleSessionName,
    97  	}
    98  }
    99  
   100  // Retrieve will attempt to assume a role from a token which is located at
   101  // 'WebIdentityTokenFilePath' specified destination and if that is empty an
   102  // error will be returned.
   103  func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
   104  	return p.RetrieveWithContext(aws.BackgroundContext())
   105  }
   106  
   107  // RetrieveWithContext will attempt to assume a role from a token which is located at
   108  // 'WebIdentityTokenFilePath' specified destination and if that is empty an
   109  // error will be returned.
   110  func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
   111  	b, err := p.tokenFetcher.FetchToken(ctx)
   112  	if err != nil {
   113  		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed fetching WebIdentity token: ", err)
   114  	}
   115  
   116  	sessionName := p.roleSessionName
   117  	if len(sessionName) == 0 {
   118  		// session name is used to uniquely identify a session. This simply
   119  		// uses unix time in nanoseconds to uniquely identify sessions.
   120  		sessionName = strconv.FormatInt(now().UnixNano(), 10)
   121  	}
   122  
   123  	var duration *int64
   124  	if p.Duration != 0 {
   125  		duration = aws.Int64(int64(p.Duration / time.Second))
   126  	}
   127  
   128  	req, resp := p.client.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{
   129  		PolicyArns:       p.PolicyArns,
   130  		RoleArn:          &p.roleARN,
   131  		RoleSessionName:  &sessionName,
   132  		WebIdentityToken: aws.String(string(b)),
   133  		DurationSeconds:  duration,
   134  	})
   135  
   136  	req.SetContext(ctx)
   137  
   138  	// InvalidIdentityToken error is a temporary error that can occur
   139  	// when assuming an Role with a JWT web identity token.
   140  	req.RetryErrorCodes = append(req.RetryErrorCodes, sts.ErrCodeInvalidIdentityTokenException)
   141  	if err := req.Send(); err != nil {
   142  		return credentials.Value{}, awserr.New(ErrCodeWebIdentity, "failed to retrieve credentials", err)
   143  	}
   144  
   145  	p.SetExpiration(aws.TimeValue(resp.Credentials.Expiration), p.ExpiryWindow)
   146  
   147  	value := credentials.Value{
   148  		AccessKeyID:     aws.StringValue(resp.Credentials.AccessKeyId),
   149  		SecretAccessKey: aws.StringValue(resp.Credentials.SecretAccessKey),
   150  		SessionToken:    aws.StringValue(resp.Credentials.SessionToken),
   151  		ProviderName:    WebIdentityProviderName,
   152  	}
   153  	return value, nil
   154  }