github.com/openshift-online/ocm-sdk-go@v0.1.473/authentication/transport_wrapper.go (about)

     1  /*
     2  Copyright (c) 2021 Red Hat, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8    http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // This file contains the implementations of a transport wrapper that implements token
    18  // authentication.
    19  
    20  package authentication
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"encoding/base64"
    26  	"encoding/json"
    27  	"fmt"
    28  	"io"
    29  	"net/http"
    30  	"net/url"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  
    36  	"github.com/cenkalti/backoff/v4"
    37  	"github.com/golang-jwt/jwt/v4"
    38  	"github.com/google/uuid"
    39  	"github.com/prometheus/client_golang/prometheus"
    40  
    41  	"github.com/openshift-online/ocm-sdk-go/internal"
    42  	"github.com/openshift-online/ocm-sdk-go/logging"
    43  )
    44  
    45  // Default values:
    46  const (
    47  	// #nosec G101
    48  	DefaultTokenURL     = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token"
    49  	DefaultClientID     = "cloud-services"
    50  	DefaultClientSecret = ""
    51  
    52  	FedRAMPTokenURL = "https://sso.openshiftusgov.com/realms/redhat-external/protocol/openid-connect/token"
    53  	FedRAMPClientID = "console-dot"
    54  )
    55  
    56  // DefaultScopes is the ser of scopes used by default:
    57  var DefaultScopes = []string{
    58  	"openid",
    59  }
    60  
    61  // TransportWrapperBuilder contains the data and logic needed to add to requests the authorization
    62  // token. Don't create objects of this type directly; use the NewTransportWrapper function instead.
    63  type TransportWrapperBuilder struct {
    64  	// Fields used for basic functionality:
    65  	logger            logging.Logger
    66  	tokenURL          string
    67  	clientID          string
    68  	clientSecret      string
    69  	user              string
    70  	password          string
    71  	tokens            []string
    72  	scopes            []string
    73  	agent             string
    74  	trustedCAs        []interface{}
    75  	insecure          bool
    76  	transportWrappers []func(http.RoundTripper) http.RoundTripper
    77  
    78  	// Fields used for metrics:
    79  	metricsSubsystem  string
    80  	metricsRegisterer prometheus.Registerer
    81  }
    82  
    83  // TransportWrapper contains the data and logic needed to wrap an HTTP round tripper with another
    84  // one that adds authorization tokens to requests.
    85  type TransportWrapper struct {
    86  	// Fields used for basic functionality:
    87  	logger                logging.Logger
    88  	clientID              string
    89  	clientSecret          string
    90  	user                  string
    91  	password              string
    92  	scopes                []string
    93  	agent                 string
    94  	clientSelector        *internal.ClientSelector
    95  	tokenURL              string
    96  	tokenServer           *internal.ServerAddress
    97  	tokenMutex            *sync.Mutex
    98  	tokenParser           *jwt.Parser
    99  	accessToken           *tokenInfo
   100  	refreshToken          *tokenInfo
   101  	pullSecretAccessToken *tokenInfo
   102  
   103  	// Fields used for metrics:
   104  	metricsSubsystem    string
   105  	metricsRegisterer   prometheus.Registerer
   106  	tokenCountMetric    *prometheus.CounterVec
   107  	tokenDurationMetric *prometheus.HistogramVec
   108  }
   109  
   110  // roundTripper is a round tripper that adds authorization tokens to requests.
   111  type roundTripper struct {
   112  	owner     *TransportWrapper
   113  	logger    logging.Logger
   114  	transport http.RoundTripper
   115  }
   116  
   117  // Make sure that we implement the interface:
   118  var _ http.RoundTripper = (*roundTripper)(nil)
   119  
   120  // NewTransportWrapper creates a new builder that can then be used to configure and create a new
   121  // authentication round tripper.
   122  func NewTransportWrapper() *TransportWrapperBuilder {
   123  	return &TransportWrapperBuilder{
   124  		metricsRegisterer: prometheus.DefaultRegisterer,
   125  	}
   126  }
   127  
   128  // Logger sets the logger that will be used by the wrapper and by the transports that it creates.
   129  func (b *TransportWrapperBuilder) Logger(value logging.Logger) *TransportWrapperBuilder {
   130  	b.logger = value
   131  	return b
   132  }
   133  
   134  // TokenURL sets the URL that will be used to request OpenID access tokens. The default is
   135  // `https://sso.redhat.com/auth/realms/cloud-services/protocol/openid-connect/token`.
   136  func (b *TransportWrapperBuilder) TokenURL(url string) *TransportWrapperBuilder {
   137  	b.tokenURL = url
   138  	return b
   139  }
   140  
   141  // Client sets OpenID client identifier and secret that will be used to request OpenID tokens. The
   142  // default identifier is `cloud-services`. The default secret is the empty string. When these two
   143  // values are provided and no user name and password is provided, the round trippers will use the
   144  // client credentials grant to obtain the token. For example, to create a connection using the
   145  // client credentials grant do the following:
   146  //
   147  //	// Use the client credentials grant:
   148  //	wrapper, err := authentication.NewTransportWrapper().
   149  //		Client("myclientid", "myclientsecret").
   150  //		Build()
   151  //
   152  // Note that some OpenID providers (Keycloak, for example) require the client identifier also for
   153  // the resource owner password grant. In that case use the set only the identifier, and let the
   154  // secret blank. For example:
   155  //
   156  //	// Use the resource owner password grant:
   157  //	wrapper, err := authentication.NewTransportWrapper().
   158  //		User("myuser", "mypassword").
   159  //		Client("myclientid", "").
   160  //		Build()
   161  //
   162  // Note the empty client secret.
   163  func (b *TransportWrapperBuilder) Client(id string, secret string) *TransportWrapperBuilder {
   164  	b.clientID = id
   165  	b.clientSecret = secret
   166  	return b
   167  }
   168  
   169  // User sets the user name and password that will be used to request OpenID access tokens. When
   170  // these two values are provided the round trippers will use the resource owner password grant type
   171  // to obtain the token. For example:
   172  //
   173  //	// Use the resource owner password grant:
   174  //	wrapper, err := authentication.NewTransportWrapper().
   175  //		User("myuser", "mypassword").
   176  //		Build()
   177  //
   178  // Note that some OpenID providers (Keycloak, for example) require the client identifier also for
   179  // the resource owner password grant. In that case use the set only the identifier, and let the
   180  // secret blank. For example:
   181  //
   182  //	// Use the resource owner password grant:
   183  //	wrapper, err := authentication.NewConnectionBuilder().
   184  //		User("myuser", "mypassword").
   185  //		Client("myclientid", "").
   186  //		Build()
   187  //
   188  // Note the empty client secret.
   189  func (b *TransportWrapperBuilder) User(name string, password string) *TransportWrapperBuilder {
   190  	b.user = name
   191  	b.password = password
   192  	return b
   193  }
   194  
   195  // Scopes sets the OpenID scopes that will be included in the token request. The default is to use
   196  // the `openid` scope. If this method is used then that default will be completely replaced, so you
   197  // will need to specify it explicitly if you want to use it. For example, if you want to add the
   198  // scope 'myscope' without loosing the default you will have to do something like this:
   199  //
   200  //	// Create a wrapper with the default 'openid' scope and some additional scopes:
   201  //	wrapper, err := authentication.NewTransportWrapper().
   202  //		User("myuser", "mypassword").
   203  //		Scopes("openid", "myscope", "yourscope").
   204  //		Build()
   205  //
   206  // If you just want to use the default 'openid' then there is no need to use this method.
   207  func (b *TransportWrapperBuilder) Scopes(values ...string) *TransportWrapperBuilder {
   208  	b.scopes = make([]string, len(values))
   209  	copy(b.scopes, values)
   210  	return b
   211  }
   212  
   213  // Tokens sets the OpenID tokens that will be used to authenticate. Multiple types of tokens are
   214  // accepted, and used according to their type. For example, you can pass a single access token, or
   215  // an access token and a refresh token, or just a refresh token. If no token is provided then the
   216  // round trippers will the user name and password or the client identifier and client secret (see
   217  // the User and Client methods) to request new ones.
   218  //
   219  // If the wrapper is created with these tokens and no user or client credentials, it will stop
   220  // working when both tokens expire. That can happen, for example, if the connection isn't used for a
   221  // period of time longer than the life of the refresh token.
   222  func (b *TransportWrapperBuilder) Tokens(tokens ...string) *TransportWrapperBuilder {
   223  	b.tokens = append(b.tokens, tokens...)
   224  	return b
   225  }
   226  
   227  // Agent sets the `User-Agent` header that the round trippers will use in all the HTTP requests. The
   228  // default is `OCM-SDK` followed by an slash and the version of the SDK, for example `OCM/0.0.0`.
   229  func (b *TransportWrapperBuilder) Agent(agent string) *TransportWrapperBuilder {
   230  	b.agent = agent
   231  	return b
   232  }
   233  
   234  // TrustedCA sets a source that contains he certificate authorities that will be trusted by the HTTP
   235  // client used to request tokens. If this isn't explicitly specified then the clients will trust the
   236  // certificate authorities trusted by default by the system. The value can be a *x509.CertPool or a
   237  // string, anything else will cause an error when Build method is called. If it is a *x509.CertPool
   238  // then the value will replace any other source given before. If it is a string then it should be
   239  // the name of a PEM file. The contents of that file will be added to the previously given sources.
   240  func (b *TransportWrapperBuilder) TrustedCA(value interface{}) *TransportWrapperBuilder {
   241  	if value != nil {
   242  		b.trustedCAs = append(b.trustedCAs, value)
   243  	}
   244  	return b
   245  }
   246  
   247  // TrustedCAs sets a list of sources that contains he certificate authorities that will be trusted
   248  // by the HTTP client used to request tokens. See the documentation of the TrustedCA method for more
   249  // information about the accepted values.
   250  func (b *TransportWrapperBuilder) TrustedCAs(values ...interface{}) *TransportWrapperBuilder {
   251  	for _, value := range values {
   252  		b.TrustedCA(value)
   253  	}
   254  	return b
   255  }
   256  
   257  // Insecure enables insecure communication with the OpenID server. This disables verification of TLS
   258  // certificates and host names and it isn't recommended for a production environment.
   259  func (b *TransportWrapperBuilder) Insecure(flag bool) *TransportWrapperBuilder {
   260  	b.insecure = flag
   261  	return b
   262  }
   263  
   264  // TransportWrapper adds a function that will be used to wrap the transports of the HTTP client used
   265  // to request tokens. If used multiple times the transport wrappers will be called in the same order
   266  // that they are added.
   267  func (b *TransportWrapperBuilder) TransportWrapper(
   268  	value func(http.RoundTripper) http.RoundTripper) *TransportWrapperBuilder {
   269  	if value != nil {
   270  		b.transportWrappers = append(b.transportWrappers, value)
   271  	}
   272  	return b
   273  }
   274  
   275  // TransportWrappers adds a list of functions that will be used to wrap the transports of the HTTP
   276  // client used to request tokens
   277  func (b *TransportWrapperBuilder) TransportWrappers(
   278  	values ...func(http.RoundTripper) http.RoundTripper) *TransportWrapperBuilder {
   279  	for _, value := range values {
   280  		b.TransportWrapper(value)
   281  	}
   282  	return b
   283  }
   284  
   285  // MetricsSubsystem sets the name of the subsystem that will be used by the wrapper to register
   286  // metrics with Prometheus. If this isn't explicitly specified, or if it is an empty string, then no
   287  // metrics will be registered. For example, if the value is `api_outbound` then the following
   288  // metrics will be registered:
   289  //
   290  //	api_outbound_token_request_count - Number of token requests sent.
   291  //	api_outbound_token_request_duration_sum - Total time to send token requests, in seconds.
   292  //	api_outbound_token_request_duration_count - Total number of token requests measured.
   293  //	api_outbound_token_request_duration_bucket - Number of token requests organized in buckets.
   294  //
   295  // The duration buckets metrics contain an `le` label that indicates the upper bound. For example if
   296  // the `le` label is `1` then the value will be the number of requests that were processed in less
   297  // than one second.
   298  //
   299  //	code - HTTP response code, for example 200 or 500.
   300  //
   301  // The value of the `code` label will be zero when sending the request failed without a response
   302  // code, for example if it wasn't possible to open the connection, or if there was a timeout waiting
   303  // for the response.
   304  //
   305  // Note that setting this attribute is not enough to have metrics published, you also need to
   306  // create and start a metrics server, as described in the documentation of the Prometheus library.
   307  func (b *TransportWrapperBuilder) MetricsSubsystem(value string) *TransportWrapperBuilder {
   308  	b.metricsSubsystem = value
   309  	return b
   310  }
   311  
   312  // MetricsRegisterer sets the Prometheus registerer that will be used to register the metrics. The
   313  // default is to use the default Prometheus registerer and there is usually no need to change that.
   314  // This is intended for unit tests, where it is convenient to have a registerer that doesn't
   315  // interfere with the rest of the system.
   316  func (b *TransportWrapperBuilder) MetricsRegisterer(
   317  	value prometheus.Registerer) *TransportWrapperBuilder {
   318  	if value == nil {
   319  		value = prometheus.DefaultRegisterer
   320  	}
   321  	b.metricsRegisterer = value
   322  	return b
   323  }
   324  
   325  // Build uses the information stored in the builder to create a new transport wrapper.
   326  func (b *TransportWrapperBuilder) Build(ctx context.Context) (result *TransportWrapper, err error) {
   327  	// Check parameters:
   328  	if b.logger == nil {
   329  		err = fmt.Errorf("logger is mandatory")
   330  		return
   331  	}
   332  
   333  	// Check that we have some kind of credentials or a token:
   334  	haveTokens := len(b.tokens) > 0
   335  	havePassword := b.user != "" && b.password != ""
   336  	haveSecret := b.clientID != "" && b.clientSecret != ""
   337  	if !haveTokens && !havePassword && !haveSecret {
   338  		err = fmt.Errorf(
   339  			"either a token, an user name and password or a client identifier and secret are " +
   340  				"necessary, but none has been provided",
   341  		)
   342  		return
   343  	}
   344  
   345  	// Create the token parser:
   346  	tokenParser := &jwt.Parser{}
   347  
   348  	// Parse the tokens:
   349  	var accessToken *tokenInfo
   350  	var refreshToken *tokenInfo
   351  	var pullSecretAccessToken *tokenInfo
   352  	for i, text := range b.tokens {
   353  		var object *jwt.Token
   354  
   355  		object, _, err = tokenParser.ParseUnverified(text, jwt.MapClaims{})
   356  		if err != nil {
   357  			b.logger.Debug(
   358  				ctx,
   359  				"Can't parse token %d, will assume that it is either an "+
   360  					"opaque refresh token or pull secret access token: %v",
   361  				i, err,
   362  			)
   363  
   364  			// Attempt to detect/parse the token as a pull-secret access token
   365  			err := parsePullSecretAccessToken(text)
   366  			if err != nil {
   367  				b.logger.Debug(
   368  					ctx,
   369  					"Can't parse pull secret access token %d, will assume "+
   370  						"that it is an opaque refresh token: %v",
   371  					i, err,
   372  				)
   373  
   374  				// Not a pull-secret access token, so assume a opaque refresh token
   375  				refreshToken = &tokenInfo{
   376  					text: text,
   377  				}
   378  				continue
   379  			}
   380  
   381  			// Parsing as a pull-secret access token was successful, treat it as such
   382  			pullSecretAccessToken = &tokenInfo{
   383  				text: text,
   384  			}
   385  			continue
   386  		}
   387  
   388  		claims, ok := object.Claims.(jwt.MapClaims)
   389  		if !ok {
   390  			err = fmt.Errorf("claims of token %d are of type '%T'", i, claims)
   391  			return
   392  		}
   393  		claim, ok := claims["token_use"]
   394  		if !ok {
   395  			claim, ok = claims["typ"]
   396  			if !ok {
   397  				// When the token doesn't have the `typ` claim we will use the position to
   398  				// decide: first token should be the access token and second should be the
   399  				// refresh token. That is consistent with the signature of the method that
   400  				// returns the tokens.
   401  				switch i {
   402  				case 0:
   403  					b.logger.Debug(
   404  						ctx,
   405  						"First token doesn't have a 'typ' claim, will assume "+
   406  							"that it is an access token",
   407  					)
   408  					accessToken = &tokenInfo{
   409  						text:   text,
   410  						object: object,
   411  					}
   412  					continue
   413  				case 1:
   414  					b.logger.Debug(
   415  						ctx,
   416  						"Second token doesn't have a 'typ' claim, will assume "+
   417  							"that it is a refresh token",
   418  					)
   419  					refreshToken = &tokenInfo{
   420  						text:   text,
   421  						object: object,
   422  					}
   423  					continue
   424  				default:
   425  					err = fmt.Errorf("token %d doesn't contain the 'typ' claim", i)
   426  					return
   427  				}
   428  			}
   429  		}
   430  		typ, ok := claim.(string)
   431  		if !ok {
   432  			err = fmt.Errorf("claim 'type' of token %d is of type '%T'", i, claim)
   433  			return
   434  		}
   435  		switch strings.ToLower(typ) {
   436  		case "access", "bearer":
   437  			accessToken = &tokenInfo{
   438  				text:   text,
   439  				object: object,
   440  			}
   441  		case "refresh", "offline":
   442  			refreshToken = &tokenInfo{
   443  				text:   text,
   444  				object: object,
   445  			}
   446  		default:
   447  			err = fmt.Errorf("type '%s' of token %d is unknown", typ, i)
   448  			return
   449  		}
   450  	}
   451  
   452  	// Set the default authentication details, if needed:
   453  	tokenURL := b.tokenURL
   454  	if tokenURL == "" {
   455  		tokenURL = DefaultTokenURL
   456  		b.logger.Debug(
   457  			ctx,
   458  			"Token URL wasn't provided, will use the default '%s'",
   459  			tokenURL,
   460  		)
   461  	}
   462  	tokenServer, err := internal.ParseServerAddress(ctx, tokenURL)
   463  	if err != nil {
   464  		err = fmt.Errorf("can't parse token URL '%s': %w", tokenURL, err)
   465  		return
   466  	}
   467  	clientID := b.clientID
   468  	if clientID == "" {
   469  		clientID = DefaultClientID
   470  		b.logger.Debug(
   471  			ctx,
   472  			"Client identifier wasn't provided, will use the default '%s'",
   473  			clientID,
   474  		)
   475  	}
   476  	clientSecret := b.clientSecret
   477  	if clientSecret == "" {
   478  		clientSecret = DefaultClientSecret
   479  		b.logger.Debug(
   480  			ctx,
   481  			"Client secret wasn't provided, will use the default",
   482  		)
   483  	}
   484  
   485  	// Set the default authentication scopes, if needed:
   486  	scopes := b.scopes
   487  	if len(scopes) == 0 {
   488  		scopes = DefaultScopes
   489  	} else {
   490  		scopes = make([]string, len(b.scopes))
   491  		copy(scopes, b.scopes)
   492  	}
   493  
   494  	// Create the client selector:
   495  	clientSelector, err := internal.NewClientSelector().
   496  		Logger(b.logger).
   497  		TrustedCAs(b.trustedCAs...).
   498  		Insecure(b.insecure).
   499  		TransportWrappers(b.transportWrappers...).
   500  		Build(ctx)
   501  	if err != nil {
   502  		return
   503  	}
   504  
   505  	// Register the metrics:
   506  	var tokenCountMetric *prometheus.CounterVec
   507  	var tokenDurationMetric *prometheus.HistogramVec
   508  	if b.metricsSubsystem != "" && b.metricsRegisterer != nil {
   509  		tokenCountMetric = prometheus.NewCounterVec(
   510  			prometheus.CounterOpts{
   511  				Subsystem: b.metricsSubsystem,
   512  				Name:      "token_request_count",
   513  				Help:      "Number of token requests sent.",
   514  			},
   515  			tokenMetricsLabels,
   516  		)
   517  		err = b.metricsRegisterer.Register(tokenCountMetric)
   518  		if err != nil {
   519  			registered, ok := err.(prometheus.AlreadyRegisteredError)
   520  			if ok {
   521  				tokenCountMetric = registered.ExistingCollector.(*prometheus.CounterVec)
   522  				err = nil //nolint:all
   523  			} else {
   524  				return
   525  			}
   526  		}
   527  
   528  		tokenDurationMetric = prometheus.NewHistogramVec(
   529  			prometheus.HistogramOpts{
   530  				Subsystem: b.metricsSubsystem,
   531  				Name:      "token_request_duration",
   532  				Help:      "Token request duration in seconds.",
   533  				Buckets: []float64{
   534  					0.1,
   535  					1.0,
   536  					10.0,
   537  					30.0,
   538  				},
   539  			},
   540  			tokenMetricsLabels,
   541  		)
   542  		err = b.metricsRegisterer.Register(tokenDurationMetric)
   543  		if err != nil {
   544  			registered, ok := err.(prometheus.AlreadyRegisteredError)
   545  			if ok {
   546  				tokenDurationMetric = registered.ExistingCollector.(*prometheus.HistogramVec)
   547  				err = nil
   548  			} else {
   549  				return
   550  			}
   551  		}
   552  	}
   553  
   554  	// Create and populate the object:
   555  	result = &TransportWrapper{
   556  		logger:                b.logger,
   557  		clientID:              clientID,
   558  		clientSecret:          clientSecret,
   559  		user:                  b.user,
   560  		password:              b.password,
   561  		scopes:                scopes,
   562  		agent:                 b.agent,
   563  		clientSelector:        clientSelector,
   564  		tokenURL:              tokenURL,
   565  		tokenServer:           tokenServer,
   566  		tokenMutex:            &sync.Mutex{},
   567  		tokenParser:           tokenParser,
   568  		accessToken:           accessToken,
   569  		refreshToken:          refreshToken,
   570  		pullSecretAccessToken: pullSecretAccessToken,
   571  		metricsSubsystem:      b.metricsSubsystem,
   572  		metricsRegisterer:     b.metricsRegisterer,
   573  		tokenCountMetric:      tokenCountMetric,
   574  		tokenDurationMetric:   tokenDurationMetric,
   575  	}
   576  
   577  	return
   578  }
   579  
   580  // Logger returns the logger that is used by the wrapper.
   581  func (w *TransportWrapper) Logger() logging.Logger {
   582  	return w.logger
   583  }
   584  
   585  // TokenURL returns the URL that the connection is using request OpenID access tokens.
   586  func (w *TransportWrapper) TokenURL() string {
   587  	return w.tokenURL
   588  }
   589  
   590  // Client returns OpenID client identifier and secret that the wrapper is using to request OpenID
   591  // access tokens.
   592  func (w *TransportWrapper) Client() (id, secret string) {
   593  	id = w.clientID
   594  	secret = w.clientSecret
   595  	return
   596  }
   597  
   598  // User returns the user name and password that the wrapper is using to request OpenID access
   599  // tokens.
   600  func (w *TransportWrapper) User() (user, password string) {
   601  	user = w.user
   602  	password = w.password
   603  	return
   604  }
   605  
   606  // Scopes returns the OpenID scopes that the wrapper is using to request OpenID access tokens.
   607  func (w *TransportWrapper) Scopes() []string {
   608  	result := make([]string, len(w.scopes))
   609  	copy(result, w.scopes)
   610  	return result
   611  }
   612  
   613  // Wrap creates a new round tripper that wraps the given one and populates the authorization header.
   614  func (w *TransportWrapper) Wrap(transport http.RoundTripper) http.RoundTripper {
   615  	return &roundTripper{
   616  		owner:     w,
   617  		logger:    w.logger,
   618  		transport: transport,
   619  	}
   620  }
   621  
   622  // Close releases all the resources used by the wrapper.
   623  func (w *TransportWrapper) Close() error {
   624  	err := w.clientSelector.Close()
   625  	if err != nil {
   626  		return err
   627  	}
   628  	return nil
   629  }
   630  
   631  // RoundTrip is the implementation of the round tripper interface.
   632  func (t *roundTripper) RoundTrip(request *http.Request) (response *http.Response, err error) {
   633  	// Get the context:
   634  	ctx := request.Context()
   635  
   636  	// Get the access token:
   637  	token, _, err := t.owner.Tokens(ctx)
   638  	if err != nil {
   639  		err = fmt.Errorf("can't get access token: %w", err)
   640  		return
   641  	}
   642  
   643  	// Add the authorization header:
   644  	if request.Header == nil {
   645  		request.Header = make(http.Header)
   646  	}
   647  
   648  	// If the access token is a pull-secret-access-token type, a
   649  	// different Authorization header must be used
   650  	if token != "" {
   651  		if err := parsePullSecretAccessToken(token); err == nil {
   652  			// It is a pull-secret access token
   653  			request.Header.Set("Authorization", "AccessToken "+token)
   654  		} else {
   655  			request.Header.Set("Authorization", "Bearer "+token)
   656  		}
   657  	}
   658  
   659  	// Call the wrapped transport:
   660  	response, err = t.transport.RoundTrip(request)
   661  
   662  	return
   663  }
   664  
   665  // Tokens returns the access and refresh tokens that are currently in use by the wrapper. If it is
   666  // necessary to request new tokens because they weren't requested yet, or because they are expired,
   667  // this method will do it and will return an error if it fails.
   668  //
   669  // If new tokens are needed the request will be retried with an exponential backoff.
   670  func (w *TransportWrapper) Tokens(ctx context.Context, expiresIn ...time.Duration) (access,
   671  	refresh string, err error) {
   672  	expiresDuration := tokenExpiry
   673  	if len(expiresIn) == 1 {
   674  		expiresDuration = expiresIn[0]
   675  	}
   676  
   677  	// Configure the back-off so that it honours the deadline of the context passed
   678  	// to the method. Note that we need to specify explicitly the type of the variable
   679  	// because the backoff.NewExponentialBackOff function returns the implementation
   680  	// type but backoff.WithContext returns the interface instead.
   681  	exponentialBackoffMethod := backoff.NewExponentialBackOff()
   682  	exponentialBackoffMethod.MaxElapsedTime = 15 * time.Second
   683  	var backoffMethod backoff.BackOff = exponentialBackoffMethod
   684  	if ctx != nil {
   685  		backoffMethod = backoff.WithContext(backoffMethod, ctx)
   686  	}
   687  
   688  	attempt := 0
   689  	operation := func() error {
   690  		attempt++
   691  		var code int
   692  		code, access, refresh, err = w.tokens(ctx, attempt, expiresDuration)
   693  		if err != nil {
   694  			if code >= http.StatusInternalServerError {
   695  				w.logger.Debug(
   696  					ctx,
   697  					"Can't get tokens, got HTTP code %d, will retry: %v",
   698  					code, err,
   699  				)
   700  				return err
   701  			}
   702  			w.logger.Debug(
   703  				ctx,
   704  				"Can't get tokens, got HTTP code %d, will not retry: %v",
   705  				code, err,
   706  			)
   707  			return backoff.Permanent(err)
   708  		}
   709  
   710  		if attempt > 1 {
   711  			w.logger.Debug(ctx, "Got tokens on attempt %d", attempt)
   712  		} else {
   713  			w.logger.Debug(ctx, "Got tokens on first attempt")
   714  		}
   715  		return nil
   716  	}
   717  
   718  	// nolint
   719  	backoff.Retry(operation, backoffMethod)
   720  	return access, refresh, err
   721  }
   722  
   723  func (w *TransportWrapper) tokens(ctx context.Context, attempt int,
   724  	minRemaining time.Duration) (code int, access, refresh string, err error) {
   725  	// We need to make sure that this method isn't execute concurrently, as we will be updating
   726  	// multiple attributes of the connection:
   727  	w.tokenMutex.Lock()
   728  	defer w.tokenMutex.Unlock()
   729  
   730  	// A pull-secret access token can just be used as-is
   731  	if w.pullSecretAccessToken != nil {
   732  		access = w.pullSecretAccessToken.text
   733  		return
   734  	}
   735  
   736  	// Check the expiration times of the tokens:
   737  	now := time.Now()
   738  	var accessExpires bool
   739  	var accessRemaining time.Duration
   740  	if w.accessToken != nil {
   741  		accessExpires, accessRemaining, err = tokenRemaining(w.accessToken, now)
   742  		if err != nil {
   743  			return
   744  		}
   745  	}
   746  	var refreshExpires bool
   747  	var refreshRemaining time.Duration
   748  	if w.refreshToken != nil {
   749  		refreshExpires, refreshRemaining, err = tokenRemaining(w.refreshToken, now)
   750  		if err != nil {
   751  			return
   752  		}
   753  	}
   754  	if w.logger.DebugEnabled() {
   755  		w.debugExpiry(ctx, "Bearer", w.accessToken, accessExpires, accessRemaining)
   756  		w.debugExpiry(ctx, "Refresh", w.refreshToken, refreshExpires, refreshRemaining)
   757  	}
   758  
   759  	// If the access token is available and it isn't expired or about to expire then we can
   760  	// return the current tokens directly:
   761  	if w.accessToken != nil && (!accessExpires || accessRemaining >= minRemaining) {
   762  		access, refresh = w.currentTokens()
   763  		return
   764  	}
   765  
   766  	// At this point we know that the access token is unavailable, expired or about to expire.
   767  	w.logger.Debug(ctx, "Trying to get new tokens (attempt %d)", attempt)
   768  
   769  	// If we have a client identifier and secret we should use the client credentials grant even
   770  	// if we have a valid refresh token. Having both is a side effect of a incorrect behaviour
   771  	// of an old version of the SSO server. Note that we don't ignore the returned refresh token
   772  	// in that case, not because we will use it, but because we return it to the caller and we
   773  	// don't want to change that deprecated behaviour yet.
   774  	if w.haveSecret() {
   775  		code, _, err = w.sendClientCredentialsForm(ctx, attempt)
   776  		if err != nil {
   777  			return
   778  		}
   779  		access, refresh = w.currentTokens()
   780  		return
   781  	}
   782  
   783  	// At this point we know that we don't have client credentials, so we should try to use the
   784  	// refresh token if available and not expired.
   785  	if w.refreshToken != nil && (!refreshExpires || refreshRemaining >= minRemaining) {
   786  		code, _, err = w.sendRefreshForm(ctx, attempt)
   787  		if err != nil {
   788  			return
   789  		}
   790  		access, refresh = w.currentTokens()
   791  		return
   792  	}
   793  
   794  	// Now we know that both the access and refresh tokens are unavailable, expired or about to
   795  	// expire. We also know that we don't have client credentials, but we may still have a user
   796  	// name and password.
   797  	if w.havePassword() {
   798  		code, _, err = w.sendPasswordForm(ctx, attempt)
   799  		if err != nil {
   800  			return
   801  		}
   802  		access, refresh = w.currentTokens()
   803  		return
   804  	}
   805  
   806  	// Here we know that the access and refresh tokens are unavailable, expired or about to
   807  	// expire. We also know that we don't have credentials to request new ones. But we could
   808  	// still use the refresh token if it isn't completely expired.
   809  	if w.refreshToken != nil && refreshRemaining > 0 {
   810  		w.logger.Warn(
   811  			ctx,
   812  			"Refresh token expires in only %s, but there is no other mechanism to "+
   813  				"obtain a new token, so will try to use it anyhow",
   814  			refreshRemaining,
   815  		)
   816  		code, _, err = w.sendRefreshForm(ctx, attempt)
   817  		if err != nil {
   818  			return
   819  		}
   820  		access, refresh = w.currentTokens()
   821  		return
   822  	}
   823  
   824  	// At this point we know that the access token is expired or about to expire. We know also
   825  	// that the refresh token is unavailable or completely expired. And we know that we don't
   826  	// have credentials to request new tokens. But we can still use the access token if it isn't
   827  	// expired.
   828  	if w.accessToken != nil && accessRemaining > 0 {
   829  		w.logger.Warn(
   830  			ctx,
   831  			"Access token expires in only %s, but there is no other mechanism to "+
   832  				"obtain a new token, so will try to use it anyhow",
   833  			accessRemaining,
   834  		)
   835  		access, refresh = w.currentTokens()
   836  		return
   837  	}
   838  
   839  	// There is no way to get a valid access token, so all we can do is report the failure:
   840  	err = fmt.Errorf(
   841  		"access and refresh tokens are unavailable or expired, and there are no " +
   842  			"password or client secret to request new ones",
   843  	)
   844  
   845  	return
   846  }
   847  
   848  // currentTokens returns the current tokens without trying to send any request to refresh them, and
   849  // checking that they are actually available. If they aren't available then it will return empty
   850  // strings.
   851  func (w *TransportWrapper) currentTokens() (access, refresh string) {
   852  	if w.accessToken != nil {
   853  		access = w.accessToken.text
   854  	}
   855  	if w.refreshToken != nil {
   856  		refresh = w.refreshToken.text
   857  	}
   858  	return
   859  }
   860  
   861  func (w *TransportWrapper) sendClientCredentialsForm(ctx context.Context, attempt int) (code int,
   862  	result *internal.TokenResponse, err error) {
   863  	form := url.Values{}
   864  	headers := map[string]string{}
   865  	w.logger.Debug(ctx, "Requesting new token using the client credentials grant")
   866  	form.Set(grantTypeField, clientCredentialsGrant)
   867  	form.Set(clientIDField, w.clientID)
   868  	form.Set(scopeField, strings.Join(w.scopes, " "))
   869  	// Encode client_id and client_secret to use as basic auth
   870  	// https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
   871  	auth := fmt.Sprintf("%s:%s", w.clientID, w.clientSecret)
   872  	hash := base64.StdEncoding.EncodeToString([]byte(auth))
   873  	headers["Authorization"] = fmt.Sprintf("Basic %s", hash)
   874  	return w.sendForm(ctx, form, headers, attempt)
   875  }
   876  
   877  func (w *TransportWrapper) sendPasswordForm(ctx context.Context, attempt int) (code int,
   878  	result *internal.TokenResponse, err error) {
   879  	form := url.Values{}
   880  	w.logger.Debug(ctx, "Requesting new token using the password grant")
   881  	form.Set(grantTypeField, passwordGrant)
   882  	form.Set(clientIDField, w.clientID)
   883  	form.Set(usernameField, w.user)
   884  	form.Set(passwordField, w.password)
   885  	form.Set(scopeField, strings.Join(w.scopes, " "))
   886  	return w.sendForm(ctx, form, nil, attempt)
   887  }
   888  
   889  func (w *TransportWrapper) sendRefreshForm(ctx context.Context, attempt int) (code int,
   890  	result *internal.TokenResponse, err error) {
   891  	w.logger.Debug(ctx, "Requesting new token using the refresh token grant")
   892  	form := url.Values{}
   893  	form.Set(grantTypeField, refreshTokenGrant)
   894  	form.Set(clientIDField, w.clientID)
   895  	form.Set(refreshTokenField, w.refreshToken.text)
   896  	code, result, err = w.sendForm(ctx, form, nil, attempt)
   897  	return
   898  }
   899  
   900  func (w *TransportWrapper) sendForm(ctx context.Context, form url.Values, headers map[string]string,
   901  	attempt int) (code int, result *internal.TokenResponse, err error) {
   902  	// Measure the time that it takes to send the request and receive the response:
   903  	start := time.Now()
   904  	code, result, err = w.sendFormTimed(ctx, form, headers)
   905  	elapsed := time.Since(start)
   906  
   907  	// Update the metrics:
   908  	if w.tokenCountMetric != nil || w.tokenDurationMetric != nil {
   909  		labels := map[string]string{
   910  			metricsAttemptLabel: strconv.Itoa(attempt),
   911  			metricsCodeLabel:    strconv.Itoa(code),
   912  		}
   913  		if w.tokenCountMetric != nil {
   914  			w.tokenCountMetric.With(labels).Inc()
   915  		}
   916  		if w.tokenDurationMetric != nil {
   917  			w.tokenDurationMetric.With(labels).Observe(elapsed.Seconds())
   918  		}
   919  	}
   920  
   921  	// Return the original error:
   922  	return
   923  }
   924  
   925  func (w *TransportWrapper) sendFormTimed(ctx context.Context, form url.Values, headers map[string]string) (code int,
   926  	result *internal.TokenResponse, err error) {
   927  	// Create the HTTP request:
   928  	body := []byte(form.Encode())
   929  	request, err := http.NewRequest(http.MethodPost, w.tokenURL, bytes.NewReader(body))
   930  	request.Close = true
   931  	header := request.Header
   932  	if w.agent != "" {
   933  		header.Set("User-Agent", w.agent)
   934  	}
   935  	header.Set("Content-Type", "application/x-www-form-urlencoded")
   936  	header.Set("Accept", "application/json")
   937  	// Add any additional headers:
   938  	for k, v := range headers {
   939  		header.Set(k, v)
   940  	}
   941  	if err != nil {
   942  		err = fmt.Errorf("can't create request: %w", err)
   943  		return
   944  	}
   945  
   946  	// Set the context:
   947  	if ctx != nil {
   948  		request = request.WithContext(ctx)
   949  	}
   950  
   951  	// Select the HTTP client:
   952  	client, err := w.clientSelector.Select(ctx, w.tokenServer)
   953  	if err != nil {
   954  		return
   955  	}
   956  
   957  	// Send the HTTP request:
   958  	response, err := client.Do(request)
   959  	if err != nil {
   960  		err = fmt.Errorf("can't send request: %w", err)
   961  		return
   962  	}
   963  	defer response.Body.Close()
   964  
   965  	code = response.StatusCode
   966  
   967  	// Check that the response content type is JSON:
   968  	err = internal.CheckContentType(response)
   969  	if err != nil {
   970  		return
   971  	}
   972  
   973  	// Read the response body:
   974  	body, err = io.ReadAll(response.Body)
   975  	if err != nil {
   976  		err = fmt.Errorf("can't read response: %w", err)
   977  		return
   978  	}
   979  
   980  	// Parse the response body:
   981  	result = &internal.TokenResponse{}
   982  	err = json.Unmarshal(body, result)
   983  	if err != nil {
   984  		err = fmt.Errorf("can't parse JSON response: %w", err)
   985  		return
   986  	}
   987  	if result.Error != nil {
   988  		if result.ErrorDescription != nil {
   989  			err = fmt.Errorf("%s: %s", *result.Error, *result.ErrorDescription)
   990  			return
   991  		}
   992  		err = fmt.Errorf("%s", *result.Error)
   993  		return
   994  	}
   995  	if response.StatusCode != http.StatusOK {
   996  		err = fmt.Errorf("token response status code is '%d'", response.StatusCode)
   997  		return
   998  	}
   999  	if result.TokenType != nil && !strings.EqualFold(*result.TokenType, "bearer") {
  1000  		err = fmt.Errorf("expected 'bearer' token type but got '%s'", *result.TokenType)
  1001  		return
  1002  	}
  1003  
  1004  	// The response should always contains the access token, regardless of the kind of grant
  1005  	// that was used:
  1006  	var accessTokenText string
  1007  	var accessTokenObject *jwt.Token
  1008  	var accessToken *tokenInfo
  1009  	if result.AccessToken == nil {
  1010  		err = fmt.Errorf("no access token was received")
  1011  		return
  1012  	}
  1013  	accessTokenText = *result.AccessToken
  1014  	accessTokenObject, _, err = w.tokenParser.ParseUnverified(
  1015  		accessTokenText,
  1016  		jwt.MapClaims{},
  1017  	)
  1018  	if err != nil {
  1019  		return
  1020  	}
  1021  	if accessTokenText != "" {
  1022  		accessToken = &tokenInfo{
  1023  			text:   accessTokenText,
  1024  			object: accessTokenObject,
  1025  		}
  1026  	}
  1027  
  1028  	// If a refresh token is not included in the response, we can safely assume that the old
  1029  	// one is still valid and does not need to be discarded
  1030  	// https://datatracker.ietf.org/doc/html/rfc6749#section-6
  1031  	var refreshTokenText string
  1032  	var refreshTokenObject *jwt.Token
  1033  	var refreshToken *tokenInfo
  1034  	if result.RefreshToken == nil {
  1035  		if w.refreshToken != nil && w.refreshToken.text != "" {
  1036  			result.RefreshToken = &w.refreshToken.text
  1037  		}
  1038  	} else {
  1039  		refreshTokenText = *result.RefreshToken
  1040  		refreshTokenObject, _, err = w.tokenParser.ParseUnverified(
  1041  			refreshTokenText,
  1042  			jwt.MapClaims{},
  1043  		)
  1044  		if err != nil {
  1045  			w.logger.Debug(
  1046  				ctx,
  1047  				"Refresh token can't be parsed, will assume it is opaque: %v",
  1048  				err,
  1049  			)
  1050  			err = nil
  1051  		}
  1052  	}
  1053  	if refreshTokenText != "" {
  1054  		refreshToken = &tokenInfo{
  1055  			text:   refreshTokenText,
  1056  			object: refreshTokenObject,
  1057  		}
  1058  	}
  1059  
  1060  	// Save the new tokens:
  1061  	if accessToken != nil {
  1062  		w.accessToken = accessToken
  1063  	}
  1064  	if refreshToken != nil {
  1065  		w.refreshToken = refreshToken
  1066  	}
  1067  
  1068  	return
  1069  }
  1070  
  1071  func (w *TransportWrapper) havePassword() bool {
  1072  	return w.user != "" && w.password != ""
  1073  }
  1074  
  1075  func (w *TransportWrapper) haveSecret() bool {
  1076  	return w.clientID != "" && w.clientSecret != ""
  1077  }
  1078  
  1079  // debugExpiry sends to the log information about the expiration of the given token.
  1080  func (w *TransportWrapper) debugExpiry(ctx context.Context, typ string, token *tokenInfo,
  1081  	expires bool, left time.Duration) {
  1082  	if token != nil {
  1083  		if expires {
  1084  			if left < 0 {
  1085  				w.logger.Debug(ctx, "%s token expired %s ago", typ, -left)
  1086  			} else if left > 0 {
  1087  				w.logger.Debug(ctx, "%s token expires in %s", typ, left)
  1088  			} else {
  1089  				w.logger.Debug(ctx, "%s token expired just now", typ)
  1090  			}
  1091  		}
  1092  	} else {
  1093  		w.logger.Debug(ctx, "%s token isn't available", typ)
  1094  	}
  1095  }
  1096  
  1097  // parsePullSecretAccessToken will parse the supplied token to verify conformity
  1098  // with that of a pull secret access token. A pull secret access token is of the
  1099  // form <cluster id>:<Base64d pull secret token>.
  1100  func parsePullSecretAccessToken(text string) error {
  1101  	elems := strings.Split(text, ":")
  1102  	if len(elems) != 2 {
  1103  		return fmt.Errorf("unparseable pull secret token")
  1104  	}
  1105  	_, err := uuid.Parse(elems[0])
  1106  	if err != nil {
  1107  		return fmt.Errorf("unparseable pull secret token cluster ID")
  1108  	}
  1109  	_, err = base64.StdEncoding.DecodeString(elems[1])
  1110  	if err != nil {
  1111  		return fmt.Errorf("unparseable pull secret token value")
  1112  	}
  1113  	return nil
  1114  }
  1115  
  1116  // Names of fields in the token form:
  1117  const (
  1118  	grantTypeField    = "grant_type"
  1119  	clientIDField     = "client_id"
  1120  	usernameField     = "username"
  1121  	passwordField     = "password"
  1122  	refreshTokenField = "refresh_token"
  1123  	scopeField        = "scope"
  1124  )
  1125  
  1126  // Grant kinds:
  1127  const (
  1128  	clientCredentialsGrant = "client_credentials"
  1129  	passwordGrant          = "password"
  1130  	refreshTokenGrant      = "refresh_token"
  1131  )
  1132  
  1133  const (
  1134  	tokenExpiry = 1 * time.Minute
  1135  )
  1136  
  1137  // Names of the labels added to metrics:
  1138  const (
  1139  	metricsAttemptLabel = "attempt"
  1140  	metricsCodeLabel    = "code"
  1141  )
  1142  
  1143  // Array of labels added to token metrics:
  1144  var tokenMetricsLabels = []string{
  1145  	metricsAttemptLabel,
  1146  	metricsCodeLabel,
  1147  }