k8s.io/client-go@v0.31.1/plugin/pkg/client/auth/oidc/oidc.go (about)

     1  /*
     2  Copyright 2016 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package oidc
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"golang.org/x/oauth2"
    32  	"k8s.io/apimachinery/pkg/util/net"
    33  	restclient "k8s.io/client-go/rest"
    34  	"k8s.io/klog/v2"
    35  )
    36  
    37  const (
    38  	cfgIssuerURL                = "idp-issuer-url"
    39  	cfgClientID                 = "client-id"
    40  	cfgClientSecret             = "client-secret"
    41  	cfgCertificateAuthority     = "idp-certificate-authority"
    42  	cfgCertificateAuthorityData = "idp-certificate-authority-data"
    43  	cfgIDToken                  = "id-token"
    44  	cfgRefreshToken             = "refresh-token"
    45  
    46  	// Unused. Scopes aren't sent during refreshing.
    47  	cfgExtraScopes = "extra-scopes"
    48  )
    49  
    50  func init() {
    51  	if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil {
    52  		klog.Fatalf("Failed to register oidc auth plugin: %v", err)
    53  	}
    54  }
    55  
    56  // expiryDelta determines how earlier a token should be considered
    57  // expired than its actual expiration time. It is used to avoid late
    58  // expirations due to client-server time mismatches.
    59  //
    60  // NOTE(ericchiang): this is take from golang.org/x/oauth2
    61  const expiryDelta = 10 * time.Second
    62  
    63  var cache = newClientCache()
    64  
    65  // Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. This ensures
    66  // current requests from different clients don't concurrently attempt to refresh the same
    67  // set of credentials.
    68  type clientCache struct {
    69  	mu sync.RWMutex
    70  
    71  	cache map[cacheKey]*oidcAuthProvider
    72  }
    73  
    74  func newClientCache() *clientCache {
    75  	return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)}
    76  }
    77  
    78  type cacheKey struct {
    79  	clusterAddress string
    80  	// Canonical issuer URL string of the provider.
    81  	issuerURL string
    82  	clientID  string
    83  }
    84  
    85  func (c *clientCache) getClient(clusterAddress, issuer, clientID string) (*oidcAuthProvider, bool) {
    86  	c.mu.RLock()
    87  	defer c.mu.RUnlock()
    88  	client, ok := c.cache[cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID}]
    89  	return client, ok
    90  }
    91  
    92  // setClient attempts to put the client in the cache but may return any clients
    93  // with the same keys set before. This is so there's only ever one client for a provider.
    94  func (c *clientCache) setClient(clusterAddress, issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
    95  	c.mu.Lock()
    96  	defer c.mu.Unlock()
    97  	key := cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID}
    98  
    99  	// If another client has already initialized a client for the given provider we want
   100  	// to use that client instead of the one we're trying to set. This is so all transports
   101  	// share a client and can coordinate around the same mutex when refreshing and writing
   102  	// to the kubeconfig.
   103  	if oldClient, ok := c.cache[key]; ok {
   104  		return oldClient
   105  	}
   106  
   107  	c.cache[key] = client
   108  	return client
   109  }
   110  
   111  func newOIDCAuthProvider(clusterAddress string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
   112  	issuer := cfg[cfgIssuerURL]
   113  	if issuer == "" {
   114  		return nil, fmt.Errorf("Must provide %s", cfgIssuerURL)
   115  	}
   116  
   117  	clientID := cfg[cfgClientID]
   118  	if clientID == "" {
   119  		return nil, fmt.Errorf("Must provide %s", cfgClientID)
   120  	}
   121  
   122  	// Check cache for existing provider.
   123  	if provider, ok := cache.getClient(clusterAddress, issuer, clientID); ok {
   124  		return provider, nil
   125  	}
   126  
   127  	if len(cfg[cfgExtraScopes]) > 0 {
   128  		klog.V(2).Infof("%s auth provider field depricated, refresh request don't send scopes",
   129  			cfgExtraScopes)
   130  	}
   131  
   132  	var certAuthData []byte
   133  	var err error
   134  	if cfg[cfgCertificateAuthorityData] != "" {
   135  		certAuthData, err = base64.StdEncoding.DecodeString(cfg[cfgCertificateAuthorityData])
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	clientConfig := restclient.Config{
   142  		TLSClientConfig: restclient.TLSClientConfig{
   143  			CAFile: cfg[cfgCertificateAuthority],
   144  			CAData: certAuthData,
   145  		},
   146  	}
   147  
   148  	trans, err := restclient.TransportFor(&clientConfig)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	hc := &http.Client{Transport: trans}
   153  
   154  	provider := &oidcAuthProvider{
   155  		client:    hc,
   156  		now:       time.Now,
   157  		cfg:       cfg,
   158  		persister: persister,
   159  	}
   160  
   161  	return cache.setClient(clusterAddress, issuer, clientID, provider), nil
   162  }
   163  
   164  type oidcAuthProvider struct {
   165  	client *http.Client
   166  
   167  	// Method for determining the current time.
   168  	now func() time.Time
   169  
   170  	// Mutex guards persisting to the kubeconfig file and allows synchronized
   171  	// updates to the in-memory config. It also ensures concurrent calls to
   172  	// the RoundTripper only trigger a single refresh request.
   173  	mu        sync.Mutex
   174  	cfg       map[string]string
   175  	persister restclient.AuthProviderConfigPersister
   176  }
   177  
   178  func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
   179  	return &roundTripper{
   180  		wrapped:  rt,
   181  		provider: p,
   182  	}
   183  }
   184  
   185  func (p *oidcAuthProvider) Login() error {
   186  	return errors.New("not yet implemented")
   187  }
   188  
   189  type roundTripper struct {
   190  	provider *oidcAuthProvider
   191  	wrapped  http.RoundTripper
   192  }
   193  
   194  var _ net.RoundTripperWrapper = &roundTripper{}
   195  
   196  func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   197  	if len(req.Header.Get("Authorization")) != 0 {
   198  		return r.wrapped.RoundTrip(req)
   199  	}
   200  	token, err := r.provider.idToken()
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	// shallow copy of the struct
   206  	r2 := new(http.Request)
   207  	*r2 = *req
   208  	// deep copy of the Header so we don't modify the original
   209  	// request's Header (as per RoundTripper contract).
   210  	r2.Header = make(http.Header)
   211  	for k, s := range req.Header {
   212  		r2.Header[k] = s
   213  	}
   214  	r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
   215  
   216  	return r.wrapped.RoundTrip(r2)
   217  }
   218  
   219  func (r *roundTripper) WrappedRoundTripper() http.RoundTripper { return r.wrapped }
   220  
   221  func (p *oidcAuthProvider) idToken() (string, error) {
   222  	p.mu.Lock()
   223  	defer p.mu.Unlock()
   224  
   225  	if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
   226  		valid, err := idTokenExpired(p.now, idToken)
   227  		if err != nil {
   228  			return "", err
   229  		}
   230  		if valid {
   231  			// If the cached id token is still valid use it.
   232  			return idToken, nil
   233  		}
   234  	}
   235  
   236  	// Try to request a new token using the refresh token.
   237  	rt, ok := p.cfg[cfgRefreshToken]
   238  	if !ok || len(rt) == 0 {
   239  		return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
   240  	}
   241  
   242  	// Determine provider's OAuth2 token endpoint.
   243  	tokenURL, err := tokenEndpoint(p.client, p.cfg[cfgIssuerURL])
   244  	if err != nil {
   245  		return "", err
   246  	}
   247  
   248  	config := oauth2.Config{
   249  		ClientID:     p.cfg[cfgClientID],
   250  		ClientSecret: p.cfg[cfgClientSecret],
   251  		Endpoint:     oauth2.Endpoint{TokenURL: tokenURL},
   252  	}
   253  
   254  	ctx := context.WithValue(context.Background(), oauth2.HTTPClient, p.client)
   255  	token, err := config.TokenSource(ctx, &oauth2.Token{RefreshToken: rt}).Token()
   256  	if err != nil {
   257  		return "", fmt.Errorf("failed to refresh token: %v", err)
   258  	}
   259  
   260  	idToken, ok := token.Extra("id_token").(string)
   261  	if !ok {
   262  		// id_token isn't a required part of a refresh token response, so some
   263  		// providers (Okta) don't return this value.
   264  		//
   265  		// See https://github.com/kubernetes/kubernetes/issues/36847
   266  		return "", fmt.Errorf("token response did not contain an id_token, either the scope \"openid\" wasn't requested upon login, or the provider doesn't support id_tokens as part of the refresh response")
   267  	}
   268  
   269  	// Create a new config to persist.
   270  	newCfg := make(map[string]string)
   271  	for key, val := range p.cfg {
   272  		newCfg[key] = val
   273  	}
   274  
   275  	// Update the refresh token if the server returned another one.
   276  	if token.RefreshToken != "" && token.RefreshToken != rt {
   277  		newCfg[cfgRefreshToken] = token.RefreshToken
   278  	}
   279  	newCfg[cfgIDToken] = idToken
   280  
   281  	// Persist new config and if successful, update the in memory config.
   282  	if err = p.persister.Persist(newCfg); err != nil {
   283  		return "", fmt.Errorf("could not persist new tokens: %v", err)
   284  	}
   285  	p.cfg = newCfg
   286  
   287  	return idToken, nil
   288  }
   289  
   290  // tokenEndpoint uses OpenID Connect discovery to determine the OAuth2 token
   291  // endpoint for the provider, the endpoint the client will use the refresh
   292  // token against.
   293  func tokenEndpoint(client *http.Client, issuer string) (string, error) {
   294  	// Well known URL for getting OpenID Connect metadata.
   295  	//
   296  	// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
   297  	wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
   298  	resp, err := client.Get(wellKnown)
   299  	if err != nil {
   300  		return "", err
   301  	}
   302  	defer resp.Body.Close()
   303  
   304  	body, err := io.ReadAll(resp.Body)
   305  	if err != nil {
   306  		return "", err
   307  	}
   308  	if resp.StatusCode != http.StatusOK {
   309  		// Don't produce an error that's too huge (e.g. if we get HTML back for some reason).
   310  		const n = 80
   311  		if len(body) > n {
   312  			body = append(body[:n], []byte("...")...)
   313  		}
   314  		return "", fmt.Errorf("oidc: failed to query metadata endpoint %s: %q", resp.Status, body)
   315  	}
   316  
   317  	// Metadata object. We only care about the token_endpoint, the thing endpoint
   318  	// we'll be refreshing against.
   319  	//
   320  	// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
   321  	var metadata struct {
   322  		TokenURL string `json:"token_endpoint"`
   323  	}
   324  	if err := json.Unmarshal(body, &metadata); err != nil {
   325  		return "", fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
   326  	}
   327  	if metadata.TokenURL == "" {
   328  		return "", fmt.Errorf("oidc: discovery object doesn't contain a token_endpoint")
   329  	}
   330  	return metadata.TokenURL, nil
   331  }
   332  
   333  func idTokenExpired(now func() time.Time, idToken string) (bool, error) {
   334  	parts := strings.Split(idToken, ".")
   335  	if len(parts) != 3 {
   336  		return false, fmt.Errorf("ID Token is not a valid JWT")
   337  	}
   338  
   339  	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
   340  	if err != nil {
   341  		return false, err
   342  	}
   343  	var claims struct {
   344  		Expiry jsonTime `json:"exp"`
   345  	}
   346  	if err := json.Unmarshal(payload, &claims); err != nil {
   347  		return false, fmt.Errorf("parsing claims: %v", err)
   348  	}
   349  
   350  	return now().Add(expiryDelta).Before(time.Time(claims.Expiry)), nil
   351  }
   352  
   353  // jsonTime is a json.Unmarshaler that parses a unix timestamp.
   354  // Because JSON numbers don't differentiate between ints and floats,
   355  // we want to ensure we can parse either.
   356  type jsonTime time.Time
   357  
   358  func (j *jsonTime) UnmarshalJSON(b []byte) error {
   359  	var n json.Number
   360  	if err := json.Unmarshal(b, &n); err != nil {
   361  		return err
   362  	}
   363  	var unix int64
   364  
   365  	if t, err := n.Int64(); err == nil {
   366  		unix = t
   367  	} else {
   368  		f, err := n.Float64()
   369  		if err != nil {
   370  			return err
   371  		}
   372  		unix = int64(f)
   373  	}
   374  	*j = jsonTime(time.Unix(unix, 0))
   375  	return nil
   376  }
   377  
   378  func (j jsonTime) MarshalJSON() ([]byte, error) {
   379  	return json.Marshal(time.Time(j).Unix())
   380  }