github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/credentials/oauth2.go (about)

     1  package credentials
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"os/user"
    14  	"path/filepath"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/golang-jwt/jwt/v4"
    23  
    24  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff"
    25  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/secret"
    26  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    27  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    28  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xstring"
    29  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    30  )
    31  
    32  const (
    33  	defaultRequestTimeout      = time.Second * 10
    34  	defaultSyncExchangeTimeout = time.Second * 20
    35  	defaultJWTTokenTTL         = 3600 * time.Second
    36  	updateTimeDivider          = 2
    37  	retryJitterLimit           = 0.5
    38  	syncRetryFastSlot          = time.Millisecond * 100
    39  	syncRetrySlowSlot          = time.Millisecond * 300
    40  	syncRetryCeiling           = 1
    41  	backgroundRetryFastSlot    = time.Millisecond * 10
    42  	backgroundRetrySlowSlot    = time.Millisecond * 300
    43  	backgroundRetryFastCeiling = 12
    44  	backgroundRetrySlowCeiling = 7
    45  )
    46  
    47  var (
    48  	syncRetryFastBackoff = backoff.New(
    49  		backoff.WithSlotDuration(syncRetryFastSlot),
    50  		backoff.WithCeiling(syncRetryCeiling),
    51  		backoff.WithJitterLimit(retryJitterLimit),
    52  	)
    53  	syncRetrySlowBackoff = backoff.New(
    54  		backoff.WithSlotDuration(syncRetrySlowSlot),
    55  		backoff.WithCeiling(syncRetryCeiling),
    56  		backoff.WithJitterLimit(retryJitterLimit),
    57  	)
    58  	backgroundRetryFastBackoff = backoff.New(
    59  		backoff.WithSlotDuration(backgroundRetryFastSlot),
    60  		backoff.WithCeiling(backgroundRetryFastCeiling),
    61  		backoff.WithJitterLimit(retryJitterLimit),
    62  	)
    63  	backgroundRetrySlowBackoff = backoff.New(
    64  		backoff.WithSlotDuration(backgroundRetrySlowSlot),
    65  		backoff.WithCeiling(backgroundRetrySlowCeiling),
    66  		backoff.WithJitterLimit(retryJitterLimit),
    67  	)
    68  )
    69  
    70  var (
    71  	errCouldNotReadFile           = errors.New("could not read file")
    72  	errCouldNotParseHomeDir       = errors.New("could not parse home dir")
    73  	errEmptyTokenEndpointError    = errors.New("OAuth2 token exchange: empty token endpoint")
    74  	errCouldNotParseResponse      = errors.New("OAuth2 token exchange: could not parse response")
    75  	errCouldNotExchangeToken      = errors.New("OAuth2 token exchange: could not exchange token")
    76  	errUnsupportedTokenType       = errors.New("OAuth2 token exchange: unsupported token type")
    77  	errIncorrectExpirationTime    = errors.New("OAuth2 token exchange: incorrect expiration time")
    78  	errDifferentScope             = errors.New("OAuth2 token exchange: got different scope")
    79  	errEmptyAccessToken           = errors.New("OAuth2 token exchange: got empty access token")
    80  	errCouldNotMakeHTTPRequest    = errors.New("OAuth2 token exchange: could not make http request")
    81  	errCouldNotApplyOption        = errors.New("OAuth2 token exchange: could not apply option")
    82  	errCouldNotCreateTokenSource  = errors.New("OAuth2 token exchange: could not create TokenSource")
    83  	errNoSigningMethodError       = errors.New("JWT token source: no signing method")
    84  	errNoPrivateKeyError          = errors.New("JWT token source: no private key")
    85  	errCouldNotSignJWTToken       = errors.New("JWT token source: could not sign jwt token")
    86  	errCouldNotApplyJWTOption     = errors.New("JWT token source: could not apply option")
    87  	errCouldNotparsePrivateKey    = errors.New("JWT token source: could not parse private key from PEM")
    88  	errCouldNotReadPrivateKeyFile = errors.New("JWT token source: could not read from private key file")
    89  	errCouldNotParseBase64Secret  = errors.New("JWT token source: could not parse base64 secret")
    90  	errCouldNotReadConfigFile     = errors.New("OAuth2 token exchange file: could not read from config file")
    91  	errCouldNotUnmarshalJSON      = errors.New("OAuth2 token exchange file: could not unmarshal json config file")
    92  	errUnknownTokenSourceType     = errors.New("OAuth2 token exchange file: incorrect \"type\" parameter: only \"JWT\" and \"FIXED\" are supported") //nolint:lll
    93  	errTokenAndTokenTypeRequired  = errors.New("OAuth2 token exchange file: \"token\" and \"token-type\" are required")
    94  	errAlgAndKeyRequired          = errors.New("OAuth2 token exchange file: \"alg\" and \"private-key\" are required")
    95  	errUnsupportedSigningMethod   = errors.New("OAuth2 token exchange file: signing method not supported")
    96  	errTTLMustBePositive          = errors.New("OAuth2 token exchange file: \"ttl\" must be positive value")
    97  )
    98  
    99  func readFileContent(filePath string) ([]byte, error) {
   100  	if len(filePath) > 0 && filePath[0] == '~' {
   101  		usr, err := user.Current()
   102  		if err != nil {
   103  			return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseHomeDir, err))
   104  		}
   105  		filePath = filepath.Join(usr.HomeDir, filePath[1:])
   106  	}
   107  	bytes, err := os.ReadFile(filePath)
   108  	if err != nil {
   109  		return nil, xerrors.WithStackTrace(fmt.Errorf("%w %s: %w", errCouldNotReadFile, filePath, err))
   110  	}
   111  
   112  	return bytes, nil
   113  }
   114  
   115  type Oauth2TokenExchangeCredentialsOption interface {
   116  	ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error
   117  }
   118  
   119  // TokenEndpoint
   120  type tokenEndpointOption string
   121  
   122  func (endpoint tokenEndpointOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   123  	c.tokenEndpoint = string(endpoint)
   124  
   125  	return nil
   126  }
   127  
   128  func WithTokenEndpoint(endpoint string) tokenEndpointOption {
   129  	return tokenEndpointOption(endpoint)
   130  }
   131  
   132  // GrantType
   133  type grantTypeOption string
   134  
   135  func (grantType grantTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   136  	c.grantType = string(grantType)
   137  
   138  	return nil
   139  }
   140  
   141  func WithGrantType(grantType string) grantTypeOption {
   142  	return grantTypeOption(grantType)
   143  }
   144  
   145  // Resource
   146  type resourceOption []string
   147  
   148  func (resource resourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   149  	c.resource = append(c.resource, resource...)
   150  
   151  	return nil
   152  }
   153  
   154  func WithResource(resource string, resources ...string) resourceOption {
   155  	return append([]string{resource}, resources...)
   156  }
   157  
   158  // RequestedTokenType
   159  type requestedTokenTypeOption string
   160  
   161  func (requestedTokenType requestedTokenTypeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   162  	c.requestedTokenType = string(requestedTokenType)
   163  
   164  	return nil
   165  }
   166  
   167  func WithRequestedTokenType(requestedTokenType string) requestedTokenTypeOption {
   168  	return requestedTokenTypeOption(requestedTokenType)
   169  }
   170  
   171  // Audience
   172  type audienceOption []string
   173  
   174  func (audience audienceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   175  	c.audience = append(c.audience, audience...)
   176  
   177  	return nil
   178  }
   179  
   180  func WithAudience(audience string, audiences ...string) audienceOption {
   181  	return append([]string{audience}, audiences...)
   182  }
   183  
   184  // Scope
   185  type scopeOption []string
   186  
   187  func (scope scopeOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   188  	c.scope = append(c.scope, scope...)
   189  
   190  	return nil
   191  }
   192  
   193  func WithScope(scope string, scopes ...string) scopeOption {
   194  	return append([]string{scope}, scopes...)
   195  }
   196  
   197  // RequestTimeout
   198  type requestTimeoutOption time.Duration
   199  
   200  func (timeout requestTimeoutOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   201  	c.requestTimeout = time.Duration(timeout)
   202  
   203  	return nil
   204  }
   205  
   206  func WithRequestTimeout(timeout time.Duration) requestTimeoutOption {
   207  	return requestTimeoutOption(timeout)
   208  }
   209  
   210  // SyncExchangeTimeout
   211  type syncExchangeTimeoutOption time.Duration
   212  
   213  func (timeout syncExchangeTimeoutOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   214  	c.syncExchangeTimeout = time.Duration(timeout)
   215  
   216  	return nil
   217  }
   218  
   219  func WithSyncExchangeTimeout(timeout time.Duration) syncExchangeTimeoutOption {
   220  	return syncExchangeTimeoutOption(timeout)
   221  }
   222  
   223  const (
   224  	SubjectTokenSourceType = 1
   225  	ActorTokenSourceType   = 2
   226  )
   227  
   228  // SubjectTokenSource/ActorTokenSource
   229  type tokenSourceOption struct {
   230  	source          TokenSource
   231  	createFunc      func() (TokenSource, error)
   232  	tokenSourceType int
   233  }
   234  
   235  func (tokenSource *tokenSourceOption) ApplyOauth2CredentialsOption(c *oauth2TokenExchange) error {
   236  	src := tokenSource.source
   237  	var err error
   238  	if src == nil {
   239  		src, err = tokenSource.createFunc()
   240  		if err != nil {
   241  			return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotCreateTokenSource, err))
   242  		}
   243  	}
   244  	switch tokenSource.tokenSourceType {
   245  	case SubjectTokenSourceType:
   246  		c.subjectTokenSource = src
   247  	case ActorTokenSourceType:
   248  		c.actorTokenSource = src
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func WithSubjectToken(subjectToken TokenSource) *tokenSourceOption {
   255  	return &tokenSourceOption{
   256  		source:          subjectToken,
   257  		tokenSourceType: SubjectTokenSourceType,
   258  	}
   259  }
   260  
   261  func WithFixedSubjectToken(token, tokenType string) *tokenSourceOption {
   262  	return &tokenSourceOption{
   263  		createFunc: func() (TokenSource, error) {
   264  			return NewFixedTokenSource(token, tokenType), nil
   265  		},
   266  		tokenSourceType: SubjectTokenSourceType,
   267  	}
   268  }
   269  
   270  func WithJWTSubjectToken(opts ...JWTTokenSourceOption) *tokenSourceOption {
   271  	return &tokenSourceOption{
   272  		createFunc: func() (TokenSource, error) {
   273  			return NewJWTTokenSource(opts...)
   274  		},
   275  		tokenSourceType: SubjectTokenSourceType,
   276  	}
   277  }
   278  
   279  // ActorTokenSource
   280  func WithActorToken(actorToken TokenSource) *tokenSourceOption {
   281  	return &tokenSourceOption{
   282  		source:          actorToken,
   283  		tokenSourceType: ActorTokenSourceType,
   284  	}
   285  }
   286  
   287  func WithFixedActorToken(token, tokenType string) *tokenSourceOption {
   288  	return &tokenSourceOption{
   289  		createFunc: func() (TokenSource, error) {
   290  			return NewFixedTokenSource(token, tokenType), nil
   291  		},
   292  		tokenSourceType: ActorTokenSourceType,
   293  	}
   294  }
   295  
   296  func WithJWTActorToken(opts ...JWTTokenSourceOption) *tokenSourceOption {
   297  	return &tokenSourceOption{
   298  		createFunc: func() (TokenSource, error) {
   299  			return NewJWTTokenSource(opts...)
   300  		},
   301  		tokenSourceType: ActorTokenSourceType,
   302  	}
   303  }
   304  
   305  type oauth2TokenExchange struct {
   306  	tokenEndpoint string
   307  
   308  	// grant_type parameter
   309  	// urn:ietf:params:oauth:grant-type:token-exchange by default
   310  	grantType string
   311  
   312  	resource []string
   313  	audience []string
   314  	scope    []string
   315  
   316  	// requested_token_type parameter
   317  	// urn:ietf:params:oauth:token-type:access_token by default
   318  	requestedTokenType string
   319  
   320  	subjectTokenSource TokenSource
   321  
   322  	actorTokenSource TokenSource
   323  
   324  	// Http request timeout
   325  	// 10 by default
   326  	requestTimeout time.Duration
   327  
   328  	// Timeout when performing synchronous token exchange
   329  	// It is used when getting token for the first time
   330  	// or when it is already expired
   331  	syncExchangeTimeout time.Duration
   332  
   333  	// Received data
   334  	receivedToken           string
   335  	updateTokenTime         time.Time
   336  	receivedTokenExpireTime time.Time
   337  
   338  	mutex    sync.RWMutex
   339  	updating atomic.Bool // true if separate goroutine is run and updates token in background
   340  
   341  	sourceInfo string
   342  }
   343  
   344  func NewOauth2TokenExchangeCredentials(
   345  	opts ...Oauth2TokenExchangeCredentialsOption,
   346  ) (*oauth2TokenExchange, error) {
   347  	c := &oauth2TokenExchange{
   348  		grantType:           "urn:ietf:params:oauth:grant-type:token-exchange",
   349  		requestedTokenType:  "urn:ietf:params:oauth:token-type:access_token",
   350  		requestTimeout:      defaultRequestTimeout,
   351  		syncExchangeTimeout: defaultSyncExchangeTimeout,
   352  		sourceInfo:          stack.Record(1),
   353  	}
   354  
   355  	var err error
   356  	for _, opt := range opts {
   357  		if opt != nil {
   358  			err = opt.ApplyOauth2CredentialsOption(c)
   359  			if err != nil {
   360  				return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyOption, err))
   361  			}
   362  		}
   363  	}
   364  
   365  	if c.tokenEndpoint == "" {
   366  		return nil, xerrors.WithStackTrace(errEmptyTokenEndpointError)
   367  	}
   368  
   369  	return c, nil
   370  }
   371  
   372  type privateKeyLoadOptionFunc func(string) JWTTokenSourceOption
   373  
   374  type signingMethodDescription struct {
   375  	method        jwt.SigningMethod
   376  	keyLoadOption privateKeyLoadOptionFunc
   377  }
   378  
   379  func getHMACPrivateKeyOption(privateKey string) JWTTokenSourceOption {
   380  	return WithHMACSecretKeyBase64Content(privateKey)
   381  }
   382  
   383  func getRSAPrivateKeyOption(privateKey string) JWTTokenSourceOption {
   384  	return WithRSAPrivateKeyPEMContent([]byte(privateKey))
   385  }
   386  
   387  func getECPrivateKeyOption(privateKey string) JWTTokenSourceOption {
   388  	return WithECPrivateKeyPEMContent([]byte(privateKey))
   389  }
   390  
   391  var signingMethodsRegistry = map[string]signingMethodDescription{
   392  	"HS256": {
   393  		method:        jwt.SigningMethodHS256,
   394  		keyLoadOption: getHMACPrivateKeyOption,
   395  	},
   396  	"HS384": {
   397  		method:        jwt.SigningMethodHS384,
   398  		keyLoadOption: getHMACPrivateKeyOption,
   399  	},
   400  	"HS512": {
   401  		method:        jwt.SigningMethodHS512,
   402  		keyLoadOption: getHMACPrivateKeyOption,
   403  	},
   404  	"RS256": {
   405  		method:        jwt.SigningMethodRS256,
   406  		keyLoadOption: getRSAPrivateKeyOption,
   407  	},
   408  	"RS384": {
   409  		method:        jwt.SigningMethodRS384,
   410  		keyLoadOption: getRSAPrivateKeyOption,
   411  	},
   412  	"RS512": {
   413  		method:        jwt.SigningMethodRS512,
   414  		keyLoadOption: getRSAPrivateKeyOption,
   415  	},
   416  	"PS256": {
   417  		method:        jwt.SigningMethodPS256,
   418  		keyLoadOption: getRSAPrivateKeyOption,
   419  	},
   420  	"PS384": {
   421  		method:        jwt.SigningMethodPS384,
   422  		keyLoadOption: getRSAPrivateKeyOption,
   423  	},
   424  	"PS512": {
   425  		method:        jwt.SigningMethodPS512,
   426  		keyLoadOption: getRSAPrivateKeyOption,
   427  	},
   428  	"ES256": {
   429  		method:        jwt.SigningMethodES256,
   430  		keyLoadOption: getECPrivateKeyOption,
   431  	},
   432  	"ES384": {
   433  		method:        jwt.SigningMethodES384,
   434  		keyLoadOption: getECPrivateKeyOption,
   435  	},
   436  	"ES512": {
   437  		method:        jwt.SigningMethodES512,
   438  		keyLoadOption: getECPrivateKeyOption,
   439  	},
   440  }
   441  
   442  func GetSupportedOauth2TokenExchangeJwtAlgorithms() []string {
   443  	algs := make([]string, len(signingMethodsRegistry))
   444  	i := 0
   445  	for alg := range signingMethodsRegistry {
   446  		algs[i] = alg
   447  		i++
   448  	}
   449  	sort.Strings(algs)
   450  
   451  	return algs
   452  }
   453  
   454  type stringOrArrayConfig struct {
   455  	Values []string
   456  }
   457  
   458  func (a *stringOrArrayConfig) UnmarshalJSON(data []byte) error {
   459  	// Case 1: string
   460  	var s string
   461  	err := json.Unmarshal(data, &s)
   462  	if err == nil {
   463  		a.Values = []string{s}
   464  
   465  		return nil
   466  	}
   467  
   468  	var arr []string
   469  	err = json.Unmarshal(data, &arr)
   470  	if err != nil {
   471  		return xerrors.WithStackTrace(err)
   472  	}
   473  	a.Values = arr
   474  
   475  	return nil
   476  }
   477  
   478  type prettyTTL struct {
   479  	Value time.Duration
   480  }
   481  
   482  func (d *prettyTTL) UnmarshalJSON(data []byte) error {
   483  	var s string
   484  	err := json.Unmarshal(data, &s)
   485  	if err != nil {
   486  		return xerrors.WithStackTrace(err)
   487  	}
   488  	d.Value, err = time.ParseDuration(s)
   489  	if err != nil {
   490  		return xerrors.WithStackTrace(err)
   491  	}
   492  	if d.Value <= 0 {
   493  		return xerrors.WithStackTrace(fmt.Errorf("%w, but got: %q", errTTLMustBePositive, s))
   494  	}
   495  
   496  	return xerrors.WithStackTrace(err)
   497  }
   498  
   499  //nolint:tagliatelle
   500  type oauth2TokenSourceConfig struct {
   501  	Type string `json:"type"`
   502  
   503  	// Fixed
   504  	Token     string `json:"token"`
   505  	TokenType string `json:"token-type"`
   506  
   507  	// JWT
   508  	Algorithm  string               `json:"alg"`
   509  	PrivateKey string               `json:"private-key"`
   510  	KeyID      string               `json:"kid"`
   511  	Issuer     string               `json:"iss"`
   512  	Subject    string               `json:"sub"`
   513  	Audience   *stringOrArrayConfig `json:"aud"`
   514  	ID         string               `json:"jti"`
   515  	TTL        *prettyTTL           `json:"ttl"`
   516  }
   517  
   518  func signingMethodNotSupportedError(method string) error {
   519  	var supported string
   520  	for i, alg := range GetSupportedOauth2TokenExchangeJwtAlgorithms() {
   521  		if i != 0 {
   522  			supported += ", "
   523  		}
   524  		supported += "\""
   525  		supported += alg
   526  		supported += "\""
   527  	}
   528  
   529  	return fmt.Errorf("%w: %q. Supported signing methods are %s", errUnsupportedSigningMethod, method, supported)
   530  }
   531  
   532  func (cfg *oauth2TokenSourceConfig) applyConfigFixed(tokenSrcType int) (*tokenSourceOption, error) {
   533  	if cfg.Token == "" || cfg.TokenType == "" {
   534  		return nil, xerrors.WithStackTrace(errTokenAndTokenTypeRequired)
   535  	}
   536  
   537  	return &tokenSourceOption{
   538  		createFunc: func() (TokenSource, error) {
   539  			return NewFixedTokenSource(cfg.Token, cfg.TokenType), nil
   540  		},
   541  		tokenSourceType: tokenSrcType,
   542  	}, nil
   543  }
   544  
   545  func (cfg *oauth2TokenSourceConfig) applyConfigFixedJWT(tokenSrcType int) (*tokenSourceOption, error) {
   546  	var opts []JWTTokenSourceOption
   547  
   548  	if cfg.Algorithm == "" || cfg.PrivateKey == "" {
   549  		return nil, xerrors.WithStackTrace(errAlgAndKeyRequired)
   550  	}
   551  
   552  	signingMethodDesc, signingMethodFound := signingMethodsRegistry[strings.ToUpper(cfg.Algorithm)]
   553  	if !signingMethodFound {
   554  		return nil, xerrors.WithStackTrace(signingMethodNotSupportedError(cfg.Algorithm))
   555  	}
   556  
   557  	opts = append(opts,
   558  		WithSigningMethod(signingMethodDesc.method),
   559  		signingMethodDesc.keyLoadOption(cfg.PrivateKey),
   560  	)
   561  
   562  	if cfg.KeyID != "" {
   563  		opts = append(opts, WithKeyID(cfg.KeyID))
   564  	}
   565  
   566  	if cfg.Issuer != "" {
   567  		opts = append(opts, WithIssuer(cfg.Issuer))
   568  	}
   569  
   570  	if cfg.Subject != "" {
   571  		opts = append(opts, WithSubject(cfg.Subject))
   572  	}
   573  
   574  	if cfg.Audience != nil && len(cfg.Audience.Values) > 0 {
   575  		opts = append(opts, WithAudience(cfg.Audience.Values[0], cfg.Audience.Values[1:]...))
   576  	}
   577  
   578  	if cfg.ID != "" {
   579  		opts = append(opts, WithID(cfg.ID))
   580  	}
   581  
   582  	if cfg.TTL != nil {
   583  		opts = append(opts, WithTokenTTL(cfg.TTL.Value))
   584  	}
   585  
   586  	return &tokenSourceOption{
   587  		createFunc: func() (TokenSource, error) {
   588  			return NewJWTTokenSource(opts...)
   589  		},
   590  		tokenSourceType: tokenSrcType,
   591  	}, nil
   592  }
   593  
   594  func (cfg *oauth2TokenSourceConfig) applyConfig(tokenSrcType int) (*tokenSourceOption, error) {
   595  	if strings.EqualFold(cfg.Type, "FIXED") {
   596  		return cfg.applyConfigFixed(tokenSrcType)
   597  	}
   598  
   599  	if strings.EqualFold(cfg.Type, "JWT") {
   600  		return cfg.applyConfigFixedJWT(tokenSrcType)
   601  	}
   602  
   603  	return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %q", errUnknownTokenSourceType, cfg.Type))
   604  }
   605  
   606  //nolint:tagliatelle
   607  type oauth2Config struct {
   608  	GrantType          string               `json:"grant-type"`
   609  	Resource           *stringOrArrayConfig `json:"res"`
   610  	Audience           *stringOrArrayConfig `json:"aud"`
   611  	Scope              *stringOrArrayConfig `json:"scope"`
   612  	RequestedTokenType string               `json:"requested-token-type"`
   613  	TokenEndpoint      string               `json:"token-endpoint"`
   614  
   615  	SubjectCreds *oauth2TokenSourceConfig `json:"subject-credentials"`
   616  	ActorCreds   *oauth2TokenSourceConfig `json:"actor-credentials"`
   617  }
   618  
   619  func (cfg *oauth2Config) applyConfig(opts *[]Oauth2TokenExchangeCredentialsOption) error {
   620  	if cfg.GrantType != "" {
   621  		*opts = append(*opts, WithGrantType(cfg.GrantType))
   622  	}
   623  
   624  	if cfg.Resource != nil && len(cfg.Resource.Values) > 0 {
   625  		*opts = append(*opts, WithResource(cfg.Resource.Values[0], cfg.Resource.Values[1:]...))
   626  	}
   627  
   628  	if cfg.Audience != nil && len(cfg.Audience.Values) > 0 {
   629  		*opts = append(*opts, WithAudience(cfg.Audience.Values[0], cfg.Audience.Values[1:]...))
   630  	}
   631  
   632  	if cfg.Scope != nil && len(cfg.Scope.Values) > 0 {
   633  		*opts = append(*opts, WithScope(cfg.Scope.Values[0], cfg.Scope.Values[1:]...))
   634  	}
   635  
   636  	if cfg.RequestedTokenType != "" {
   637  		*opts = append(*opts, WithRequestedTokenType(cfg.RequestedTokenType))
   638  	}
   639  
   640  	if cfg.TokenEndpoint != "" {
   641  		*opts = append(*opts, WithTokenEndpoint(cfg.TokenEndpoint))
   642  	}
   643  
   644  	if cfg.SubjectCreds != nil {
   645  		opt, err := cfg.SubjectCreds.applyConfig(SubjectTokenSourceType)
   646  		if err != nil {
   647  			return xerrors.WithStackTrace(err)
   648  		}
   649  		*opts = append(*opts, opt)
   650  	}
   651  
   652  	if cfg.ActorCreds != nil {
   653  		opt, err := cfg.ActorCreds.applyConfig(ActorTokenSourceType)
   654  		if err != nil {
   655  			return xerrors.WithStackTrace(err)
   656  		}
   657  		*opts = append(*opts, opt)
   658  	}
   659  
   660  	return nil
   661  }
   662  
   663  func NewOauth2TokenExchangeCredentialsFile(
   664  	configFilePath string,
   665  	opts ...Oauth2TokenExchangeCredentialsOption,
   666  ) (*oauth2TokenExchange, error) {
   667  	configFileData, err := readFileContent(configFilePath)
   668  	if err != nil {
   669  		return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadConfigFile, err))
   670  	}
   671  
   672  	var cfg oauth2Config
   673  	if err = json.Unmarshal(configFileData, &cfg); err != nil {
   674  		return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotUnmarshalJSON, err))
   675  	}
   676  
   677  	var fullOptions []Oauth2TokenExchangeCredentialsOption
   678  	err = cfg.applyConfig(&fullOptions)
   679  	if err != nil {
   680  		return nil, err
   681  	}
   682  
   683  	// Add additional options
   684  	for _, opt := range opts {
   685  		if opt != nil {
   686  			fullOptions = append(fullOptions, opt)
   687  		}
   688  	}
   689  
   690  	return NewOauth2TokenExchangeCredentials(fullOptions...)
   691  }
   692  
   693  func (provider *oauth2TokenExchange) getScopeParam() string {
   694  	var scope string
   695  	if len(provider.scope) != 0 {
   696  		for _, s := range provider.scope {
   697  			if s != "" {
   698  				if scope != "" {
   699  					scope += " "
   700  				}
   701  				scope += s
   702  			}
   703  		}
   704  	}
   705  
   706  	return scope
   707  }
   708  
   709  func (provider *oauth2TokenExchange) addTokenSrc(params *url.Values, src TokenSource, tName, tTypeName string) error {
   710  	if src != nil {
   711  		token, err := src.Token()
   712  		if err != nil {
   713  			return xerrors.WithStackTrace(err)
   714  		}
   715  		params.Set(tName, token.Token)
   716  		params.Set(tTypeName, token.TokenType)
   717  	}
   718  
   719  	return nil
   720  }
   721  
   722  func (provider *oauth2TokenExchange) getRequestParams() (string, error) {
   723  	params := url.Values{}
   724  	params.Set("grant_type", provider.grantType)
   725  	for _, res := range provider.resource {
   726  		if res != "" {
   727  			params.Add("resource", res)
   728  		}
   729  	}
   730  	for _, aud := range provider.audience {
   731  		if aud != "" {
   732  			params.Add("audience", aud)
   733  		}
   734  	}
   735  	scope := provider.getScopeParam()
   736  	if scope != "" {
   737  		params.Set("scope", scope)
   738  	}
   739  
   740  	params.Set("requested_token_type", provider.requestedTokenType)
   741  
   742  	err := provider.addTokenSrc(&params, provider.subjectTokenSource, "subject_token", "subject_token_type")
   743  	if err != nil {
   744  		return "", xerrors.WithStackTrace(err)
   745  	}
   746  
   747  	err = provider.addTokenSrc(&params, provider.actorTokenSource, "actor_token", "actor_token_type")
   748  	if err != nil {
   749  		return "", xerrors.WithStackTrace(err)
   750  	}
   751  
   752  	return params.Encode(), nil
   753  }
   754  
   755  func (provider *oauth2TokenExchange) processTokenExchangeResponse(
   756  	result *http.Response,
   757  	now time.Time,
   758  	retryAllErrors bool,
   759  ) (*tokenResponse, error) {
   760  	data, err := readResponseBody(result)
   761  	if err != nil {
   762  		return nil, err
   763  	}
   764  
   765  	if result.StatusCode != http.StatusOK {
   766  		return nil, provider.handleErrorResponse(result, data, retryAllErrors)
   767  	}
   768  
   769  	parsedResponse, err := parseTokenResponse(data, retryAllErrors)
   770  	if err != nil {
   771  		return nil, err
   772  	}
   773  
   774  	if err := validateTokenResponse(parsedResponse, provider); err != nil {
   775  		return nil, err
   776  	}
   777  
   778  	parsedResponse.Now = now
   779  
   780  	return parsedResponse, nil
   781  }
   782  
   783  func readResponseBody(result *http.Response) ([]byte, error) {
   784  	if result.Body != nil {
   785  		data, err := io.ReadAll(result.Body)
   786  		if err != nil {
   787  			return nil, xerrors.WithStackTrace(xerrors.Retryable(err,
   788  				xerrors.WithBackoff(retry.TypeFastBackoff),
   789  			))
   790  		}
   791  
   792  		return data, nil
   793  	}
   794  
   795  	return make([]byte, 0), nil
   796  }
   797  
   798  func makeError(result *http.Response, err error, retryAllErrors bool) error {
   799  	if result != nil {
   800  		if result.StatusCode == http.StatusRequestTimeout ||
   801  			result.StatusCode == http.StatusGatewayTimeout ||
   802  			result.StatusCode == http.StatusTooManyRequests ||
   803  			result.StatusCode == http.StatusInternalServerError ||
   804  			result.StatusCode == http.StatusBadGateway ||
   805  			result.StatusCode == http.StatusServiceUnavailable {
   806  			return xerrors.WithStackTrace(xerrors.Retryable(err,
   807  				xerrors.WithBackoff(retry.TypeSlowBackoff),
   808  			))
   809  		}
   810  	}
   811  	if retryAllErrors {
   812  		return xerrors.WithStackTrace(xerrors.Retryable(err,
   813  			xerrors.WithBackoff(retry.TypeFastBackoff),
   814  		))
   815  	}
   816  
   817  	return xerrors.WithStackTrace(err)
   818  }
   819  
   820  func (provider *oauth2TokenExchange) handleErrorResponse(
   821  	result *http.Response,
   822  	data []byte,
   823  	retryAllErrors bool,
   824  ) error {
   825  	description := result.Status
   826  
   827  	//nolint:tagliatelle
   828  	type errorResponse struct {
   829  		ErrorName        string `json:"error"`
   830  		ErrorDescription string `json:"error_description"`
   831  		ErrorURI         string `json:"error_uri"`
   832  	}
   833  	var parsedErrorResponse errorResponse
   834  	if err := json.Unmarshal(data, &parsedErrorResponse); err != nil {
   835  		description += ", could not parse response: " + err.Error()
   836  
   837  		return makeError(
   838  			result,
   839  			xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)),
   840  			retryAllErrors)
   841  	}
   842  
   843  	if parsedErrorResponse.ErrorName != "" {
   844  		description += ", error: " + parsedErrorResponse.ErrorName
   845  	}
   846  	if parsedErrorResponse.ErrorDescription != "" {
   847  		description += fmt.Sprintf(", description: %q", parsedErrorResponse.ErrorDescription)
   848  	}
   849  	if parsedErrorResponse.ErrorURI != "" {
   850  		description += ", error_uri: " + parsedErrorResponse.ErrorURI
   851  	}
   852  
   853  	return makeError(
   854  		result,
   855  		xerrors.WithStackTrace(fmt.Errorf("%w: %s", errCouldNotExchangeToken, description)),
   856  		retryAllErrors)
   857  }
   858  
   859  //nolint:tagliatelle
   860  type tokenResponse struct {
   861  	AccessToken string    `json:"access_token"`
   862  	TokenType   string    `json:"token_type"`
   863  	ExpiresIn   int64     `json:"expires_in"`
   864  	Scope       string    `json:"scope"`
   865  	Now         time.Time `json:"-"`
   866  }
   867  
   868  func parseTokenResponse(data []byte, retryAllErrors bool) (*tokenResponse, error) {
   869  	var parsedResponse tokenResponse
   870  	if err := json.Unmarshal(data, &parsedResponse); err != nil {
   871  		return nil, makeError(
   872  			nil,
   873  			xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseResponse, err)),
   874  			retryAllErrors)
   875  	}
   876  
   877  	return &parsedResponse, nil
   878  }
   879  
   880  func validateTokenResponse(parsedResponse *tokenResponse, provider *oauth2TokenExchange) error {
   881  	if !strings.EqualFold(parsedResponse.TokenType, "bearer") {
   882  		return xerrors.WithStackTrace(
   883  			fmt.Errorf("%w: %q", errUnsupportedTokenType, parsedResponse.TokenType))
   884  	}
   885  
   886  	if parsedResponse.ExpiresIn <= 0 {
   887  		return xerrors.WithStackTrace(
   888  			fmt.Errorf("%w: %d", errIncorrectExpirationTime, parsedResponse.ExpiresIn))
   889  	}
   890  
   891  	if parsedResponse.Scope != "" {
   892  		scope := provider.getScopeParam()
   893  		if parsedResponse.Scope != scope {
   894  			return xerrors.WithStackTrace(
   895  				fmt.Errorf("%w. Expected %q, but got %q", errDifferentScope, scope, parsedResponse.Scope))
   896  		}
   897  	}
   898  
   899  	if parsedResponse.AccessToken == "" {
   900  		return xerrors.WithStackTrace(errEmptyAccessToken)
   901  	}
   902  
   903  	return nil
   904  }
   905  
   906  func (provider *oauth2TokenExchange) updateToken(parsedResponse *tokenResponse) {
   907  	provider.receivedToken = "Bearer " + parsedResponse.AccessToken
   908  
   909  	expireDelta := time.Duration(parsedResponse.ExpiresIn) * time.Second
   910  	provider.receivedTokenExpireTime = parsedResponse.Now.Add(expireDelta)
   911  
   912  	updateDelta := time.Duration(parsedResponse.ExpiresIn/updateTimeDivider) * time.Second
   913  	provider.updateTokenTime = parsedResponse.Now.Add(updateDelta)
   914  }
   915  
   916  // performExchangeTokenRequest is a read only func that performs request. Can be used without lock
   917  func (provider *oauth2TokenExchange) performExchangeTokenRequest(
   918  	ctx context.Context,
   919  	retryAllErrors bool,
   920  ) (*tokenResponse, error) {
   921  	body, err := provider.getRequestParams()
   922  	if err != nil {
   923  		return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err))
   924  	}
   925  
   926  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, provider.tokenEndpoint, strings.NewReader(body))
   927  	if err != nil {
   928  		return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotMakeHTTPRequest, err))
   929  	}
   930  	req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   931  	req.Header.Add("Content-Length", strconv.Itoa(len(body)))
   932  	req.Close = true
   933  
   934  	client := http.Client{
   935  		Transport: http.DefaultTransport,
   936  		Timeout:   provider.requestTimeout,
   937  	}
   938  
   939  	now := time.Now()
   940  	result, err := client.Do(req)
   941  	if err != nil {
   942  		return nil, xerrors.WithStackTrace(xerrors.Retryable(
   943  			fmt.Errorf("%w: %w", errCouldNotExchangeToken, err),
   944  			xerrors.WithBackoff(retry.TypeFastBackoff),
   945  		))
   946  	}
   947  
   948  	defer result.Body.Close()
   949  
   950  	return provider.processTokenExchangeResponse(result, now, retryAllErrors)
   951  }
   952  
   953  // exchangeTokenSync exchanges token synchronously, must be called under lock
   954  func (provider *oauth2TokenExchange) exchangeTokenSync(ctx context.Context) error {
   955  	retryAllErrors := provider.receivedToken != "" // already received token => all params are correct, can retry
   956  
   957  	ctx, cancelFunc := context.WithTimeout(ctx, provider.syncExchangeTimeout)
   958  	defer cancelFunc()
   959  
   960  	response, err := retry.RetryWithResult[*tokenResponse](
   961  		ctx,
   962  		func(ctx context.Context) (*tokenResponse, error) {
   963  			return provider.performExchangeTokenRequest(ctx, retryAllErrors)
   964  		},
   965  		retry.WithFastBackoff(syncRetryFastBackoff),
   966  		retry.WithSlowBackoff(syncRetrySlowBackoff),
   967  	)
   968  	if err != nil {
   969  		return xerrors.WithStackTrace(err)
   970  	}
   971  
   972  	provider.updateToken(response)
   973  
   974  	return nil
   975  }
   976  
   977  func (provider *oauth2TokenExchange) exchangeTokenInBackground() {
   978  	defer provider.updating.Store(false)
   979  
   980  	provider.mutex.RLock()
   981  	ctx, cancelFunc := context.WithDeadline(context.Background(), provider.receivedTokenExpireTime)
   982  	provider.mutex.RUnlock()
   983  	defer cancelFunc()
   984  
   985  	response, err := retry.RetryWithResult[*tokenResponse](
   986  		ctx,
   987  		func(ctx context.Context) (*tokenResponse, error) {
   988  			return provider.performExchangeTokenRequest(ctx, true)
   989  		},
   990  		retry.WithFastBackoff(backgroundRetryFastBackoff),
   991  		retry.WithSlowBackoff(backgroundRetrySlowBackoff),
   992  	)
   993  	if err != nil {
   994  		return
   995  	}
   996  
   997  	provider.mutex.Lock()
   998  	defer provider.mutex.Unlock()
   999  	provider.updateToken(response)
  1000  }
  1001  
  1002  func (provider *oauth2TokenExchange) checkBackgroundUpdate(now time.Time) {
  1003  	if provider.needUpdate(now) && !provider.updating.Load() {
  1004  		if provider.updating.CompareAndSwap(false, true) {
  1005  			go provider.exchangeTokenInBackground()
  1006  		}
  1007  	}
  1008  }
  1009  
  1010  func (provider *oauth2TokenExchange) expired(now time.Time) bool {
  1011  	return now.Compare(provider.receivedTokenExpireTime) > 0
  1012  }
  1013  
  1014  func (provider *oauth2TokenExchange) needUpdate(now time.Time) bool {
  1015  	return now.Compare(provider.updateTokenTime) > 0
  1016  }
  1017  
  1018  func (provider *oauth2TokenExchange) fastCheck(now time.Time) string {
  1019  	provider.mutex.RLock()
  1020  	defer provider.mutex.RUnlock()
  1021  
  1022  	if !provider.expired(now) {
  1023  		provider.checkBackgroundUpdate(now)
  1024  
  1025  		return provider.receivedToken
  1026  	}
  1027  
  1028  	return ""
  1029  }
  1030  
  1031  func (provider *oauth2TokenExchange) Token(ctx context.Context) (string, error) {
  1032  	now := time.Now()
  1033  
  1034  	token := provider.fastCheck(now)
  1035  	if token != "" {
  1036  		return token, nil
  1037  	}
  1038  
  1039  	provider.mutex.Lock()
  1040  	defer provider.mutex.Unlock()
  1041  
  1042  	if !provider.expired(now) { // for the case of concurrent call
  1043  		return provider.receivedToken, nil
  1044  	}
  1045  
  1046  	if err := provider.exchangeTokenSync(ctx); err != nil {
  1047  		return "", err
  1048  	}
  1049  
  1050  	return provider.receivedToken, nil
  1051  }
  1052  
  1053  func (provider *oauth2TokenExchange) String() string {
  1054  	buffer := xstring.Buffer()
  1055  	defer buffer.Free()
  1056  	fmt.Fprintf(
  1057  		buffer,
  1058  		"OAuth2TokenExchange{Endpoint:%q,GrantType:%s,Resource:%v,Audience:%v,Scope:%v,RequestedTokenType:%s",
  1059  		provider.tokenEndpoint,
  1060  		provider.grantType,
  1061  		provider.resource,
  1062  		provider.audience,
  1063  		provider.scope,
  1064  		provider.requestedTokenType,
  1065  	)
  1066  	if provider.subjectTokenSource != nil {
  1067  		fmt.Fprintf(buffer, ",SubjectToken:%s", provider.subjectTokenSource)
  1068  	}
  1069  	if provider.actorTokenSource != nil {
  1070  		fmt.Fprintf(buffer, ",ActorToken:%s", provider.actorTokenSource)
  1071  	}
  1072  	if provider.sourceInfo != "" {
  1073  		fmt.Fprintf(buffer, ",From:%q", provider.sourceInfo)
  1074  	}
  1075  	buffer.WriteByte('}')
  1076  
  1077  	return buffer.String()
  1078  }
  1079  
  1080  type Token struct {
  1081  	Token string
  1082  
  1083  	// token type according to OAuth 2.0 token exchange protocol
  1084  	// https://www.rfc-editor.org/rfc/rfc8693#TokenTypeIdentifiers
  1085  	// for example urn:ietf:params:oauth:token-type:jwt
  1086  	TokenType string
  1087  }
  1088  
  1089  type TokenSource interface {
  1090  	Token() (Token, error)
  1091  }
  1092  
  1093  type fixedTokenSource struct {
  1094  	fixedToken Token
  1095  }
  1096  
  1097  func (s *fixedTokenSource) Token() (Token, error) {
  1098  	return s.fixedToken, nil
  1099  }
  1100  
  1101  func (s *fixedTokenSource) String() string {
  1102  	buffer := xstring.Buffer()
  1103  	defer buffer.Free()
  1104  	fmt.Fprintf(
  1105  		buffer,
  1106  		"FixedTokenSource{Token:%q,Type:%s}",
  1107  		secret.Token(s.fixedToken.Token),
  1108  		s.fixedToken.TokenType,
  1109  	)
  1110  
  1111  	return buffer.String()
  1112  }
  1113  
  1114  func NewFixedTokenSource(token, tokenType string) *fixedTokenSource {
  1115  	return &fixedTokenSource{
  1116  		fixedToken: Token{
  1117  			Token:     token,
  1118  			TokenType: tokenType,
  1119  		},
  1120  	}
  1121  }
  1122  
  1123  type JWTTokenSourceOption interface {
  1124  	ApplyJWTTokenSourceOption(s *jwtTokenSource) error
  1125  }
  1126  
  1127  // Issuer
  1128  type issuerOption string
  1129  
  1130  func (issuer issuerOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1131  	s.issuer = string(issuer)
  1132  
  1133  	return nil
  1134  }
  1135  
  1136  func WithIssuer(issuer string) issuerOption {
  1137  	return issuerOption(issuer)
  1138  }
  1139  
  1140  // Subject
  1141  type subjectOption string
  1142  
  1143  func (subject subjectOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1144  	s.subject = string(subject)
  1145  
  1146  	return nil
  1147  }
  1148  
  1149  func WithSubject(subject string) subjectOption {
  1150  	return subjectOption(subject)
  1151  }
  1152  
  1153  // Audience
  1154  func (audience audienceOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1155  	s.audience = append(s.audience, audience...)
  1156  
  1157  	return nil
  1158  }
  1159  
  1160  // ID
  1161  type idOption string
  1162  
  1163  func (id idOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1164  	s.id = string(id)
  1165  
  1166  	return nil
  1167  }
  1168  
  1169  func WithID(id string) idOption {
  1170  	return idOption(id)
  1171  }
  1172  
  1173  // TokenTTL
  1174  type tokenTTLOption time.Duration
  1175  
  1176  func (ttl tokenTTLOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1177  	s.tokenTTL = time.Duration(ttl)
  1178  
  1179  	return nil
  1180  }
  1181  
  1182  func WithTokenTTL(ttl time.Duration) tokenTTLOption {
  1183  	return tokenTTLOption(ttl)
  1184  }
  1185  
  1186  // SigningMethod
  1187  type signingMethodOption struct {
  1188  	method jwt.SigningMethod
  1189  }
  1190  
  1191  func (method *signingMethodOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1192  	s.signingMethod = method.method
  1193  
  1194  	return nil
  1195  }
  1196  
  1197  func WithSigningMethod(method jwt.SigningMethod) *signingMethodOption {
  1198  	return &signingMethodOption{method}
  1199  }
  1200  
  1201  // SigningMethodName
  1202  type signingMethodNameOption struct {
  1203  	method string
  1204  }
  1205  
  1206  func (method *signingMethodNameOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1207  	signingMethodDesc, signingMethodFound := signingMethodsRegistry[strings.ToUpper(method.method)]
  1208  	if !signingMethodFound {
  1209  		return xerrors.WithStackTrace(signingMethodNotSupportedError(method.method))
  1210  	}
  1211  
  1212  	s.signingMethod = signingMethodDesc.method
  1213  
  1214  	return nil
  1215  }
  1216  
  1217  func WithSigningMethodName(method string) *signingMethodNameOption {
  1218  	return &signingMethodNameOption{method}
  1219  }
  1220  
  1221  // KeyID
  1222  type keyIDOption string
  1223  
  1224  func (id keyIDOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1225  	s.keyID = string(id)
  1226  
  1227  	return nil
  1228  }
  1229  
  1230  func WithKeyID(id string) keyIDOption {
  1231  	return keyIDOption(id)
  1232  }
  1233  
  1234  // PrivateKey
  1235  type privateKeyOption struct {
  1236  	key interface{}
  1237  }
  1238  
  1239  func (key *privateKeyOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1240  	s.privateKey = key.key
  1241  
  1242  	return nil
  1243  }
  1244  
  1245  func WithPrivateKey(key interface{}) *privateKeyOption {
  1246  	return &privateKeyOption{key}
  1247  }
  1248  
  1249  // PrivateKey
  1250  type rsaPrivateKeyPemContentOption struct {
  1251  	keyContent []byte
  1252  }
  1253  
  1254  func (key *rsaPrivateKeyPemContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1255  	privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(key.keyContent)
  1256  	if err != nil {
  1257  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotparsePrivateKey, err))
  1258  	}
  1259  	s.privateKey = privateKey
  1260  
  1261  	return nil
  1262  }
  1263  
  1264  func WithRSAPrivateKeyPEMContent(key []byte) *rsaPrivateKeyPemContentOption {
  1265  	return &rsaPrivateKeyPemContentOption{key}
  1266  }
  1267  
  1268  // PrivateKey
  1269  type rsaPrivateKeyPemFileOption struct {
  1270  	path string
  1271  }
  1272  
  1273  func (key *rsaPrivateKeyPemFileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1274  	bytes, err := readFileContent(key.path)
  1275  	if err != nil {
  1276  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err))
  1277  	}
  1278  
  1279  	o := rsaPrivateKeyPemContentOption{bytes}
  1280  
  1281  	return o.ApplyJWTTokenSourceOption(s)
  1282  }
  1283  
  1284  func WithRSAPrivateKeyPEMFile(path string) *rsaPrivateKeyPemFileOption {
  1285  	return &rsaPrivateKeyPemFileOption{path}
  1286  }
  1287  
  1288  // PrivateKey
  1289  type ecPrivateKeyPemContentOption struct {
  1290  	keyContent []byte
  1291  }
  1292  
  1293  func (key *ecPrivateKeyPemContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1294  	privateKey, err := jwt.ParseECPrivateKeyFromPEM(key.keyContent)
  1295  	if err != nil {
  1296  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotparsePrivateKey, err))
  1297  	}
  1298  	s.privateKey = privateKey
  1299  
  1300  	return nil
  1301  }
  1302  
  1303  func WithECPrivateKeyPEMContent(key []byte) *ecPrivateKeyPemContentOption {
  1304  	return &ecPrivateKeyPemContentOption{key}
  1305  }
  1306  
  1307  // PrivateKey
  1308  type ecPrivateKeyPemFileOption struct {
  1309  	path string
  1310  }
  1311  
  1312  func (key *ecPrivateKeyPemFileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1313  	bytes, err := readFileContent(key.path)
  1314  	if err != nil {
  1315  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err))
  1316  	}
  1317  
  1318  	o := ecPrivateKeyPemContentOption{bytes}
  1319  
  1320  	return o.ApplyJWTTokenSourceOption(s)
  1321  }
  1322  
  1323  func WithECPrivateKeyPEMFile(path string) *ecPrivateKeyPemFileOption {
  1324  	return &ecPrivateKeyPemFileOption{path}
  1325  }
  1326  
  1327  // Key
  1328  type hmacSecretKeyContentOption struct {
  1329  	keyContent []byte
  1330  }
  1331  
  1332  func (key *hmacSecretKeyContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1333  	s.privateKey = key.keyContent
  1334  
  1335  	return nil
  1336  }
  1337  
  1338  func WithHMACSecretKey(key []byte) *hmacSecretKeyContentOption {
  1339  	return &hmacSecretKeyContentOption{key}
  1340  }
  1341  
  1342  // Key
  1343  type hmacSecretKeyBase64ContentOption struct {
  1344  	base64KeyContent string
  1345  }
  1346  
  1347  func (key *hmacSecretKeyBase64ContentOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1348  	keyData, err := base64.StdEncoding.DecodeString(key.base64KeyContent)
  1349  	if err != nil {
  1350  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotParseBase64Secret, err))
  1351  	}
  1352  	s.privateKey = keyData
  1353  
  1354  	return nil
  1355  }
  1356  
  1357  func WithHMACSecretKeyBase64Content(base64KeyContent string) *hmacSecretKeyBase64ContentOption {
  1358  	return &hmacSecretKeyBase64ContentOption{base64KeyContent}
  1359  }
  1360  
  1361  // Key
  1362  type hmacSecretKeyFileOption struct {
  1363  	path string
  1364  }
  1365  
  1366  func (key *hmacSecretKeyFileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1367  	bytes, err := readFileContent(key.path)
  1368  	if err != nil {
  1369  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err))
  1370  	}
  1371  
  1372  	s.privateKey = bytes
  1373  
  1374  	return nil
  1375  }
  1376  
  1377  func WithHMACSecretKeyFile(path string) *hmacSecretKeyFileOption {
  1378  	return &hmacSecretKeyFileOption{path}
  1379  }
  1380  
  1381  // Key
  1382  type hmacSecretKeyBase64FileOption struct {
  1383  	path string
  1384  }
  1385  
  1386  func (key *hmacSecretKeyBase64FileOption) ApplyJWTTokenSourceOption(s *jwtTokenSource) error {
  1387  	bytes, err := readFileContent(key.path)
  1388  	if err != nil {
  1389  		return xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotReadPrivateKeyFile, err))
  1390  	}
  1391  
  1392  	o := hmacSecretKeyBase64ContentOption{string(bytes)}
  1393  
  1394  	return o.ApplyJWTTokenSourceOption(s)
  1395  }
  1396  
  1397  func WithHMACSecretKeyBase64File(path string) *hmacSecretKeyBase64FileOption {
  1398  	return &hmacSecretKeyBase64FileOption{path}
  1399  }
  1400  
  1401  func NewJWTTokenSource(opts ...JWTTokenSourceOption) (*jwtTokenSource, error) {
  1402  	s := &jwtTokenSource{
  1403  		tokenTTL: defaultJWTTokenTTL,
  1404  	}
  1405  
  1406  	var err error
  1407  	for _, opt := range opts {
  1408  		if opt != nil {
  1409  			err = opt.ApplyJWTTokenSourceOption(s)
  1410  			if err != nil {
  1411  				return nil, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotApplyJWTOption, err))
  1412  			}
  1413  		}
  1414  	}
  1415  
  1416  	if s.signingMethod == nil {
  1417  		return nil, xerrors.WithStackTrace(errNoSigningMethodError)
  1418  	}
  1419  
  1420  	if s.privateKey == nil {
  1421  		return nil, xerrors.WithStackTrace(errNoPrivateKeyError)
  1422  	}
  1423  
  1424  	return s, nil
  1425  }
  1426  
  1427  type jwtTokenSource struct {
  1428  	signingMethod jwt.SigningMethod
  1429  	keyID         string
  1430  	privateKey    interface{} // symmetric key in case of symmetric algorithm
  1431  
  1432  	// JWT claims
  1433  	issuer   string
  1434  	subject  string
  1435  	audience []string
  1436  	id       string
  1437  	tokenTTL time.Duration
  1438  }
  1439  
  1440  func (s *jwtTokenSource) Token() (Token, error) {
  1441  	var (
  1442  		now    = time.Now()
  1443  		issued = jwt.NewNumericDate(now.UTC())
  1444  		expire = jwt.NewNumericDate(now.Add(s.tokenTTL).UTC())
  1445  		err    error
  1446  	)
  1447  	t := jwt.Token{
  1448  		Header: map[string]interface{}{
  1449  			"typ": "JWT",
  1450  			"alg": s.signingMethod.Alg(),
  1451  			"kid": s.keyID,
  1452  		},
  1453  		Claims: jwt.RegisteredClaims{
  1454  			Issuer:    s.issuer,
  1455  			Subject:   s.subject,
  1456  			IssuedAt:  issued,
  1457  			Audience:  s.audience,
  1458  			ExpiresAt: expire,
  1459  			ID:        s.id,
  1460  		},
  1461  		Method: s.signingMethod,
  1462  	}
  1463  
  1464  	var token Token
  1465  	token.Token, err = t.SignedString(s.privateKey)
  1466  	if err != nil {
  1467  		return token, xerrors.WithStackTrace(fmt.Errorf("%w: %w", errCouldNotSignJWTToken, err))
  1468  	}
  1469  	token.TokenType = "urn:ietf:params:oauth:token-type:jwt"
  1470  
  1471  	return token, nil
  1472  }
  1473  
  1474  func (s *jwtTokenSource) String() string {
  1475  	buffer := xstring.Buffer()
  1476  	defer buffer.Free()
  1477  	fmt.Fprintf(
  1478  		buffer,
  1479  		"JWTTokenSource{Method:%s,KeyID:%s,Issuer:%q,Subject:%q,Audience:%v,ID:%s,TokenTTL:%s}",
  1480  		s.signingMethod.Alg(),
  1481  		s.keyID,
  1482  		s.issuer,
  1483  		s.subject,
  1484  		s.audience,
  1485  		s.id,
  1486  		s.tokenTTL,
  1487  	)
  1488  
  1489  	return buffer.String()
  1490  }