github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/auth/mtls_token_provider_test.go (about)

     1  /*
     2   * Copyright 2020 The Compass Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package auth_test
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/tls"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/kyma-incubator/compass/components/director/pkg/apperrors"
    34  	"github.com/kyma-incubator/compass/components/director/pkg/auth"
    35  	"github.com/kyma-incubator/compass/components/director/pkg/auth/automock"
    36  	"github.com/kyma-incubator/compass/components/director/pkg/oauth"
    37  	"github.com/stretchr/testify/require"
    38  	"github.com/stretchr/testify/suite"
    39  )
    40  
    41  const (
    42  	fakeTkn                      = "fake-token"
    43  	tenant                       = "tenant42"
    44  	externalClientCertSecretName = "resource-name"
    45  )
    46  
    47  var oauthCfg = oauth.Config{
    48  	ClientID:              "client-id",
    49  	TokenEndpointProtocol: "https",
    50  	TokenBaseURL:          "test.mtls.domain.com",
    51  	TokenPath:             "/cert/token",
    52  	ScopesClaim:           []string{"my-scope"},
    53  	TenantHeaderName:      "x-tenant",
    54  }
    55  
    56  func TestMtlsTokenAuthorizationProviderTestSuite(t *testing.T) {
    57  	suite.Run(t, new(MtlsTokenAuthorizationProviderTestSuite))
    58  }
    59  
    60  type MtlsTokenAuthorizationProviderTestSuite struct {
    61  	suite.Suite
    62  }
    63  
    64  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DefaultMtlsClientCreator() {
    65  	cache := &automock.CertificateCache{}
    66  	cache.On("Get").Return(map[string]*tls.Certificate{"resource-name": &tls.Certificate{}}, nil).Once()
    67  	defer cache.AssertExpectations(suite.T())
    68  
    69  	client := auth.DefaultMtlsClientCreator(cache, true, time.Second, "resource-name")
    70  
    71  	ts := httptest.NewUnstartedServer(testServerHandlerFunc(suite.T()))
    72  	ts.TLS = &tls.Config{
    73  		ClientAuth: tls.RequestClientCert,
    74  	}
    75  
    76  	ts.StartTLS()
    77  	defer ts.Close()
    78  
    79  	resp, err := client.Get(ts.URL)
    80  	suite.Require().NoError(err)
    81  	suite.Require().NotNil(resp)
    82  }
    83  
    84  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_New() {
    85  	provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator)
    86  	suite.Require().NotNil(provider)
    87  }
    88  
    89  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_Name() {
    90  	provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator)
    91  
    92  	name := provider.Name()
    93  
    94  	suite.Require().Equal(name, "MtlsTokenAuthorizationProvider")
    95  }
    96  
    97  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_Matches() {
    98  	provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator)
    99  
   100  	matches := provider.Matches(auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{}))
   101  	suite.Require().Equal(matches, true)
   102  }
   103  
   104  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DoesNotMatchWhenBasicCredentialsInContext() {
   105  	provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator)
   106  
   107  	matches := provider.Matches(auth.SaveToContext(context.Background(), &auth.BasicCredentials{}))
   108  	suite.Require().Equal(matches, false)
   109  }
   110  
   111  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DoesNotMatchNoCredentialsInContext() {
   112  	provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator)
   113  
   114  	matches := provider.Matches(context.TODO())
   115  	suite.Require().Equal(matches, false)
   116  }
   117  
   118  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorization() {
   119  	provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, false))
   120  
   121  	ctx := auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{
   122  		ClientID:          oauthCfg.ClientID,
   123  		TokenURL:          oauthCfg.TokenEndpointProtocol + "://" + oauthCfg.TokenBaseURL + oauthCfg.TokenPath,
   124  		Scopes:            strings.Join(oauthCfg.ScopesClaim, " "),
   125  		AdditionalHeaders: map[string]string{oauthCfg.TenantHeaderName: tenant},
   126  	})
   127  	authorization, err := provider.GetAuthorization(ctx)
   128  
   129  	suite.Require().NoError(err)
   130  	suite.Require().NotEmpty(authorization)
   131  
   132  	suite.Require().Equal("Bearer "+fakeTkn, authorization)
   133  }
   134  
   135  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenRequestFails() {
   136  	provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true))
   137  
   138  	ctx := auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{
   139  		ClientID:          oauthCfg.ClientID,
   140  		TokenURL:          oauthCfg.TokenEndpointProtocol + "://" + oauthCfg.TokenBaseURL + oauthCfg.TokenPath,
   141  		Scopes:            strings.Join(oauthCfg.ScopesClaim, " "),
   142  		AdditionalHeaders: map[string]string{oauthCfg.TenantHeaderName: tenant},
   143  	})
   144  	authorization, err := provider.GetAuthorization(ctx)
   145  	suite.Require().Error(err)
   146  	suite.Require().Contains(err.Error(), "error")
   147  	suite.Require().Empty(authorization)
   148  }
   149  
   150  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenNoCredentialsInContext() {
   151  	provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true))
   152  
   153  	authorization, err := provider.GetAuthorization(context.TODO())
   154  
   155  	suite.Require().Error(err)
   156  	suite.Require().True(apperrors.IsNotFoundError(err))
   157  	suite.Require().Empty(authorization)
   158  }
   159  
   160  func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenBasicCredentialsAreInContext() {
   161  	provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true))
   162  
   163  	authorization, err := provider.GetAuthorization(auth.SaveToContext(context.Background(), &auth.BasicCredentials{}))
   164  
   165  	suite.Require().Error(err)
   166  	suite.Require().Contains(err.Error(), "failed to cast credentials to mtls oauth credentials type")
   167  	suite.Require().Empty(authorization)
   168  }
   169  
   170  func getFakeCreator(oauthCfg oauth.Config, suite suite.Suite, shouldFail bool) auth.MtlsClientCreator {
   171  	return func(_ auth.CertificateCache, skipSSLValidation bool, timeout time.Duration, secretName string) *http.Client {
   172  		return &http.Client{
   173  			Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
   174  				suite.Require().Equal(req.URL.Host, oauthCfg.TokenBaseURL)
   175  				suite.Require().Equal(req.URL.Scheme, oauthCfg.TokenEndpointProtocol)
   176  				suite.Require().Equal(req.URL.Path, oauthCfg.TokenPath)
   177  				suite.Require().Equal(req.Header.Get(oauthCfg.TenantHeaderName), tenant)
   178  				suite.Require().Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded")
   179  
   180  				body, err := io.ReadAll(req.Body)
   181  				suite.Require().NoError(err)
   182  				suite.Require().NotNil(body)
   183  
   184  				form, err := url.ParseQuery(string(body))
   185  				suite.Require().NoError(err)
   186  				suite.Require().NotNil(form)
   187  
   188  				suite.Require().Equal(form.Get("grant_type"), "client_credentials")
   189  				suite.Require().Equal(form.Get("client_id"), oauthCfg.ClientID)
   190  				suite.Require().Equal(form.Get("scope"), strings.Join(oauthCfg.ScopesClaim, " "))
   191  
   192  				if shouldFail {
   193  					return nil, errors.New("error")
   194  				}
   195  
   196  				return &http.Response{
   197  					StatusCode: http.StatusOK,
   198  					Body:       io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"access_token": "%s"}`, fakeTkn)))),
   199  				}, nil
   200  			}),
   201  		}
   202  	}
   203  }
   204  
   205  type roundTripFunc func(req *http.Request) (*http.Response, error)
   206  
   207  func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
   208  	return f(req)
   209  }
   210  
   211  func testServerHandlerFunc(t *testing.T) http.HandlerFunc {
   212  	return func(w http.ResponseWriter, r *http.Request) {
   213  		_, err := fmt.Fprintln(w, "Hello, client")
   214  		require.NoError(t, err)
   215  	}
   216  }