github.com/argoproj/argo-cd/v3@v3.2.1/util/helm/creds_test.go (about) 1 package helm 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 "time" 9 10 "github.com/golang-jwt/jwt/v5" 11 gocache "github.com/patrickmn/go-cache" 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/require" 14 15 argoutils "github.com/argoproj/argo-cd/v3/util" 16 "github.com/argoproj/argo-cd/v3/util/workloadidentity" 17 "github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks" 18 ) 19 20 func TestWorkLoadIdentityUserNameShouldBeEmptyGuid(t *testing.T) { 21 workloadIdentityMock := new(mocks.TokenProvider) 22 creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock) 23 username := creds.GetUsername() 24 25 assert.Equal(t, workloadidentity.EmptyGuid, username, "The username for azure workload identity is not empty Guid") 26 } 27 28 func TestGetAccessTokenShouldReturnTokenFromCacheIfPresent(t *testing.T) { 29 workloadIdentityMock := new(mocks.TokenProvider) 30 creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock) 31 32 cacheKey, err := argoutils.GenerateCacheKey("accesstoken-%s", "contoso.azurecr.io") 33 require.NoError(t, err, "Error generating cache key") 34 35 // Store the token in the cache 36 storeAzureToken(cacheKey, "testToken", time.Hour) 37 38 // Retrieve the token from the cache 39 token, err := creds.GetAccessToken() 40 require.NoError(t, err, "Error getting access token") 41 assert.Equal(t, "testToken", token, "The retrieved token should match the stored token") 42 } 43 44 func TestGetPasswordShouldReturnTokenFromCacheIfPresent(t *testing.T) { 45 workloadIdentityMock := new(mocks.TokenProvider) 46 creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock) 47 48 cacheKey, err := argoutils.GenerateCacheKey("accesstoken-%s", "contoso.azurecr.io") 49 require.NoError(t, err, "Error generating cache key") 50 51 // Store the token in the cache 52 storeAzureToken(cacheKey, "testToken", time.Hour) 53 54 // Retrieve the token from the cache 55 token, err := creds.GetPassword() 56 require.NoError(t, err, "Error getting access token") 57 assert.Equal(t, "testToken", token, "The retrieved token should match the stored token") 58 } 59 60 func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) { 61 mockServerURL := "" 62 mockedServerURL := func() string { 63 return mockServerURL 64 } 65 66 // Mock the server to return a successful response 67 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 switch r.URL.Path { 69 case "/v2/": 70 w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:])) 71 w.WriteHeader(http.StatusUnauthorized) 72 73 case "/oauth2/exchange": 74 response := `{"refresh_token":"newRefreshToken"}` 75 w.WriteHeader(http.StatusOK) 76 _, err := w.Write([]byte(response)) 77 require.NoError(t, err) 78 } 79 })) 80 mockServerURL = mockServer.URL 81 defer mockServer.Close() 82 83 workloadIdentityMock := new(mocks.TokenProvider) 84 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 85 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 86 87 // Retrieve the token from the cache 88 token, err := creds.GetPassword() 89 require.NoError(t, err) 90 assert.Equal(t, "newRefreshToken", token, "The retrieved token should match the stored token") 91 } 92 93 func TestChallengeAzureContainerRegistry(t *testing.T) { 94 // Set up the mock server 95 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 96 assert.Equal(t, "/v2/", r.URL.Path) 97 w.Header().Set("Www-Authenticate", `Bearer realm="https://login.microsoftonline.com/",service="registry.example.com"`) 98 w.WriteHeader(http.StatusUnauthorized) 99 })) 100 defer mockServer.Close() 101 102 workloadIdentityMock := new(mocks.TokenProvider) 103 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 104 105 tokenParams, err := creds.challengeAzureContainerRegistry(creds.repoURL) 106 require.NoError(t, err) 107 108 expectedParams := map[string]string{ 109 "realm": "https://login.microsoftonline.com/", 110 "service": "registry.example.com", 111 } 112 assert.Equal(t, expectedParams, tokenParams) 113 } 114 115 func TestChallengeAzureContainerRegistryNoChallenge(t *testing.T) { 116 // Set up the mock server without Www-Authenticate header 117 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 118 assert.Equal(t, "/v2/", r.URL.Path) 119 w.WriteHeader(http.StatusOK) 120 })) 121 defer mockServer.Close() 122 123 // Replace the real URL with the mock server URL 124 workloadIdentityMock := new(mocks.TokenProvider) 125 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 126 127 _, err := creds.challengeAzureContainerRegistry(creds.repoURL) 128 require.Error(t, err) 129 assert.Contains(t, err.Error(), "did not issue a challenge") 130 } 131 132 func TestChallengeAzureContainerRegistryNonBearer(t *testing.T) { 133 // Set up the mock server with a non-Bearer Www-Authenticate header 134 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 135 assert.Equal(t, "/v2/", r.URL.Path) 136 w.Header().Set("Www-Authenticate", `Basic realm="example"`) 137 w.WriteHeader(http.StatusUnauthorized) 138 })) 139 defer mockServer.Close() 140 141 // Replace the real URL with the mock server URL 142 workloadIdentityMock := new(mocks.TokenProvider) 143 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 144 145 _, err := creds.challengeAzureContainerRegistry(creds.repoURL) 146 assert.ErrorContains(t, err, "does not allow 'Bearer' authentication") 147 } 148 149 func TestChallengeAzureContainerRegistryNoService(t *testing.T) { 150 // Set up the mock server with a non-Bearer Www-Authenticate header 151 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 152 assert.Equal(t, "/v2/", r.URL.Path) 153 w.Header().Set("Www-Authenticate", `Bearer realm="example"`) 154 w.WriteHeader(http.StatusUnauthorized) 155 })) 156 defer mockServer.Close() 157 158 // Replace the real URL with the mock server URL 159 workloadIdentityMock := new(mocks.TokenProvider) 160 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 161 162 _, err := creds.challengeAzureContainerRegistry(creds.repoURL) 163 assert.ErrorContains(t, err, "service parameter not found in challenge") 164 } 165 166 func TestChallengeAzureContainerRegistryNoRealm(t *testing.T) { 167 // Set up the mock server with a non-Bearer Www-Authenticate header 168 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 169 assert.Equal(t, "/v2/", r.URL.Path) 170 w.Header().Set("Www-Authenticate", `Bearer service="example"`) 171 w.WriteHeader(http.StatusUnauthorized) 172 })) 173 defer mockServer.Close() 174 175 // Replace the real URL with the mock server URL 176 workloadIdentityMock := new(mocks.TokenProvider) 177 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 178 179 _, err := creds.challengeAzureContainerRegistry(creds.repoURL) 180 assert.ErrorContains(t, err, "realm parameter not found in challenge") 181 } 182 183 func TestGetAccessTokenAfterChallenge_Success(t *testing.T) { 184 // Mock the server to return a successful response 185 mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 186 assert.Equal(t, "/oauth2/exchange", r.URL.Path) 187 188 response := `{"refresh_token":"newRefreshToken"}` 189 w.WriteHeader(http.StatusOK) 190 _, err := w.Write([]byte(response)) 191 require.NoError(t, err) 192 })) 193 defer mockServer.Close() 194 195 workloadIdentityMock := new(mocks.TokenProvider) 196 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 197 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 198 199 tokenParams := map[string]string{ 200 "realm": mockServer.URL, 201 "service": "registry.example.com", 202 } 203 204 refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams) 205 require.NoError(t, err) 206 assert.Equal(t, "newRefreshToken", refreshToken) 207 } 208 209 func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) { 210 // Mock the server to return an error response 211 mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 212 assert.Equal(t, "/oauth2/exchange", r.URL.Path) 213 w.WriteHeader(http.StatusBadRequest) 214 _, err := w.Write([]byte(`{"error": "invalid_request"}`)) 215 require.NoError(t, err) 216 })) 217 defer mockServer.Close() 218 219 // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper 220 workloadIdentityMock := new(mocks.TokenProvider) 221 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 222 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 223 224 tokenParams := map[string]string{ 225 "realm": mockServer.URL, 226 "service": "registry.example.com", 227 } 228 229 refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams) 230 require.ErrorContains(t, err, "failed to get refresh token") 231 assert.Empty(t, refreshToken) 232 } 233 234 func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) { 235 // Mock the server to return a malformed JSON response 236 mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 237 assert.Equal(t, "/oauth2/exchange", r.URL.Path) 238 w.WriteHeader(http.StatusOK) 239 _, err := w.Write([]byte(`{"refresh_token":`)) 240 require.NoError(t, err) 241 })) 242 defer mockServer.Close() 243 244 // Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper 245 workloadIdentityMock := new(mocks.TokenProvider) 246 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 247 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 248 249 tokenParams := map[string]string{ 250 "realm": mockServer.URL, 251 "service": "registry.example.com", 252 } 253 254 refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams) 255 require.ErrorContains(t, err, "failed to unmarshal response body") 256 assert.Empty(t, refreshToken) 257 } 258 259 // Helper to generate a mock JWT token with a given expiry time 260 func generateMockJWT(expiry time.Time) (string, error) { 261 claims := jwt.MapClaims{ 262 "exp": expiry.Unix(), 263 } 264 token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) 265 // Use a dummy secret for signing 266 return token.SignedString([]byte("dummy-secret")) 267 } 268 269 func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) { 270 resetAzureTokenCache() 271 accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) 272 accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) 273 274 mockServerURL := "" 275 mockedServerURL := func() string { 276 return mockServerURL 277 } 278 279 callCount := 0 280 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 281 switch r.URL.Path { 282 case "/v2/": 283 assert.Equal(t, "/v2/", r.URL.Path) 284 w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:])) 285 w.WriteHeader(http.StatusUnauthorized) 286 case "/oauth2/exchange": 287 assert.Equal(t, "/oauth2/exchange", r.URL.Path) 288 var response string 289 switch callCount { 290 case 0: 291 response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken1) 292 case 1: 293 response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken2) 294 default: 295 response = `{"refresh_token": "defaultToken"}` 296 } 297 callCount++ 298 w.WriteHeader(http.StatusOK) 299 _, err := w.Write([]byte(response)) 300 require.NoError(t, err) 301 default: 302 http.NotFound(w, r) 303 } 304 })) 305 defer mockServer.Close() 306 mockServerURL = mockServer.URL 307 308 workloadIdentityMock := new(mocks.TokenProvider) 309 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 310 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 311 312 refreshToken, err := creds.GetAccessToken() 313 require.NoError(t, err) 314 assert.Equal(t, accessToken1, refreshToken) 315 316 time.Sleep(5 * time.Second) // Wait for the token to expire 317 318 refreshToken, err = creds.GetAccessToken() 319 require.NoError(t, err) 320 assert.Equal(t, accessToken2, refreshToken) 321 } 322 323 func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) { 324 resetAzureTokenCache() 325 accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute)) 326 accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute)) 327 328 mockServerURL := "" 329 mockedServerURL := func() string { 330 return mockServerURL 331 } 332 333 callCount := 0 334 mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 335 switch r.URL.Path { 336 case "/v2/": 337 assert.Equal(t, "/v2/", r.URL.Path) 338 w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:])) 339 w.WriteHeader(http.StatusUnauthorized) 340 case "/oauth2/exchange": 341 assert.Equal(t, "/oauth2/exchange", r.URL.Path) 342 var response string 343 switch callCount { 344 case 0: 345 response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken1) 346 case 1: 347 response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken2) 348 default: 349 response = `{"refresh_token": "defaultToken"}` 350 } 351 callCount++ 352 w.WriteHeader(http.StatusOK) 353 _, err := w.Write([]byte(response)) 354 require.NoError(t, err) 355 default: 356 http.NotFound(w, r) 357 } 358 })) 359 defer mockServer.Close() 360 mockServerURL = mockServer.URL 361 362 workloadIdentityMock := new(mocks.TokenProvider) 363 workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil) 364 creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock) 365 366 refreshToken, err := creds.GetAccessToken() 367 require.NoError(t, err) 368 assert.Equal(t, accessToken1, refreshToken) 369 370 time.Sleep(5 * time.Second) // Wait for the token to expire 371 372 refreshToken, err = creds.GetAccessToken() 373 require.NoError(t, err) 374 assert.Equal(t, accessToken1, refreshToken) 375 } 376 377 func resetAzureTokenCache() { 378 azureTokenCache = gocache.New(gocache.NoExpiration, 0) 379 }