github.com/aavshr/aws-sdk-go@v1.41.3/aws/credentials/credentials_test.go (about) 1 package credentials 2 3 import ( 4 "math/rand" 5 "sync" 6 "testing" 7 "time" 8 9 "github.com/aavshr/aws-sdk-go/aws/awserr" 10 ) 11 12 type stubProvider struct { 13 creds Value 14 retrievedCount int 15 expired bool 16 err error 17 } 18 19 func (s *stubProvider) Retrieve() (Value, error) { 20 s.retrievedCount++ 21 s.expired = false 22 s.creds.ProviderName = "stubProvider" 23 return s.creds, s.err 24 } 25 func (s *stubProvider) IsExpired() bool { 26 return s.expired 27 } 28 29 func TestCredentialsGet(t *testing.T) { 30 c := NewCredentials(&stubProvider{ 31 creds: Value{ 32 AccessKeyID: "AKID", 33 SecretAccessKey: "SECRET", 34 SessionToken: "", 35 }, 36 expired: true, 37 }) 38 39 creds, err := c.Get() 40 if err != nil { 41 t.Errorf("Expected no error, got %v", err) 42 } 43 if e, a := "AKID", creds.AccessKeyID; e != a { 44 t.Errorf("Expect access key ID to match, %v got %v", e, a) 45 } 46 if e, a := "SECRET", creds.SecretAccessKey; e != a { 47 t.Errorf("Expect secret access key to match, %v got %v", e, a) 48 } 49 if v := creds.SessionToken; len(v) != 0 { 50 t.Errorf("Expect session token to be empty, %v", v) 51 } 52 } 53 54 func TestCredentialsGetWithError(t *testing.T) { 55 c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true}) 56 57 _, err := c.Get() 58 if e, a := "provider error", err.(awserr.Error).Code(); e != a { 59 t.Errorf("Expected provider error, %v got %v", e, a) 60 } 61 } 62 63 func TestCredentialsExpire(t *testing.T) { 64 stub := &stubProvider{} 65 c := NewCredentials(stub) 66 67 stub.expired = false 68 if !c.IsExpired() { 69 t.Errorf("Expected to start out expired") 70 } 71 c.Expire() 72 if !c.IsExpired() { 73 t.Errorf("Expected to be expired") 74 } 75 76 c.Get() 77 if c.IsExpired() { 78 t.Errorf("Expected not to be expired") 79 } 80 81 stub.expired = true 82 if !c.IsExpired() { 83 t.Errorf("Expected to be expired") 84 } 85 } 86 87 type MockProvider struct { 88 Expiry 89 } 90 91 func (*MockProvider) Retrieve() (Value, error) { 92 return Value{}, nil 93 } 94 95 func TestCredentialsGetWithProviderName(t *testing.T) { 96 stub := &stubProvider{} 97 98 c := NewCredentials(stub) 99 100 creds, err := c.Get() 101 if err != nil { 102 t.Errorf("Expected no error, got %v", err) 103 } 104 if e, a := creds.ProviderName, "stubProvider"; e != a { 105 t.Errorf("Expected provider name to match, %v got %v", e, a) 106 } 107 } 108 109 func TestCredentialsIsExpired_Race(t *testing.T) { 110 creds := NewChainCredentials([]Provider{&MockProvider{}}) 111 112 starter := make(chan struct{}) 113 var wg sync.WaitGroup 114 wg.Add(10) 115 for i := 0; i < 10; i++ { 116 go func() { 117 defer wg.Done() 118 <-starter 119 for i := 0; i < 100; i++ { 120 creds.IsExpired() 121 } 122 }() 123 } 124 close(starter) 125 126 wg.Wait() 127 } 128 129 func TestCredentialsExpiresAt_NoExpirer(t *testing.T) { 130 stub := &stubProvider{} 131 c := NewCredentials(stub) 132 133 _, err := c.ExpiresAt() 134 if e, a := "ProviderNotExpirer", err.(awserr.Error).Code(); e != a { 135 t.Errorf("Expected provider error, %v got %v", e, a) 136 } 137 } 138 139 type stubProviderExpirer struct { 140 stubProvider 141 expiration time.Time 142 } 143 144 func (s *stubProviderExpirer) ExpiresAt() time.Time { 145 return s.expiration 146 } 147 148 func TestCredentialsExpiresAt_HasExpirer(t *testing.T) { 149 stub := &stubProviderExpirer{} 150 c := NewCredentials(stub) 151 152 // fetch initial credentials so that forceRefresh is set false 153 _, err := c.Get() 154 if err != nil { 155 t.Errorf("Unexpecte error: %v", err) 156 } 157 158 stub.expiration = time.Unix(rand.Int63(), 0) 159 expiration, err := c.ExpiresAt() 160 if err != nil { 161 t.Errorf("Expected no error, got %v", err) 162 } 163 if stub.expiration != expiration { 164 t.Errorf("Expected matching expiration, %v got %v", stub.expiration, expiration) 165 } 166 167 c.Expire() 168 expiration, err = c.ExpiresAt() 169 if err != nil { 170 t.Errorf("Expected no error, got %v", err) 171 } 172 if !expiration.IsZero() { 173 t.Errorf("Expected distant past expiration, got %v", expiration) 174 } 175 } 176 177 type stubProviderConcurrent struct { 178 stubProvider 179 done chan struct{} 180 } 181 182 func (s *stubProviderConcurrent) Retrieve() (Value, error) { 183 <-s.done 184 return s.stubProvider.Retrieve() 185 } 186 187 func TestCredentialsGetConcurrent(t *testing.T) { 188 stub := &stubProviderConcurrent{ 189 done: make(chan struct{}), 190 } 191 192 c := NewCredentials(stub) 193 done := make(chan struct{}) 194 195 for i := 0; i < 2; i++ { 196 go func() { 197 c.Get() 198 done <- struct{}{} 199 }() 200 } 201 202 // Validates that a single call to Retrieve is shared between two calls to Get 203 stub.done <- struct{}{} 204 <-done 205 <-done 206 } 207 208 type stubProviderRefreshable struct { 209 creds Value 210 expired bool 211 hasRetrieved bool 212 } 213 214 func (s *stubProviderRefreshable) Retrieve() (Value, error) { 215 // On first retrieval, return the creds that this provider was created with. 216 // On subsequent retrievals, return new refreshed credentials. 217 if !s.hasRetrieved { 218 s.expired = true 219 s.hasRetrieved = true 220 } else { 221 s.creds = Value{ 222 AccessKeyID: "AKID", 223 SecretAccessKey: "SECRET", 224 SessionToken: "NEW_SESSION", 225 } 226 s.expired = false 227 time.Sleep(10 * time.Millisecond) 228 } 229 return s.creds, nil 230 } 231 232 func (s *stubProviderRefreshable) IsExpired() bool { 233 return s.expired 234 } 235 236 func TestCredentialsGet_RefreshableProviderRace(t *testing.T) { 237 stub := &stubProviderRefreshable{ 238 creds: Value{ 239 AccessKeyID: "AKID", 240 SecretAccessKey: "SECRET", 241 SessionToken: "OLD_SESSION", 242 }, 243 } 244 245 c := NewCredentials(stub) 246 247 // The first Get() causes stubProviderRefreshable to consider its 248 // OLD_SESSION credentials expired on subsequent retrievals. 249 creds, err := c.Get() 250 if err != nil { 251 t.Errorf("Expected no error, got %v", err) 252 } 253 if e, a := "OLD_SESSION", creds.SessionToken; e != a { 254 t.Errorf("Expect session token to match, %v got %v", e, a) 255 } 256 257 // Since stubProviderRefreshable considers its OLD_SESSION credentials 258 // expired, all subsequent calls to Get() should retrieve NEW_SESSION creds. 259 var wg sync.WaitGroup 260 wg.Add(100) 261 for i := 0; i < 100; i++ { 262 go func() { 263 defer wg.Done() 264 creds, err := c.Get() 265 if err != nil { 266 t.Errorf("Expected no error, got %v", err) 267 } 268 269 if c.IsExpired() { 270 t.Errorf("not expect expired") 271 } 272 273 if e, a := "NEW_SESSION", creds.SessionToken; e != a { 274 t.Errorf("Expect session token to match, %v got %v", e, a) 275 } 276 }() 277 } 278 279 wg.Wait() 280 }