k8s.io/client-go@v0.22.2/plugin/pkg/client/auth/azure/azure_test.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package azure
    18  
    19  import (
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"net/http"
    24  	"strconv"
    25  	"strings"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/Azure/go-autorest/autorest/adal"
    31  	"github.com/Azure/go-autorest/autorest/azure"
    32  )
    33  
    34  func TestAzureAuthProvider(t *testing.T) {
    35  	t.Run("validate against invalid configurations", func(t *testing.T) {
    36  		vectors := []struct {
    37  			cfg           map[string]string
    38  			expectedError string
    39  		}{
    40  			{
    41  				cfg: map[string]string{
    42  					cfgClientID:    "foo",
    43  					cfgApiserverID: "foo",
    44  					cfgTenantID:    "foo",
    45  					cfgConfigMode:  "-1",
    46  				},
    47  				expectedError: "config-mode:-1 is not a valid mode",
    48  			},
    49  			{
    50  				cfg: map[string]string{
    51  					cfgClientID:    "foo",
    52  					cfgApiserverID: "foo",
    53  					cfgTenantID:    "foo",
    54  					cfgConfigMode:  "2",
    55  				},
    56  				expectedError: "config-mode:2 is not a valid mode",
    57  			},
    58  			{
    59  				cfg: map[string]string{
    60  					cfgClientID:    "foo",
    61  					cfgApiserverID: "foo",
    62  					cfgTenantID:    "foo",
    63  					cfgConfigMode:  "foo",
    64  				},
    65  				expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax",
    66  			},
    67  		}
    68  
    69  		for _, v := range vectors {
    70  			persister := &fakePersister{}
    71  			_, err := newAzureAuthProvider("", v.cfg, persister)
    72  			if !strings.Contains(err.Error(), v.expectedError) {
    73  				t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err)
    74  			}
    75  		}
    76  	})
    77  
    78  	t.Run("it should return non-nil provider in happy cases", func(t *testing.T) {
    79  		vectors := []struct {
    80  			cfg                map[string]string
    81  			expectedConfigMode configMode
    82  		}{
    83  			{
    84  				cfg: map[string]string{
    85  					cfgClientID:    "foo",
    86  					cfgApiserverID: "foo",
    87  					cfgTenantID:    "foo",
    88  				},
    89  				expectedConfigMode: configModeDefault,
    90  			},
    91  			{
    92  				cfg: map[string]string{
    93  					cfgClientID:    "foo",
    94  					cfgApiserverID: "foo",
    95  					cfgTenantID:    "foo",
    96  					cfgConfigMode:  "0",
    97  				},
    98  				expectedConfigMode: configModeDefault,
    99  			},
   100  			{
   101  				cfg: map[string]string{
   102  					cfgClientID:    "foo",
   103  					cfgApiserverID: "foo",
   104  					cfgTenantID:    "foo",
   105  					cfgConfigMode:  "1",
   106  				},
   107  				expectedConfigMode: configModeOmitSPNPrefix,
   108  			},
   109  		}
   110  
   111  		for _, v := range vectors {
   112  			persister := &fakePersister{}
   113  			provider, err := newAzureAuthProvider("", v.cfg, persister)
   114  			if err != nil {
   115  				t.Errorf("newAzureAuthProvider should not fail with '%s'", err)
   116  			}
   117  			if provider == nil {
   118  				t.Fatalf("newAzureAuthProvider should return non-nil provider")
   119  			}
   120  			azureProvider := provider.(*azureAuthProvider)
   121  			if azureProvider == nil {
   122  				t.Fatalf("newAzureAuthProvider should return an instance of type azureAuthProvider")
   123  			}
   124  			ts := azureProvider.tokenSource.(*azureTokenSource)
   125  			if ts == nil {
   126  				t.Fatalf("azureAuthProvider should be an instance of azureTokenSource")
   127  			}
   128  			if ts.configMode != v.expectedConfigMode {
   129  				t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode)
   130  			}
   131  		}
   132  	})
   133  }
   134  
   135  func TestTokenSourceDeviceCode(t *testing.T) {
   136  	var (
   137  		clientID    = "clientID"
   138  		tenantID    = "tenantID"
   139  		apiserverID = "apiserverID"
   140  		configMode  = configModeDefault
   141  		azureEnv    = azure.Environment{}
   142  	)
   143  	t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) {
   144  		if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil {
   145  			t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
   146  		}
   147  
   148  		if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil {
   149  			t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
   150  		}
   151  
   152  		_, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode)
   153  		actual := "client-id is empty"
   154  		if err.Error() != actual {
   155  			t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
   156  		}
   157  
   158  		_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode)
   159  		actual = "tenant-id is empty"
   160  		if err.Error() != actual {
   161  			t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
   162  		}
   163  
   164  		_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode)
   165  		actual = "apiserver-id is empty"
   166  		if err.Error() != actual {
   167  			t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
   168  		}
   169  	})
   170  }
   171  func TestAzureTokenSource(t *testing.T) {
   172  	configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
   173  	expectedConfigModes := []string{"1", "0"}
   174  
   175  	for i, configMode := range configModes {
   176  		t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) {
   177  			const (
   178  				serverID     = "fakeServerID"
   179  				clientID     = "fakeClientID"
   180  				tenantID     = "fakeTenantID"
   181  				accessToken  = "fakeToken"
   182  				environment  = "fakeEnvironment"
   183  				refreshToken = "fakeToken"
   184  				expiresIn    = "foo"
   185  				expiresOn    = "foo"
   186  			)
   187  			cfg := map[string]string{
   188  				cfgConfigMode:   strconv.Itoa(int(configMode)),
   189  				cfgApiserverID:  serverID,
   190  				cfgClientID:     clientID,
   191  				cfgTenantID:     tenantID,
   192  				cfgEnvironment:  environment,
   193  				cfgAccessToken:  accessToken,
   194  				cfgRefreshToken: refreshToken,
   195  				cfgExpiresIn:    expiresIn,
   196  				cfgExpiresOn:    expiresOn,
   197  			}
   198  			fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))}
   199  			persiter := &fakePersister{cache: make(map[string]string)}
   200  			tokenCache := newAzureTokenCache()
   201  			tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
   202  			azTokenSource := tokenSource.(*azureTokenSource)
   203  			token, err := azTokenSource.retrieveTokenFromCfg()
   204  			if err != nil {
   205  				t.Errorf("failed to retrieve the token form cfg: %s", err)
   206  			}
   207  			if token.apiserverID != serverID {
   208  				t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID)
   209  			}
   210  			if token.clientID != clientID {
   211  				t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID)
   212  			}
   213  			if token.tenantID != tenantID {
   214  				t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID)
   215  			}
   216  			expectedAudience := serverID
   217  			if configMode == configModeDefault {
   218  				expectedAudience = fmt.Sprintf("spn:%s", serverID)
   219  			}
   220  			if token.token.Resource != expectedAudience {
   221  				t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource)
   222  			}
   223  		})
   224  
   225  		t.Run("validate token against cache", func(t *testing.T) {
   226  			fakeAccessToken := "fake token 1"
   227  			fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
   228  			cfg := make(map[string]string)
   229  			persiter := &fakePersister{cache: make(map[string]string)}
   230  			tokenCache := newAzureTokenCache()
   231  			tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
   232  			token, err := tokenSource.Token()
   233  			if err != nil {
   234  				t.Errorf("failed to retrieve the token form cache: %v", err)
   235  			}
   236  
   237  			wantCacheLen := 1
   238  			if len(tokenCache.cache) != wantCacheLen {
   239  				t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen)
   240  			}
   241  
   242  			if token != tokenCache.cache[azureTokenKey] {
   243  				t.Error("Token() returned token != cached token")
   244  			}
   245  
   246  			wantCfg := token2Cfg(token)
   247  			wantCfg[cfgConfigMode] = expectedConfigModes[i]
   248  			persistedCfg := persiter.Cache()
   249  
   250  			wantCfgLen := len(wantCfg)
   251  			persistedCfgLen := len(persistedCfg)
   252  			if wantCfgLen != persistedCfgLen {
   253  				t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen)
   254  			}
   255  
   256  			for k, v := range persistedCfg {
   257  				if strings.Compare(v, wantCfg[k]) != 0 {
   258  					t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k])
   259  				}
   260  			}
   261  
   262  			fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
   263  			token, err = tokenSource.Token()
   264  			if err != nil {
   265  				t.Errorf("failed to retrieve the cached token: %v", err)
   266  			}
   267  
   268  			if token.token.AccessToken != fakeAccessToken {
   269  				t.Errorf("Token() didn't return the cached token")
   270  			}
   271  		})
   272  	}
   273  }
   274  
   275  func TestAzureTokenSourceScenarios(t *testing.T) {
   276  	expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
   277  	extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
   278  	fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
   279  	wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
   280  	tests := []struct {
   281  		name         string
   282  		sourceToken  *azureToken
   283  		refreshToken *azureToken
   284  		cachedToken  *azureToken
   285  		configToken  *azureToken
   286  		expectToken  *azureToken
   287  		tokenErr     error
   288  		refreshErr   error
   289  		expectErr    string
   290  		tokenCalls   uint
   291  		refreshCalls uint
   292  		persistCalls uint
   293  	}{
   294  		{
   295  			name:         "new config",
   296  			sourceToken:  fakeToken,
   297  			expectToken:  fakeToken,
   298  			tokenCalls:   1,
   299  			persistCalls: 1,
   300  		},
   301  		{
   302  			name:        "load token from cache",
   303  			sourceToken: wrongToken,
   304  			cachedToken: fakeToken,
   305  			configToken: wrongToken,
   306  			expectToken: fakeToken,
   307  		},
   308  		{
   309  			name:        "load token from config",
   310  			sourceToken: wrongToken,
   311  			configToken: fakeToken,
   312  			expectToken: fakeToken,
   313  		},
   314  		{
   315  			name:         "cached token timeout, extend success, config token should never load",
   316  			cachedToken:  expiredToken,
   317  			refreshToken: extendedToken,
   318  			configToken:  wrongToken,
   319  			expectToken:  extendedToken,
   320  			refreshCalls: 1,
   321  			persistCalls: 1,
   322  		},
   323  		{
   324  			name:         "config token timeout, extend failure, acquire new token",
   325  			configToken:  expiredToken,
   326  			refreshErr:   fakeTokenRefreshError{message: "FakeError happened when refreshing"},
   327  			sourceToken:  fakeToken,
   328  			expectToken:  fakeToken,
   329  			refreshCalls: 1,
   330  			tokenCalls:   1,
   331  			persistCalls: 1,
   332  		},
   333  		{
   334  			name:         "extend failure with fmt.Errorf nested tokenRefreshError",
   335  			configToken:  expiredToken,
   336  			refreshErr:   fmt.Errorf("refreshing token: %w", fakeTokenRefreshError{message: "nested FakeError happened when refreshing"}),
   337  			sourceToken:  fakeToken,
   338  			expectToken:  fakeToken,
   339  			refreshCalls: 1,
   340  			tokenCalls:   1,
   341  			persistCalls: 1,
   342  		},
   343  		{
   344  			name:         "unexpected error when extend",
   345  			configToken:  expiredToken,
   346  			refreshErr:   errors.New("unexpected refresh error"),
   347  			sourceToken:  fakeToken,
   348  			expectErr:    "unexpected refresh error",
   349  			refreshCalls: 1,
   350  		},
   351  		{
   352  			name:       "token error",
   353  			tokenErr:   errors.New("tokenerr"),
   354  			expectErr:  "tokenerr",
   355  			tokenCalls: 1,
   356  		},
   357  		{
   358  			name:        "Token() got expired token",
   359  			sourceToken: expiredToken,
   360  			expectErr:   "newly acquired token is expired",
   361  			tokenCalls:  1,
   362  		},
   363  		{
   364  			name:        "Token() got nil but no error",
   365  			sourceToken: nil,
   366  			expectErr:   "unable to acquire token",
   367  			tokenCalls:  1,
   368  		},
   369  	}
   370  	for _, tc := range tests {
   371  		configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
   372  
   373  		for _, configMode := range configModes {
   374  			t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) {
   375  				persister := newFakePersister()
   376  
   377  				cfg := map[string]string{
   378  					cfgConfigMode: strconv.Itoa(int(configMode)),
   379  				}
   380  				if tc.configToken != nil {
   381  					cfg = token2Cfg(tc.configToken)
   382  				}
   383  
   384  				tokenCache := newAzureTokenCache()
   385  				if tc.cachedToken != nil {
   386  					tokenCache.setToken(azureTokenKey, tc.cachedToken)
   387  				}
   388  
   389  				fakeSource := fakeTokenSource{
   390  					token:        tc.sourceToken,
   391  					tokenErr:     tc.tokenErr,
   392  					refreshToken: tc.refreshToken,
   393  					refreshErr:   tc.refreshErr,
   394  				}
   395  
   396  				tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
   397  				token, err := tokenSource.Token()
   398  
   399  				if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID {
   400  					t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID)
   401  				}
   402  				if fakeSource.tokenCalls != tc.tokenCalls {
   403  					t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
   404  				}
   405  
   406  				if fakeSource.refreshCalls != tc.refreshCalls {
   407  					t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
   408  				}
   409  
   410  				if persister.calls != tc.persistCalls {
   411  					t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
   412  				}
   413  
   414  				if tc.expectErr != "" {
   415  					if !strings.Contains(err.Error(), tc.expectErr) {
   416  						t.Errorf("expecting error %v, got %v", tc.expectErr, err)
   417  					}
   418  					if token != nil {
   419  						t.Errorf("token should be nil in err situation, got %v", token)
   420  					}
   421  				} else {
   422  					if err != nil {
   423  						t.Fatalf("error should be nil, got %v", err)
   424  					}
   425  					if token.token.AccessToken != tc.expectToken.token.AccessToken {
   426  						t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
   427  					}
   428  				}
   429  			})
   430  		}
   431  	}
   432  }
   433  
   434  type fakePersister struct {
   435  	lock  sync.Mutex
   436  	cache map[string]string
   437  	calls uint
   438  }
   439  
   440  func newFakePersister() fakePersister {
   441  	return fakePersister{cache: make(map[string]string), calls: 0}
   442  }
   443  
   444  func (p *fakePersister) Persist(cache map[string]string) error {
   445  	p.lock.Lock()
   446  	defer p.lock.Unlock()
   447  	p.calls++
   448  	p.cache = map[string]string{}
   449  	for k, v := range cache {
   450  		p.cache[k] = v
   451  	}
   452  	return nil
   453  }
   454  
   455  func (p *fakePersister) Cache() map[string]string {
   456  	ret := map[string]string{}
   457  	p.lock.Lock()
   458  	defer p.lock.Unlock()
   459  	for k, v := range p.cache {
   460  		ret[k] = v
   461  	}
   462  	return ret
   463  }
   464  
   465  // a simple token source simply always returns the token property
   466  type fakeTokenSource struct {
   467  	token        *azureToken
   468  	tokenCalls   uint
   469  	tokenErr     error
   470  	refreshToken *azureToken
   471  	refreshCalls uint
   472  	refreshErr   error
   473  }
   474  
   475  func (ts *fakeTokenSource) Token() (*azureToken, error) {
   476  	ts.tokenCalls++
   477  	return ts.token, ts.tokenErr
   478  }
   479  
   480  func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
   481  	ts.refreshCalls++
   482  	return ts.refreshToken, ts.refreshErr
   483  }
   484  
   485  func token2Cfg(token *azureToken) map[string]string {
   486  	cfg := make(map[string]string)
   487  	cfg[cfgAccessToken] = token.token.AccessToken
   488  	cfg[cfgRefreshToken] = token.token.RefreshToken
   489  	cfg[cfgEnvironment] = token.environment
   490  	cfg[cfgClientID] = token.clientID
   491  	cfg[cfgTenantID] = token.tenantID
   492  	cfg[cfgApiserverID] = token.apiserverID
   493  	cfg[cfgExpiresIn] = string(token.token.ExpiresIn)
   494  	cfg[cfgExpiresOn] = string(token.token.ExpiresOn)
   495  	return cfg
   496  }
   497  
   498  func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
   499  	return &azureToken{
   500  		token:       newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
   501  		environment: "testenv",
   502  		clientID:    "fake",
   503  		tenantID:    "fake",
   504  		apiserverID: "fake",
   505  	}
   506  }
   507  
   508  func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
   509  	return adal.Token{
   510  		AccessToken:  accessToken,
   511  		RefreshToken: "fake",
   512  		ExpiresIn:    "3600",
   513  		ExpiresOn:    json.Number(expiresOn),
   514  		NotBefore:    json.Number(expiresOn),
   515  		Resource:     "fake",
   516  		Type:         "fake",
   517  	}
   518  }
   519  
   520  // copied from go-autorest/adal
   521  type fakeTokenRefreshError struct {
   522  	message string
   523  	resp    *http.Response
   524  }
   525  
   526  // Error implements the error interface which is part of the TokenRefreshError interface.
   527  func (tre fakeTokenRefreshError) Error() string {
   528  	return tre.message
   529  }
   530  
   531  // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
   532  func (tre fakeTokenRefreshError) Response() *http.Response {
   533  	return tre.resp
   534  }