github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/credentials.go (about)

     1  // Package credentials provides credential retrieval and management
     2  //
     3  // The Credentials is the primary method of getting access to and managing
     4  // credentials Values. Using dependency injection retrieval of the credential
     5  // values is handled by a object which satisfies the Provider interface.
     6  //
     7  // By default the Credentials.Get() will cache the successful result of a
     8  // Provider's Retrieve() until Provider.IsExpired() returns true. At which
     9  // point Credentials will call Provider's Retrieve() to get new credential Value.
    10  //
    11  // The Provider is responsible for determining when credentials Value have expired.
    12  // It is also important to note that Credentials will always call Retrieve the
    13  // first time Credentials.Get() is called.
    14  //
    15  // Example of using the environment variable credentials.
    16  //
    17  //     creds := credentials.NewEnvCredentials()
    18  //
    19  //     // Retrieve the credentials value
    20  //     credValue, err := creds.Get()
    21  //     if err != nil {
    22  //         // handle error
    23  //     }
    24  //
    25  // Example of forcing credentials to expire and be refreshed on the next Get().
    26  // This may be helpful to proactively expire credentials and refresh them sooner
    27  // than they would naturally expire on their own.
    28  //
    29  //     creds := credentials.NewCredentials(&ec2rolecreds.EC2RoleProvider{})
    30  //     creds.Expire()
    31  //     credsValue, err := creds.Get()
    32  //     // New credentials will be retrieved instead of from cache.
    33  //
    34  //
    35  // Custom Provider
    36  //
    37  // Each Provider built into this package also provides a helper method to generate
    38  // a Credentials pointer setup with the provider. To use a custom Provider just
    39  // create a type which satisfies the Provider interface and pass it to the
    40  // NewCredentials method.
    41  //
    42  //     type MyProvider struct{}
    43  //     func (m *MyProvider) Retrieve() (Value, error) {...}
    44  //     func (m *MyProvider) IsExpired() bool {...}
    45  //
    46  //     creds := credentials.NewCredentials(&MyProvider{})
    47  //     credValue, err := creds.Get()
    48  //
    49  package credentials
    50  
    51  import (
    52  	"fmt"
    53  	"sync"
    54  	"time"
    55  
    56  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    57  	"github.com/aavshr/aws-sdk-go/internal/sync/singleflight"
    58  )
    59  
    60  // AnonymousCredentials is an empty Credential object that can be used as
    61  // dummy placeholder credentials for requests that do not need signed.
    62  //
    63  // This Credentials can be used to configure a service to not sign requests
    64  // when making service API calls. For example, when accessing public
    65  // s3 buckets.
    66  //
    67  //     svc := s3.New(session.Must(session.NewSession(&aws.Config{
    68  //       Credentials: credentials.AnonymousCredentials,
    69  //     })))
    70  //     // Access public S3 buckets.
    71  var AnonymousCredentials = NewStaticCredentials("", "", "")
    72  
    73  // A Value is the AWS credentials value for individual credential fields.
    74  type Value struct {
    75  	// AWS Access key ID
    76  	AccessKeyID string
    77  
    78  	// AWS Secret Access Key
    79  	SecretAccessKey string
    80  
    81  	// AWS Session Token
    82  	SessionToken string
    83  
    84  	// Provider used to get credentials
    85  	ProviderName string
    86  }
    87  
    88  // HasKeys returns if the credentials Value has both AccessKeyID and
    89  // SecretAccessKey value set.
    90  func (v Value) HasKeys() bool {
    91  	return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
    92  }
    93  
    94  // A Provider is the interface for any component which will provide credentials
    95  // Value. A provider is required to manage its own Expired state, and what to
    96  // be expired means.
    97  //
    98  // The Provider should not need to implement its own mutexes, because
    99  // that will be managed by Credentials.
   100  type Provider interface {
   101  	// Retrieve returns nil if it successfully retrieved the value.
   102  	// Error is returned if the value were not obtainable, or empty.
   103  	Retrieve() (Value, error)
   104  
   105  	// IsExpired returns if the credentials are no longer valid, and need
   106  	// to be retrieved.
   107  	IsExpired() bool
   108  }
   109  
   110  // ProviderWithContext is a Provider that can retrieve credentials with a Context
   111  type ProviderWithContext interface {
   112  	Provider
   113  
   114  	RetrieveWithContext(Context) (Value, error)
   115  }
   116  
   117  // An Expirer is an interface that Providers can implement to expose the expiration
   118  // time, if known.  If the Provider cannot accurately provide this info,
   119  // it should not implement this interface.
   120  type Expirer interface {
   121  	// The time at which the credentials are no longer valid
   122  	ExpiresAt() time.Time
   123  }
   124  
   125  // An ErrorProvider is a stub credentials provider that always returns an error
   126  // this is used by the SDK when construction a known provider is not possible
   127  // due to an error.
   128  type ErrorProvider struct {
   129  	// The error to be returned from Retrieve
   130  	Err error
   131  
   132  	// The provider name to set on the Retrieved returned Value
   133  	ProviderName string
   134  }
   135  
   136  // Retrieve will always return the error that the ErrorProvider was created with.
   137  func (p ErrorProvider) Retrieve() (Value, error) {
   138  	return Value{ProviderName: p.ProviderName}, p.Err
   139  }
   140  
   141  // IsExpired will always return not expired.
   142  func (p ErrorProvider) IsExpired() bool {
   143  	return false
   144  }
   145  
   146  // A Expiry provides shared expiration logic to be used by credentials
   147  // providers to implement expiry functionality.
   148  //
   149  // The best method to use this struct is as an anonymous field within the
   150  // provider's struct.
   151  //
   152  // Example:
   153  //     type EC2RoleProvider struct {
   154  //         Expiry
   155  //         ...
   156  //     }
   157  type Expiry struct {
   158  	// The date/time when to expire on
   159  	expiration time.Time
   160  
   161  	// If set will be used by IsExpired to determine the current time.
   162  	// Defaults to time.Now if CurrentTime is not set.  Available for testing
   163  	// to be able to mock out the current time.
   164  	CurrentTime func() time.Time
   165  }
   166  
   167  // SetExpiration sets the expiration IsExpired will check when called.
   168  //
   169  // If window is greater than 0 the expiration time will be reduced by the
   170  // window value.
   171  //
   172  // Using a window is helpful to trigger credentials to expire sooner than
   173  // the expiration time given to ensure no requests are made with expired
   174  // tokens.
   175  func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
   176  	// Passed in expirations should have the monotonic clock values stripped.
   177  	// This ensures time comparisons will be based on wall-time.
   178  	e.expiration = expiration.Round(0)
   179  	if window > 0 {
   180  		e.expiration = e.expiration.Add(-window)
   181  	}
   182  }
   183  
   184  // IsExpired returns if the credentials are expired.
   185  func (e *Expiry) IsExpired() bool {
   186  	curTime := e.CurrentTime
   187  	if curTime == nil {
   188  		curTime = time.Now
   189  	}
   190  	return e.expiration.Before(curTime())
   191  }
   192  
   193  // ExpiresAt returns the expiration time of the credential
   194  func (e *Expiry) ExpiresAt() time.Time {
   195  	return e.expiration
   196  }
   197  
   198  // A Credentials provides concurrency safe retrieval of AWS credentials Value.
   199  // Credentials will cache the credentials value until they expire. Once the value
   200  // expires the next Get will attempt to retrieve valid credentials.
   201  //
   202  // Credentials is safe to use across multiple goroutines and will manage the
   203  // synchronous state so the Providers do not need to implement their own
   204  // synchronization.
   205  //
   206  // The first Credentials.Get() will always call Provider.Retrieve() to get the
   207  // first instance of the credentials Value. All calls to Get() after that
   208  // will return the cached credentials Value until IsExpired() returns true.
   209  type Credentials struct {
   210  	sf singleflight.Group
   211  
   212  	m        sync.RWMutex
   213  	creds    Value
   214  	provider Provider
   215  }
   216  
   217  // NewCredentials returns a pointer to a new Credentials with the provider set.
   218  func NewCredentials(provider Provider) *Credentials {
   219  	c := &Credentials{
   220  		provider: provider,
   221  	}
   222  	return c
   223  }
   224  
   225  // GetWithContext returns the credentials value, or error if the credentials
   226  // Value failed to be retrieved. Will return early if the passed in context is
   227  // canceled.
   228  //
   229  // Will return the cached credentials Value if it has not expired. If the
   230  // credentials Value has expired the Provider's Retrieve() will be called
   231  // to refresh the credentials.
   232  //
   233  // If Credentials.Expire() was called the credentials Value will be force
   234  // expired, and the next call to Get() will cause them to be refreshed.
   235  //
   236  // Passed in Context is equivalent to aws.Context, and context.Context.
   237  func (c *Credentials) GetWithContext(ctx Context) (Value, error) {
   238  	// Check if credentials are cached, and not expired.
   239  	select {
   240  	case curCreds, ok := <-c.asyncIsExpired():
   241  		// ok will only be true, of the credentials were not expired. ok will
   242  		// be false and have no value if the credentials are expired.
   243  		if ok {
   244  			return curCreds, nil
   245  		}
   246  	case <-ctx.Done():
   247  		return Value{}, awserr.New("RequestCanceled",
   248  			"request context canceled", ctx.Err())
   249  	}
   250  
   251  	// Cannot pass context down to the actual retrieve, because the first
   252  	// context would cancel the whole group when there is not direct
   253  	// association of items in the group.
   254  	resCh := c.sf.DoChan("", func() (interface{}, error) {
   255  		return c.singleRetrieve(&suppressedContext{ctx})
   256  	})
   257  	select {
   258  	case res := <-resCh:
   259  		return res.Val.(Value), res.Err
   260  	case <-ctx.Done():
   261  		return Value{}, awserr.New("RequestCanceled",
   262  			"request context canceled", ctx.Err())
   263  	}
   264  }
   265  
   266  func (c *Credentials) singleRetrieve(ctx Context) (interface{}, error) {
   267  	c.m.Lock()
   268  	defer c.m.Unlock()
   269  
   270  	if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
   271  		return curCreds, nil
   272  	}
   273  
   274  	var creds Value
   275  	var err error
   276  	if p, ok := c.provider.(ProviderWithContext); ok {
   277  		creds, err = p.RetrieveWithContext(ctx)
   278  	} else {
   279  		creds, err = c.provider.Retrieve()
   280  	}
   281  	if err == nil {
   282  		c.creds = creds
   283  	}
   284  
   285  	return creds, err
   286  }
   287  
   288  // Get returns the credentials value, or error if the credentials Value failed
   289  // to be retrieved.
   290  //
   291  // Will return the cached credentials Value if it has not expired. If the
   292  // credentials Value has expired the Provider's Retrieve() will be called
   293  // to refresh the credentials.
   294  //
   295  // If Credentials.Expire() was called the credentials Value will be force
   296  // expired, and the next call to Get() will cause them to be refreshed.
   297  func (c *Credentials) Get() (Value, error) {
   298  	return c.GetWithContext(backgroundContext())
   299  }
   300  
   301  // Expire expires the credentials and forces them to be retrieved on the
   302  // next call to Get().
   303  //
   304  // This will override the Provider's expired state, and force Credentials
   305  // to call the Provider's Retrieve().
   306  func (c *Credentials) Expire() {
   307  	c.m.Lock()
   308  	defer c.m.Unlock()
   309  
   310  	c.creds = Value{}
   311  }
   312  
   313  // IsExpired returns if the credentials are no longer valid, and need
   314  // to be retrieved.
   315  //
   316  // If the Credentials were forced to be expired with Expire() this will
   317  // reflect that override.
   318  func (c *Credentials) IsExpired() bool {
   319  	c.m.RLock()
   320  	defer c.m.RUnlock()
   321  
   322  	return c.isExpiredLocked(c.creds)
   323  }
   324  
   325  // asyncIsExpired returns a channel of credentials Value. If the channel is
   326  // closed the credentials are expired and credentials value are not empty.
   327  func (c *Credentials) asyncIsExpired() <-chan Value {
   328  	ch := make(chan Value, 1)
   329  	go func() {
   330  		c.m.RLock()
   331  		defer c.m.RUnlock()
   332  
   333  		if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
   334  			ch <- curCreds
   335  		}
   336  
   337  		close(ch)
   338  	}()
   339  
   340  	return ch
   341  }
   342  
   343  // isExpiredLocked helper method wrapping the definition of expired credentials.
   344  func (c *Credentials) isExpiredLocked(creds interface{}) bool {
   345  	return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
   346  }
   347  
   348  // ExpiresAt provides access to the functionality of the Expirer interface of
   349  // the underlying Provider, if it supports that interface.  Otherwise, it returns
   350  // an error.
   351  func (c *Credentials) ExpiresAt() (time.Time, error) {
   352  	c.m.RLock()
   353  	defer c.m.RUnlock()
   354  
   355  	expirer, ok := c.provider.(Expirer)
   356  	if !ok {
   357  		return time.Time{}, awserr.New("ProviderNotExpirer",
   358  			fmt.Sprintf("provider %s does not support ExpiresAt()",
   359  				c.creds.ProviderName),
   360  			nil)
   361  	}
   362  	if c.creds == (Value{}) {
   363  		// set expiration time to the distant past
   364  		return time.Time{}, nil
   365  	}
   366  	return expirer.ExpiresAt(), nil
   367  }
   368  
   369  type suppressedContext struct {
   370  	Context
   371  }
   372  
   373  func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
   374  	return time.Time{}, false
   375  }
   376  
   377  func (s *suppressedContext) Done() <-chan struct{} {
   378  	return nil
   379  }
   380  
   381  func (s *suppressedContext) Err() error {
   382  	return nil
   383  }