github.com/argoproj/argo-cd/v3@v3.2.1/util/helm/creds_test.go (about)

     1  package helm
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/golang-jwt/jwt/v5"
    11  	gocache "github.com/patrickmn/go-cache"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	argoutils "github.com/argoproj/argo-cd/v3/util"
    16  	"github.com/argoproj/argo-cd/v3/util/workloadidentity"
    17  	"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
    18  )
    19  
    20  func TestWorkLoadIdentityUserNameShouldBeEmptyGuid(t *testing.T) {
    21  	workloadIdentityMock := new(mocks.TokenProvider)
    22  	creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock)
    23  	username := creds.GetUsername()
    24  
    25  	assert.Equal(t, workloadidentity.EmptyGuid, username, "The username for azure workload identity is not empty Guid")
    26  }
    27  
    28  func TestGetAccessTokenShouldReturnTokenFromCacheIfPresent(t *testing.T) {
    29  	workloadIdentityMock := new(mocks.TokenProvider)
    30  	creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock)
    31  
    32  	cacheKey, err := argoutils.GenerateCacheKey("accesstoken-%s", "contoso.azurecr.io")
    33  	require.NoError(t, err, "Error generating cache key")
    34  
    35  	// Store the token in the cache
    36  	storeAzureToken(cacheKey, "testToken", time.Hour)
    37  
    38  	// Retrieve the token from the cache
    39  	token, err := creds.GetAccessToken()
    40  	require.NoError(t, err, "Error getting access token")
    41  	assert.Equal(t, "testToken", token, "The retrieved token should match the stored token")
    42  }
    43  
    44  func TestGetPasswordShouldReturnTokenFromCacheIfPresent(t *testing.T) {
    45  	workloadIdentityMock := new(mocks.TokenProvider)
    46  	creds := NewAzureWorkloadIdentityCreds("contoso.azurecr.io/charts", "", nil, nil, false, workloadIdentityMock)
    47  
    48  	cacheKey, err := argoutils.GenerateCacheKey("accesstoken-%s", "contoso.azurecr.io")
    49  	require.NoError(t, err, "Error generating cache key")
    50  
    51  	// Store the token in the cache
    52  	storeAzureToken(cacheKey, "testToken", time.Hour)
    53  
    54  	// Retrieve the token from the cache
    55  	token, err := creds.GetPassword()
    56  	require.NoError(t, err, "Error getting access token")
    57  	assert.Equal(t, "testToken", token, "The retrieved token should match the stored token")
    58  }
    59  
    60  func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
    61  	mockServerURL := ""
    62  	mockedServerURL := func() string {
    63  		return mockServerURL
    64  	}
    65  
    66  	// Mock the server to return a successful response
    67  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    68  		switch r.URL.Path {
    69  		case "/v2/":
    70  			w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:]))
    71  			w.WriteHeader(http.StatusUnauthorized)
    72  
    73  		case "/oauth2/exchange":
    74  			response := `{"refresh_token":"newRefreshToken"}`
    75  			w.WriteHeader(http.StatusOK)
    76  			_, err := w.Write([]byte(response))
    77  			require.NoError(t, err)
    78  		}
    79  	}))
    80  	mockServerURL = mockServer.URL
    81  	defer mockServer.Close()
    82  
    83  	workloadIdentityMock := new(mocks.TokenProvider)
    84  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
    85  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
    86  
    87  	// Retrieve the token from the cache
    88  	token, err := creds.GetPassword()
    89  	require.NoError(t, err)
    90  	assert.Equal(t, "newRefreshToken", token, "The retrieved token should match the stored token")
    91  }
    92  
    93  func TestChallengeAzureContainerRegistry(t *testing.T) {
    94  	// Set up the mock server
    95  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    96  		assert.Equal(t, "/v2/", r.URL.Path)
    97  		w.Header().Set("Www-Authenticate", `Bearer realm="https://login.microsoftonline.com/",service="registry.example.com"`)
    98  		w.WriteHeader(http.StatusUnauthorized)
    99  	}))
   100  	defer mockServer.Close()
   101  
   102  	workloadIdentityMock := new(mocks.TokenProvider)
   103  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   104  
   105  	tokenParams, err := creds.challengeAzureContainerRegistry(creds.repoURL)
   106  	require.NoError(t, err)
   107  
   108  	expectedParams := map[string]string{
   109  		"realm":   "https://login.microsoftonline.com/",
   110  		"service": "registry.example.com",
   111  	}
   112  	assert.Equal(t, expectedParams, tokenParams)
   113  }
   114  
   115  func TestChallengeAzureContainerRegistryNoChallenge(t *testing.T) {
   116  	// Set up the mock server without Www-Authenticate header
   117  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   118  		assert.Equal(t, "/v2/", r.URL.Path)
   119  		w.WriteHeader(http.StatusOK)
   120  	}))
   121  	defer mockServer.Close()
   122  
   123  	// Replace the real URL with the mock server URL
   124  	workloadIdentityMock := new(mocks.TokenProvider)
   125  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   126  
   127  	_, err := creds.challengeAzureContainerRegistry(creds.repoURL)
   128  	require.Error(t, err)
   129  	assert.Contains(t, err.Error(), "did not issue a challenge")
   130  }
   131  
   132  func TestChallengeAzureContainerRegistryNonBearer(t *testing.T) {
   133  	// Set up the mock server with a non-Bearer Www-Authenticate header
   134  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   135  		assert.Equal(t, "/v2/", r.URL.Path)
   136  		w.Header().Set("Www-Authenticate", `Basic realm="example"`)
   137  		w.WriteHeader(http.StatusUnauthorized)
   138  	}))
   139  	defer mockServer.Close()
   140  
   141  	// Replace the real URL with the mock server URL
   142  	workloadIdentityMock := new(mocks.TokenProvider)
   143  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   144  
   145  	_, err := creds.challengeAzureContainerRegistry(creds.repoURL)
   146  	assert.ErrorContains(t, err, "does not allow 'Bearer' authentication")
   147  }
   148  
   149  func TestChallengeAzureContainerRegistryNoService(t *testing.T) {
   150  	// Set up the mock server with a non-Bearer Www-Authenticate header
   151  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   152  		assert.Equal(t, "/v2/", r.URL.Path)
   153  		w.Header().Set("Www-Authenticate", `Bearer realm="example"`)
   154  		w.WriteHeader(http.StatusUnauthorized)
   155  	}))
   156  	defer mockServer.Close()
   157  
   158  	// Replace the real URL with the mock server URL
   159  	workloadIdentityMock := new(mocks.TokenProvider)
   160  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   161  
   162  	_, err := creds.challengeAzureContainerRegistry(creds.repoURL)
   163  	assert.ErrorContains(t, err, "service parameter not found in challenge")
   164  }
   165  
   166  func TestChallengeAzureContainerRegistryNoRealm(t *testing.T) {
   167  	// Set up the mock server with a non-Bearer Www-Authenticate header
   168  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   169  		assert.Equal(t, "/v2/", r.URL.Path)
   170  		w.Header().Set("Www-Authenticate", `Bearer service="example"`)
   171  		w.WriteHeader(http.StatusUnauthorized)
   172  	}))
   173  	defer mockServer.Close()
   174  
   175  	// Replace the real URL with the mock server URL
   176  	workloadIdentityMock := new(mocks.TokenProvider)
   177  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   178  
   179  	_, err := creds.challengeAzureContainerRegistry(creds.repoURL)
   180  	assert.ErrorContains(t, err, "realm parameter not found in challenge")
   181  }
   182  
   183  func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
   184  	// Mock the server to return a successful response
   185  	mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  		assert.Equal(t, "/oauth2/exchange", r.URL.Path)
   187  
   188  		response := `{"refresh_token":"newRefreshToken"}`
   189  		w.WriteHeader(http.StatusOK)
   190  		_, err := w.Write([]byte(response))
   191  		require.NoError(t, err)
   192  	}))
   193  	defer mockServer.Close()
   194  
   195  	workloadIdentityMock := new(mocks.TokenProvider)
   196  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
   197  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   198  
   199  	tokenParams := map[string]string{
   200  		"realm":   mockServer.URL,
   201  		"service": "registry.example.com",
   202  	}
   203  
   204  	refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams)
   205  	require.NoError(t, err)
   206  	assert.Equal(t, "newRefreshToken", refreshToken)
   207  }
   208  
   209  func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {
   210  	// Mock the server to return an error response
   211  	mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   212  		assert.Equal(t, "/oauth2/exchange", r.URL.Path)
   213  		w.WriteHeader(http.StatusBadRequest)
   214  		_, err := w.Write([]byte(`{"error": "invalid_request"}`))
   215  		require.NoError(t, err)
   216  	}))
   217  	defer mockServer.Close()
   218  
   219  	// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
   220  	workloadIdentityMock := new(mocks.TokenProvider)
   221  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
   222  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   223  
   224  	tokenParams := map[string]string{
   225  		"realm":   mockServer.URL,
   226  		"service": "registry.example.com",
   227  	}
   228  
   229  	refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams)
   230  	require.ErrorContains(t, err, "failed to get refresh token")
   231  	assert.Empty(t, refreshToken)
   232  }
   233  
   234  func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
   235  	// Mock the server to return a malformed JSON response
   236  	mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   237  		assert.Equal(t, "/oauth2/exchange", r.URL.Path)
   238  		w.WriteHeader(http.StatusOK)
   239  		_, err := w.Write([]byte(`{"refresh_token":`))
   240  		require.NoError(t, err)
   241  	}))
   242  	defer mockServer.Close()
   243  
   244  	// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
   245  	workloadIdentityMock := new(mocks.TokenProvider)
   246  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
   247  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   248  
   249  	tokenParams := map[string]string{
   250  		"realm":   mockServer.URL,
   251  		"service": "registry.example.com",
   252  	}
   253  
   254  	refreshToken, err := creds.getAccessTokenAfterChallenge(tokenParams)
   255  	require.ErrorContains(t, err, "failed to unmarshal response body")
   256  	assert.Empty(t, refreshToken)
   257  }
   258  
   259  // Helper to generate a mock JWT token with a given expiry time
   260  func generateMockJWT(expiry time.Time) (string, error) {
   261  	claims := jwt.MapClaims{
   262  		"exp": expiry.Unix(),
   263  	}
   264  	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
   265  	// Use a dummy secret for signing
   266  	return token.SignedString([]byte("dummy-secret"))
   267  }
   268  
   269  func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) {
   270  	resetAzureTokenCache()
   271  	accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
   272  	accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
   273  
   274  	mockServerURL := ""
   275  	mockedServerURL := func() string {
   276  		return mockServerURL
   277  	}
   278  
   279  	callCount := 0
   280  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   281  		switch r.URL.Path {
   282  		case "/v2/":
   283  			assert.Equal(t, "/v2/", r.URL.Path)
   284  			w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:]))
   285  			w.WriteHeader(http.StatusUnauthorized)
   286  		case "/oauth2/exchange":
   287  			assert.Equal(t, "/oauth2/exchange", r.URL.Path)
   288  			var response string
   289  			switch callCount {
   290  			case 0:
   291  				response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken1)
   292  			case 1:
   293  				response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken2)
   294  			default:
   295  				response = `{"refresh_token": "defaultToken"}`
   296  			}
   297  			callCount++
   298  			w.WriteHeader(http.StatusOK)
   299  			_, err := w.Write([]byte(response))
   300  			require.NoError(t, err)
   301  		default:
   302  			http.NotFound(w, r)
   303  		}
   304  	}))
   305  	defer mockServer.Close()
   306  	mockServerURL = mockServer.URL
   307  
   308  	workloadIdentityMock := new(mocks.TokenProvider)
   309  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
   310  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   311  
   312  	refreshToken, err := creds.GetAccessToken()
   313  	require.NoError(t, err)
   314  	assert.Equal(t, accessToken1, refreshToken)
   315  
   316  	time.Sleep(5 * time.Second) // Wait for the token to expire
   317  
   318  	refreshToken, err = creds.GetAccessToken()
   319  	require.NoError(t, err)
   320  	assert.Equal(t, accessToken2, refreshToken)
   321  }
   322  
   323  func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
   324  	resetAzureTokenCache()
   325  	accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute))
   326  	accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
   327  
   328  	mockServerURL := ""
   329  	mockedServerURL := func() string {
   330  		return mockServerURL
   331  	}
   332  
   333  	callCount := 0
   334  	mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   335  		switch r.URL.Path {
   336  		case "/v2/":
   337  			assert.Equal(t, "/v2/", r.URL.Path)
   338  			w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm=%q,service=%q`, mockedServerURL(), mockedServerURL()[8:]))
   339  			w.WriteHeader(http.StatusUnauthorized)
   340  		case "/oauth2/exchange":
   341  			assert.Equal(t, "/oauth2/exchange", r.URL.Path)
   342  			var response string
   343  			switch callCount {
   344  			case 0:
   345  				response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken1)
   346  			case 1:
   347  				response = fmt.Sprintf(`{"refresh_token": %q}`, accessToken2)
   348  			default:
   349  				response = `{"refresh_token": "defaultToken"}`
   350  			}
   351  			callCount++
   352  			w.WriteHeader(http.StatusOK)
   353  			_, err := w.Write([]byte(response))
   354  			require.NoError(t, err)
   355  		default:
   356  			http.NotFound(w, r)
   357  		}
   358  	}))
   359  	defer mockServer.Close()
   360  	mockServerURL = mockServer.URL
   361  
   362  	workloadIdentityMock := new(mocks.TokenProvider)
   363  	workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
   364  	creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
   365  
   366  	refreshToken, err := creds.GetAccessToken()
   367  	require.NoError(t, err)
   368  	assert.Equal(t, accessToken1, refreshToken)
   369  
   370  	time.Sleep(5 * time.Second) // Wait for the token to expire
   371  
   372  	refreshToken, err = creds.GetAccessToken()
   373  	require.NoError(t, err)
   374  	assert.Equal(t, accessToken1, refreshToken)
   375  }
   376  
   377  func resetAzureTokenCache() {
   378  	azureTokenCache = gocache.New(gocache.NoExpiration, 0)
   379  }