github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/credentials_test.go (about)

     1  package credentials
     2  
     3  import (
     4  	"math/rand"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    10  )
    11  
    12  type stubProvider struct {
    13  	creds          Value
    14  	retrievedCount int
    15  	expired        bool
    16  	err            error
    17  }
    18  
    19  func (s *stubProvider) Retrieve() (Value, error) {
    20  	s.retrievedCount++
    21  	s.expired = false
    22  	s.creds.ProviderName = "stubProvider"
    23  	return s.creds, s.err
    24  }
    25  func (s *stubProvider) IsExpired() bool {
    26  	return s.expired
    27  }
    28  
    29  func TestCredentialsGet(t *testing.T) {
    30  	c := NewCredentials(&stubProvider{
    31  		creds: Value{
    32  			AccessKeyID:     "AKID",
    33  			SecretAccessKey: "SECRET",
    34  			SessionToken:    "",
    35  		},
    36  		expired: true,
    37  	})
    38  
    39  	creds, err := c.Get()
    40  	if err != nil {
    41  		t.Errorf("Expected no error, got %v", err)
    42  	}
    43  	if e, a := "AKID", creds.AccessKeyID; e != a {
    44  		t.Errorf("Expect access key ID to match, %v got %v", e, a)
    45  	}
    46  	if e, a := "SECRET", creds.SecretAccessKey; e != a {
    47  		t.Errorf("Expect secret access key to match, %v got %v", e, a)
    48  	}
    49  	if v := creds.SessionToken; len(v) != 0 {
    50  		t.Errorf("Expect session token to be empty, %v", v)
    51  	}
    52  }
    53  
    54  func TestCredentialsGetWithError(t *testing.T) {
    55  	c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
    56  
    57  	_, err := c.Get()
    58  	if e, a := "provider error", err.(awserr.Error).Code(); e != a {
    59  		t.Errorf("Expected provider error, %v got %v", e, a)
    60  	}
    61  }
    62  
    63  func TestCredentialsExpire(t *testing.T) {
    64  	stub := &stubProvider{}
    65  	c := NewCredentials(stub)
    66  
    67  	stub.expired = false
    68  	if !c.IsExpired() {
    69  		t.Errorf("Expected to start out expired")
    70  	}
    71  	c.Expire()
    72  	if !c.IsExpired() {
    73  		t.Errorf("Expected to be expired")
    74  	}
    75  
    76  	c.Get()
    77  	if c.IsExpired() {
    78  		t.Errorf("Expected not to be expired")
    79  	}
    80  
    81  	stub.expired = true
    82  	if !c.IsExpired() {
    83  		t.Errorf("Expected to be expired")
    84  	}
    85  }
    86  
    87  type MockProvider struct {
    88  	Expiry
    89  }
    90  
    91  func (*MockProvider) Retrieve() (Value, error) {
    92  	return Value{}, nil
    93  }
    94  
    95  func TestCredentialsGetWithProviderName(t *testing.T) {
    96  	stub := &stubProvider{}
    97  
    98  	c := NewCredentials(stub)
    99  
   100  	creds, err := c.Get()
   101  	if err != nil {
   102  		t.Errorf("Expected no error, got %v", err)
   103  	}
   104  	if e, a := creds.ProviderName, "stubProvider"; e != a {
   105  		t.Errorf("Expected provider name to match, %v got %v", e, a)
   106  	}
   107  }
   108  
   109  func TestCredentialsIsExpired_Race(t *testing.T) {
   110  	creds := NewChainCredentials([]Provider{&MockProvider{}})
   111  
   112  	starter := make(chan struct{})
   113  	var wg sync.WaitGroup
   114  	wg.Add(10)
   115  	for i := 0; i < 10; i++ {
   116  		go func() {
   117  			defer wg.Done()
   118  			<-starter
   119  			for i := 0; i < 100; i++ {
   120  				creds.IsExpired()
   121  			}
   122  		}()
   123  	}
   124  	close(starter)
   125  
   126  	wg.Wait()
   127  }
   128  
   129  func TestCredentialsExpiresAt_NoExpirer(t *testing.T) {
   130  	stub := &stubProvider{}
   131  	c := NewCredentials(stub)
   132  
   133  	_, err := c.ExpiresAt()
   134  	if e, a := "ProviderNotExpirer", err.(awserr.Error).Code(); e != a {
   135  		t.Errorf("Expected provider error, %v got %v", e, a)
   136  	}
   137  }
   138  
   139  type stubProviderExpirer struct {
   140  	stubProvider
   141  	expiration time.Time
   142  }
   143  
   144  func (s *stubProviderExpirer) ExpiresAt() time.Time {
   145  	return s.expiration
   146  }
   147  
   148  func TestCredentialsExpiresAt_HasExpirer(t *testing.T) {
   149  	stub := &stubProviderExpirer{}
   150  	c := NewCredentials(stub)
   151  
   152  	// fetch initial credentials so that forceRefresh is set false
   153  	_, err := c.Get()
   154  	if err != nil {
   155  		t.Errorf("Unexpecte error: %v", err)
   156  	}
   157  
   158  	stub.expiration = time.Unix(rand.Int63(), 0)
   159  	expiration, err := c.ExpiresAt()
   160  	if err != nil {
   161  		t.Errorf("Expected no error, got %v", err)
   162  	}
   163  	if stub.expiration != expiration {
   164  		t.Errorf("Expected matching expiration, %v got %v", stub.expiration, expiration)
   165  	}
   166  
   167  	c.Expire()
   168  	expiration, err = c.ExpiresAt()
   169  	if err != nil {
   170  		t.Errorf("Expected no error, got %v", err)
   171  	}
   172  	if !expiration.IsZero() {
   173  		t.Errorf("Expected distant past expiration, got %v", expiration)
   174  	}
   175  }
   176  
   177  type stubProviderConcurrent struct {
   178  	stubProvider
   179  	done chan struct{}
   180  }
   181  
   182  func (s *stubProviderConcurrent) Retrieve() (Value, error) {
   183  	<-s.done
   184  	return s.stubProvider.Retrieve()
   185  }
   186  
   187  func TestCredentialsGetConcurrent(t *testing.T) {
   188  	stub := &stubProviderConcurrent{
   189  		done: make(chan struct{}),
   190  	}
   191  
   192  	c := NewCredentials(stub)
   193  	done := make(chan struct{})
   194  
   195  	for i := 0; i < 2; i++ {
   196  		go func() {
   197  			c.Get()
   198  			done <- struct{}{}
   199  		}()
   200  	}
   201  
   202  	// Validates that a single call to Retrieve is shared between two calls to Get
   203  	stub.done <- struct{}{}
   204  	<-done
   205  	<-done
   206  }
   207  
   208  type stubProviderRefreshable struct {
   209  	creds        Value
   210  	expired      bool
   211  	hasRetrieved bool
   212  }
   213  
   214  func (s *stubProviderRefreshable) Retrieve() (Value, error) {
   215  	// On first retrieval, return the creds that this provider was created with.
   216  	// On subsequent retrievals, return new refreshed credentials.
   217  	if !s.hasRetrieved {
   218  		s.expired = true
   219  		s.hasRetrieved = true
   220  	} else {
   221  		s.creds = Value{
   222  			AccessKeyID:     "AKID",
   223  			SecretAccessKey: "SECRET",
   224  			SessionToken:    "NEW_SESSION",
   225  		}
   226  		s.expired = false
   227  		time.Sleep(10 * time.Millisecond)
   228  	}
   229  	return s.creds, nil
   230  }
   231  
   232  func (s *stubProviderRefreshable) IsExpired() bool {
   233  	return s.expired
   234  }
   235  
   236  func TestCredentialsGet_RefreshableProviderRace(t *testing.T) {
   237  	stub := &stubProviderRefreshable{
   238  		creds: Value{
   239  			AccessKeyID:     "AKID",
   240  			SecretAccessKey: "SECRET",
   241  			SessionToken:    "OLD_SESSION",
   242  		},
   243  	}
   244  
   245  	c := NewCredentials(stub)
   246  
   247  	// The first Get() causes stubProviderRefreshable to consider its
   248  	// OLD_SESSION credentials expired on subsequent retrievals.
   249  	creds, err := c.Get()
   250  	if err != nil {
   251  		t.Errorf("Expected no error, got %v", err)
   252  	}
   253  	if e, a := "OLD_SESSION", creds.SessionToken; e != a {
   254  		t.Errorf("Expect session token to match, %v got %v", e, a)
   255  	}
   256  
   257  	// Since stubProviderRefreshable considers its OLD_SESSION credentials
   258  	// expired, all subsequent calls to Get() should retrieve NEW_SESSION creds.
   259  	var wg sync.WaitGroup
   260  	wg.Add(100)
   261  	for i := 0; i < 100; i++ {
   262  		go func() {
   263  			defer wg.Done()
   264  			creds, err := c.Get()
   265  			if err != nil {
   266  				t.Errorf("Expected no error, got %v", err)
   267  			}
   268  
   269  			if c.IsExpired() {
   270  				t.Errorf("not expect expired")
   271  			}
   272  
   273  			if e, a := "NEW_SESSION", creds.SessionToken; e != a {
   274  				t.Errorf("Expect session token to match, %v got %v", e, a)
   275  			}
   276  		}()
   277  	}
   278  
   279  	wg.Wait()
   280  }