github.com/argoproj/argo-cd/v3@v3.2.1/server/extension/extension_test.go (about) 1 package extension_test 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "sync" 11 "testing" 12 13 "github.com/sirupsen/logrus/hooks/test" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/mock" 16 "github.com/stretchr/testify/require" 17 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 18 19 "github.com/argoproj/argo-cd/v3/util/rbac" 20 21 "github.com/argoproj/argo-cd/v3/pkg/apis/application/v1alpha1" 22 "github.com/argoproj/argo-cd/v3/server/extension" 23 "github.com/argoproj/argo-cd/v3/server/extension/mocks" 24 dbmocks "github.com/argoproj/argo-cd/v3/util/db/mocks" 25 "github.com/argoproj/argo-cd/v3/util/settings" 26 ) 27 28 func TestValidateHeaders(t *testing.T) { 29 t.Run("will build RequestResources successfully", func(t *testing.T) { 30 // given 31 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 32 require.NoError(t, err, "error initializing request") 33 r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name") 34 r.Header.Add(extension.HeaderArgoCDProjectName, "project-name") 35 36 // when 37 rr, err := extension.ValidateHeaders(r) 38 39 // then 40 require.NoError(t, err) 41 assert.NotNil(t, rr) 42 assert.Equal(t, "namespace", rr.ApplicationNamespace) 43 assert.Equal(t, "app-name", rr.ApplicationName) 44 assert.Equal(t, "project-name", rr.ProjectName) 45 }) 46 t.Run("will return error if application is malformatted", func(t *testing.T) { 47 // given 48 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 49 require.NoError(t, err, "error initializing request") 50 r.Header.Add(extension.HeaderArgoCDApplicationName, "no-namespace") 51 52 // when 53 rr, err := extension.ValidateHeaders(r) 54 55 // then 56 require.Error(t, err) 57 assert.Nil(t, rr) 58 }) 59 t.Run("will return error if application header is missing", func(t *testing.T) { 60 // given 61 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 62 require.NoError(t, err, "error initializing request") 63 r.Header.Add(extension.HeaderArgoCDProjectName, "project-name") 64 65 // when 66 rr, err := extension.ValidateHeaders(r) 67 68 // then 69 require.Error(t, err) 70 assert.Nil(t, rr) 71 }) 72 t.Run("will return error if project header is missing", func(t *testing.T) { 73 // given 74 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 75 require.NoError(t, err, "error initializing request") 76 r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name") 77 78 // when 79 rr, err := extension.ValidateHeaders(r) 80 81 // then 82 require.Error(t, err) 83 assert.Nil(t, rr) 84 }) 85 t.Run("will return error if invalid namespace", func(t *testing.T) { 86 // given 87 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 88 require.NoError(t, err, "error initializing request") 89 r.Header.Add(extension.HeaderArgoCDApplicationName, "bad%namespace:app-name") 90 r.Header.Add(extension.HeaderArgoCDProjectName, "project-name") 91 92 // when 93 rr, err := extension.ValidateHeaders(r) 94 95 // then 96 require.Error(t, err) 97 assert.Nil(t, rr) 98 }) 99 t.Run("will return error if invalid app name", func(t *testing.T) { 100 // given 101 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 102 require.NoError(t, err, "error initializing request") 103 r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:bad@app") 104 r.Header.Add(extension.HeaderArgoCDProjectName, "project-name") 105 106 // when 107 rr, err := extension.ValidateHeaders(r) 108 109 // then 110 require.Error(t, err) 111 assert.Nil(t, rr) 112 }) 113 t.Run("will return error if invalid project name", func(t *testing.T) { 114 // given 115 r, err := http.NewRequest(http.MethodGet, "http://null", http.NoBody) 116 require.NoError(t, err, "error initializing request") 117 r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app") 118 r.Header.Add(extension.HeaderArgoCDProjectName, "bad^project") 119 120 // when 121 rr, err := extension.ValidateHeaders(r) 122 123 // then 124 require.Error(t, err) 125 assert.Nil(t, rr) 126 }) 127 } 128 129 func TestRegisterExtensions(t *testing.T) { 130 t.Parallel() 131 132 type fixture struct { 133 settingsGetterMock *mocks.SettingsGetter 134 manager *extension.Manager 135 } 136 137 setup := func() *fixture { 138 settMock := &mocks.SettingsGetter{} 139 140 logger, _ := test.NewNullLogger() 141 logEntry := logger.WithContext(t.Context()) 142 m := extension.NewManager(logEntry, "", settMock, nil, nil, nil, nil, nil) 143 144 return &fixture{ 145 settingsGetterMock: settMock, 146 manager: m, 147 } 148 } 149 t.Run("will register extensions successfully", func(t *testing.T) { 150 // given 151 t.Parallel() 152 f := setup() 153 settings := &settings.ArgoCDSettings{ 154 ExtensionConfig: map[string]string{ 155 "": getExtensionConfigString(), 156 "another-ext": getSingleExtensionConfigString(), 157 }, 158 } 159 f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil) 160 expectedProxyRegistries := []string{ 161 "external-backend", 162 "some-backend", 163 "another-ext", 164 } 165 166 // when 167 err := f.manager.RegisterExtensions() 168 169 // then 170 require.NoError(t, err) 171 for _, expectedProxyRegistry := range expectedProxyRegistries { 172 proxyRegistry, found := f.manager.ProxyRegistry(expectedProxyRegistry) 173 assert.True(t, found) 174 assert.NotNil(t, proxyRegistry) 175 } 176 }) 177 t.Run("will return error if extension config is invalid", func(t *testing.T) { 178 // given 179 t.Parallel() 180 type testCase struct { 181 name string 182 configYaml string 183 } 184 cases := []testCase{ 185 { 186 name: "no name", 187 configYaml: getExtensionConfigNoName(), 188 }, 189 { 190 name: "no service", 191 configYaml: getExtensionConfigNoService(), 192 }, 193 { 194 name: "no URL", 195 configYaml: getExtensionConfigNoURL(), 196 }, 197 { 198 name: "invalid name", 199 configYaml: getExtensionConfigInvalidName(), 200 }, 201 { 202 name: "no header name", 203 configYaml: getExtensionConfigNoHeaderName(), 204 }, 205 { 206 name: "no header value", 207 configYaml: getExtensionConfigNoHeaderValue(), 208 }, 209 } 210 211 // when 212 for _, tc := range cases { 213 tc := tc 214 t.Run(tc.name, func(t *testing.T) { 215 // given 216 t.Parallel() 217 f := setup() 218 settings := &settings.ArgoCDSettings{ 219 ExtensionConfig: map[string]string{ 220 "": tc.configYaml, 221 }, 222 } 223 f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil) 224 225 // when 226 err := f.manager.RegisterExtensions() 227 228 // then 229 require.Error(t, err, "expected error in test %s but got nil", tc.name) 230 }) 231 } 232 }) 233 } 234 235 func TestCallExtension(t *testing.T) { 236 t.Parallel() 237 238 type fixture struct { 239 mux *http.ServeMux 240 appGetterMock *mocks.ApplicationGetter 241 settingsGetterMock *mocks.SettingsGetter 242 rbacMock *mocks.RbacEnforcer 243 projMock *mocks.ProjectGetter 244 metricsMock *mocks.ExtensionMetricsRegistry 245 userMock *mocks.UserGetter 246 manager *extension.Manager 247 } 248 defaultServerNamespace := "control-plane-ns" 249 defaultProjectName := "project-name" 250 251 setup := func() *fixture { 252 appMock := &mocks.ApplicationGetter{} 253 settMock := &mocks.SettingsGetter{} 254 rbacMock := &mocks.RbacEnforcer{} 255 projMock := &mocks.ProjectGetter{} 256 metricsMock := &mocks.ExtensionMetricsRegistry{} 257 userMock := &mocks.UserGetter{} 258 259 dbMock := &dbmocks.ArgoDB{} 260 dbMock.On("GetClusterServersByName", mock.Anything, mock.Anything).Return([]string{"cluster1"}, nil) 261 dbMock.On("GetCluster", mock.Anything, mock.Anything).Return(&v1alpha1.Cluster{Server: "some-url", Name: "cluster1"}, nil) 262 263 logger, _ := test.NewNullLogger() 264 logEntry := logger.WithContext(t.Context()) 265 m := extension.NewManager(logEntry, defaultServerNamespace, settMock, appMock, projMock, dbMock, rbacMock, userMock) 266 m.AddMetricsRegistry(metricsMock) 267 268 mux := http.NewServeMux() 269 extHandler := http.HandlerFunc(m.CallExtension()) 270 mux.Handle(extension.URLPrefix+"/", extHandler) 271 272 return &fixture{ 273 mux: mux, 274 appGetterMock: appMock, 275 settingsGetterMock: settMock, 276 rbacMock: rbacMock, 277 projMock: projMock, 278 metricsMock: metricsMock, 279 userMock: userMock, 280 manager: m, 281 } 282 } 283 284 getApp := func(destName, destServer, projName string) *v1alpha1.Application { 285 return &v1alpha1.Application{ 286 TypeMeta: metav1.TypeMeta{}, 287 ObjectMeta: metav1.ObjectMeta{}, 288 Spec: v1alpha1.ApplicationSpec{ 289 Destination: v1alpha1.ApplicationDestination{ 290 Name: destName, 291 Server: destServer, 292 }, 293 Project: projName, 294 }, 295 Status: v1alpha1.ApplicationStatus{ 296 Resources: []v1alpha1.ResourceStatus{ 297 { 298 Group: "apps", 299 Version: "v1", 300 Kind: "Pod", 301 Namespace: "default", 302 Name: "some-pod", 303 }, 304 }, 305 }, 306 } 307 } 308 309 getProjectWithDestinations := func(prjName string, destNames []string, destURLs []string) *v1alpha1.AppProject { 310 destinations := []v1alpha1.ApplicationDestination{} 311 for _, destName := range destNames { 312 destination := v1alpha1.ApplicationDestination{ 313 Name: destName, 314 } 315 destinations = append(destinations, destination) 316 } 317 for _, destURL := range destURLs { 318 destination := v1alpha1.ApplicationDestination{ 319 Server: destURL, 320 } 321 destinations = append(destinations, destination) 322 } 323 return &v1alpha1.AppProject{ 324 ObjectMeta: metav1.ObjectMeta{ 325 Name: prjName, 326 }, 327 Spec: v1alpha1.AppProjectSpec{ 328 Destinations: destinations, 329 }, 330 } 331 } 332 333 withProject := func(prj *v1alpha1.AppProject, f *fixture) { 334 f.projMock.On("Get", prj.GetName()).Return(prj, nil) 335 } 336 337 withMetrics := func(f *fixture) { 338 f.metricsMock.On("IncExtensionRequestCounter", mock.Anything, mock.Anything) 339 f.metricsMock.On("ObserveExtensionRequestDuration", mock.Anything, mock.Anything) 340 } 341 342 withRbac := func(f *fixture, allowApp, allowExt bool) { 343 var appAccessError error 344 var extAccessError error 345 if !allowApp { 346 appAccessError = errors.New("no app permission") 347 } 348 if !allowExt { 349 extAccessError = errors.New("no extension permission") 350 } 351 f.rbacMock.On("EnforceErr", mock.Anything, rbac.ResourceApplications, rbac.ActionGet, mock.Anything).Return(appAccessError) 352 f.rbacMock.On("EnforceErr", mock.Anything, rbac.ResourceExtensions, rbac.ActionInvoke, mock.Anything).Return(extAccessError) 353 } 354 355 withUser := func(f *fixture, userId string, username string, groups []string) { 356 f.userMock.On("GetUserId", mock.Anything).Return(userId) 357 f.userMock.On("GetUsername", mock.Anything).Return(username) 358 f.userMock.On("GetGroups", mock.Anything).Return(groups) 359 } 360 361 withExtensionConfig := func(configYaml string, f *fixture) { 362 secrets := make(map[string]string) 363 secrets["extension.auth.header"] = "Bearer some-bearer-token" 364 secrets["extension.auth.header2"] = "Bearer another-bearer-token" 365 366 settings := &settings.ArgoCDSettings{ 367 ExtensionConfig: map[string]string{ 368 "ephemeral": "services:\n- url: http://some-server.com", 369 "": configYaml, 370 }, 371 Secrets: secrets, 372 } 373 f.settingsGetterMock.On("Get", mock.Anything).Return(settings, nil) 374 } 375 376 startTestServer := func(t *testing.T, f *fixture) *httptest.Server { 377 t.Helper() 378 err := f.manager.RegisterExtensions() 379 require.NoError(t, err, "error starting test server") 380 return httptest.NewServer(f.mux) 381 } 382 383 startBackendTestSrv := func(response string) *httptest.Server { 384 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 385 for k, v := range r.Header { 386 w.Header().Add(k, strings.Join(v, ",")) 387 } 388 fmt.Fprintln(w, response) 389 })) 390 } 391 newExtensionRequest := func(t *testing.T, method, url string) *http.Request { 392 t.Helper() 393 r, err := http.NewRequest(method, url, http.NoBody) 394 require.NoError(t, err, "error initializing request") 395 r.Header.Add(extension.HeaderArgoCDApplicationName, "namespace:app-name") 396 r.Header.Add(extension.HeaderArgoCDProjectName, defaultProjectName) 397 return r 398 } 399 400 t.Run("will call extension backend successfully", func(t *testing.T) { 401 // given 402 t.Parallel() 403 f := setup() 404 backendResponse := "some data" 405 backendEndpoint := "some-backend" 406 clusterURL := "some-url" 407 backendSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 408 for k, v := range r.Header { 409 w.Header().Add(k, strings.Join(v, ",")) 410 } 411 fmt.Fprintln(w, backendResponse) 412 })) 413 defer backendSrv.Close() 414 withRbac(f, true, true) 415 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 416 withExtensionConfig(getExtensionConfig(backendEndpoint, backendSrv.URL), f) 417 ts := startTestServer(t, f) 418 defer ts.Close() 419 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, backendEndpoint)) 420 app := getApp("", clusterURL, defaultProjectName) 421 proj := getProjectWithDestinations("project-name", nil, []string{clusterURL}) 422 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(app, nil) 423 withProject(proj, f) 424 var wg sync.WaitGroup 425 wg.Add(2) 426 f.metricsMock. 427 On("IncExtensionRequestCounter", mock.Anything, mock.Anything). 428 Run(func(_ mock.Arguments) { 429 wg.Done() 430 }) 431 f.metricsMock. 432 On("ObserveExtensionRequestDuration", mock.Anything, mock.Anything). 433 Run(func(_ mock.Arguments) { 434 wg.Done() 435 }) 436 437 // when 438 resp, err := http.DefaultClient.Do(r) 439 440 // then 441 require.NoError(t, err) 442 require.NotNil(t, resp) 443 assert.Equal(t, http.StatusOK, resp.StatusCode) 444 body, err := io.ReadAll(resp.Body) 445 require.NoError(t, err) 446 actual := strings.TrimSuffix(string(body), "\n") 447 assert.Equal(t, backendResponse, actual) 448 assert.Equal(t, defaultServerNamespace, resp.Header.Get(extension.HeaderArgoCDNamespace)) 449 assert.Equal(t, clusterURL, resp.Header.Get(extension.HeaderArgoCDTargetClusterURL)) 450 assert.Equal(t, "Bearer some-bearer-token", resp.Header.Get("Authorization")) 451 assert.Equal(t, "some-user", resp.Header.Get(extension.HeaderArgoCDUsername)) 452 assert.Equal(t, "some-user-id", resp.Header.Get(extension.HeaderArgoCDUserId)) 453 assert.Equal(t, "group1,group2", resp.Header.Get(extension.HeaderArgoCDGroups)) 454 455 // waitgroup is necessary to make sure assertions aren't executed before 456 // the goroutine initiated by extension.CallExtension concludes which would 457 // lead to flaky test. 458 wg.Wait() 459 f.metricsMock.AssertCalled(t, "IncExtensionRequestCounter", backendEndpoint, http.StatusOK) 460 f.metricsMock.AssertCalled(t, "ObserveExtensionRequestDuration", backendEndpoint, mock.Anything) 461 }) 462 t.Run("proxy will return 404 if extension endpoint not registered", func(t *testing.T) { 463 // given 464 t.Parallel() 465 f := setup() 466 withExtensionConfig(getExtensionConfigString(), f) 467 withRbac(f, true, true) 468 withMetrics(f) 469 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 470 cluster1Name := "cluster1" 471 f.appGetterMock.On("Get", "namespace", "app-name").Return(getApp(cluster1Name, "", defaultProjectName), nil) 472 withProject(getProjectWithDestinations("project-name", []string{cluster1Name}, []string{"some-url"}), f) 473 474 ts := startTestServer(t, f) 475 defer ts.Close() 476 nonRegistered := "non-registered" 477 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, nonRegistered)) 478 479 // when 480 resp, err := http.DefaultClient.Do(r) 481 482 // then 483 require.NoError(t, err) 484 require.NotNil(t, resp) 485 assert.Equal(t, http.StatusNotFound, resp.StatusCode) 486 }) 487 t.Run("will route requests with 2 backends for the same extension successfully", func(t *testing.T) { 488 // given 489 t.Parallel() 490 f := setup() 491 extName := "some-extension" 492 493 response1 := "response backend 1" 494 cluster1Name := "cluster1" 495 cluster1URL := "url1" 496 beSrv1 := startBackendTestSrv(response1) 497 defer beSrv1.Close() 498 499 response2 := "response backend 2" 500 cluster2Name := "cluster2" 501 cluster2URL := "url2" 502 beSrv2 := startBackendTestSrv(response2) 503 defer beSrv2.Close() 504 505 f.appGetterMock.On("Get", "ns1", "app1").Return(getApp(cluster1Name, "", defaultProjectName), nil) 506 f.appGetterMock.On("Get", "ns2", "app2").Return(getApp("", cluster2URL, defaultProjectName), nil) 507 508 withRbac(f, true, true) 509 withExtensionConfig(getExtensionConfigWith2Backends(extName, beSrv1.URL, cluster1Name, cluster1URL, beSrv2.URL, cluster2Name, cluster2URL), f) 510 withProject(getProjectWithDestinations("project-name", []string{cluster1Name}, []string{cluster2URL}), f) 511 withMetrics(f) 512 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 513 514 ts := startTestServer(t, f) 515 defer ts.Close() 516 517 url := fmt.Sprintf("%s/extensions/%s/", ts.URL, extName) 518 req := newExtensionRequest(t, http.MethodGet, url) 519 req.Header.Del(extension.HeaderArgoCDApplicationName) 520 521 req1 := req.Clone(t.Context()) 522 req1.Header.Add(extension.HeaderArgoCDApplicationName, "ns1:app1") 523 req2 := req.Clone(t.Context()) 524 req2.Header.Add(extension.HeaderArgoCDApplicationName, "ns2:app2") 525 526 // when 527 resp1, err := http.DefaultClient.Do(req1) 528 require.NoError(t, err) 529 resp2, err := http.DefaultClient.Do(req2) 530 require.NoError(t, err) 531 532 // then 533 require.NotNil(t, resp1) 534 assert.Equal(t, http.StatusOK, resp1.StatusCode) 535 body, err := io.ReadAll(resp1.Body) 536 require.NoError(t, err) 537 actual := strings.TrimSuffix(string(body), "\n") 538 assert.Equal(t, response1, actual) 539 assert.Equal(t, "Bearer some-bearer-token", resp1.Header.Get("Authorization")) 540 541 require.NotNil(t, resp2) 542 assert.Equal(t, http.StatusOK, resp2.StatusCode) 543 body, err = io.ReadAll(resp2.Body) 544 require.NoError(t, err) 545 actual = strings.TrimSuffix(string(body), "\n") 546 assert.Equal(t, response2, actual) 547 assert.Equal(t, "Bearer another-bearer-token", resp2.Header.Get("Authorization")) 548 }) 549 t.Run("will return 401 if sub has no access to get application", func(t *testing.T) { 550 // given 551 t.Parallel() 552 f := setup() 553 allowApp := false 554 allowExtension := true 555 extName := "some-extension" 556 withRbac(f, allowApp, allowExtension) 557 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 558 withMetrics(f) 559 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 560 ts := startTestServer(t, f) 561 defer ts.Close() 562 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)) 563 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil) 564 565 // when 566 resp, err := http.DefaultClient.Do(r) 567 568 // then 569 require.NoError(t, err) 570 require.NotNil(t, resp) 571 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 572 }) 573 t.Run("will return 401 if sub has no access to invoke extension", func(t *testing.T) { 574 // given 575 t.Parallel() 576 f := setup() 577 allowApp := true 578 allowExtension := false 579 extName := "some-extension" 580 withRbac(f, allowApp, allowExtension) 581 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 582 withMetrics(f) 583 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 584 ts := startTestServer(t, f) 585 defer ts.Close() 586 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)) 587 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil) 588 589 // when 590 resp, err := http.DefaultClient.Do(r) 591 592 // then 593 require.NoError(t, err) 594 require.NotNil(t, resp) 595 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 596 }) 597 t.Run("will return 401 if project has no access to target cluster", func(t *testing.T) { 598 // given 599 t.Parallel() 600 f := setup() 601 allowApp := true 602 allowExtension := true 603 extName := "some-extension" 604 noCluster := []string{} 605 withRbac(f, allowApp, allowExtension) 606 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 607 withMetrics(f) 608 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 609 ts := startTestServer(t, f) 610 defer ts.Close() 611 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)) 612 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil) 613 proj := getProjectWithDestinations("project-name", nil, noCluster) 614 withProject(proj, f) 615 616 // when 617 resp, err := http.DefaultClient.Do(r) 618 619 // then 620 require.NoError(t, err) 621 require.NotNil(t, resp) 622 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 623 }) 624 t.Run("will return 401 if project in application does not exist", func(t *testing.T) { 625 // given 626 t.Parallel() 627 f := setup() 628 allowApp := true 629 allowExtension := true 630 extName := "some-extension" 631 withRbac(f, allowApp, allowExtension) 632 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 633 withMetrics(f) 634 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 635 ts := startTestServer(t, f) 636 defer ts.Close() 637 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)) 638 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", defaultProjectName), nil) 639 f.projMock.On("Get", defaultProjectName).Return(nil, nil) 640 641 // when 642 resp, err := http.DefaultClient.Do(r) 643 644 // then 645 require.NoError(t, err) 646 require.NotNil(t, resp) 647 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 648 }) 649 t.Run("will return 401 if project in application does not match with header", func(t *testing.T) { 650 // given 651 t.Parallel() 652 f := setup() 653 allowApp := true 654 allowExtension := true 655 extName := "some-extension" 656 differentProject := "differentProject" 657 withRbac(f, allowApp, allowExtension) 658 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 659 withMetrics(f) 660 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 661 ts := startTestServer(t, f) 662 defer ts.Close() 663 r := newExtensionRequest(t, "Get", fmt.Sprintf("%s/extensions/%s/", ts.URL, extName)) 664 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", differentProject), nil) 665 666 // when 667 resp, err := http.DefaultClient.Do(r) 668 669 // then 670 require.NoError(t, err) 671 require.NotNil(t, resp) 672 assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) 673 }) 674 t.Run("will return 401 if application defines name and server destination", func(t *testing.T) { 675 // This test is to validate a security risk with malicious application 676 // trying to gain access to execute extensions in clusters it doesn't 677 // have access. 678 679 // given 680 t.Parallel() 681 f := setup() 682 extName := "some-extension" 683 maliciousName := "srv1" 684 destinationServer := "some-valid-server" 685 686 f.appGetterMock.On("Get", "ns1", "app1").Return(getApp(maliciousName, destinationServer, defaultProjectName), nil) 687 688 withRbac(f, true, true) 689 withExtensionConfig(getExtensionConfigWith2Backends(extName, "url1", "cluster1Name", "cluster1URL", "url2", "cluster2Name", "cluster2URL"), f) 690 withProject(getProjectWithDestinations("project-name", nil, []string{"srv1", destinationServer}), f) 691 withMetrics(f) 692 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 693 694 ts := startTestServer(t, f) 695 defer ts.Close() 696 697 url := fmt.Sprintf("%s/extensions/%s/", ts.URL, extName) 698 req := newExtensionRequest(t, http.MethodGet, url) 699 req.Header.Del(extension.HeaderArgoCDApplicationName) 700 req1 := req.Clone(t.Context()) 701 req1.Header.Add(extension.HeaderArgoCDApplicationName, "ns1:app1") 702 703 // when 704 resp1, err := http.DefaultClient.Do(req1) 705 require.NoError(t, err) 706 707 // then 708 require.NotNil(t, resp1) 709 assert.Equal(t, http.StatusUnauthorized, resp1.StatusCode) 710 body, err := io.ReadAll(resp1.Body) 711 require.NoError(t, err) 712 actual := strings.TrimSuffix(string(body), "\n") 713 assert.Equal(t, "Unauthorized extension request", actual) 714 }) 715 t.Run("will return 400 if no extension name is provided", func(t *testing.T) { 716 // given 717 t.Parallel() 718 f := setup() 719 allowApp := true 720 allowExtension := true 721 extName := "some-extension" 722 differentProject := "differentProject" 723 withRbac(f, allowApp, allowExtension) 724 withExtensionConfig(getExtensionConfig(extName, "http://fake"), f) 725 withMetrics(f) 726 withUser(f, "some-user-id", "some-user", []string{"group1", "group2"}) 727 ts := startTestServer(t, f) 728 defer ts.Close() 729 r := newExtensionRequest(t, "Get", ts.URL+"/extensions/") 730 f.appGetterMock.On("Get", mock.Anything, mock.Anything).Return(getApp("", "", differentProject), nil) 731 732 // when 733 resp, err := http.DefaultClient.Do(r) 734 735 // then 736 require.NoError(t, err) 737 require.NotNil(t, resp) 738 assert.Equal(t, http.StatusBadRequest, resp.StatusCode) 739 }) 740 } 741 742 func getExtensionConfig(name, url string) string { 743 cfg := ` 744 extensions: 745 - name: %s 746 backend: 747 services: 748 - url: %s 749 headers: 750 - name: Authorization 751 value: '$extension.auth.header' 752 ` 753 return fmt.Sprintf(cfg, name, url) 754 } 755 756 func getExtensionConfigWith2Backends(name, url1, clus1Name, clus1URL, url2, clus2Name, clus2URL string) string { 757 cfg := ` 758 extensions: 759 - name: %s 760 backend: 761 services: 762 - url: %s 763 headers: 764 - name: Authorization 765 value: '$extension.auth.header' 766 cluster: 767 name: %s 768 server: %s 769 - url: %s 770 headers: 771 - name: Authorization 772 value: '$extension.auth.header2' 773 cluster: 774 name: %s 775 server: %s 776 - url: http://test.com 777 cluster: 778 name: cl1 779 - url: http://test2.com 780 cluster: 781 name: cl2 782 ` 783 return fmt.Sprintf(cfg, name, url1, clus1Name, clus1URL, url2, clus2Name, clus2URL) 784 } 785 786 func getExtensionConfigString() string { 787 return ` 788 extensions: 789 - name: external-backend 790 backend: 791 connectionTimeout: 10s 792 keepAlive: 11s 793 idleConnectionTimeout: 12s 794 maxIdleConnections: 30 795 services: 796 - url: https://httpbin.org 797 headers: 798 - name: some-header 799 value: '$some.secret.ref' 800 - name: some-backend 801 backend: 802 services: 803 - url: http://localhost:7777 804 ` 805 } 806 807 func getSingleExtensionConfigString() string { 808 return ` 809 connectionTimeout: 10s 810 keepAlive: 11s 811 idleConnectionTimeout: 12s 812 maxIdleConnections: 30 813 services: 814 - url: http://localhost:7777 815 ` 816 } 817 818 func getExtensionConfigNoService() string { 819 return ` 820 extensions: 821 - backend: 822 connectionTimeout: 2s 823 ` 824 } 825 826 func getExtensionConfigNoName() string { 827 return ` 828 extensions: 829 - backend: 830 services: 831 - url: https://httpbin.org 832 ` 833 } 834 835 func getExtensionConfigInvalidName() string { 836 return ` 837 extensions: 838 - name: invalid/name 839 backend: 840 services: 841 - url: https://httpbin.org 842 ` 843 } 844 845 func getExtensionConfigNoURL() string { 846 return ` 847 extensions: 848 - name: some-backend 849 backend: 850 services: 851 - cluster: some-cluster 852 ` 853 } 854 855 func getExtensionConfigNoHeaderName() string { 856 return ` 857 extensions: 858 - name: some-extension 859 backend: 860 services: 861 - url: https://httpbin.org 862 headers: 863 - value: '$some.secret.key' 864 ` 865 } 866 867 func getExtensionConfigNoHeaderValue() string { 868 return ` 869 extensions: 870 - name: some-extension 871 backend: 872 services: 873 - url: https://httpbin.org 874 headers: 875 - name: some-header-name 876 ` 877 }