github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/lorry/engines/kafka/sasl_oauthbearer.go (about)

     1  /*
     2  Copyright 2021 The Dapr Authors
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6      http://www.apache.org/licenses/LICENSE-2.0
     7  Unless required by applicable law or agreed to in writing, software
     8  distributed under the License is distributed on an "AS IS" BASIS,
     9  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  See the License for the specific language governing permissions and
    11  limitations under the License.
    12  */
    13  
    14  package kafka
    15  
    16  import (
    17  	ctx "context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"encoding/pem"
    21  	"fmt"
    22  	"net/http"
    23  	"time"
    24  
    25  	"github.com/Shopify/sarama"
    26  	"golang.org/x/oauth2"
    27  	ccred "golang.org/x/oauth2/clientcredentials"
    28  )
    29  
    30  type OAuthTokenSource struct {
    31  	CachedToken   oauth2.Token
    32  	Extensions    map[string]string
    33  	TokenEndpoint oauth2.Endpoint
    34  	ClientID      string
    35  	ClientSecret  string
    36  	Scopes        []string
    37  	httpClient    *http.Client
    38  	trustedCas    []*x509.Certificate
    39  	skipCaVerify  bool
    40  }
    41  
    42  func newOAuthTokenSource(oidcTokenEndpoint, oidcClientID, oidcClientSecret string, oidcScopes []string) OAuthTokenSource {
    43  	return OAuthTokenSource{TokenEndpoint: oauth2.Endpoint{TokenURL: oidcTokenEndpoint}, ClientID: oidcClientID, ClientSecret: oidcClientSecret, Scopes: oidcScopes}
    44  }
    45  
    46  var tokenRequestTimeout, _ = time.ParseDuration("30s")
    47  
    48  func (ts *OAuthTokenSource) addCa(caPem string) error {
    49  	pemBytes := []byte(caPem)
    50  
    51  	block, _ := pem.Decode(pemBytes)
    52  
    53  	if block == nil || block.Type != "CERTIFICATE" {
    54  		return fmt.Errorf("PEM data not valid or not of a valid type (CERTIFICATE)")
    55  	}
    56  
    57  	caCert, err := x509.ParseCertificate(block.Bytes)
    58  	if err != nil {
    59  		return fmt.Errorf("error parsing PEM certificate: %w", err)
    60  	}
    61  
    62  	if ts.trustedCas == nil {
    63  		ts.trustedCas = make([]*x509.Certificate, 0)
    64  	}
    65  	ts.trustedCas = append(ts.trustedCas, caCert)
    66  
    67  	return nil
    68  }
    69  
    70  func (ts *OAuthTokenSource) configureClient() {
    71  	if ts.httpClient != nil {
    72  		return
    73  	}
    74  
    75  	tlsConfig := &tls.Config{
    76  		MinVersion:         tls.VersionTLS12,
    77  		InsecureSkipVerify: ts.skipCaVerify, //nolint:gosec
    78  	}
    79  
    80  	if ts.trustedCas != nil {
    81  		caPool, err := x509.SystemCertPool()
    82  		if err != nil {
    83  			caPool = x509.NewCertPool()
    84  		}
    85  
    86  		for _, c := range ts.trustedCas {
    87  			caPool.AddCert(c)
    88  		}
    89  		tlsConfig.RootCAs = caPool
    90  	}
    91  
    92  	ts.httpClient = &http.Client{
    93  		Transport: &http.Transport{
    94  			TLSClientConfig: tlsConfig,
    95  		},
    96  	}
    97  }
    98  
    99  func (ts *OAuthTokenSource) Token() (*sarama.AccessToken, error) {
   100  	if ts.CachedToken.Valid() {
   101  		return ts.asSaramaToken(), nil
   102  	}
   103  
   104  	if ts.TokenEndpoint.TokenURL == "" || ts.ClientID == "" || ts.ClientSecret == "" {
   105  		return nil, fmt.Errorf("cannot generate token, OAuthTokenSource not fully configured")
   106  	}
   107  
   108  	oidcCfg := ccred.Config{ClientID: ts.ClientID, ClientSecret: ts.ClientSecret, Scopes: ts.Scopes, TokenURL: ts.TokenEndpoint.TokenURL, AuthStyle: ts.TokenEndpoint.AuthStyle}
   109  
   110  	timeoutCtx, cancel := ctx.WithTimeout(ctx.TODO(), tokenRequestTimeout)
   111  	defer cancel()
   112  
   113  	ts.configureClient()
   114  
   115  	timeoutCtx = ctx.WithValue(timeoutCtx, oauth2.HTTPClient, ts.httpClient)
   116  
   117  	token, err := oidcCfg.Token(timeoutCtx)
   118  	if err != nil {
   119  		return nil, fmt.Errorf("error generating oauth2 token: %w", err)
   120  	}
   121  
   122  	ts.CachedToken = *token
   123  	return ts.asSaramaToken(), nil
   124  }
   125  
   126  func (ts *OAuthTokenSource) asSaramaToken() *sarama.AccessToken {
   127  	return &(sarama.AccessToken{Token: ts.CachedToken.AccessToken, Extensions: ts.Extensions})
   128  }