k8s.io/client-go@v0.22.2/plugin/pkg/client/auth/azure/azure_test.go (about) 1 /* 2 Copyright 2017 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package azure 18 19 import ( 20 "encoding/json" 21 "errors" 22 "fmt" 23 "net/http" 24 "strconv" 25 "strings" 26 "sync" 27 "testing" 28 "time" 29 30 "github.com/Azure/go-autorest/autorest/adal" 31 "github.com/Azure/go-autorest/autorest/azure" 32 ) 33 34 func TestAzureAuthProvider(t *testing.T) { 35 t.Run("validate against invalid configurations", func(t *testing.T) { 36 vectors := []struct { 37 cfg map[string]string 38 expectedError string 39 }{ 40 { 41 cfg: map[string]string{ 42 cfgClientID: "foo", 43 cfgApiserverID: "foo", 44 cfgTenantID: "foo", 45 cfgConfigMode: "-1", 46 }, 47 expectedError: "config-mode:-1 is not a valid mode", 48 }, 49 { 50 cfg: map[string]string{ 51 cfgClientID: "foo", 52 cfgApiserverID: "foo", 53 cfgTenantID: "foo", 54 cfgConfigMode: "2", 55 }, 56 expectedError: "config-mode:2 is not a valid mode", 57 }, 58 { 59 cfg: map[string]string{ 60 cfgClientID: "foo", 61 cfgApiserverID: "foo", 62 cfgTenantID: "foo", 63 cfgConfigMode: "foo", 64 }, 65 expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax", 66 }, 67 } 68 69 for _, v := range vectors { 70 persister := &fakePersister{} 71 _, err := newAzureAuthProvider("", v.cfg, persister) 72 if !strings.Contains(err.Error(), v.expectedError) { 73 t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err) 74 } 75 } 76 }) 77 78 t.Run("it should return non-nil provider in happy cases", func(t *testing.T) { 79 vectors := []struct { 80 cfg map[string]string 81 expectedConfigMode configMode 82 }{ 83 { 84 cfg: map[string]string{ 85 cfgClientID: "foo", 86 cfgApiserverID: "foo", 87 cfgTenantID: "foo", 88 }, 89 expectedConfigMode: configModeDefault, 90 }, 91 { 92 cfg: map[string]string{ 93 cfgClientID: "foo", 94 cfgApiserverID: "foo", 95 cfgTenantID: "foo", 96 cfgConfigMode: "0", 97 }, 98 expectedConfigMode: configModeDefault, 99 }, 100 { 101 cfg: map[string]string{ 102 cfgClientID: "foo", 103 cfgApiserverID: "foo", 104 cfgTenantID: "foo", 105 cfgConfigMode: "1", 106 }, 107 expectedConfigMode: configModeOmitSPNPrefix, 108 }, 109 } 110 111 for _, v := range vectors { 112 persister := &fakePersister{} 113 provider, err := newAzureAuthProvider("", v.cfg, persister) 114 if err != nil { 115 t.Errorf("newAzureAuthProvider should not fail with '%s'", err) 116 } 117 if provider == nil { 118 t.Fatalf("newAzureAuthProvider should return non-nil provider") 119 } 120 azureProvider := provider.(*azureAuthProvider) 121 if azureProvider == nil { 122 t.Fatalf("newAzureAuthProvider should return an instance of type azureAuthProvider") 123 } 124 ts := azureProvider.tokenSource.(*azureTokenSource) 125 if ts == nil { 126 t.Fatalf("azureAuthProvider should be an instance of azureTokenSource") 127 } 128 if ts.configMode != v.expectedConfigMode { 129 t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode) 130 } 131 } 132 }) 133 } 134 135 func TestTokenSourceDeviceCode(t *testing.T) { 136 var ( 137 clientID = "clientID" 138 tenantID = "tenantID" 139 apiserverID = "apiserverID" 140 configMode = configModeDefault 141 azureEnv = azure.Environment{} 142 ) 143 t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) { 144 if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil { 145 t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) 146 } 147 148 if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil { 149 t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) 150 } 151 152 _, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode) 153 actual := "client-id is empty" 154 if err.Error() != actual { 155 t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) 156 } 157 158 _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode) 159 actual = "tenant-id is empty" 160 if err.Error() != actual { 161 t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) 162 } 163 164 _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode) 165 actual = "apiserver-id is empty" 166 if err.Error() != actual { 167 t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) 168 } 169 }) 170 } 171 func TestAzureTokenSource(t *testing.T) { 172 configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} 173 expectedConfigModes := []string{"1", "0"} 174 175 for i, configMode := range configModes { 176 t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) { 177 const ( 178 serverID = "fakeServerID" 179 clientID = "fakeClientID" 180 tenantID = "fakeTenantID" 181 accessToken = "fakeToken" 182 environment = "fakeEnvironment" 183 refreshToken = "fakeToken" 184 expiresIn = "foo" 185 expiresOn = "foo" 186 ) 187 cfg := map[string]string{ 188 cfgConfigMode: strconv.Itoa(int(configMode)), 189 cfgApiserverID: serverID, 190 cfgClientID: clientID, 191 cfgTenantID: tenantID, 192 cfgEnvironment: environment, 193 cfgAccessToken: accessToken, 194 cfgRefreshToken: refreshToken, 195 cfgExpiresIn: expiresIn, 196 cfgExpiresOn: expiresOn, 197 } 198 fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))} 199 persiter := &fakePersister{cache: make(map[string]string)} 200 tokenCache := newAzureTokenCache() 201 tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) 202 azTokenSource := tokenSource.(*azureTokenSource) 203 token, err := azTokenSource.retrieveTokenFromCfg() 204 if err != nil { 205 t.Errorf("failed to retrieve the token form cfg: %s", err) 206 } 207 if token.apiserverID != serverID { 208 t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID) 209 } 210 if token.clientID != clientID { 211 t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID) 212 } 213 if token.tenantID != tenantID { 214 t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID) 215 } 216 expectedAudience := serverID 217 if configMode == configModeDefault { 218 expectedAudience = fmt.Sprintf("spn:%s", serverID) 219 } 220 if token.token.Resource != expectedAudience { 221 t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource) 222 } 223 }) 224 225 t.Run("validate token against cache", func(t *testing.T) { 226 fakeAccessToken := "fake token 1" 227 fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))} 228 cfg := make(map[string]string) 229 persiter := &fakePersister{cache: make(map[string]string)} 230 tokenCache := newAzureTokenCache() 231 tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) 232 token, err := tokenSource.Token() 233 if err != nil { 234 t.Errorf("failed to retrieve the token form cache: %v", err) 235 } 236 237 wantCacheLen := 1 238 if len(tokenCache.cache) != wantCacheLen { 239 t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen) 240 } 241 242 if token != tokenCache.cache[azureTokenKey] { 243 t.Error("Token() returned token != cached token") 244 } 245 246 wantCfg := token2Cfg(token) 247 wantCfg[cfgConfigMode] = expectedConfigModes[i] 248 persistedCfg := persiter.Cache() 249 250 wantCfgLen := len(wantCfg) 251 persistedCfgLen := len(persistedCfg) 252 if wantCfgLen != persistedCfgLen { 253 t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen) 254 } 255 256 for k, v := range persistedCfg { 257 if strings.Compare(v, wantCfg[k]) != 0 { 258 t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k]) 259 } 260 } 261 262 fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second)) 263 token, err = tokenSource.Token() 264 if err != nil { 265 t.Errorf("failed to retrieve the cached token: %v", err) 266 } 267 268 if token.token.AccessToken != fakeAccessToken { 269 t.Errorf("Token() didn't return the cached token") 270 } 271 }) 272 } 273 } 274 275 func TestAzureTokenSourceScenarios(t *testing.T) { 276 expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second)) 277 extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second)) 278 fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second)) 279 wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second)) 280 tests := []struct { 281 name string 282 sourceToken *azureToken 283 refreshToken *azureToken 284 cachedToken *azureToken 285 configToken *azureToken 286 expectToken *azureToken 287 tokenErr error 288 refreshErr error 289 expectErr string 290 tokenCalls uint 291 refreshCalls uint 292 persistCalls uint 293 }{ 294 { 295 name: "new config", 296 sourceToken: fakeToken, 297 expectToken: fakeToken, 298 tokenCalls: 1, 299 persistCalls: 1, 300 }, 301 { 302 name: "load token from cache", 303 sourceToken: wrongToken, 304 cachedToken: fakeToken, 305 configToken: wrongToken, 306 expectToken: fakeToken, 307 }, 308 { 309 name: "load token from config", 310 sourceToken: wrongToken, 311 configToken: fakeToken, 312 expectToken: fakeToken, 313 }, 314 { 315 name: "cached token timeout, extend success, config token should never load", 316 cachedToken: expiredToken, 317 refreshToken: extendedToken, 318 configToken: wrongToken, 319 expectToken: extendedToken, 320 refreshCalls: 1, 321 persistCalls: 1, 322 }, 323 { 324 name: "config token timeout, extend failure, acquire new token", 325 configToken: expiredToken, 326 refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"}, 327 sourceToken: fakeToken, 328 expectToken: fakeToken, 329 refreshCalls: 1, 330 tokenCalls: 1, 331 persistCalls: 1, 332 }, 333 { 334 name: "extend failure with fmt.Errorf nested tokenRefreshError", 335 configToken: expiredToken, 336 refreshErr: fmt.Errorf("refreshing token: %w", fakeTokenRefreshError{message: "nested FakeError happened when refreshing"}), 337 sourceToken: fakeToken, 338 expectToken: fakeToken, 339 refreshCalls: 1, 340 tokenCalls: 1, 341 persistCalls: 1, 342 }, 343 { 344 name: "unexpected error when extend", 345 configToken: expiredToken, 346 refreshErr: errors.New("unexpected refresh error"), 347 sourceToken: fakeToken, 348 expectErr: "unexpected refresh error", 349 refreshCalls: 1, 350 }, 351 { 352 name: "token error", 353 tokenErr: errors.New("tokenerr"), 354 expectErr: "tokenerr", 355 tokenCalls: 1, 356 }, 357 { 358 name: "Token() got expired token", 359 sourceToken: expiredToken, 360 expectErr: "newly acquired token is expired", 361 tokenCalls: 1, 362 }, 363 { 364 name: "Token() got nil but no error", 365 sourceToken: nil, 366 expectErr: "unable to acquire token", 367 tokenCalls: 1, 368 }, 369 } 370 for _, tc := range tests { 371 configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} 372 373 for _, configMode := range configModes { 374 t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) { 375 persister := newFakePersister() 376 377 cfg := map[string]string{ 378 cfgConfigMode: strconv.Itoa(int(configMode)), 379 } 380 if tc.configToken != nil { 381 cfg = token2Cfg(tc.configToken) 382 } 383 384 tokenCache := newAzureTokenCache() 385 if tc.cachedToken != nil { 386 tokenCache.setToken(azureTokenKey, tc.cachedToken) 387 } 388 389 fakeSource := fakeTokenSource{ 390 token: tc.sourceToken, 391 tokenErr: tc.tokenErr, 392 refreshToken: tc.refreshToken, 393 refreshErr: tc.refreshErr, 394 } 395 396 tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister) 397 token, err := tokenSource.Token() 398 399 if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID { 400 t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID) 401 } 402 if fakeSource.tokenCalls != tc.tokenCalls { 403 t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls) 404 } 405 406 if fakeSource.refreshCalls != tc.refreshCalls { 407 t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls) 408 } 409 410 if persister.calls != tc.persistCalls { 411 t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls) 412 } 413 414 if tc.expectErr != "" { 415 if !strings.Contains(err.Error(), tc.expectErr) { 416 t.Errorf("expecting error %v, got %v", tc.expectErr, err) 417 } 418 if token != nil { 419 t.Errorf("token should be nil in err situation, got %v", token) 420 } 421 } else { 422 if err != nil { 423 t.Fatalf("error should be nil, got %v", err) 424 } 425 if token.token.AccessToken != tc.expectToken.token.AccessToken { 426 t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken) 427 } 428 } 429 }) 430 } 431 } 432 } 433 434 type fakePersister struct { 435 lock sync.Mutex 436 cache map[string]string 437 calls uint 438 } 439 440 func newFakePersister() fakePersister { 441 return fakePersister{cache: make(map[string]string), calls: 0} 442 } 443 444 func (p *fakePersister) Persist(cache map[string]string) error { 445 p.lock.Lock() 446 defer p.lock.Unlock() 447 p.calls++ 448 p.cache = map[string]string{} 449 for k, v := range cache { 450 p.cache[k] = v 451 } 452 return nil 453 } 454 455 func (p *fakePersister) Cache() map[string]string { 456 ret := map[string]string{} 457 p.lock.Lock() 458 defer p.lock.Unlock() 459 for k, v := range p.cache { 460 ret[k] = v 461 } 462 return ret 463 } 464 465 // a simple token source simply always returns the token property 466 type fakeTokenSource struct { 467 token *azureToken 468 tokenCalls uint 469 tokenErr error 470 refreshToken *azureToken 471 refreshCalls uint 472 refreshErr error 473 } 474 475 func (ts *fakeTokenSource) Token() (*azureToken, error) { 476 ts.tokenCalls++ 477 return ts.token, ts.tokenErr 478 } 479 480 func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) { 481 ts.refreshCalls++ 482 return ts.refreshToken, ts.refreshErr 483 } 484 485 func token2Cfg(token *azureToken) map[string]string { 486 cfg := make(map[string]string) 487 cfg[cfgAccessToken] = token.token.AccessToken 488 cfg[cfgRefreshToken] = token.token.RefreshToken 489 cfg[cfgEnvironment] = token.environment 490 cfg[cfgClientID] = token.clientID 491 cfg[cfgTenantID] = token.tenantID 492 cfg[cfgApiserverID] = token.apiserverID 493 cfg[cfgExpiresIn] = string(token.token.ExpiresIn) 494 cfg[cfgExpiresOn] = string(token.token.ExpiresOn) 495 return cfg 496 } 497 498 func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken { 499 return &azureToken{ 500 token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)), 501 environment: "testenv", 502 clientID: "fake", 503 tenantID: "fake", 504 apiserverID: "fake", 505 } 506 } 507 508 func newFakeADALToken(accessToken string, expiresOn string) adal.Token { 509 return adal.Token{ 510 AccessToken: accessToken, 511 RefreshToken: "fake", 512 ExpiresIn: "3600", 513 ExpiresOn: json.Number(expiresOn), 514 NotBefore: json.Number(expiresOn), 515 Resource: "fake", 516 Type: "fake", 517 } 518 } 519 520 // copied from go-autorest/adal 521 type fakeTokenRefreshError struct { 522 message string 523 resp *http.Response 524 } 525 526 // Error implements the error interface which is part of the TokenRefreshError interface. 527 func (tre fakeTokenRefreshError) Error() string { 528 return tre.message 529 } 530 531 // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. 532 func (tre fakeTokenRefreshError) Response() *http.Response { 533 return tre.resp 534 }