github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/auth/mtls_token_provider_test.go (about) 1 /* 2 * Copyright 2020 The Compass Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package auth_test 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "errors" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/httptest" 28 "net/url" 29 "strings" 30 "testing" 31 "time" 32 33 "github.com/kyma-incubator/compass/components/director/pkg/apperrors" 34 "github.com/kyma-incubator/compass/components/director/pkg/auth" 35 "github.com/kyma-incubator/compass/components/director/pkg/auth/automock" 36 "github.com/kyma-incubator/compass/components/director/pkg/oauth" 37 "github.com/stretchr/testify/require" 38 "github.com/stretchr/testify/suite" 39 ) 40 41 const ( 42 fakeTkn = "fake-token" 43 tenant = "tenant42" 44 externalClientCertSecretName = "resource-name" 45 ) 46 47 var oauthCfg = oauth.Config{ 48 ClientID: "client-id", 49 TokenEndpointProtocol: "https", 50 TokenBaseURL: "test.mtls.domain.com", 51 TokenPath: "/cert/token", 52 ScopesClaim: []string{"my-scope"}, 53 TenantHeaderName: "x-tenant", 54 } 55 56 func TestMtlsTokenAuthorizationProviderTestSuite(t *testing.T) { 57 suite.Run(t, new(MtlsTokenAuthorizationProviderTestSuite)) 58 } 59 60 type MtlsTokenAuthorizationProviderTestSuite struct { 61 suite.Suite 62 } 63 64 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DefaultMtlsClientCreator() { 65 cache := &automock.CertificateCache{} 66 cache.On("Get").Return(map[string]*tls.Certificate{"resource-name": &tls.Certificate{}}, nil).Once() 67 defer cache.AssertExpectations(suite.T()) 68 69 client := auth.DefaultMtlsClientCreator(cache, true, time.Second, "resource-name") 70 71 ts := httptest.NewUnstartedServer(testServerHandlerFunc(suite.T())) 72 ts.TLS = &tls.Config{ 73 ClientAuth: tls.RequestClientCert, 74 } 75 76 ts.StartTLS() 77 defer ts.Close() 78 79 resp, err := client.Get(ts.URL) 80 suite.Require().NoError(err) 81 suite.Require().NotNil(resp) 82 } 83 84 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_New() { 85 provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator) 86 suite.Require().NotNil(provider) 87 } 88 89 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_Name() { 90 provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator) 91 92 name := provider.Name() 93 94 suite.Require().Equal(name, "MtlsTokenAuthorizationProvider") 95 } 96 97 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_Matches() { 98 provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator) 99 100 matches := provider.Matches(auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{})) 101 suite.Require().Equal(matches, true) 102 } 103 104 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DoesNotMatchWhenBasicCredentialsInContext() { 105 provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator) 106 107 matches := provider.Matches(auth.SaveToContext(context.Background(), &auth.BasicCredentials{})) 108 suite.Require().Equal(matches, false) 109 } 110 111 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_DoesNotMatchNoCredentialsInContext() { 112 provider := auth.NewMtlsTokenAuthorizationProvider(oauth.Config{}, externalClientCertSecretName, &automock.CertificateCache{}, auth.DefaultMtlsClientCreator) 113 114 matches := provider.Matches(context.TODO()) 115 suite.Require().Equal(matches, false) 116 } 117 118 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorization() { 119 provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, false)) 120 121 ctx := auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{ 122 ClientID: oauthCfg.ClientID, 123 TokenURL: oauthCfg.TokenEndpointProtocol + "://" + oauthCfg.TokenBaseURL + oauthCfg.TokenPath, 124 Scopes: strings.Join(oauthCfg.ScopesClaim, " "), 125 AdditionalHeaders: map[string]string{oauthCfg.TenantHeaderName: tenant}, 126 }) 127 authorization, err := provider.GetAuthorization(ctx) 128 129 suite.Require().NoError(err) 130 suite.Require().NotEmpty(authorization) 131 132 suite.Require().Equal("Bearer "+fakeTkn, authorization) 133 } 134 135 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenRequestFails() { 136 provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true)) 137 138 ctx := auth.SaveToContext(context.Background(), &auth.OAuthMtlsCredentials{ 139 ClientID: oauthCfg.ClientID, 140 TokenURL: oauthCfg.TokenEndpointProtocol + "://" + oauthCfg.TokenBaseURL + oauthCfg.TokenPath, 141 Scopes: strings.Join(oauthCfg.ScopesClaim, " "), 142 AdditionalHeaders: map[string]string{oauthCfg.TenantHeaderName: tenant}, 143 }) 144 authorization, err := provider.GetAuthorization(ctx) 145 suite.Require().Error(err) 146 suite.Require().Contains(err.Error(), "error") 147 suite.Require().Empty(authorization) 148 } 149 150 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenNoCredentialsInContext() { 151 provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true)) 152 153 authorization, err := provider.GetAuthorization(context.TODO()) 154 155 suite.Require().Error(err) 156 suite.Require().True(apperrors.IsNotFoundError(err)) 157 suite.Require().Empty(authorization) 158 } 159 160 func (suite *MtlsTokenAuthorizationProviderTestSuite) TestMtlsTokenAuthorizationProvider_GetAuthorizationFailsWhenBasicCredentialsAreInContext() { 161 provider := auth.NewMtlsTokenAuthorizationProvider(oauthCfg, externalClientCertSecretName, nil, getFakeCreator(oauthCfg, suite.Suite, true)) 162 163 authorization, err := provider.GetAuthorization(auth.SaveToContext(context.Background(), &auth.BasicCredentials{})) 164 165 suite.Require().Error(err) 166 suite.Require().Contains(err.Error(), "failed to cast credentials to mtls oauth credentials type") 167 suite.Require().Empty(authorization) 168 } 169 170 func getFakeCreator(oauthCfg oauth.Config, suite suite.Suite, shouldFail bool) auth.MtlsClientCreator { 171 return func(_ auth.CertificateCache, skipSSLValidation bool, timeout time.Duration, secretName string) *http.Client { 172 return &http.Client{ 173 Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { 174 suite.Require().Equal(req.URL.Host, oauthCfg.TokenBaseURL) 175 suite.Require().Equal(req.URL.Scheme, oauthCfg.TokenEndpointProtocol) 176 suite.Require().Equal(req.URL.Path, oauthCfg.TokenPath) 177 suite.Require().Equal(req.Header.Get(oauthCfg.TenantHeaderName), tenant) 178 suite.Require().Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") 179 180 body, err := io.ReadAll(req.Body) 181 suite.Require().NoError(err) 182 suite.Require().NotNil(body) 183 184 form, err := url.ParseQuery(string(body)) 185 suite.Require().NoError(err) 186 suite.Require().NotNil(form) 187 188 suite.Require().Equal(form.Get("grant_type"), "client_credentials") 189 suite.Require().Equal(form.Get("client_id"), oauthCfg.ClientID) 190 suite.Require().Equal(form.Get("scope"), strings.Join(oauthCfg.ScopesClaim, " ")) 191 192 if shouldFail { 193 return nil, errors.New("error") 194 } 195 196 return &http.Response{ 197 StatusCode: http.StatusOK, 198 Body: io.NopCloser(bytes.NewReader([]byte(fmt.Sprintf(`{"access_token": "%s"}`, fakeTkn)))), 199 }, nil 200 }), 201 } 202 } 203 } 204 205 type roundTripFunc func(req *http.Request) (*http.Response, error) 206 207 func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { 208 return f(req) 209 } 210 211 func testServerHandlerFunc(t *testing.T) http.HandlerFunc { 212 return func(w http.ResponseWriter, r *http.Request) { 213 _, err := fmt.Fprintln(w, "Hello, client") 214 require.NoError(t, err) 215 } 216 }