github.com/Axway/agent-sdk@v1.1.101/pkg/authz/oauth/provider.go (about)

     1  package oauth
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/Axway/agent-sdk/pkg/api"
    12  	coreapi "github.com/Axway/agent-sdk/pkg/api"
    13  	corecfg "github.com/Axway/agent-sdk/pkg/config"
    14  	"github.com/Axway/agent-sdk/pkg/util/log"
    15  )
    16  
    17  // ProviderType - type of provider
    18  type ProviderType int
    19  
    20  // Provider - interface for external IdP provider
    21  type Provider interface {
    22  	GetName() string
    23  	GetTitle() string
    24  	GetIssuer() string
    25  	GetTokenEndpoint() string
    26  	GetMTLSTokenEndpoint() string
    27  	GetAuthorizationEndpoint() string
    28  	GetSupportedScopes() []string
    29  	GetSupportedGrantTypes() []string
    30  	GetSupportedTokenAuthMethods() []string
    31  	GetSupportedResponseMethod() []string
    32  	RegisterClient(clientMetadata ClientMetadata) (ClientMetadata, error)
    33  	UnregisterClient(clientID, accessToken string) error
    34  	Validate() error
    35  	GetConfig() corecfg.IDPConfig
    36  	GetMetadata() *AuthorizationServerMetadata
    37  }
    38  
    39  type provider struct {
    40  	logger             log.FieldLogger
    41  	cfg                corecfg.IDPConfig
    42  	metadataURL        string
    43  	extraProperties    map[string]string
    44  	requestHeaders     map[string]string
    45  	queryParameters    map[string]string
    46  	apiClient          coreapi.Client
    47  	authServerMetadata *AuthorizationServerMetadata
    48  	authClient         AuthClient
    49  	idpType            typedIDP
    50  }
    51  
    52  type typedIDP interface {
    53  	getAuthorizationHeaderPrefix() string
    54  	preProcessClientRequest(clientRequest *clientMetadata)
    55  }
    56  
    57  type providerOptions struct {
    58  	authServerMetadata *AuthorizationServerMetadata
    59  }
    60  
    61  func WithAuthServerMetadata(metadata *AuthorizationServerMetadata) func(*providerOptions) {
    62  	return func(p *providerOptions) {
    63  		p.authServerMetadata = metadata
    64  	}
    65  }
    66  
    67  // NewProvider - create a new IdP provider
    68  func NewProvider(idp corecfg.IDPConfig, tlsCfg corecfg.TLSConfig, proxyURL string, clientTimeout time.Duration, opts ...func(*providerOptions)) (Provider, error) {
    69  	logger := log.NewFieldLogger().
    70  		WithComponent("provider").
    71  		WithPackage("sdk.agent.authz.oauth")
    72  
    73  	pOpts := &providerOptions{}
    74  	for _, opt := range opts {
    75  		opt(pOpts)
    76  	}
    77  
    78  	apiClient := coreapi.NewClient(tlsCfg, proxyURL, coreapi.WithTimeout(clientTimeout))
    79  	var idpType typedIDP
    80  	switch idp.GetIDPType() {
    81  	case TypeOkta:
    82  		idpType = &okta{}
    83  	default: // keycloak, generic
    84  		idpType = &genericIDP{}
    85  	}
    86  
    87  	p := &provider{
    88  		logger:             logger,
    89  		metadataURL:        idp.GetMetadataURL(),
    90  		cfg:                idp,
    91  		extraProperties:    idp.GetExtraProperties(),
    92  		requestHeaders:     idp.GetRequestHeaders(),
    93  		queryParameters:    idp.GetQueryParams(),
    94  		apiClient:          apiClient,
    95  		idpType:            idpType,
    96  		authServerMetadata: pOpts.authServerMetadata,
    97  	}
    98  
    99  	if p.authServerMetadata == nil {
   100  		metadata, err := p.fetchMetadata()
   101  		if err != nil {
   102  			p.logger.
   103  				WithField("provider", p.cfg.GetIDPName()).
   104  				WithField("type", p.cfg.GetIDPType()).
   105  				WithField("metadata-url", p.metadataURL).
   106  				WithError(err).
   107  				Error("unable to fetch OAuth authorization server metadata")
   108  			return nil, err
   109  		}
   110  
   111  		p.authServerMetadata = metadata
   112  	}
   113  
   114  	// No OAuth client is needed to request token for access token based authentication to IdP
   115  	if p.cfg.GetAuthConfig() != nil && p.cfg.GetAuthConfig().GetType() != corecfg.AccessToken {
   116  		authClient, err := p.createAuthClient()
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  		p.authClient = authClient
   121  	}
   122  	return p, nil
   123  }
   124  
   125  func FetchMetadata(apiClient api.Client, metadataURL string) (*AuthorizationServerMetadata, error) {
   126  	if apiClient == nil || metadataURL == "" {
   127  		return nil, errors.New("unexpected arguments")
   128  	}
   129  	request := coreapi.Request{
   130  		Method: coreapi.GET,
   131  		URL:    metadataURL,
   132  	}
   133  
   134  	response, err := apiClient.Send(request)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  
   139  	if response.Code == http.StatusOK {
   140  		authSrvMetadata := &AuthorizationServerMetadata{}
   141  		err = json.Unmarshal(response.Body, authSrvMetadata)
   142  		return authSrvMetadata, err
   143  	}
   144  	return nil, fmt.Errorf("error fetching metadata status code: %d, body: %s", response.Code, string(response.Body))
   145  
   146  }
   147  
   148  func (p *provider) fetchMetadata() (*AuthorizationServerMetadata, error) {
   149  	return FetchMetadata(p.apiClient, p.metadataURL)
   150  }
   151  
   152  func (p *provider) createAuthClient() (AuthClient, error) {
   153  	switch p.cfg.GetAuthConfig().GetType() {
   154  	case corecfg.Client:
   155  		fallthrough
   156  	case corecfg.ClientSecretPost:
   157  		return p.createClientSecretPostAuthClient()
   158  	case corecfg.ClientSecretBasic:
   159  		return p.createClientSecretBasicAuthClient()
   160  	case corecfg.ClientSecretJWT:
   161  		return p.createClientSecretJWTAuthClient()
   162  	case corecfg.PrivateKeyJWT:
   163  		return p.createPrivateKeyJWTAuthClient()
   164  	case corecfg.TLSClientAuth:
   165  		fallthrough
   166  	case corecfg.SelfSignedTLSClientAuth:
   167  		return p.createTLSAuthClient()
   168  	default:
   169  		return nil, fmt.Errorf("%s", "unknown IdP auth type")
   170  	}
   171  }
   172  
   173  func (p *provider) createClientSecretPostAuthClient() (AuthClient, error) {
   174  	return NewAuthClient(p.GetTokenEndpoint(), p.apiClient,
   175  		WithServerName(p.cfg.GetIDPName()),
   176  		WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()),
   177  		WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()),
   178  		WithClientSecretPostAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientSecret(), p.cfg.GetAuthConfig().GetClientScope()))
   179  }
   180  
   181  func (p *provider) createClientSecretBasicAuthClient() (AuthClient, error) {
   182  	return NewAuthClient(p.GetTokenEndpoint(), p.apiClient,
   183  		WithServerName(p.cfg.GetIDPName()),
   184  		WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()),
   185  		WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()),
   186  		WithClientSecretBasicAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientSecret(), p.cfg.GetAuthConfig().GetClientScope()))
   187  }
   188  
   189  func (p *provider) createClientSecretJWTAuthClient() (AuthClient, error) {
   190  	return NewAuthClient(p.GetTokenEndpoint(), p.apiClient,
   191  		WithServerName(p.cfg.GetIDPName()),
   192  		WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()),
   193  		WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()),
   194  		WithClientSecretJwtAuth(
   195  			p.cfg.GetAuthConfig().GetClientID(),
   196  			p.cfg.GetAuthConfig().GetClientSecret(),
   197  			p.cfg.GetAuthConfig().GetClientScope(),
   198  			p.cfg.GetAuthConfig().GetClientID(),
   199  			p.authServerMetadata.Issuer,
   200  			p.cfg.GetAuthConfig().GetTokenSigningMethod(),
   201  		))
   202  }
   203  
   204  func (p *provider) createPrivateKeyJWTAuthClient() (AuthClient, error) {
   205  	keyReader := NewKeyReader(
   206  		p.cfg.GetAuthConfig().GetPrivateKey(),
   207  		p.cfg.GetAuthConfig().GetPublicKey(),
   208  		p.cfg.GetAuthConfig().GetKeyPassword(),
   209  	)
   210  	privateKey, keyErr := keyReader.GetPrivateKey()
   211  	if keyErr != nil {
   212  		return nil, keyErr
   213  	}
   214  
   215  	publicKey, keyErr := keyReader.GetPublicKey()
   216  	if keyErr != nil {
   217  		return nil, keyErr
   218  	}
   219  	return NewAuthClient(p.GetTokenEndpoint(), p.apiClient,
   220  		WithServerName(p.cfg.GetIDPName()),
   221  		WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()),
   222  		WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()),
   223  		WithKeyPairAuth(
   224  			p.cfg.GetAuthConfig().GetClientID(),
   225  			p.cfg.GetAuthConfig().GetClientID(),
   226  			p.authServerMetadata.Issuer,
   227  			privateKey,
   228  			publicKey,
   229  			p.cfg.GetAuthConfig().GetClientScope(),
   230  			p.cfg.GetAuthConfig().GetTokenSigningMethod(),
   231  		),
   232  	)
   233  }
   234  
   235  func (p *provider) createTLSAuthClient() (AuthClient, error) {
   236  	return NewAuthClient(p.GetMTLSTokenEndpoint(), p.apiClient,
   237  		WithServerName(p.cfg.GetIDPName()),
   238  		WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()),
   239  		WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()),
   240  		WithTLSClientAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientScope()))
   241  }
   242  
   243  // GetName - returns the name of the provider
   244  func (p *provider) GetName() string {
   245  	return p.cfg.GetIDPName()
   246  }
   247  
   248  // GetTitle - returns the friendly name of the provider
   249  func (p *provider) GetTitle() string {
   250  	return p.cfg.GetIDPTitle()
   251  }
   252  
   253  // GetIssuer - returns the issuer for the provider
   254  func (p *provider) GetIssuer() string {
   255  	if p.authServerMetadata != nil {
   256  		return p.authServerMetadata.Issuer
   257  	}
   258  	return ""
   259  }
   260  
   261  func (p *provider) useTLSAuth() bool {
   262  	if p.cfg.GetAuthConfig() == nil {
   263  		return false
   264  	}
   265  	return p.cfg.GetAuthConfig().GetType() == corecfg.TLSClientAuth || p.cfg.GetAuthConfig().GetType() == corecfg.SelfSignedTLSClientAuth
   266  }
   267  
   268  // GetTokenEndpoint - return the token endpoint URL
   269  func (p *provider) GetTokenEndpoint() string {
   270  	return p.authServerMetadata.TokenEndpoint
   271  }
   272  
   273  func (p *provider) GetMTLSTokenEndpoint() string {
   274  	if p.authServerMetadata != nil {
   275  		if p.authServerMetadata.MTLSEndPointAlias != nil && p.authServerMetadata.MTLSEndPointAlias.TokenEndpoint != "" {
   276  			return p.authServerMetadata.MTLSEndPointAlias.TokenEndpoint
   277  		}
   278  		return p.authServerMetadata.TokenEndpoint
   279  	}
   280  	return ""
   281  }
   282  
   283  // GetAuthorizationEndpoint - return authorization endpoint
   284  func (p *provider) GetAuthorizationEndpoint() string {
   285  	if p.authServerMetadata != nil {
   286  		return p.authServerMetadata.AuthorizationEndpoint
   287  	}
   288  	return ""
   289  }
   290  
   291  // GetSupportedScopes - returns the global scopes supported by provider
   292  func (p *provider) GetSupportedScopes() []string {
   293  	if p.authServerMetadata != nil {
   294  		return p.authServerMetadata.ScopesSupported
   295  	}
   296  	return []string{""}
   297  }
   298  
   299  // GetSupportedGrantTypes - returns the grant type supported by provider
   300  func (p *provider) GetSupportedGrantTypes() []string {
   301  	if p.authServerMetadata != nil {
   302  		return p.authServerMetadata.GrantTypesSupported
   303  	}
   304  	return []string{""}
   305  }
   306  
   307  // GetSupportedTokenAuthMethods - returns the token auth method supported by provider
   308  func (p *provider) GetSupportedTokenAuthMethods() []string {
   309  	if p.authServerMetadata != nil {
   310  		return p.authServerMetadata.TokenEndpointAuthMethodSupported
   311  	}
   312  	return []string{""}
   313  
   314  }
   315  
   316  // GetSupportedResponseMethod - returns the token response method supported by provider
   317  func (p *provider) GetSupportedResponseMethod() []string {
   318  	if p.authServerMetadata != nil {
   319  		return p.authServerMetadata.ResponseTypesSupported
   320  	}
   321  	return []string{""}
   322  }
   323  
   324  func (p *provider) getClientRegistrationEndpoint() string {
   325  	registrationEndpoint := p.authServerMetadata.RegistrationEndpoint
   326  	if p.useTLSAuth() &&
   327  		p.authServerMetadata.MTLSEndPointAlias != nil && p.authServerMetadata.MTLSEndPointAlias.RegistrationEndpoint != "" {
   328  		registrationEndpoint = p.authServerMetadata.MTLSEndPointAlias.RegistrationEndpoint
   329  	}
   330  	return registrationEndpoint
   331  }
   332  
   333  func (p *provider) prepareHeaders(authPrefix, token string) map[string]string {
   334  	headers := make(map[string]string)
   335  	for key, value := range p.requestHeaders {
   336  		headers[key] = value
   337  	}
   338  	headers[hdrAuthorization] = authPrefix + " " + token
   339  	headers[hdrContentType] = mimeApplicationJSON
   340  	return headers
   341  }
   342  
   343  // RegisterClient - register the OAuth client with IDP
   344  func (p *provider) RegisterClient(clientReq ClientMetadata) (ClientMetadata, error) {
   345  	authPrefix := p.idpType.getAuthorizationHeaderPrefix()
   346  	err := p.enrichClientReq(clientReq)
   347  	if err != nil {
   348  		return nil, err
   349  	}
   350  
   351  	clientBuffer, err := json.Marshal(clientReq)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	token, err := p.getClientToken()
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  
   361  	request := coreapi.Request{
   362  		Method:      coreapi.POST,
   363  		URL:         p.getClientRegistrationEndpoint(),
   364  		QueryParams: p.queryParameters,
   365  		Headers:     p.prepareHeaders(authPrefix, token),
   366  		Body:        clientBuffer,
   367  	}
   368  
   369  	response, err := p.apiClient.Send(request)
   370  	if err != nil {
   371  		return nil, err
   372  	}
   373  
   374  	if response.Code == http.StatusCreated || response.Code == http.StatusOK {
   375  		clientRes := &clientMetadata{}
   376  		err = json.Unmarshal(response.Body, clientRes)
   377  		if !p.cfg.GetAuthConfig().UseRegistrationAccessToken() {
   378  			clientRes.RegistrationAccessToken = ""
   379  		}
   380  
   381  		p.logger.
   382  			WithField("provider", p.cfg.GetIDPName()).
   383  			WithField("client-name", clientReq.GetClientName()).
   384  			WithField("client-id", clientReq.GetClientName()).
   385  			WithField("grant-type", clientReq.GetGrantTypes()).
   386  			WithField("token-auth-method", clientReq.GetTokenEndpointAuthMethod()).
   387  			WithField("response-type", clientReq.GetResponseTypes()).
   388  			WithField("redirect-url", clientReq.GetRedirectURIs()).
   389  			Info("registered client")
   390  		return clientRes, err
   391  	}
   392  
   393  	err = fmt.Errorf("error status code: %d, body: %s", response.Code, string(response.Body))
   394  	p.logger.
   395  		WithField("provider", p.cfg.GetIDPName()).
   396  		WithField("client-name", clientReq.GetClientName()).
   397  		WithField("grant-type", clientReq.GetGrantTypes()).
   398  		WithField("token-auth-method", clientReq.GetTokenEndpointAuthMethod()).
   399  		WithField("response-type", clientReq.GetResponseTypes()).
   400  		WithField("redirect-url", clientReq.GetRedirectURIs()).
   401  		Error(err.Error())
   402  
   403  	return nil, err
   404  }
   405  
   406  func (p *provider) enrichClientReq(clientReq ClientMetadata) error {
   407  	clientRequest, ok := clientReq.(*clientMetadata)
   408  	if !ok {
   409  		return fmt.Errorf("unrecognized client request metadata")
   410  	}
   411  
   412  	p.applyClientDefaults(clientRequest)
   413  
   414  	clientRequest.extraProperties = p.extraProperties
   415  
   416  	p.idpType.preProcessClientRequest(clientRequest)
   417  	p.preProcessResponseType(clientRequest)
   418  	return nil
   419  }
   420  
   421  func (p *provider) applyClientDefaults(clientRequest *clientMetadata) {
   422  	// Default the values from config if not set on the request
   423  	if len(clientRequest.GetScopes()) == 0 {
   424  		clientRequest.Scope = strings.Split(p.cfg.GetClientScopes(), " ")
   425  	}
   426  
   427  	if len(clientRequest.GetGrantTypes()) == 0 {
   428  		clientRequest.GrantTypes = []string{p.cfg.GetGrantType()}
   429  	}
   430  
   431  	if clientRequest.TokenEndpointAuthMethod == "" {
   432  		clientRequest.TokenEndpointAuthMethod = p.cfg.GetAuthMethod()
   433  	}
   434  }
   435  
   436  func (p *provider) preProcessResponseType(clientRequest *clientMetadata) {
   437  	for _, grantTypes := range clientRequest.GrantTypes {
   438  		switch grantTypes {
   439  		case GrantTypeAuthorizationCode:
   440  			if !hasResponseType(clientRequest, AuthResponseCode) {
   441  				addResponseType(clientRequest, AuthResponseCode)
   442  			}
   443  		case GrantTypeImplicit:
   444  			if !hasResponseType(clientRequest, AuthResponseToken) {
   445  				addResponseType(clientRequest, AuthResponseToken)
   446  			}
   447  		}
   448  	}
   449  }
   450  
   451  func hasResponseType(clientRequest *clientMetadata, responseType string) bool {
   452  	for _, clientResponseType := range clientRequest.ResponseTypes {
   453  		if clientResponseType == responseType {
   454  			return true
   455  		}
   456  	}
   457  	return false
   458  }
   459  
   460  func addResponseType(clientRequest *clientMetadata, responseType string) {
   461  	if clientRequest.ResponseTypes == nil {
   462  		clientRequest.ResponseTypes = make([]string, 0)
   463  	}
   464  	clientRequest.ResponseTypes = append(clientRequest.ResponseTypes, responseType)
   465  }
   466  
   467  // UnregisterClient - removes the OAuth client from IDP
   468  func (p *provider) UnregisterClient(clientID, accessToken string) error {
   469  	authPrefix := p.idpType.getAuthorizationHeaderPrefix()
   470  	if accessToken == "" {
   471  		token, err := p.getClientToken()
   472  		if err != nil {
   473  			return err
   474  		}
   475  		accessToken = token
   476  	}
   477  
   478  	request := coreapi.Request{
   479  		Method:      coreapi.DELETE,
   480  		URL:         p.getClientRegistrationEndpoint() + "/" + clientID,
   481  		QueryParams: p.queryParameters,
   482  		Headers:     p.prepareHeaders(authPrefix, accessToken),
   483  	}
   484  
   485  	response, err := p.apiClient.Send(request)
   486  	if err != nil {
   487  		return err
   488  	}
   489  
   490  	if response.Code != http.StatusNoContent {
   491  		err := fmt.Errorf("error status code: %d, body: %s", response.Code, string(response.Body))
   492  		p.logger.
   493  			WithField("provider", p.cfg.GetIDPName()).
   494  			WithField("client-id", clientID).
   495  			Error(err.Error())
   496  		return err
   497  	}
   498  
   499  	p.logger.
   500  		WithField("provider", p.cfg.GetIDPName()).
   501  		WithField("client-id", clientID).
   502  		Info("unregistered client")
   503  	return nil
   504  }
   505  
   506  func (p *provider) getClientToken() (string, error) {
   507  	if p.authClient != nil {
   508  		useTokenCache := p.cfg.GetAuthConfig().UseTokenCache()
   509  		return p.authClient.FetchToken(useTokenCache)
   510  	}
   511  	return p.cfg.GetAuthConfig().GetAccessToken(), nil
   512  }
   513  
   514  func (p *provider) GetConfig() corecfg.IDPConfig {
   515  	return p.cfg
   516  }
   517  
   518  func (p *provider) GetMetadata() *AuthorizationServerMetadata {
   519  	return p.authServerMetadata
   520  }
   521  
   522  func (p *provider) Validate() error {
   523  	// Validate fetching token using client id/secret with oauth flow
   524  	// how to validate accessToken
   525  	// validate if the auth used has authorization?
   526  	_, err := p.getClientToken()
   527  	if err != nil {
   528  		return err
   529  	}
   530  	return nil
   531  }