github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/auth/mtls_token_provider.go (about)

     1  /*
     2  * Copyright 2020 The Compass Authors
     3  *
     4  * Licensed under the Apache License, Version 2.0 (the "License");
     5  * you may not use this file except in compliance with the License.
     6  * You may obtain a copy of the License at
     7  *
     8  *     http://www.apache.org/licenses/LICENSE-2.0
     9  *
    10  * Unless required by applicable law or agreed to in writing, software
    11  * distributed under the License is distributed on an "AS IS" BASIS,
    12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  * See the License for the specific language governing permissions and
    14  * limitations under the License.
    15   */
    16  
    17  package auth
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"encoding/json"
    23  	"io"
    24  	"net/http"
    25  	"net/url"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/kyma-incubator/compass/components/director/pkg/oauth"
    30  
    31  	httpdirector "github.com/kyma-incubator/compass/components/director/pkg/http"
    32  	"github.com/kyma-incubator/compass/components/director/pkg/log"
    33  	httputils "github.com/kyma-incubator/compass/components/system-broker/pkg/http"
    34  	"github.com/pkg/errors"
    35  )
    36  
    37  // CertificateCache missing godoc
    38  //go:generate mockery --name=CertificateCache --output=automock --outpkg=automock --case=underscore --disable-version-string
    39  type CertificateCache interface {
    40  	Get() map[string]*tls.Certificate
    41  }
    42  
    43  // MtlsClientCreator is a constructor function for http.Clients
    44  type MtlsClientCreator func(cache CertificateCache, skipSSLValidation bool, timeout time.Duration, secretName string) *http.Client
    45  
    46  // DefaultMtlsClientCreator is the default http client creator
    47  func DefaultMtlsClientCreator(cc CertificateCache, skipSSLValidation bool, timeout time.Duration, secretName string) *http.Client {
    48  	httpTransport := httpdirector.NewCorrelationIDTransport(httpdirector.NewHTTPTransportWrapper(&http.Transport{
    49  		TLSClientConfig: &tls.Config{
    50  			InsecureSkipVerify: skipSSLValidation,
    51  			GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
    52  				return cc.Get()[secretName], nil
    53  			},
    54  		},
    55  	}))
    56  
    57  	return &http.Client{
    58  		Transport: httpTransport,
    59  		Timeout:   timeout,
    60  	}
    61  }
    62  
    63  // mtlsTokenAuthorizationProvider presents a AuthorizationProvider implementation which crafts OAuth Bearer token values for the Authorization header using mtls http client
    64  type mtlsTokenAuthorizationProvider struct {
    65  	httpClient *http.Client
    66  }
    67  
    68  // NewMtlsTokenAuthorizationProvider constructs an TokenAuthorizationProvider
    69  func NewMtlsTokenAuthorizationProvider(oauthCfg oauth.Config, externalClientCertSecretName string, cache CertificateCache, creator MtlsClientCreator) *mtlsTokenAuthorizationProvider {
    70  	return &mtlsTokenAuthorizationProvider{
    71  		httpClient: creator(cache, oauthCfg.SkipSSLValidation, oauthCfg.TokenRequestTimeout, externalClientCertSecretName),
    72  	}
    73  }
    74  
    75  // NewMtlsTokenAuthorizationProviderWithClient constructs an TokenAuthorizationProvider using the provided mtls client
    76  func NewMtlsTokenAuthorizationProviderWithClient(client *http.Client) *mtlsTokenAuthorizationProvider {
    77  	return &mtlsTokenAuthorizationProvider{
    78  		httpClient: client,
    79  	}
    80  }
    81  
    82  // Name specifies the name of the AuthorizationProvider
    83  func (p *mtlsTokenAuthorizationProvider) Name() string {
    84  	return "MtlsTokenAuthorizationProvider"
    85  }
    86  
    87  // Matches contains the logic for matching the AuthorizationProvider
    88  func (p *mtlsTokenAuthorizationProvider) Matches(ctx context.Context) bool {
    89  	credentials, err := LoadFromContext(ctx)
    90  	if err != nil {
    91  		return false
    92  	}
    93  
    94  	return credentials.Type() == OAuthMtlsCredentialType
    95  }
    96  
    97  // GetAuthorization crafts an OAuth Bearer token to inject as part of the executing request
    98  func (p *mtlsTokenAuthorizationProvider) GetAuthorization(ctx context.Context) (string, error) {
    99  	log.C(ctx).Debug("Getting new token...")
   100  
   101  	credentials, err := LoadFromContext(ctx)
   102  	if err != nil {
   103  		return "", err
   104  	}
   105  
   106  	mtlsCredentials, ok := credentials.Get().(*OAuthMtlsCredentials)
   107  	if !ok {
   108  		return "", errors.New("failed to cast credentials to mtls oauth credentials type")
   109  	}
   110  
   111  	token, err := p.getToken(ctx, mtlsCredentials)
   112  	if err != nil {
   113  		return "", err
   114  	}
   115  
   116  	return "Bearer " + token.AccessToken, nil
   117  }
   118  
   119  func (p *mtlsTokenAuthorizationProvider) getToken(ctx context.Context, credentials *OAuthMtlsCredentials) (httputils.Token, error) {
   120  	form := url.Values{}
   121  	form.Add("grant_type", "client_credentials")
   122  	form.Add("client_id", credentials.ClientID)
   123  	form.Add("scope", credentials.Scopes)
   124  
   125  	body := strings.NewReader(form.Encode())
   126  	request, err := http.NewRequest(http.MethodPost, credentials.TokenURL, body)
   127  	if err != nil {
   128  		return httputils.Token{}, errors.Wrap(err, "Failed to create authorisation token request")
   129  	}
   130  
   131  	request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   132  	if credentials.AdditionalHeaders != nil {
   133  		for headerName, headerValue := range credentials.AdditionalHeaders {
   134  			request.Header.Set(headerName, headerValue)
   135  		}
   136  	}
   137  
   138  	response, err := p.httpClient.Do(request)
   139  	if err != nil {
   140  		return httputils.Token{}, errors.Wrap(err, "while send request to token endpoint")
   141  	}
   142  	defer func() {
   143  		if err := response.Body.Close(); err != nil {
   144  			log.C(ctx).Warn("Cannot close connection body inside oauth client")
   145  		}
   146  	}()
   147  
   148  	respBody, err := io.ReadAll(response.Body)
   149  	if err != nil {
   150  		return httputils.Token{}, errors.Wrapf(err, "while reading token response body from %q", credentials.TokenURL)
   151  	}
   152  
   153  	if response.StatusCode != http.StatusOK {
   154  		return httputils.Token{}, errors.Wrapf(err, "oauth server returned unexpected status code %d and body: %s", response.StatusCode, respBody)
   155  	}
   156  
   157  	tokenResponse := httputils.Token{}
   158  	err = json.Unmarshal(respBody, &tokenResponse)
   159  	if err != nil {
   160  		return httputils.Token{}, errors.Wrap(err, "while unmarshalling token response body")
   161  	}
   162  
   163  	if tokenResponse.AccessToken == "" {
   164  		return httputils.Token{}, errors.New("while fetching token: access token from oauth client is empty")
   165  	}
   166  
   167  	log.C(ctx).Debug("Successfully unmarshal response oauth token")
   168  	return tokenResponse, nil
   169  }