github.com/openshift-online/ocm-sdk-go@v0.1.473/authentication/transport_wrapper.go (about) 1 /* 2 Copyright (c) 2021 Red Hat, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 // This file contains the implementations of a transport wrapper that implements token 18 // authentication. 19 20 package authentication 21 22 import ( 23 "bytes" 24 "context" 25 "encoding/base64" 26 "encoding/json" 27 "fmt" 28 "io" 29 "net/http" 30 "net/url" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 36 "github.com/cenkalti/backoff/v4" 37 "github.com/golang-jwt/jwt/v4" 38 "github.com/google/uuid" 39 "github.com/prometheus/client_golang/prometheus" 40 41 "github.com/openshift-online/ocm-sdk-go/internal" 42 "github.com/openshift-online/ocm-sdk-go/logging" 43 ) 44 45 // Default values: 46 const ( 47 // #nosec G101 48 DefaultTokenURL = "https://sso.redhat.com/auth/realms/redhat-external/protocol/openid-connect/token" 49 DefaultClientID = "cloud-services" 50 DefaultClientSecret = "" 51 52 FedRAMPTokenURL = "https://sso.openshiftusgov.com/realms/redhat-external/protocol/openid-connect/token" 53 FedRAMPClientID = "console-dot" 54 ) 55 56 // DefaultScopes is the ser of scopes used by default: 57 var DefaultScopes = []string{ 58 "openid", 59 } 60 61 // TransportWrapperBuilder contains the data and logic needed to add to requests the authorization 62 // token. Don't create objects of this type directly; use the NewTransportWrapper function instead. 63 type TransportWrapperBuilder struct { 64 // Fields used for basic functionality: 65 logger logging.Logger 66 tokenURL string 67 clientID string 68 clientSecret string 69 user string 70 password string 71 tokens []string 72 scopes []string 73 agent string 74 trustedCAs []interface{} 75 insecure bool 76 transportWrappers []func(http.RoundTripper) http.RoundTripper 77 78 // Fields used for metrics: 79 metricsSubsystem string 80 metricsRegisterer prometheus.Registerer 81 } 82 83 // TransportWrapper contains the data and logic needed to wrap an HTTP round tripper with another 84 // one that adds authorization tokens to requests. 85 type TransportWrapper struct { 86 // Fields used for basic functionality: 87 logger logging.Logger 88 clientID string 89 clientSecret string 90 user string 91 password string 92 scopes []string 93 agent string 94 clientSelector *internal.ClientSelector 95 tokenURL string 96 tokenServer *internal.ServerAddress 97 tokenMutex *sync.Mutex 98 tokenParser *jwt.Parser 99 accessToken *tokenInfo 100 refreshToken *tokenInfo 101 pullSecretAccessToken *tokenInfo 102 103 // Fields used for metrics: 104 metricsSubsystem string 105 metricsRegisterer prometheus.Registerer 106 tokenCountMetric *prometheus.CounterVec 107 tokenDurationMetric *prometheus.HistogramVec 108 } 109 110 // roundTripper is a round tripper that adds authorization tokens to requests. 111 type roundTripper struct { 112 owner *TransportWrapper 113 logger logging.Logger 114 transport http.RoundTripper 115 } 116 117 // Make sure that we implement the interface: 118 var _ http.RoundTripper = (*roundTripper)(nil) 119 120 // NewTransportWrapper creates a new builder that can then be used to configure and create a new 121 // authentication round tripper. 122 func NewTransportWrapper() *TransportWrapperBuilder { 123 return &TransportWrapperBuilder{ 124 metricsRegisterer: prometheus.DefaultRegisterer, 125 } 126 } 127 128 // Logger sets the logger that will be used by the wrapper and by the transports that it creates. 129 func (b *TransportWrapperBuilder) Logger(value logging.Logger) *TransportWrapperBuilder { 130 b.logger = value 131 return b 132 } 133 134 // TokenURL sets the URL that will be used to request OpenID access tokens. The default is 135 // `https://sso.redhat.com/auth/realms/cloud-services/protocol/openid-connect/token`. 136 func (b *TransportWrapperBuilder) TokenURL(url string) *TransportWrapperBuilder { 137 b.tokenURL = url 138 return b 139 } 140 141 // Client sets OpenID client identifier and secret that will be used to request OpenID tokens. The 142 // default identifier is `cloud-services`. The default secret is the empty string. When these two 143 // values are provided and no user name and password is provided, the round trippers will use the 144 // client credentials grant to obtain the token. For example, to create a connection using the 145 // client credentials grant do the following: 146 // 147 // // Use the client credentials grant: 148 // wrapper, err := authentication.NewTransportWrapper(). 149 // Client("myclientid", "myclientsecret"). 150 // Build() 151 // 152 // Note that some OpenID providers (Keycloak, for example) require the client identifier also for 153 // the resource owner password grant. In that case use the set only the identifier, and let the 154 // secret blank. For example: 155 // 156 // // Use the resource owner password grant: 157 // wrapper, err := authentication.NewTransportWrapper(). 158 // User("myuser", "mypassword"). 159 // Client("myclientid", ""). 160 // Build() 161 // 162 // Note the empty client secret. 163 func (b *TransportWrapperBuilder) Client(id string, secret string) *TransportWrapperBuilder { 164 b.clientID = id 165 b.clientSecret = secret 166 return b 167 } 168 169 // User sets the user name and password that will be used to request OpenID access tokens. When 170 // these two values are provided the round trippers will use the resource owner password grant type 171 // to obtain the token. For example: 172 // 173 // // Use the resource owner password grant: 174 // wrapper, err := authentication.NewTransportWrapper(). 175 // User("myuser", "mypassword"). 176 // Build() 177 // 178 // Note that some OpenID providers (Keycloak, for example) require the client identifier also for 179 // the resource owner password grant. In that case use the set only the identifier, and let the 180 // secret blank. For example: 181 // 182 // // Use the resource owner password grant: 183 // wrapper, err := authentication.NewConnectionBuilder(). 184 // User("myuser", "mypassword"). 185 // Client("myclientid", ""). 186 // Build() 187 // 188 // Note the empty client secret. 189 func (b *TransportWrapperBuilder) User(name string, password string) *TransportWrapperBuilder { 190 b.user = name 191 b.password = password 192 return b 193 } 194 195 // Scopes sets the OpenID scopes that will be included in the token request. The default is to use 196 // the `openid` scope. If this method is used then that default will be completely replaced, so you 197 // will need to specify it explicitly if you want to use it. For example, if you want to add the 198 // scope 'myscope' without loosing the default you will have to do something like this: 199 // 200 // // Create a wrapper with the default 'openid' scope and some additional scopes: 201 // wrapper, err := authentication.NewTransportWrapper(). 202 // User("myuser", "mypassword"). 203 // Scopes("openid", "myscope", "yourscope"). 204 // Build() 205 // 206 // If you just want to use the default 'openid' then there is no need to use this method. 207 func (b *TransportWrapperBuilder) Scopes(values ...string) *TransportWrapperBuilder { 208 b.scopes = make([]string, len(values)) 209 copy(b.scopes, values) 210 return b 211 } 212 213 // Tokens sets the OpenID tokens that will be used to authenticate. Multiple types of tokens are 214 // accepted, and used according to their type. For example, you can pass a single access token, or 215 // an access token and a refresh token, or just a refresh token. If no token is provided then the 216 // round trippers will the user name and password or the client identifier and client secret (see 217 // the User and Client methods) to request new ones. 218 // 219 // If the wrapper is created with these tokens and no user or client credentials, it will stop 220 // working when both tokens expire. That can happen, for example, if the connection isn't used for a 221 // period of time longer than the life of the refresh token. 222 func (b *TransportWrapperBuilder) Tokens(tokens ...string) *TransportWrapperBuilder { 223 b.tokens = append(b.tokens, tokens...) 224 return b 225 } 226 227 // Agent sets the `User-Agent` header that the round trippers will use in all the HTTP requests. The 228 // default is `OCM-SDK` followed by an slash and the version of the SDK, for example `OCM/0.0.0`. 229 func (b *TransportWrapperBuilder) Agent(agent string) *TransportWrapperBuilder { 230 b.agent = agent 231 return b 232 } 233 234 // TrustedCA sets a source that contains he certificate authorities that will be trusted by the HTTP 235 // client used to request tokens. If this isn't explicitly specified then the clients will trust the 236 // certificate authorities trusted by default by the system. The value can be a *x509.CertPool or a 237 // string, anything else will cause an error when Build method is called. If it is a *x509.CertPool 238 // then the value will replace any other source given before. If it is a string then it should be 239 // the name of a PEM file. The contents of that file will be added to the previously given sources. 240 func (b *TransportWrapperBuilder) TrustedCA(value interface{}) *TransportWrapperBuilder { 241 if value != nil { 242 b.trustedCAs = append(b.trustedCAs, value) 243 } 244 return b 245 } 246 247 // TrustedCAs sets a list of sources that contains he certificate authorities that will be trusted 248 // by the HTTP client used to request tokens. See the documentation of the TrustedCA method for more 249 // information about the accepted values. 250 func (b *TransportWrapperBuilder) TrustedCAs(values ...interface{}) *TransportWrapperBuilder { 251 for _, value := range values { 252 b.TrustedCA(value) 253 } 254 return b 255 } 256 257 // Insecure enables insecure communication with the OpenID server. This disables verification of TLS 258 // certificates and host names and it isn't recommended for a production environment. 259 func (b *TransportWrapperBuilder) Insecure(flag bool) *TransportWrapperBuilder { 260 b.insecure = flag 261 return b 262 } 263 264 // TransportWrapper adds a function that will be used to wrap the transports of the HTTP client used 265 // to request tokens. If used multiple times the transport wrappers will be called in the same order 266 // that they are added. 267 func (b *TransportWrapperBuilder) TransportWrapper( 268 value func(http.RoundTripper) http.RoundTripper) *TransportWrapperBuilder { 269 if value != nil { 270 b.transportWrappers = append(b.transportWrappers, value) 271 } 272 return b 273 } 274 275 // TransportWrappers adds a list of functions that will be used to wrap the transports of the HTTP 276 // client used to request tokens 277 func (b *TransportWrapperBuilder) TransportWrappers( 278 values ...func(http.RoundTripper) http.RoundTripper) *TransportWrapperBuilder { 279 for _, value := range values { 280 b.TransportWrapper(value) 281 } 282 return b 283 } 284 285 // MetricsSubsystem sets the name of the subsystem that will be used by the wrapper to register 286 // metrics with Prometheus. If this isn't explicitly specified, or if it is an empty string, then no 287 // metrics will be registered. For example, if the value is `api_outbound` then the following 288 // metrics will be registered: 289 // 290 // api_outbound_token_request_count - Number of token requests sent. 291 // api_outbound_token_request_duration_sum - Total time to send token requests, in seconds. 292 // api_outbound_token_request_duration_count - Total number of token requests measured. 293 // api_outbound_token_request_duration_bucket - Number of token requests organized in buckets. 294 // 295 // The duration buckets metrics contain an `le` label that indicates the upper bound. For example if 296 // the `le` label is `1` then the value will be the number of requests that were processed in less 297 // than one second. 298 // 299 // code - HTTP response code, for example 200 or 500. 300 // 301 // The value of the `code` label will be zero when sending the request failed without a response 302 // code, for example if it wasn't possible to open the connection, or if there was a timeout waiting 303 // for the response. 304 // 305 // Note that setting this attribute is not enough to have metrics published, you also need to 306 // create and start a metrics server, as described in the documentation of the Prometheus library. 307 func (b *TransportWrapperBuilder) MetricsSubsystem(value string) *TransportWrapperBuilder { 308 b.metricsSubsystem = value 309 return b 310 } 311 312 // MetricsRegisterer sets the Prometheus registerer that will be used to register the metrics. The 313 // default is to use the default Prometheus registerer and there is usually no need to change that. 314 // This is intended for unit tests, where it is convenient to have a registerer that doesn't 315 // interfere with the rest of the system. 316 func (b *TransportWrapperBuilder) MetricsRegisterer( 317 value prometheus.Registerer) *TransportWrapperBuilder { 318 if value == nil { 319 value = prometheus.DefaultRegisterer 320 } 321 b.metricsRegisterer = value 322 return b 323 } 324 325 // Build uses the information stored in the builder to create a new transport wrapper. 326 func (b *TransportWrapperBuilder) Build(ctx context.Context) (result *TransportWrapper, err error) { 327 // Check parameters: 328 if b.logger == nil { 329 err = fmt.Errorf("logger is mandatory") 330 return 331 } 332 333 // Check that we have some kind of credentials or a token: 334 haveTokens := len(b.tokens) > 0 335 havePassword := b.user != "" && b.password != "" 336 haveSecret := b.clientID != "" && b.clientSecret != "" 337 if !haveTokens && !havePassword && !haveSecret { 338 err = fmt.Errorf( 339 "either a token, an user name and password or a client identifier and secret are " + 340 "necessary, but none has been provided", 341 ) 342 return 343 } 344 345 // Create the token parser: 346 tokenParser := &jwt.Parser{} 347 348 // Parse the tokens: 349 var accessToken *tokenInfo 350 var refreshToken *tokenInfo 351 var pullSecretAccessToken *tokenInfo 352 for i, text := range b.tokens { 353 var object *jwt.Token 354 355 object, _, err = tokenParser.ParseUnverified(text, jwt.MapClaims{}) 356 if err != nil { 357 b.logger.Debug( 358 ctx, 359 "Can't parse token %d, will assume that it is either an "+ 360 "opaque refresh token or pull secret access token: %v", 361 i, err, 362 ) 363 364 // Attempt to detect/parse the token as a pull-secret access token 365 err := parsePullSecretAccessToken(text) 366 if err != nil { 367 b.logger.Debug( 368 ctx, 369 "Can't parse pull secret access token %d, will assume "+ 370 "that it is an opaque refresh token: %v", 371 i, err, 372 ) 373 374 // Not a pull-secret access token, so assume a opaque refresh token 375 refreshToken = &tokenInfo{ 376 text: text, 377 } 378 continue 379 } 380 381 // Parsing as a pull-secret access token was successful, treat it as such 382 pullSecretAccessToken = &tokenInfo{ 383 text: text, 384 } 385 continue 386 } 387 388 claims, ok := object.Claims.(jwt.MapClaims) 389 if !ok { 390 err = fmt.Errorf("claims of token %d are of type '%T'", i, claims) 391 return 392 } 393 claim, ok := claims["token_use"] 394 if !ok { 395 claim, ok = claims["typ"] 396 if !ok { 397 // When the token doesn't have the `typ` claim we will use the position to 398 // decide: first token should be the access token and second should be the 399 // refresh token. That is consistent with the signature of the method that 400 // returns the tokens. 401 switch i { 402 case 0: 403 b.logger.Debug( 404 ctx, 405 "First token doesn't have a 'typ' claim, will assume "+ 406 "that it is an access token", 407 ) 408 accessToken = &tokenInfo{ 409 text: text, 410 object: object, 411 } 412 continue 413 case 1: 414 b.logger.Debug( 415 ctx, 416 "Second token doesn't have a 'typ' claim, will assume "+ 417 "that it is a refresh token", 418 ) 419 refreshToken = &tokenInfo{ 420 text: text, 421 object: object, 422 } 423 continue 424 default: 425 err = fmt.Errorf("token %d doesn't contain the 'typ' claim", i) 426 return 427 } 428 } 429 } 430 typ, ok := claim.(string) 431 if !ok { 432 err = fmt.Errorf("claim 'type' of token %d is of type '%T'", i, claim) 433 return 434 } 435 switch strings.ToLower(typ) { 436 case "access", "bearer": 437 accessToken = &tokenInfo{ 438 text: text, 439 object: object, 440 } 441 case "refresh", "offline": 442 refreshToken = &tokenInfo{ 443 text: text, 444 object: object, 445 } 446 default: 447 err = fmt.Errorf("type '%s' of token %d is unknown", typ, i) 448 return 449 } 450 } 451 452 // Set the default authentication details, if needed: 453 tokenURL := b.tokenURL 454 if tokenURL == "" { 455 tokenURL = DefaultTokenURL 456 b.logger.Debug( 457 ctx, 458 "Token URL wasn't provided, will use the default '%s'", 459 tokenURL, 460 ) 461 } 462 tokenServer, err := internal.ParseServerAddress(ctx, tokenURL) 463 if err != nil { 464 err = fmt.Errorf("can't parse token URL '%s': %w", tokenURL, err) 465 return 466 } 467 clientID := b.clientID 468 if clientID == "" { 469 clientID = DefaultClientID 470 b.logger.Debug( 471 ctx, 472 "Client identifier wasn't provided, will use the default '%s'", 473 clientID, 474 ) 475 } 476 clientSecret := b.clientSecret 477 if clientSecret == "" { 478 clientSecret = DefaultClientSecret 479 b.logger.Debug( 480 ctx, 481 "Client secret wasn't provided, will use the default", 482 ) 483 } 484 485 // Set the default authentication scopes, if needed: 486 scopes := b.scopes 487 if len(scopes) == 0 { 488 scopes = DefaultScopes 489 } else { 490 scopes = make([]string, len(b.scopes)) 491 copy(scopes, b.scopes) 492 } 493 494 // Create the client selector: 495 clientSelector, err := internal.NewClientSelector(). 496 Logger(b.logger). 497 TrustedCAs(b.trustedCAs...). 498 Insecure(b.insecure). 499 TransportWrappers(b.transportWrappers...). 500 Build(ctx) 501 if err != nil { 502 return 503 } 504 505 // Register the metrics: 506 var tokenCountMetric *prometheus.CounterVec 507 var tokenDurationMetric *prometheus.HistogramVec 508 if b.metricsSubsystem != "" && b.metricsRegisterer != nil { 509 tokenCountMetric = prometheus.NewCounterVec( 510 prometheus.CounterOpts{ 511 Subsystem: b.metricsSubsystem, 512 Name: "token_request_count", 513 Help: "Number of token requests sent.", 514 }, 515 tokenMetricsLabels, 516 ) 517 err = b.metricsRegisterer.Register(tokenCountMetric) 518 if err != nil { 519 registered, ok := err.(prometheus.AlreadyRegisteredError) 520 if ok { 521 tokenCountMetric = registered.ExistingCollector.(*prometheus.CounterVec) 522 err = nil //nolint:all 523 } else { 524 return 525 } 526 } 527 528 tokenDurationMetric = prometheus.NewHistogramVec( 529 prometheus.HistogramOpts{ 530 Subsystem: b.metricsSubsystem, 531 Name: "token_request_duration", 532 Help: "Token request duration in seconds.", 533 Buckets: []float64{ 534 0.1, 535 1.0, 536 10.0, 537 30.0, 538 }, 539 }, 540 tokenMetricsLabels, 541 ) 542 err = b.metricsRegisterer.Register(tokenDurationMetric) 543 if err != nil { 544 registered, ok := err.(prometheus.AlreadyRegisteredError) 545 if ok { 546 tokenDurationMetric = registered.ExistingCollector.(*prometheus.HistogramVec) 547 err = nil 548 } else { 549 return 550 } 551 } 552 } 553 554 // Create and populate the object: 555 result = &TransportWrapper{ 556 logger: b.logger, 557 clientID: clientID, 558 clientSecret: clientSecret, 559 user: b.user, 560 password: b.password, 561 scopes: scopes, 562 agent: b.agent, 563 clientSelector: clientSelector, 564 tokenURL: tokenURL, 565 tokenServer: tokenServer, 566 tokenMutex: &sync.Mutex{}, 567 tokenParser: tokenParser, 568 accessToken: accessToken, 569 refreshToken: refreshToken, 570 pullSecretAccessToken: pullSecretAccessToken, 571 metricsSubsystem: b.metricsSubsystem, 572 metricsRegisterer: b.metricsRegisterer, 573 tokenCountMetric: tokenCountMetric, 574 tokenDurationMetric: tokenDurationMetric, 575 } 576 577 return 578 } 579 580 // Logger returns the logger that is used by the wrapper. 581 func (w *TransportWrapper) Logger() logging.Logger { 582 return w.logger 583 } 584 585 // TokenURL returns the URL that the connection is using request OpenID access tokens. 586 func (w *TransportWrapper) TokenURL() string { 587 return w.tokenURL 588 } 589 590 // Client returns OpenID client identifier and secret that the wrapper is using to request OpenID 591 // access tokens. 592 func (w *TransportWrapper) Client() (id, secret string) { 593 id = w.clientID 594 secret = w.clientSecret 595 return 596 } 597 598 // User returns the user name and password that the wrapper is using to request OpenID access 599 // tokens. 600 func (w *TransportWrapper) User() (user, password string) { 601 user = w.user 602 password = w.password 603 return 604 } 605 606 // Scopes returns the OpenID scopes that the wrapper is using to request OpenID access tokens. 607 func (w *TransportWrapper) Scopes() []string { 608 result := make([]string, len(w.scopes)) 609 copy(result, w.scopes) 610 return result 611 } 612 613 // Wrap creates a new round tripper that wraps the given one and populates the authorization header. 614 func (w *TransportWrapper) Wrap(transport http.RoundTripper) http.RoundTripper { 615 return &roundTripper{ 616 owner: w, 617 logger: w.logger, 618 transport: transport, 619 } 620 } 621 622 // Close releases all the resources used by the wrapper. 623 func (w *TransportWrapper) Close() error { 624 err := w.clientSelector.Close() 625 if err != nil { 626 return err 627 } 628 return nil 629 } 630 631 // RoundTrip is the implementation of the round tripper interface. 632 func (t *roundTripper) RoundTrip(request *http.Request) (response *http.Response, err error) { 633 // Get the context: 634 ctx := request.Context() 635 636 // Get the access token: 637 token, _, err := t.owner.Tokens(ctx) 638 if err != nil { 639 err = fmt.Errorf("can't get access token: %w", err) 640 return 641 } 642 643 // Add the authorization header: 644 if request.Header == nil { 645 request.Header = make(http.Header) 646 } 647 648 // If the access token is a pull-secret-access-token type, a 649 // different Authorization header must be used 650 if token != "" { 651 if err := parsePullSecretAccessToken(token); err == nil { 652 // It is a pull-secret access token 653 request.Header.Set("Authorization", "AccessToken "+token) 654 } else { 655 request.Header.Set("Authorization", "Bearer "+token) 656 } 657 } 658 659 // Call the wrapped transport: 660 response, err = t.transport.RoundTrip(request) 661 662 return 663 } 664 665 // Tokens returns the access and refresh tokens that are currently in use by the wrapper. If it is 666 // necessary to request new tokens because they weren't requested yet, or because they are expired, 667 // this method will do it and will return an error if it fails. 668 // 669 // If new tokens are needed the request will be retried with an exponential backoff. 670 func (w *TransportWrapper) Tokens(ctx context.Context, expiresIn ...time.Duration) (access, 671 refresh string, err error) { 672 expiresDuration := tokenExpiry 673 if len(expiresIn) == 1 { 674 expiresDuration = expiresIn[0] 675 } 676 677 // Configure the back-off so that it honours the deadline of the context passed 678 // to the method. Note that we need to specify explicitly the type of the variable 679 // because the backoff.NewExponentialBackOff function returns the implementation 680 // type but backoff.WithContext returns the interface instead. 681 exponentialBackoffMethod := backoff.NewExponentialBackOff() 682 exponentialBackoffMethod.MaxElapsedTime = 15 * time.Second 683 var backoffMethod backoff.BackOff = exponentialBackoffMethod 684 if ctx != nil { 685 backoffMethod = backoff.WithContext(backoffMethod, ctx) 686 } 687 688 attempt := 0 689 operation := func() error { 690 attempt++ 691 var code int 692 code, access, refresh, err = w.tokens(ctx, attempt, expiresDuration) 693 if err != nil { 694 if code >= http.StatusInternalServerError { 695 w.logger.Debug( 696 ctx, 697 "Can't get tokens, got HTTP code %d, will retry: %v", 698 code, err, 699 ) 700 return err 701 } 702 w.logger.Debug( 703 ctx, 704 "Can't get tokens, got HTTP code %d, will not retry: %v", 705 code, err, 706 ) 707 return backoff.Permanent(err) 708 } 709 710 if attempt > 1 { 711 w.logger.Debug(ctx, "Got tokens on attempt %d", attempt) 712 } else { 713 w.logger.Debug(ctx, "Got tokens on first attempt") 714 } 715 return nil 716 } 717 718 // nolint 719 backoff.Retry(operation, backoffMethod) 720 return access, refresh, err 721 } 722 723 func (w *TransportWrapper) tokens(ctx context.Context, attempt int, 724 minRemaining time.Duration) (code int, access, refresh string, err error) { 725 // We need to make sure that this method isn't execute concurrently, as we will be updating 726 // multiple attributes of the connection: 727 w.tokenMutex.Lock() 728 defer w.tokenMutex.Unlock() 729 730 // A pull-secret access token can just be used as-is 731 if w.pullSecretAccessToken != nil { 732 access = w.pullSecretAccessToken.text 733 return 734 } 735 736 // Check the expiration times of the tokens: 737 now := time.Now() 738 var accessExpires bool 739 var accessRemaining time.Duration 740 if w.accessToken != nil { 741 accessExpires, accessRemaining, err = tokenRemaining(w.accessToken, now) 742 if err != nil { 743 return 744 } 745 } 746 var refreshExpires bool 747 var refreshRemaining time.Duration 748 if w.refreshToken != nil { 749 refreshExpires, refreshRemaining, err = tokenRemaining(w.refreshToken, now) 750 if err != nil { 751 return 752 } 753 } 754 if w.logger.DebugEnabled() { 755 w.debugExpiry(ctx, "Bearer", w.accessToken, accessExpires, accessRemaining) 756 w.debugExpiry(ctx, "Refresh", w.refreshToken, refreshExpires, refreshRemaining) 757 } 758 759 // If the access token is available and it isn't expired or about to expire then we can 760 // return the current tokens directly: 761 if w.accessToken != nil && (!accessExpires || accessRemaining >= minRemaining) { 762 access, refresh = w.currentTokens() 763 return 764 } 765 766 // At this point we know that the access token is unavailable, expired or about to expire. 767 w.logger.Debug(ctx, "Trying to get new tokens (attempt %d)", attempt) 768 769 // If we have a client identifier and secret we should use the client credentials grant even 770 // if we have a valid refresh token. Having both is a side effect of a incorrect behaviour 771 // of an old version of the SSO server. Note that we don't ignore the returned refresh token 772 // in that case, not because we will use it, but because we return it to the caller and we 773 // don't want to change that deprecated behaviour yet. 774 if w.haveSecret() { 775 code, _, err = w.sendClientCredentialsForm(ctx, attempt) 776 if err != nil { 777 return 778 } 779 access, refresh = w.currentTokens() 780 return 781 } 782 783 // At this point we know that we don't have client credentials, so we should try to use the 784 // refresh token if available and not expired. 785 if w.refreshToken != nil && (!refreshExpires || refreshRemaining >= minRemaining) { 786 code, _, err = w.sendRefreshForm(ctx, attempt) 787 if err != nil { 788 return 789 } 790 access, refresh = w.currentTokens() 791 return 792 } 793 794 // Now we know that both the access and refresh tokens are unavailable, expired or about to 795 // expire. We also know that we don't have client credentials, but we may still have a user 796 // name and password. 797 if w.havePassword() { 798 code, _, err = w.sendPasswordForm(ctx, attempt) 799 if err != nil { 800 return 801 } 802 access, refresh = w.currentTokens() 803 return 804 } 805 806 // Here we know that the access and refresh tokens are unavailable, expired or about to 807 // expire. We also know that we don't have credentials to request new ones. But we could 808 // still use the refresh token if it isn't completely expired. 809 if w.refreshToken != nil && refreshRemaining > 0 { 810 w.logger.Warn( 811 ctx, 812 "Refresh token expires in only %s, but there is no other mechanism to "+ 813 "obtain a new token, so will try to use it anyhow", 814 refreshRemaining, 815 ) 816 code, _, err = w.sendRefreshForm(ctx, attempt) 817 if err != nil { 818 return 819 } 820 access, refresh = w.currentTokens() 821 return 822 } 823 824 // At this point we know that the access token is expired or about to expire. We know also 825 // that the refresh token is unavailable or completely expired. And we know that we don't 826 // have credentials to request new tokens. But we can still use the access token if it isn't 827 // expired. 828 if w.accessToken != nil && accessRemaining > 0 { 829 w.logger.Warn( 830 ctx, 831 "Access token expires in only %s, but there is no other mechanism to "+ 832 "obtain a new token, so will try to use it anyhow", 833 accessRemaining, 834 ) 835 access, refresh = w.currentTokens() 836 return 837 } 838 839 // There is no way to get a valid access token, so all we can do is report the failure: 840 err = fmt.Errorf( 841 "access and refresh tokens are unavailable or expired, and there are no " + 842 "password or client secret to request new ones", 843 ) 844 845 return 846 } 847 848 // currentTokens returns the current tokens without trying to send any request to refresh them, and 849 // checking that they are actually available. If they aren't available then it will return empty 850 // strings. 851 func (w *TransportWrapper) currentTokens() (access, refresh string) { 852 if w.accessToken != nil { 853 access = w.accessToken.text 854 } 855 if w.refreshToken != nil { 856 refresh = w.refreshToken.text 857 } 858 return 859 } 860 861 func (w *TransportWrapper) sendClientCredentialsForm(ctx context.Context, attempt int) (code int, 862 result *internal.TokenResponse, err error) { 863 form := url.Values{} 864 headers := map[string]string{} 865 w.logger.Debug(ctx, "Requesting new token using the client credentials grant") 866 form.Set(grantTypeField, clientCredentialsGrant) 867 form.Set(clientIDField, w.clientID) 868 form.Set(scopeField, strings.Join(w.scopes, " ")) 869 // Encode client_id and client_secret to use as basic auth 870 // https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 871 auth := fmt.Sprintf("%s:%s", w.clientID, w.clientSecret) 872 hash := base64.StdEncoding.EncodeToString([]byte(auth)) 873 headers["Authorization"] = fmt.Sprintf("Basic %s", hash) 874 return w.sendForm(ctx, form, headers, attempt) 875 } 876 877 func (w *TransportWrapper) sendPasswordForm(ctx context.Context, attempt int) (code int, 878 result *internal.TokenResponse, err error) { 879 form := url.Values{} 880 w.logger.Debug(ctx, "Requesting new token using the password grant") 881 form.Set(grantTypeField, passwordGrant) 882 form.Set(clientIDField, w.clientID) 883 form.Set(usernameField, w.user) 884 form.Set(passwordField, w.password) 885 form.Set(scopeField, strings.Join(w.scopes, " ")) 886 return w.sendForm(ctx, form, nil, attempt) 887 } 888 889 func (w *TransportWrapper) sendRefreshForm(ctx context.Context, attempt int) (code int, 890 result *internal.TokenResponse, err error) { 891 w.logger.Debug(ctx, "Requesting new token using the refresh token grant") 892 form := url.Values{} 893 form.Set(grantTypeField, refreshTokenGrant) 894 form.Set(clientIDField, w.clientID) 895 form.Set(refreshTokenField, w.refreshToken.text) 896 code, result, err = w.sendForm(ctx, form, nil, attempt) 897 return 898 } 899 900 func (w *TransportWrapper) sendForm(ctx context.Context, form url.Values, headers map[string]string, 901 attempt int) (code int, result *internal.TokenResponse, err error) { 902 // Measure the time that it takes to send the request and receive the response: 903 start := time.Now() 904 code, result, err = w.sendFormTimed(ctx, form, headers) 905 elapsed := time.Since(start) 906 907 // Update the metrics: 908 if w.tokenCountMetric != nil || w.tokenDurationMetric != nil { 909 labels := map[string]string{ 910 metricsAttemptLabel: strconv.Itoa(attempt), 911 metricsCodeLabel: strconv.Itoa(code), 912 } 913 if w.tokenCountMetric != nil { 914 w.tokenCountMetric.With(labels).Inc() 915 } 916 if w.tokenDurationMetric != nil { 917 w.tokenDurationMetric.With(labels).Observe(elapsed.Seconds()) 918 } 919 } 920 921 // Return the original error: 922 return 923 } 924 925 func (w *TransportWrapper) sendFormTimed(ctx context.Context, form url.Values, headers map[string]string) (code int, 926 result *internal.TokenResponse, err error) { 927 // Create the HTTP request: 928 body := []byte(form.Encode()) 929 request, err := http.NewRequest(http.MethodPost, w.tokenURL, bytes.NewReader(body)) 930 request.Close = true 931 header := request.Header 932 if w.agent != "" { 933 header.Set("User-Agent", w.agent) 934 } 935 header.Set("Content-Type", "application/x-www-form-urlencoded") 936 header.Set("Accept", "application/json") 937 // Add any additional headers: 938 for k, v := range headers { 939 header.Set(k, v) 940 } 941 if err != nil { 942 err = fmt.Errorf("can't create request: %w", err) 943 return 944 } 945 946 // Set the context: 947 if ctx != nil { 948 request = request.WithContext(ctx) 949 } 950 951 // Select the HTTP client: 952 client, err := w.clientSelector.Select(ctx, w.tokenServer) 953 if err != nil { 954 return 955 } 956 957 // Send the HTTP request: 958 response, err := client.Do(request) 959 if err != nil { 960 err = fmt.Errorf("can't send request: %w", err) 961 return 962 } 963 defer response.Body.Close() 964 965 code = response.StatusCode 966 967 // Check that the response content type is JSON: 968 err = internal.CheckContentType(response) 969 if err != nil { 970 return 971 } 972 973 // Read the response body: 974 body, err = io.ReadAll(response.Body) 975 if err != nil { 976 err = fmt.Errorf("can't read response: %w", err) 977 return 978 } 979 980 // Parse the response body: 981 result = &internal.TokenResponse{} 982 err = json.Unmarshal(body, result) 983 if err != nil { 984 err = fmt.Errorf("can't parse JSON response: %w", err) 985 return 986 } 987 if result.Error != nil { 988 if result.ErrorDescription != nil { 989 err = fmt.Errorf("%s: %s", *result.Error, *result.ErrorDescription) 990 return 991 } 992 err = fmt.Errorf("%s", *result.Error) 993 return 994 } 995 if response.StatusCode != http.StatusOK { 996 err = fmt.Errorf("token response status code is '%d'", response.StatusCode) 997 return 998 } 999 if result.TokenType != nil && !strings.EqualFold(*result.TokenType, "bearer") { 1000 err = fmt.Errorf("expected 'bearer' token type but got '%s'", *result.TokenType) 1001 return 1002 } 1003 1004 // The response should always contains the access token, regardless of the kind of grant 1005 // that was used: 1006 var accessTokenText string 1007 var accessTokenObject *jwt.Token 1008 var accessToken *tokenInfo 1009 if result.AccessToken == nil { 1010 err = fmt.Errorf("no access token was received") 1011 return 1012 } 1013 accessTokenText = *result.AccessToken 1014 accessTokenObject, _, err = w.tokenParser.ParseUnverified( 1015 accessTokenText, 1016 jwt.MapClaims{}, 1017 ) 1018 if err != nil { 1019 return 1020 } 1021 if accessTokenText != "" { 1022 accessToken = &tokenInfo{ 1023 text: accessTokenText, 1024 object: accessTokenObject, 1025 } 1026 } 1027 1028 // If a refresh token is not included in the response, we can safely assume that the old 1029 // one is still valid and does not need to be discarded 1030 // https://datatracker.ietf.org/doc/html/rfc6749#section-6 1031 var refreshTokenText string 1032 var refreshTokenObject *jwt.Token 1033 var refreshToken *tokenInfo 1034 if result.RefreshToken == nil { 1035 if w.refreshToken != nil && w.refreshToken.text != "" { 1036 result.RefreshToken = &w.refreshToken.text 1037 } 1038 } else { 1039 refreshTokenText = *result.RefreshToken 1040 refreshTokenObject, _, err = w.tokenParser.ParseUnverified( 1041 refreshTokenText, 1042 jwt.MapClaims{}, 1043 ) 1044 if err != nil { 1045 w.logger.Debug( 1046 ctx, 1047 "Refresh token can't be parsed, will assume it is opaque: %v", 1048 err, 1049 ) 1050 err = nil 1051 } 1052 } 1053 if refreshTokenText != "" { 1054 refreshToken = &tokenInfo{ 1055 text: refreshTokenText, 1056 object: refreshTokenObject, 1057 } 1058 } 1059 1060 // Save the new tokens: 1061 if accessToken != nil { 1062 w.accessToken = accessToken 1063 } 1064 if refreshToken != nil { 1065 w.refreshToken = refreshToken 1066 } 1067 1068 return 1069 } 1070 1071 func (w *TransportWrapper) havePassword() bool { 1072 return w.user != "" && w.password != "" 1073 } 1074 1075 func (w *TransportWrapper) haveSecret() bool { 1076 return w.clientID != "" && w.clientSecret != "" 1077 } 1078 1079 // debugExpiry sends to the log information about the expiration of the given token. 1080 func (w *TransportWrapper) debugExpiry(ctx context.Context, typ string, token *tokenInfo, 1081 expires bool, left time.Duration) { 1082 if token != nil { 1083 if expires { 1084 if left < 0 { 1085 w.logger.Debug(ctx, "%s token expired %s ago", typ, -left) 1086 } else if left > 0 { 1087 w.logger.Debug(ctx, "%s token expires in %s", typ, left) 1088 } else { 1089 w.logger.Debug(ctx, "%s token expired just now", typ) 1090 } 1091 } 1092 } else { 1093 w.logger.Debug(ctx, "%s token isn't available", typ) 1094 } 1095 } 1096 1097 // parsePullSecretAccessToken will parse the supplied token to verify conformity 1098 // with that of a pull secret access token. A pull secret access token is of the 1099 // form <cluster id>:<Base64d pull secret token>. 1100 func parsePullSecretAccessToken(text string) error { 1101 elems := strings.Split(text, ":") 1102 if len(elems) != 2 { 1103 return fmt.Errorf("unparseable pull secret token") 1104 } 1105 _, err := uuid.Parse(elems[0]) 1106 if err != nil { 1107 return fmt.Errorf("unparseable pull secret token cluster ID") 1108 } 1109 _, err = base64.StdEncoding.DecodeString(elems[1]) 1110 if err != nil { 1111 return fmt.Errorf("unparseable pull secret token value") 1112 } 1113 return nil 1114 } 1115 1116 // Names of fields in the token form: 1117 const ( 1118 grantTypeField = "grant_type" 1119 clientIDField = "client_id" 1120 usernameField = "username" 1121 passwordField = "password" 1122 refreshTokenField = "refresh_token" 1123 scopeField = "scope" 1124 ) 1125 1126 // Grant kinds: 1127 const ( 1128 clientCredentialsGrant = "client_credentials" 1129 passwordGrant = "password" 1130 refreshTokenGrant = "refresh_token" 1131 ) 1132 1133 const ( 1134 tokenExpiry = 1 * time.Minute 1135 ) 1136 1137 // Names of the labels added to metrics: 1138 const ( 1139 metricsAttemptLabel = "attempt" 1140 metricsCodeLabel = "code" 1141 ) 1142 1143 // Array of labels added to token metrics: 1144 var tokenMetricsLabels = []string{ 1145 metricsAttemptLabel, 1146 metricsCodeLabel, 1147 }