k8s.io/client-go@v0.22.2/plugin/pkg/client/auth/azure/azure.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package azure
    18  
    19  import (
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"net/http"
    24  	"os"
    25  	"strconv"
    26  	"sync"
    27  
    28  	"github.com/Azure/go-autorest/autorest"
    29  	"github.com/Azure/go-autorest/autorest/adal"
    30  	"github.com/Azure/go-autorest/autorest/azure"
    31  	"k8s.io/klog/v2"
    32  
    33  	"k8s.io/apimachinery/pkg/util/net"
    34  	restclient "k8s.io/client-go/rest"
    35  )
    36  
    37  type configMode int
    38  
    39  const (
    40  	azureTokenKey = "azureTokenKey"
    41  	tokenType     = "Bearer"
    42  	authHeader    = "Authorization"
    43  
    44  	cfgClientID     = "client-id"
    45  	cfgTenantID     = "tenant-id"
    46  	cfgAccessToken  = "access-token"
    47  	cfgRefreshToken = "refresh-token"
    48  	cfgExpiresIn    = "expires-in"
    49  	cfgExpiresOn    = "expires-on"
    50  	cfgEnvironment  = "environment"
    51  	cfgApiserverID  = "apiserver-id"
    52  	cfgConfigMode   = "config-mode"
    53  
    54  	configModeDefault       configMode = 0
    55  	configModeOmitSPNPrefix configMode = 1
    56  )
    57  
    58  func init() {
    59  	if err := restclient.RegisterAuthProviderPlugin("azure", newAzureAuthProvider); err != nil {
    60  		klog.Fatalf("Failed to register azure auth plugin: %v", err)
    61  	}
    62  }
    63  
    64  var cache = newAzureTokenCache()
    65  
    66  type azureTokenCache struct {
    67  	lock  sync.Mutex
    68  	cache map[string]*azureToken
    69  }
    70  
    71  func newAzureTokenCache() *azureTokenCache {
    72  	return &azureTokenCache{cache: make(map[string]*azureToken)}
    73  }
    74  
    75  func (c *azureTokenCache) getToken(tokenKey string) *azureToken {
    76  	c.lock.Lock()
    77  	defer c.lock.Unlock()
    78  	return c.cache[tokenKey]
    79  }
    80  
    81  func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
    82  	c.lock.Lock()
    83  	defer c.lock.Unlock()
    84  	c.cache[tokenKey] = token
    85  }
    86  
    87  var warnOnce sync.Once
    88  
    89  func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
    90  	// deprecated in v1.22, remove in v1.25
    91  	// this should be updated to use klog.Warningf in v1.24 to more actively warn consumers
    92  	warnOnce.Do(func() {
    93  		klog.V(1).Infof(`WARNING: the azure auth plugin is deprecated in v1.22+, unavailable in v1.25+; use https://github.com/Azure/kubelogin instead.
    94  To learn more, consult https://kubernetes.io/docs/reference/access-authn-authz/authentication/#client-go-credential-plugins`)
    95  	})
    96  
    97  	var (
    98  		ts          tokenSource
    99  		environment azure.Environment
   100  		err         error
   101  		mode        configMode
   102  	)
   103  
   104  	environment, err = azure.EnvironmentFromName(cfg[cfgEnvironment])
   105  	if err != nil {
   106  		environment = azure.PublicCloud
   107  	}
   108  
   109  	mode = configModeDefault
   110  	if cfg[cfgConfigMode] != "" {
   111  		configModeInt, err := strconv.Atoi(cfg[cfgConfigMode])
   112  		if err != nil {
   113  			return nil, fmt.Errorf("failed to parse %s, error: %s", cfgConfigMode, err)
   114  		}
   115  		mode = configMode(configModeInt)
   116  		switch mode {
   117  		case configModeOmitSPNPrefix:
   118  		case configModeDefault:
   119  		default:
   120  			return nil, fmt.Errorf("%s:%s is not a valid mode", cfgConfigMode, cfg[cfgConfigMode])
   121  		}
   122  	}
   123  	ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID], mode)
   124  	if err != nil {
   125  		return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
   126  	}
   127  	cacheSource := newAzureTokenSource(ts, cache, cfg, mode, persister)
   128  
   129  	return &azureAuthProvider{
   130  		tokenSource: cacheSource,
   131  	}, nil
   132  }
   133  
   134  type azureAuthProvider struct {
   135  	tokenSource tokenSource
   136  }
   137  
   138  func (p *azureAuthProvider) Login() error {
   139  	return errors.New("not yet implemented")
   140  }
   141  
   142  func (p *azureAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
   143  	return &azureRoundTripper{
   144  		tokenSource:  p.tokenSource,
   145  		roundTripper: rt,
   146  	}
   147  }
   148  
   149  type azureRoundTripper struct {
   150  	tokenSource  tokenSource
   151  	roundTripper http.RoundTripper
   152  }
   153  
   154  var _ net.RoundTripperWrapper = &azureRoundTripper{}
   155  
   156  func (r *azureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   157  	if len(req.Header.Get(authHeader)) != 0 {
   158  		return r.roundTripper.RoundTrip(req)
   159  	}
   160  
   161  	token, err := r.tokenSource.Token()
   162  	if err != nil {
   163  		klog.Errorf("Failed to acquire a token: %v", err)
   164  		return nil, fmt.Errorf("acquiring a token for authorization header: %v", err)
   165  	}
   166  
   167  	// clone the request in order to avoid modifying the headers of the original request
   168  	req2 := new(http.Request)
   169  	*req2 = *req
   170  	req2.Header = make(http.Header, len(req.Header))
   171  	for k, s := range req.Header {
   172  		req2.Header[k] = append([]string(nil), s...)
   173  	}
   174  
   175  	req2.Header.Set(authHeader, fmt.Sprintf("%s %s", tokenType, token.token.AccessToken))
   176  
   177  	return r.roundTripper.RoundTrip(req2)
   178  }
   179  
   180  func (r *azureRoundTripper) WrappedRoundTripper() http.RoundTripper { return r.roundTripper }
   181  
   182  type azureToken struct {
   183  	token       adal.Token
   184  	environment string
   185  	clientID    string
   186  	tenantID    string
   187  	apiserverID string
   188  }
   189  
   190  type tokenSource interface {
   191  	Token() (*azureToken, error)
   192  	Refresh(*azureToken) (*azureToken, error)
   193  }
   194  
   195  type azureTokenSource struct {
   196  	source     tokenSource
   197  	cache      *azureTokenCache
   198  	lock       sync.Mutex
   199  	configMode configMode
   200  	cfg        map[string]string
   201  	persister  restclient.AuthProviderConfigPersister
   202  }
   203  
   204  func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, configMode configMode, persister restclient.AuthProviderConfigPersister) tokenSource {
   205  	return &azureTokenSource{
   206  		source:     source,
   207  		cache:      cache,
   208  		cfg:        cfg,
   209  		persister:  persister,
   210  		configMode: configMode,
   211  	}
   212  }
   213  
   214  // Token fetches a token from the cache of configuration if present otherwise
   215  // acquires a new token from the configured source. Automatically refreshes
   216  // the token if expired.
   217  func (ts *azureTokenSource) Token() (*azureToken, error) {
   218  	ts.lock.Lock()
   219  	defer ts.lock.Unlock()
   220  
   221  	var err error
   222  	token := ts.cache.getToken(azureTokenKey)
   223  
   224  	if token != nil && !token.token.IsExpired() {
   225  		return token, nil
   226  	}
   227  
   228  	// retrieve from config if no cache
   229  	if token == nil {
   230  		tokenFromCfg, err := ts.retrieveTokenFromCfg()
   231  
   232  		if err == nil {
   233  			token = tokenFromCfg
   234  		}
   235  	}
   236  
   237  	if token != nil {
   238  		// cache and return if the token is as good
   239  		// avoids frequent persistor calls
   240  		if !token.token.IsExpired() {
   241  			ts.cache.setToken(azureTokenKey, token)
   242  			return token, nil
   243  		}
   244  
   245  		klog.V(4).Info("Refreshing token.")
   246  		tokenFromRefresh, err := ts.Refresh(token)
   247  		switch {
   248  		case err == nil:
   249  			token = tokenFromRefresh
   250  		case autorest.IsTokenRefreshError(err):
   251  			klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
   252  			// reset token to nil so that the token source will be used to acquire new
   253  			token = nil
   254  		default:
   255  			return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
   256  		}
   257  	}
   258  
   259  	if token == nil {
   260  		tokenFromSource, err := ts.source.Token()
   261  		if err != nil {
   262  			return nil, fmt.Errorf("failed acquiring new token: %v", err)
   263  		}
   264  		token = tokenFromSource
   265  	}
   266  
   267  	// sanity check
   268  	if token == nil {
   269  		return nil, fmt.Errorf("unable to acquire token")
   270  	}
   271  
   272  	// corner condition, newly got token is valid but expired
   273  	if token.token.IsExpired() {
   274  		return nil, fmt.Errorf("newly acquired token is expired")
   275  	}
   276  
   277  	err = ts.storeTokenInCfg(token)
   278  	if err != nil {
   279  		return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
   280  	}
   281  	ts.cache.setToken(azureTokenKey, token)
   282  
   283  	return token, nil
   284  }
   285  
   286  func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
   287  	accessToken := ts.cfg[cfgAccessToken]
   288  	if accessToken == "" {
   289  		return nil, fmt.Errorf("no access token in cfg: %s", cfgAccessToken)
   290  	}
   291  	refreshToken := ts.cfg[cfgRefreshToken]
   292  	if refreshToken == "" {
   293  		return nil, fmt.Errorf("no refresh token in cfg: %s", cfgRefreshToken)
   294  	}
   295  	environment := ts.cfg[cfgEnvironment]
   296  	if environment == "" {
   297  		return nil, fmt.Errorf("no environment in cfg: %s", cfgEnvironment)
   298  	}
   299  	clientID := ts.cfg[cfgClientID]
   300  	if clientID == "" {
   301  		return nil, fmt.Errorf("no client ID in cfg: %s", cfgClientID)
   302  	}
   303  	tenantID := ts.cfg[cfgTenantID]
   304  	if tenantID == "" {
   305  		return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
   306  	}
   307  	resourceID := ts.cfg[cfgApiserverID]
   308  	if resourceID == "" {
   309  		return nil, fmt.Errorf("no apiserver ID in cfg: %s", cfgApiserverID)
   310  	}
   311  	expiresIn := ts.cfg[cfgExpiresIn]
   312  	if expiresIn == "" {
   313  		return nil, fmt.Errorf("no expiresIn in cfg: %s", cfgExpiresIn)
   314  	}
   315  	expiresOn := ts.cfg[cfgExpiresOn]
   316  	if expiresOn == "" {
   317  		return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
   318  	}
   319  	tokenAudience := resourceID
   320  	if ts.configMode == configModeDefault {
   321  		tokenAudience = fmt.Sprintf("spn:%s", resourceID)
   322  	}
   323  
   324  	return &azureToken{
   325  		token: adal.Token{
   326  			AccessToken:  accessToken,
   327  			RefreshToken: refreshToken,
   328  			ExpiresIn:    json.Number(expiresIn),
   329  			ExpiresOn:    json.Number(expiresOn),
   330  			NotBefore:    json.Number(expiresOn),
   331  			Resource:     tokenAudience,
   332  			Type:         tokenType,
   333  		},
   334  		environment: environment,
   335  		clientID:    clientID,
   336  		tenantID:    tenantID,
   337  		apiserverID: resourceID,
   338  	}, nil
   339  }
   340  
   341  func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
   342  	newCfg := make(map[string]string)
   343  	newCfg[cfgAccessToken] = token.token.AccessToken
   344  	newCfg[cfgRefreshToken] = token.token.RefreshToken
   345  	newCfg[cfgEnvironment] = token.environment
   346  	newCfg[cfgClientID] = token.clientID
   347  	newCfg[cfgTenantID] = token.tenantID
   348  	newCfg[cfgApiserverID] = token.apiserverID
   349  	newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
   350  	newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
   351  	newCfg[cfgConfigMode] = strconv.Itoa(int(ts.configMode))
   352  
   353  	err := ts.persister.Persist(newCfg)
   354  	if err != nil {
   355  		return fmt.Errorf("persisting the configuration: %v", err)
   356  	}
   357  	ts.cfg = newCfg
   358  	return nil
   359  }
   360  
   361  func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
   362  	return ts.source.Refresh(token)
   363  }
   364  
   365  // refresh outdated token with adal.
   366  func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
   367  	env, err := azure.EnvironmentFromName(token.environment)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	var oauthConfig *adal.OAuthConfig
   373  	if ts.configMode == configModeOmitSPNPrefix {
   374  		oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, token.tenantID, nil)
   375  		if err != nil {
   376  			return nil, fmt.Errorf("building the OAuth configuration without api-version for token refresh: %v", err)
   377  		}
   378  	} else {
   379  		oauthConfig, err = adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
   380  		if err != nil {
   381  			return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
   382  		}
   383  	}
   384  
   385  	callback := func(t adal.Token) error {
   386  		return nil
   387  	}
   388  	spt, err := adal.NewServicePrincipalTokenFromManualToken(
   389  		*oauthConfig,
   390  		token.clientID,
   391  		token.apiserverID,
   392  		token.token,
   393  		callback)
   394  	if err != nil {
   395  		return nil, fmt.Errorf("creating new service principal for token refresh: %v", err)
   396  	}
   397  
   398  	if err := spt.Refresh(); err != nil {
   399  		// Caller expects IsTokenRefreshError(err) to trigger prompt.
   400  		return nil, fmt.Errorf("refreshing token: %w", err)
   401  	}
   402  
   403  	return &azureToken{
   404  		token:       spt.Token(),
   405  		environment: token.environment,
   406  		clientID:    token.clientID,
   407  		tenantID:    token.tenantID,
   408  		apiserverID: token.apiserverID,
   409  	}, nil
   410  }
   411  
   412  type azureTokenSourceDeviceCode struct {
   413  	environment azure.Environment
   414  	clientID    string
   415  	tenantID    string
   416  	apiserverID string
   417  	configMode  configMode
   418  }
   419  
   420  func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string, configMode configMode) (tokenSource, error) {
   421  	if clientID == "" {
   422  		return nil, errors.New("client-id is empty")
   423  	}
   424  	if tenantID == "" {
   425  		return nil, errors.New("tenant-id is empty")
   426  	}
   427  	if apiserverID == "" {
   428  		return nil, errors.New("apiserver-id is empty")
   429  	}
   430  	return &azureTokenSourceDeviceCode{
   431  		environment: environment,
   432  		clientID:    clientID,
   433  		tenantID:    tenantID,
   434  		apiserverID: apiserverID,
   435  		configMode:  configMode,
   436  	}, nil
   437  }
   438  
   439  func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
   440  	var (
   441  		oauthConfig *adal.OAuthConfig
   442  		err         error
   443  	)
   444  	if ts.configMode == configModeOmitSPNPrefix {
   445  		oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(ts.environment.ActiveDirectoryEndpoint, ts.tenantID, nil)
   446  		if err != nil {
   447  			return nil, fmt.Errorf("building the OAuth configuration without api-version for device code authentication: %v", err)
   448  		}
   449  	} else {
   450  		oauthConfig, err = adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
   451  		if err != nil {
   452  			return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
   453  		}
   454  	}
   455  	client := &autorest.Client{}
   456  	deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
   457  	if err != nil {
   458  		return nil, fmt.Errorf("initialing the device code authentication: %v", err)
   459  	}
   460  
   461  	_, err = fmt.Fprintln(os.Stderr, *deviceCode.Message)
   462  	if err != nil {
   463  		return nil, fmt.Errorf("prompting the device code message: %v", err)
   464  	}
   465  
   466  	token, err := adal.WaitForUserCompletion(client, deviceCode)
   467  	if err != nil {
   468  		return nil, fmt.Errorf("waiting for device code authentication to complete: %v", err)
   469  	}
   470  
   471  	return &azureToken{
   472  		token:       *token,
   473  		environment: ts.environment.Name,
   474  		clientID:    ts.clientID,
   475  		tenantID:    ts.tenantID,
   476  		apiserverID: ts.apiserverID,
   477  	}, nil
   478  }