github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/ssocreds/provider.go (about)

     1  package ssocreds
     2  
     3  import (
     4  	"crypto/sha1"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"path/filepath"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    15  	"github.com/aavshr/aws-sdk-go/aws/client"
    16  	"github.com/aavshr/aws-sdk-go/aws/credentials"
    17  	"github.com/aavshr/aws-sdk-go/service/sso"
    18  	"github.com/aavshr/aws-sdk-go/service/sso/ssoiface"
    19  )
    20  
    21  // ErrCodeSSOProviderInvalidToken is the code type that is returned if loaded token has expired or is otherwise invalid.
    22  // To refresh the SSO session run aws sso login with the corresponding profile.
    23  const ErrCodeSSOProviderInvalidToken = "SSOProviderInvalidToken"
    24  
    25  const invalidTokenMessage = "the SSO session has expired or is invalid"
    26  
    27  func init() {
    28  	nowTime = time.Now
    29  	defaultCacheLocation = defaultCacheLocationImpl
    30  }
    31  
    32  var nowTime func() time.Time
    33  
    34  // ProviderName is the name of the provider used to specify the source of credentials.
    35  const ProviderName = "SSOProvider"
    36  
    37  var defaultCacheLocation func() string
    38  
    39  func defaultCacheLocationImpl() string {
    40  	return filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")
    41  }
    42  
    43  // Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
    44  type Provider struct {
    45  	credentials.Expiry
    46  
    47  	// The Client which is configured for the AWS Region where the AWS SSO user portal is located.
    48  	Client ssoiface.SSOAPI
    49  
    50  	// The AWS account that is assigned to the user.
    51  	AccountID string
    52  
    53  	// The role name that is assigned to the user.
    54  	RoleName string
    55  
    56  	// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
    57  	StartURL string
    58  }
    59  
    60  // NewCredentials returns a new AWS Single Sign-On (AWS SSO) credential provider. The ConfigProvider is expected to be configured
    61  // for the AWS Region where the AWS SSO user portal is located.
    62  func NewCredentials(configProvider client.ConfigProvider, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
    63  	return NewCredentialsWithClient(sso.New(configProvider), accountID, roleName, startURL, optFns...)
    64  }
    65  
    66  // NewCredentialsWithClient returns a new AWS Single Sign-On (AWS SSO) credential provider. The provided client is expected to be configured
    67  // for the AWS Region where the AWS SSO user portal is located.
    68  func NewCredentialsWithClient(client ssoiface.SSOAPI, accountID, roleName, startURL string, optFns ...func(provider *Provider)) *credentials.Credentials {
    69  	p := &Provider{
    70  		Client:    client,
    71  		AccountID: accountID,
    72  		RoleName:  roleName,
    73  		StartURL:  startURL,
    74  	}
    75  
    76  	for _, fn := range optFns {
    77  		fn(p)
    78  	}
    79  
    80  	return credentials.NewCredentials(p)
    81  }
    82  
    83  // Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
    84  // by exchanging the accessToken present in ~/.aws/sso/cache.
    85  func (p *Provider) Retrieve() (credentials.Value, error) {
    86  	return p.RetrieveWithContext(aws.BackgroundContext())
    87  }
    88  
    89  // RetrieveWithContext retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
    90  // by exchanging the accessToken present in ~/.aws/sso/cache.
    91  func (p *Provider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
    92  	tokenFile, err := loadTokenFile(p.StartURL)
    93  	if err != nil {
    94  		return credentials.Value{}, err
    95  	}
    96  
    97  	output, err := p.Client.GetRoleCredentialsWithContext(ctx, &sso.GetRoleCredentialsInput{
    98  		AccessToken: &tokenFile.AccessToken,
    99  		AccountId:   &p.AccountID,
   100  		RoleName:    &p.RoleName,
   101  	})
   102  	if err != nil {
   103  		return credentials.Value{}, err
   104  	}
   105  
   106  	expireTime := time.Unix(0, aws.Int64Value(output.RoleCredentials.Expiration)*int64(time.Millisecond)).UTC()
   107  	p.SetExpiration(expireTime, 0)
   108  
   109  	return credentials.Value{
   110  		AccessKeyID:     aws.StringValue(output.RoleCredentials.AccessKeyId),
   111  		SecretAccessKey: aws.StringValue(output.RoleCredentials.SecretAccessKey),
   112  		SessionToken:    aws.StringValue(output.RoleCredentials.SessionToken),
   113  		ProviderName:    ProviderName,
   114  	}, nil
   115  }
   116  
   117  func getCacheFileName(url string) (string, error) {
   118  	hash := sha1.New()
   119  	_, err := hash.Write([]byte(url))
   120  	if err != nil {
   121  		return "", err
   122  	}
   123  	return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
   124  }
   125  
   126  type rfc3339 time.Time
   127  
   128  func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
   129  	var value string
   130  
   131  	if err := json.Unmarshal(bytes, &value); err != nil {
   132  		return err
   133  	}
   134  
   135  	parse, err := time.Parse(time.RFC3339, value)
   136  	if err != nil {
   137  		return fmt.Errorf("expected RFC3339 timestamp: %v", err)
   138  	}
   139  
   140  	*r = rfc3339(parse)
   141  
   142  	return nil
   143  }
   144  
   145  type token struct {
   146  	AccessToken string  `json:"accessToken"`
   147  	ExpiresAt   rfc3339 `json:"expiresAt"`
   148  	Region      string  `json:"region,omitempty"`
   149  	StartURL    string  `json:"startUrl,omitempty"`
   150  }
   151  
   152  func (t token) Expired() bool {
   153  	return nowTime().Round(0).After(time.Time(t.ExpiresAt))
   154  }
   155  
   156  func loadTokenFile(startURL string) (t token, err error) {
   157  	key, err := getCacheFileName(startURL)
   158  	if err != nil {
   159  		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
   160  	}
   161  
   162  	fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation(), key))
   163  	if err != nil {
   164  		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
   165  	}
   166  
   167  	if err := json.Unmarshal(fileBytes, &t); err != nil {
   168  		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, err)
   169  	}
   170  
   171  	if len(t.AccessToken) == 0 {
   172  		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
   173  	}
   174  
   175  	if t.Expired() {
   176  		return token{}, awserr.New(ErrCodeSSOProviderInvalidToken, invalidTokenMessage, nil)
   177  	}
   178  
   179  	return t, nil
   180  }