github.com/argoproj/argo-cd/v3@v3.2.1/server/extension/extension_test.go (about)

     1  package extension_test
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"sync"
    11  	"testing"
    12  
    13  	"github.com/sirupsen/logrus/hooks/test"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/mock"
    16  	"github.com/stretchr/testify/require"
    17  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    18  
    19  	"github.com/argoproj/argo-cd/v3/util/rbac"
    20  
    21  	"github.com/argoproj/argo-cd/v3/pkg/apis/application/v1alpha1"
    22  	"github.com/argoproj/argo-cd/v3/server/extension"
    23  	"github.com/argoproj/argo-cd/v3/server/extension/mocks"
    24  	dbmocks "github.com/argoproj/argo-cd/v3/util/db/mocks"
    25  	"github.com/argoproj/argo-cd/v3/util/settings"
    26  )
    27  
    28  func TestValidateHeaders(t *testing.T) {
    29  	t.Run("will build RequestResources successfully", func(t *testing.T) {
    30  		// given
    31  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
    32  		require.NoError(t, err, "error initializing request")
    33  		r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name")
    34  		r.Header.Add(extension.HeaderArgoCDProjectName, "project-name")
    35  
    36  		// when
    37  		rr, err := extension.ValidateHeaders(r)
    38  
    39  		// then
    40  		require.NoError(t, err)
    41  		assert.NotNil(t, rr)
    42  		assert.Equal(t, "namespace", rr.ApplicationNamespace)
    43  		assert.Equal(t, "app-name", rr.ApplicationName)
    44  		assert.Equal(t, "project-name", rr.ProjectName)
    45  	})
    46  	t.Run("will return error if application is malformatted", func(t *testing.T) {
    47  		// given
    48  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
    49  		require.NoError(t, err, "error initializing request")
    50  		r.Header.Add(extension.HeaderArgoCDApplicationName, "no-namespace")
    51  
    52  		// when
    53  		rr, err := extension.ValidateHeaders(r)
    54  
    55  		// then
    56  		require.Error(t, err)
    57  		assert.Nil(t, rr)
    58  	})
    59  	t.Run("will return error if application header is missing", func(t *testing.T) {
    60  		// given
    61  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
    62  		require.NoError(t, err, "error initializing request")
    63  		r.Header.Add(extension.HeaderArgoCDProjectName, "project-name")
    64  
    65  		// when
    66  		rr, err := extension.ValidateHeaders(r)
    67  
    68  		// then
    69  		require.Error(t, err)
    70  		assert.Nil(t, rr)
    71  	})
    72  	t.Run("will return error if project header is missing", func(t *testing.T) {
    73  		// given
    74  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
    75  		require.NoError(t, err, "error initializing request")
    76  		r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name")
    77  
    78  		// when
    79  		rr, err := extension.ValidateHeaders(r)
    80  
    81  		// then
    82  		require.Error(t, err)
    83  		assert.Nil(t, rr)
    84  	})
    85  	t.Run("will return error if invalid namespace", func(t *testing.T) {
    86  		// given
    87  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
    88  		require.NoError(t, err, "error initializing request")
    89  		r.Header.Add(extension.HeaderArgoCDApplicationName, "bad%namespace:app-name")
    90  		r.Header.Add(extension.HeaderArgoCDProjectName, "project-name")
    91  
    92  		// when
    93  		rr, err := extension.ValidateHeaders(r)
    94  
    95  		// then
    96  		require.Error(t, err)
    97  		assert.Nil(t, rr)
    98  	})
    99  	t.Run("will return error if invalid app name", func(t *testing.T) {
   100  		// given
   101  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
   102  		require.NoError(t, err, "error initializing request")
   103  		r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:bad@app")
   104  		r.Header.Add(extension.HeaderArgoCDProjectName, "project-name")
   105  
   106  		// when
   107  		rr, err := extension.ValidateHeaders(r)
   108  
   109  		// then
   110  		require.Error(t, err)
   111  		assert.Nil(t, rr)
   112  	})
   113  	t.Run("will return error if invalid project name", func(t *testing.T) {
   114  		// given
   115  		r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody)
   116  		require.NoError(t, err, "error initializing request")
   117  		r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app")
   118  		r.Header.Add(extension.HeaderArgoCDProjectName, "bad^project")
   119  
   120  		// when
   121  		rr, err := extension.ValidateHeaders(r)
   122  
   123  		// then
   124  		require.Error(t, err)
   125  		assert.Nil(t, rr)
   126  	})
   127  }
   128  
   129  func TestRegisterExtensions(t *testing.T) {
   130  	t.Parallel()
   131  
   132  	type fixture struct {
   133  		settingsGetterMock *mocks.SettingsGetter
   134  		manager            *extension.Manager
   135  	}
   136  
   137  	setup := func() *fixture {
   138  		settMock := &mocks.SettingsGetter{}
   139  
   140  		logger, _ := test.NewNullLogger()
   141  		logEntry := logger.WithContext(t.Context())
   142  		m := extension.NewManager(logEntry, "", settMock, nil, nil, nil, nil, nil)
   143  
   144  		return &fixture{
   145  			settingsGetterMock: settMock,
   146  			manager:            m,
   147  		}
   148  	}
   149  	t.Run("will register extensions successfully", func(t *testing.T) {
   150  		// given
   151  		t.Parallel()
   152  		f := setup()
   153  		settings := &settings.ArgoCDSettings{
   154  			ExtensionConfig: map[string]string{
   155  				"":            getExtensionConfigString(),
   156  				"another-ext": getSingleExtensionConfigString(),
   157  			},
   158  		}
   159  		f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil)
   160  		expectedProxyRegistries := []string{
   161  			"external-backend",
   162  			"some-backend",
   163  			"another-ext",
   164  		}
   165  
   166  		// when
   167  		err := f.manager.RegisterExtensions()
   168  
   169  		// then
   170  		require.NoError(t, err)
   171  		for _, expectedProxyRegistry := range expectedProxyRegistries {
   172  			proxyRegistry, found := f.manager.ProxyRegistry(expectedProxyRegistry)
   173  			assert.True(t, found)
   174  			assert.NotNil(t, proxyRegistry)
   175  		}
   176  	})
   177  	t.Run("will return error if extension config is invalid", func(t *testing.T) {
   178  		// given
   179  		t.Parallel()
   180  		type testCase struct {
   181  			name       string
   182  			configYaml string
   183  		}
   184  		cases := []testCase{
   185  			{
   186  				name:       "no name",
   187  				configYaml: getExtensionConfigNoName(),
   188  			},
   189  			{
   190  				name:       "no service",
   191  				configYaml: getExtensionConfigNoService(),
   192  			},
   193  			{
   194  				name:       "no URL",
   195  				configYaml: getExtensionConfigNoURL(),
   196  			},
   197  			{
   198  				name:       "invalid name",
   199  				configYaml: getExtensionConfigInvalidName(),
   200  			},
   201  			{
   202  				name:       "no header name",
   203  				configYaml: getExtensionConfigNoHeaderName(),
   204  			},
   205  			{
   206  				name:       "no header value",
   207  				configYaml: getExtensionConfigNoHeaderValue(),
   208  			},
   209  		}
   210  
   211  		// when
   212  		for _, tc := range cases {
   213  			tc := tc
   214  			t.Run(tc.name, func(t *testing.T) {
   215  				// given
   216  				t.Parallel()
   217  				f := setup()
   218  				settings := &settings.ArgoCDSettings{
   219  					ExtensionConfig: map[string]string{
   220  						"": tc.configYaml,
   221  					},
   222  				}
   223  				f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil)
   224  
   225  				// when
   226  				err := f.manager.RegisterExtensions()
   227  
   228  				// then
   229  				require.Error(t, err, "expected error in test %s but got nil", tc.name)
   230  			})
   231  		}
   232  	})
   233  }
   234  
   235  func TestCallExtension(t *testing.T) {
   236  	t.Parallel()
   237  
   238  	type fixture struct {
   239  		mux                *http.ServeMux
   240  		appGetterMock      *mocks.ApplicationGetter
   241  		settingsGetterMock *mocks.SettingsGetter
   242  		rbacMock           *mocks.RbacEnforcer
   243  		projMock           *mocks.ProjectGetter
   244  		metricsMock        *mocks.ExtensionMetricsRegistry
   245  		userMock           *mocks.UserGetter
   246  		manager            *extension.Manager
   247  	}
   248  	defaultServerNamespace := "control-plane-ns"
   249  	defaultProjectName := "project-name"
   250  
   251  	setup := func() *fixture {
   252  		appMock := &mocks.ApplicationGetter{}
   253  		settMock := &mocks.SettingsGetter{}
   254  		rbacMock := &mocks.RbacEnforcer{}
   255  		projMock := &mocks.ProjectGetter{}
   256  		metricsMock := &mocks.ExtensionMetricsRegistry{}
   257  		userMock := &mocks.UserGetter{}
   258  
   259  		dbMock := &dbmocks.ArgoDB{}
   260  		dbMock.On("GetClusterServersByName", mock.Anything, mock.Anything).Return([]string{"cluster1"}, nil)
   261  		dbMock.On("GetCluster", mock.Anything, mock.Anything).Return(&v1alpha1.Cluster{Server: "some-url", Name: "cluster1"}, nil)
   262  
   263  		logger, _ := test.NewNullLogger()
   264  		logEntry := logger.WithContext(t.Context())
   265  		m := extension.NewManager(logEntry, defaultServerNamespace, settMock, appMock, projMock, dbMock, rbacMock, userMock)
   266  		m.AddMetricsRegistry(metricsMock)
   267  
   268  		mux := http.NewServeMux()
   269  		extHandler := http.HandlerFunc(m.CallExtension())
   270  		mux.Handle(extension.URLPrefix+"/", extHandler)
   271  
   272  		return &fixture{
   273  			mux:                mux,
   274  			appGetterMock:      appMock,
   275  			settingsGetterMock: settMock,
   276  			rbacMock:           rbacMock,
   277  			projMock:           projMock,
   278  			metricsMock:        metricsMock,
   279  			userMock:           userMock,
   280  			manager:            m,
   281  		}
   282  	}
   283  
   284  	getApp := func(destName, destServer, projName string) *v1alpha1.Application {
   285  		return &v1alpha1.Application{
   286  			TypeMeta:   metav1.TypeMeta{},
   287  			ObjectMeta: metav1.ObjectMeta{},
   288  			Spec: v1alpha1.ApplicationSpec{
   289  				Destination: v1alpha1.ApplicationDestination{
   290  					Name:   destName,
   291  					Server: destServer,
   292  				},
   293  				Project: projName,
   294  			},
   295  			Status: v1alpha1.ApplicationStatus{
   296  				Resources: []v1alpha1.ResourceStatus{
   297  					{
   298  						Group:     "apps",
   299  						Version:   "v1",
   300  						Kind:      "Pod",
   301  						Namespace: "default",
   302  						Name:      "some-pod",
   303  					},
   304  				},
   305  			},
   306  		}
   307  	}
   308  
   309  	getProjectWithDestinations := func(prjName string, destNames []string, destURLs []string) *v1alpha1.AppProject {
   310  		destinations := []v1alpha1.ApplicationDestination{}
   311  		for _, destName := range destNames {
   312  			destination := v1alpha1.ApplicationDestination{
   313  				Name: destName,
   314  			}
   315  			destinations = append(destinations, destination)
   316  		}
   317  		for _, destURL := range destURLs {
   318  			destination := v1alpha1.ApplicationDestination{
   319  				Server: destURL,
   320  			}
   321  			destinations = append(destinations, destination)
   322  		}
   323  		return &v1alpha1.AppProject{
   324  			ObjectMeta: metav1.ObjectMeta{
   325  				Name: prjName,
   326  			},
   327  			Spec: v1alpha1.AppProjectSpec{
   328  				Destinations: destinations,
   329  			},
   330  		}
   331  	}
   332  
   333  	withProject := func(prj *v1alpha1.AppProject, f *fixture) {
   334  		f.projMock.On("Get", prj.GetName()).Return(prj, nil)
   335  	}
   336  
   337  	withMetrics := func(f *fixture) {
   338  		f.metricsMock.On("IncExtensionRequestCounter", mock.Anything, mock.Anything)
   339  		f.metricsMock.On("ObserveExtensionRequestDuration", mock.Anything, mock.Anything)
   340  	}
   341  
   342  	withRbac := func(f *fixture, allowApp, allowExt bool) {
   343  		var appAccessError error
   344  		var extAccessError error
   345  		if !allowApp {
   346  			appAccessError = errors.New("no app permission")
   347  		}
   348  		if !allowExt {
   349  			extAccessError = errors.New("no extension permission")
   350  		}
   351  		f.rbacMock.On("EnforceErr", mock.Anything, rbac.ResourceApplications, rbac.ActionGet, mock.Anything).Return(appAccessError)
   352  		f.rbacMock.On("EnforceErr", mock.Anything, rbac.ResourceExtensions, rbac.ActionInvoke, mock.Anything).Return(extAccessError)
   353  	}
   354  
   355  	withUser := func(f *fixture, userId string, username string, groups []string) {
   356  		f.userMock.On("GetUserId", mock.Anything).Return(userId)
   357  		f.userMock.On("GetUsername", mock.Anything).Return(username)
   358  		f.userMock.On("GetGroups", mock.Anything).Return(groups)
   359  	}
   360  
   361  	withExtensionConfig := func(configYaml string, f *fixture) {
   362  		secrets := make(map[string]string)
   363  		secrets["extension.auth.header"] = "Bearer some-bearer-token"
   364  		secrets["extension.auth.header2"] = "Bearer another-bearer-token"
   365  
   366  		settings := &settings.ArgoCDSettings{
   367  			ExtensionConfig: map[string]string{
   368  				"ephemeral": "services:\n- url: http://some-server.com",
   369  				"":          configYaml,
   370  			},
   371  			Secrets: secrets,
   372  		}
   373  		f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil)
   374  	}
   375  
   376  	startTestServer := func(t *testing.T, f *fixture) *httptest.Server {
   377  		t.Helper()
   378  		err := f.manager.RegisterExtensions()
   379  		require.NoError(t, err, "error starting test server")
   380  		return httptest.NewServer(f.mux)
   381  	}
   382  
   383  	startBackendTestSrv := func(response string) *httptest.Server {
   384  		return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   385  			for k, v := range r.Header {
   386  				w.Header().Add(k, strings.Join(v, ","))
   387  			}
   388  			fmt.Fprintln(w, response)
   389  		}))
   390  	}
   391  	newExtensionRequest := func(t *testing.T, method, url string) *http.Request {
   392  		t.Helper()
   393  		r, err := http.NewRequest(method, url, http.NoBody)
   394  		require.NoError(t, err, "error initializing request")
   395  		r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name")
   396  		r.Header.Add(extension.HeaderArgoCDProjectName, defaultProjectName)
   397  		return r
   398  	}
   399  
   400  	t.Run("will call extension backend successfully", func(t *testing.T) {
   401  		// given
   402  		t.Parallel()
   403  		f := setup()
   404  		backendResponse := "some data"
   405  		backendEndpoint := "some-backend"
   406  		clusterURL := "some-url"
   407  		backendSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   408  			for k, v := range r.Header {
   409  				w.Header().Add(k, strings.Join(v, ","))
   410  			}
   411  			fmt.Fprintln(w, backendResponse)
   412  		}))
   413  		defer backendSrv.Close()
   414  		withRbac(f, true, true)
   415  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   416  		withExtensionConfig(getExtensionConfig(backendEndpoint, backendSrv.URL), f)
   417  		ts := startTestServer(t, f)
   418  		defer ts.Close()
   419  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, backendEndpoint))
   420  		app := getApp("", clusterURL, defaultProjectName)
   421  		proj := getProjectWithDestinations("project-name", nil, []string{clusterURL})
   422  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(app, nil)
   423  		withProject(proj, f)
   424  		var wg sync.WaitGroup
   425  		wg.Add(2)
   426  		f.metricsMock.
   427  			On("IncExtensionRequestCounter", mock.Anything, mock.Anything).
   428  			Run(func(_ mock.Arguments) {
   429  				wg.Done()
   430  			})
   431  		f.metricsMock.
   432  			On("ObserveExtensionRequestDuration", mock.Anything, mock.Anything).
   433  			Run(func(_ mock.Arguments) {
   434  				wg.Done()
   435  			})
   436  
   437  		// when
   438  		resp, err := http.DefaultClient.Do(r)
   439  
   440  		// then
   441  		require.NoError(t, err)
   442  		require.NotNil(t, resp)
   443  		assert.Equal(t, http.StatusOK, resp.StatusCode)
   444  		body, err := io.ReadAll(resp.Body)
   445  		require.NoError(t, err)
   446  		actual := strings.TrimSuffix(string(body), "\n")
   447  		assert.Equal(t, backendResponse, actual)
   448  		assert.Equal(t, defaultServerNamespace, resp.Header.Get(extension.HeaderArgoCDNamespace))
   449  		assert.Equal(t, clusterURL, resp.Header.Get(extension.HeaderArgoCDTargetClusterURL))
   450  		assert.Equal(t, "Bearer some-bearer-token", resp.Header.Get("Authorization"))
   451  		assert.Equal(t, "some-user", resp.Header.Get(extension.HeaderArgoCDUsername))
   452  		assert.Equal(t, "some-user-id", resp.Header.Get(extension.HeaderArgoCDUserId))
   453  		assert.Equal(t, "group1,group2", resp.Header.Get(extension.HeaderArgoCDGroups))
   454  
   455  		// waitgroup is necessary to make sure assertions aren't executed before
   456  		// the goroutine initiated by extension.CallExtension concludes which would
   457  		// lead to flaky test.
   458  		wg.Wait()
   459  		f.metricsMock.AssertCalled(t, "IncExtensionRequestCounter", backendEndpoint, http.StatusOK)
   460  		f.metricsMock.AssertCalled(t, "ObserveExtensionRequestDuration", backendEndpoint, mock.Anything)
   461  	})
   462  	t.Run("proxy will return 404 if extension endpoint not registered", func(t *testing.T) {
   463  		// given
   464  		t.Parallel()
   465  		f := setup()
   466  		withExtensionConfig(getExtensionConfigString(), f)
   467  		withRbac(f, true, true)
   468  		withMetrics(f)
   469  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   470  		cluster1Name := "cluster1"
   471  		f.appGetterMock.On("Get", "namespace", "app-name").Return(getApp(cluster1Name, "", defaultProjectName), nil)
   472  		withProject(getProjectWithDestinations("project-name", []string{cluster1Name}, []string{"some-url"}), f)
   473  
   474  		ts := startTestServer(t, f)
   475  		defer ts.Close()
   476  		nonRegistered := "non-registered"
   477  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, nonRegistered))
   478  
   479  		// when
   480  		resp, err := http.DefaultClient.Do(r)
   481  
   482  		// then
   483  		require.NoError(t, err)
   484  		require.NotNil(t, resp)
   485  		assert.Equal(t, http.StatusNotFound, resp.StatusCode)
   486  	})
   487  	t.Run("will route requests with 2 backends for the same extension successfully", func(t *testing.T) {
   488  		// given
   489  		t.Parallel()
   490  		f := setup()
   491  		extName := "some-extension"
   492  
   493  		response1 := "response backend 1"
   494  		cluster1Name := "cluster1"
   495  		cluster1URL := "url1"
   496  		beSrv1 := startBackendTestSrv(response1)
   497  		defer beSrv1.Close()
   498  
   499  		response2 := "response backend 2"
   500  		cluster2Name := "cluster2"
   501  		cluster2URL := "url2"
   502  		beSrv2 := startBackendTestSrv(response2)
   503  		defer beSrv2.Close()
   504  
   505  		f.appGetterMock.On("Get", "ns1", "app1").Return(getApp(cluster1Name, "", defaultProjectName), nil)
   506  		f.appGetterMock.On("Get", "ns2", "app2").Return(getApp("", cluster2URL, defaultProjectName), nil)
   507  
   508  		withRbac(f, true, true)
   509  		withExtensionConfig(getExtensionConfigWith2Backends(extName, beSrv1.URL, cluster1Name, cluster1URL, beSrv2.URL, cluster2Name, cluster2URL), f)
   510  		withProject(getProjectWithDestinations("project-name", []string{cluster1Name}, []string{cluster2URL}), f)
   511  		withMetrics(f)
   512  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   513  
   514  		ts := startTestServer(t, f)
   515  		defer ts.Close()
   516  
   517  		url := fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)
   518  		req := newExtensionRequest(t, http.MethodGet, url)
   519  		req.Header.Del(extension.HeaderArgoCDApplicationName)
   520  
   521  		req1 := req.Clone(t.Context())
   522  		req1.Header.Add(extension.HeaderArgoCDApplicationName, "ns1:app1")
   523  		req2 := req.Clone(t.Context())
   524  		req2.Header.Add(extension.HeaderArgoCDApplicationName, "ns2:app2")
   525  
   526  		// when
   527  		resp1, err := http.DefaultClient.Do(req1)
   528  		require.NoError(t, err)
   529  		resp2, err := http.DefaultClient.Do(req2)
   530  		require.NoError(t, err)
   531  
   532  		// then
   533  		require.NotNil(t, resp1)
   534  		assert.Equal(t, http.StatusOK, resp1.StatusCode)
   535  		body, err := io.ReadAll(resp1.Body)
   536  		require.NoError(t, err)
   537  		actual := strings.TrimSuffix(string(body), "\n")
   538  		assert.Equal(t, response1, actual)
   539  		assert.Equal(t, "Bearer some-bearer-token", resp1.Header.Get("Authorization"))
   540  
   541  		require.NotNil(t, resp2)
   542  		assert.Equal(t, http.StatusOK, resp2.StatusCode)
   543  		body, err = io.ReadAll(resp2.Body)
   544  		require.NoError(t, err)
   545  		actual = strings.TrimSuffix(string(body), "\n")
   546  		assert.Equal(t, response2, actual)
   547  		assert.Equal(t, "Bearer another-bearer-token", resp2.Header.Get("Authorization"))
   548  	})
   549  	t.Run("will return 401 if sub has no access to get application", func(t *testing.T) {
   550  		// given
   551  		t.Parallel()
   552  		f := setup()
   553  		allowApp := false
   554  		allowExtension := true
   555  		extName := "some-extension"
   556  		withRbac(f, allowApp, allowExtension)
   557  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   558  		withMetrics(f)
   559  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   560  		ts := startTestServer(t, f)
   561  		defer ts.Close()
   562  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName))
   563  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil)
   564  
   565  		// when
   566  		resp, err := http.DefaultClient.Do(r)
   567  
   568  		// then
   569  		require.NoError(t, err)
   570  		require.NotNil(t, resp)
   571  		assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   572  	})
   573  	t.Run("will return 401 if sub has no access to invoke extension", func(t *testing.T) {
   574  		// given
   575  		t.Parallel()
   576  		f := setup()
   577  		allowApp := true
   578  		allowExtension := false
   579  		extName := "some-extension"
   580  		withRbac(f, allowApp, allowExtension)
   581  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   582  		withMetrics(f)
   583  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   584  		ts := startTestServer(t, f)
   585  		defer ts.Close()
   586  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName))
   587  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil)
   588  
   589  		// when
   590  		resp, err := http.DefaultClient.Do(r)
   591  
   592  		// then
   593  		require.NoError(t, err)
   594  		require.NotNil(t, resp)
   595  		assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   596  	})
   597  	t.Run("will return 401 if project has no access to target cluster", func(t *testing.T) {
   598  		// given
   599  		t.Parallel()
   600  		f := setup()
   601  		allowApp := true
   602  		allowExtension := true
   603  		extName := "some-extension"
   604  		noCluster := []string{}
   605  		withRbac(f, allowApp, allowExtension)
   606  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   607  		withMetrics(f)
   608  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   609  		ts := startTestServer(t, f)
   610  		defer ts.Close()
   611  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName))
   612  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil)
   613  		proj := getProjectWithDestinations("project-name", nil, noCluster)
   614  		withProject(proj, f)
   615  
   616  		// when
   617  		resp, err := http.DefaultClient.Do(r)
   618  
   619  		// then
   620  		require.NoError(t, err)
   621  		require.NotNil(t, resp)
   622  		assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   623  	})
   624  	t.Run("will return 401 if project in application does not exist", func(t *testing.T) {
   625  		// given
   626  		t.Parallel()
   627  		f := setup()
   628  		allowApp := true
   629  		allowExtension := true
   630  		extName := "some-extension"
   631  		withRbac(f, allowApp, allowExtension)
   632  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   633  		withMetrics(f)
   634  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   635  		ts := startTestServer(t, f)
   636  		defer ts.Close()
   637  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName))
   638  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil)
   639  		f.projMock.On("Get", defaultProjectName).Return(nil, nil)
   640  
   641  		// when
   642  		resp, err := http.DefaultClient.Do(r)
   643  
   644  		// then
   645  		require.NoError(t, err)
   646  		require.NotNil(t, resp)
   647  		assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   648  	})
   649  	t.Run("will return 401 if project in application does not match with header", func(t *testing.T) {
   650  		// given
   651  		t.Parallel()
   652  		f := setup()
   653  		allowApp := true
   654  		allowExtension := true
   655  		extName := "some-extension"
   656  		differentProject := "differentProject"
   657  		withRbac(f, allowApp, allowExtension)
   658  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   659  		withMetrics(f)
   660  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   661  		ts := startTestServer(t, f)
   662  		defer ts.Close()
   663  		r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName))
   664  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", differentProject), nil)
   665  
   666  		// when
   667  		resp, err := http.DefaultClient.Do(r)
   668  
   669  		// then
   670  		require.NoError(t, err)
   671  		require.NotNil(t, resp)
   672  		assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
   673  	})
   674  	t.Run("will return 401 if application defines name and server destination", func(t *testing.T) {
   675  		// This test is to validate a security risk with malicious application
   676  		// trying to gain access to execute extensions in clusters it doesn't
   677  		// have access.
   678  
   679  		// given
   680  		t.Parallel()
   681  		f := setup()
   682  		extName := "some-extension"
   683  		maliciousName := "srv1"
   684  		destinationServer := "some-valid-server"
   685  
   686  		f.appGetterMock.On("Get", "ns1", "app1").Return(getApp(maliciousName, destinationServer, defaultProjectName), nil)
   687  
   688  		withRbac(f, true, true)
   689  		withExtensionConfig(getExtensionConfigWith2Backends(extName, "url1", "cluster1Name", "cluster1URL", "url2", "cluster2Name", "cluster2URL"), f)
   690  		withProject(getProjectWithDestinations("project-name", nil, []string{"srv1", destinationServer}), f)
   691  		withMetrics(f)
   692  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   693  
   694  		ts := startTestServer(t, f)
   695  		defer ts.Close()
   696  
   697  		url := fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)
   698  		req := newExtensionRequest(t, http.MethodGet, url)
   699  		req.Header.Del(extension.HeaderArgoCDApplicationName)
   700  		req1 := req.Clone(t.Context())
   701  		req1.Header.Add(extension.HeaderArgoCDApplicationName, "ns1:app1")
   702  
   703  		// when
   704  		resp1, err := http.DefaultClient.Do(req1)
   705  		require.NoError(t, err)
   706  
   707  		// then
   708  		require.NotNil(t, resp1)
   709  		assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode)
   710  		body, err := io.ReadAll(resp1.Body)
   711  		require.NoError(t, err)
   712  		actual := strings.TrimSuffix(string(body), "\n")
   713  		assert.Equal(t, "Unauthorized extension request", actual)
   714  	})
   715  	t.Run("will return 400 if no extension name is provided", func(t *testing.T) {
   716  		// given
   717  		t.Parallel()
   718  		f := setup()
   719  		allowApp := true
   720  		allowExtension := true
   721  		extName := "some-extension"
   722  		differentProject := "differentProject"
   723  		withRbac(f, allowApp, allowExtension)
   724  		withExtensionConfig(getExtensionConfig(extName, "http://fake"), f)
   725  		withMetrics(f)
   726  		withUser(f, "some-user-id", "some-user", []string{"group1", "group2"})
   727  		ts := startTestServer(t, f)
   728  		defer ts.Close()
   729  		r := newExtensionRequest(t, "Get", ts.URL+"/extensions/")
   730  		f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", differentProject), nil)
   731  
   732  		// when
   733  		resp, err := http.DefaultClient.Do(r)
   734  
   735  		// then
   736  		require.NoError(t, err)
   737  		require.NotNil(t, resp)
   738  		assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
   739  	})
   740  }
   741  
   742  func getExtensionConfig(name, url string) string {
   743  	cfg := `
   744  extensions:
   745  - name: %s
   746    backend:
   747      services:
   748      - url: %s
   749        headers:
   750        - name: Authorization
   751          value: '$extension.auth.header'
   752  `
   753  	return fmt.Sprintf(cfg, name, url)
   754  }
   755  
   756  func getExtensionConfigWith2Backends(name, url1, clus1Name, clus1URL, url2, clus2Name, clus2URL string) string {
   757  	cfg := `
   758  extensions:
   759  - name: %s
   760    backend:
   761      services:
   762      - url: %s
   763        headers:
   764        - name: Authorization
   765          value: '$extension.auth.header'
   766        cluster:
   767          name: %s
   768          server: %s
   769      - url: %s
   770        headers:
   771        - name: Authorization
   772          value: '$extension.auth.header2'
   773        cluster:
   774          name: %s
   775          server: %s
   776      - url: http://test.com
   777        cluster:
   778          name: cl1
   779      - url: http://test2.com
   780        cluster:
   781          name: cl2
   782  `
   783  	return fmt.Sprintf(cfg, name, url1, clus1Name, clus1URL, url2, clus2Name, clus2URL)
   784  }
   785  
   786  func getExtensionConfigString() string {
   787  	return `
   788  extensions:
   789  - name: external-backend
   790    backend:
   791      connectionTimeout: 10s
   792      keepAlive: 11s
   793      idleConnectionTimeout: 12s
   794      maxIdleConnections: 30
   795      services:
   796      - url: https://httpbin.org
   797        headers:
   798        - name: some-header
   799          value: '$some.secret.ref'
   800  - name: some-backend
   801    backend:
   802      services:
   803      - url: http://localhost:7777
   804  `
   805  }
   806  
   807  func getSingleExtensionConfigString() string {
   808  	return `
   809  connectionTimeout: 10s
   810  keepAlive: 11s
   811  idleConnectionTimeout: 12s
   812  maxIdleConnections: 30
   813  services:
   814  - url: http://localhost:7777
   815  `
   816  }
   817  
   818  func getExtensionConfigNoService() string {
   819  	return `
   820  extensions:
   821  - backend:
   822      connectionTimeout: 2s
   823  `
   824  }
   825  
   826  func getExtensionConfigNoName() string {
   827  	return `
   828  extensions:
   829  - backend:
   830      services:
   831      - url: https://httpbin.org
   832  `
   833  }
   834  
   835  func getExtensionConfigInvalidName() string {
   836  	return `
   837  extensions:
   838  - name: invalid/name
   839    backend:
   840      services:
   841      - url: https://httpbin.org
   842  `
   843  }
   844  
   845  func getExtensionConfigNoURL() string {
   846  	return `
   847  extensions:
   848  - name: some-backend
   849    backend:
   850      services:
   851      - cluster: some-cluster
   852  `
   853  }
   854  
   855  func getExtensionConfigNoHeaderName() string {
   856  	return `
   857  extensions:
   858  - name: some-extension
   859    backend:
   860      services:
   861      - url: https://httpbin.org
   862        headers:
   863        - value: '$some.secret.key'
   864  `
   865  }
   866  
   867  func getExtensionConfigNoHeaderValue() string {
   868  	return `
   869  extensions:
   870  - name: some-extension
   871    backend:
   872      services:
   873      - url: https://httpbin.org
   874        headers:
   875        - name: some-header-name
   876  `
   877  }