github.com/Axway/agent-sdk@v1.1.101/pkg/authz/oauth/provider.go (about) 1 package oauth 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "net/http" 8 "strings" 9 "time" 10 11 "github.com/Axway/agent-sdk/pkg/api" 12 coreapi "github.com/Axway/agent-sdk/pkg/api" 13 corecfg "github.com/Axway/agent-sdk/pkg/config" 14 "github.com/Axway/agent-sdk/pkg/util/log" 15 ) 16 17 // ProviderType - type of provider 18 type ProviderType int 19 20 // Provider - interface for external IdP provider 21 type Provider interface { 22 GetName() string 23 GetTitle() string 24 GetIssuer() string 25 GetTokenEndpoint() string 26 GetMTLSTokenEndpoint() string 27 GetAuthorizationEndpoint() string 28 GetSupportedScopes() []string 29 GetSupportedGrantTypes() []string 30 GetSupportedTokenAuthMethods() []string 31 GetSupportedResponseMethod() []string 32 RegisterClient(clientMetadata ClientMetadata) (ClientMetadata, error) 33 UnregisterClient(clientID, accessToken string) error 34 Validate() error 35 GetConfig() corecfg.IDPConfig 36 GetMetadata() *AuthorizationServerMetadata 37 } 38 39 type provider struct { 40 logger log.FieldLogger 41 cfg corecfg.IDPConfig 42 metadataURL string 43 extraProperties map[string]string 44 requestHeaders map[string]string 45 queryParameters map[string]string 46 apiClient coreapi.Client 47 authServerMetadata *AuthorizationServerMetadata 48 authClient AuthClient 49 idpType typedIDP 50 } 51 52 type typedIDP interface { 53 getAuthorizationHeaderPrefix() string 54 preProcessClientRequest(clientRequest *clientMetadata) 55 } 56 57 type providerOptions struct { 58 authServerMetadata *AuthorizationServerMetadata 59 } 60 61 func WithAuthServerMetadata(metadata *AuthorizationServerMetadata) func(*providerOptions) { 62 return func(p *providerOptions) { 63 p.authServerMetadata = metadata 64 } 65 } 66 67 // NewProvider - create a new IdP provider 68 func NewProvider(idp corecfg.IDPConfig, tlsCfg corecfg.TLSConfig, proxyURL string, clientTimeout time.Duration, opts ...func(*providerOptions)) (Provider, error) { 69 logger := log.NewFieldLogger(). 70 WithComponent("provider"). 71 WithPackage("sdk.agent.authz.oauth") 72 73 pOpts := &providerOptions{} 74 for _, opt := range opts { 75 opt(pOpts) 76 } 77 78 apiClient := coreapi.NewClient(tlsCfg, proxyURL, coreapi.WithTimeout(clientTimeout)) 79 var idpType typedIDP 80 switch idp.GetIDPType() { 81 case TypeOkta: 82 idpType = &okta{} 83 default: // keycloak, generic 84 idpType = &genericIDP{} 85 } 86 87 p := &provider{ 88 logger: logger, 89 metadataURL: idp.GetMetadataURL(), 90 cfg: idp, 91 extraProperties: idp.GetExtraProperties(), 92 requestHeaders: idp.GetRequestHeaders(), 93 queryParameters: idp.GetQueryParams(), 94 apiClient: apiClient, 95 idpType: idpType, 96 authServerMetadata: pOpts.authServerMetadata, 97 } 98 99 if p.authServerMetadata == nil { 100 metadata, err := p.fetchMetadata() 101 if err != nil { 102 p.logger. 103 WithField("provider", p.cfg.GetIDPName()). 104 WithField("type", p.cfg.GetIDPType()). 105 WithField("metadata-url", p.metadataURL). 106 WithError(err). 107 Error("unable to fetch OAuth authorization server metadata") 108 return nil, err 109 } 110 111 p.authServerMetadata = metadata 112 } 113 114 // No OAuth client is needed to request token for access token based authentication to IdP 115 if p.cfg.GetAuthConfig() != nil && p.cfg.GetAuthConfig().GetType() != corecfg.AccessToken { 116 authClient, err := p.createAuthClient() 117 if err != nil { 118 return nil, err 119 } 120 p.authClient = authClient 121 } 122 return p, nil 123 } 124 125 func FetchMetadata(apiClient api.Client, metadataURL string) (*AuthorizationServerMetadata, error) { 126 if apiClient == nil || metadataURL == "" { 127 return nil, errors.New("unexpected arguments") 128 } 129 request := coreapi.Request{ 130 Method: coreapi.GET, 131 URL: metadataURL, 132 } 133 134 response, err := apiClient.Send(request) 135 if err != nil { 136 return nil, err 137 } 138 139 if response.Code == http.StatusOK { 140 authSrvMetadata := &AuthorizationServerMetadata{} 141 err = json.Unmarshal(response.Body, authSrvMetadata) 142 return authSrvMetadata, err 143 } 144 return nil, fmt.Errorf("error fetching metadata status code: %d, body: %s", response.Code, string(response.Body)) 145 146 } 147 148 func (p *provider) fetchMetadata() (*AuthorizationServerMetadata, error) { 149 return FetchMetadata(p.apiClient, p.metadataURL) 150 } 151 152 func (p *provider) createAuthClient() (AuthClient, error) { 153 switch p.cfg.GetAuthConfig().GetType() { 154 case corecfg.Client: 155 fallthrough 156 case corecfg.ClientSecretPost: 157 return p.createClientSecretPostAuthClient() 158 case corecfg.ClientSecretBasic: 159 return p.createClientSecretBasicAuthClient() 160 case corecfg.ClientSecretJWT: 161 return p.createClientSecretJWTAuthClient() 162 case corecfg.PrivateKeyJWT: 163 return p.createPrivateKeyJWTAuthClient() 164 case corecfg.TLSClientAuth: 165 fallthrough 166 case corecfg.SelfSignedTLSClientAuth: 167 return p.createTLSAuthClient() 168 default: 169 return nil, fmt.Errorf("%s", "unknown IdP auth type") 170 } 171 } 172 173 func (p *provider) createClientSecretPostAuthClient() (AuthClient, error) { 174 return NewAuthClient(p.GetTokenEndpoint(), p.apiClient, 175 WithServerName(p.cfg.GetIDPName()), 176 WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()), 177 WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()), 178 WithClientSecretPostAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientSecret(), p.cfg.GetAuthConfig().GetClientScope())) 179 } 180 181 func (p *provider) createClientSecretBasicAuthClient() (AuthClient, error) { 182 return NewAuthClient(p.GetTokenEndpoint(), p.apiClient, 183 WithServerName(p.cfg.GetIDPName()), 184 WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()), 185 WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()), 186 WithClientSecretBasicAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientSecret(), p.cfg.GetAuthConfig().GetClientScope())) 187 } 188 189 func (p *provider) createClientSecretJWTAuthClient() (AuthClient, error) { 190 return NewAuthClient(p.GetTokenEndpoint(), p.apiClient, 191 WithServerName(p.cfg.GetIDPName()), 192 WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()), 193 WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()), 194 WithClientSecretJwtAuth( 195 p.cfg.GetAuthConfig().GetClientID(), 196 p.cfg.GetAuthConfig().GetClientSecret(), 197 p.cfg.GetAuthConfig().GetClientScope(), 198 p.cfg.GetAuthConfig().GetClientID(), 199 p.authServerMetadata.Issuer, 200 p.cfg.GetAuthConfig().GetTokenSigningMethod(), 201 )) 202 } 203 204 func (p *provider) createPrivateKeyJWTAuthClient() (AuthClient, error) { 205 keyReader := NewKeyReader( 206 p.cfg.GetAuthConfig().GetPrivateKey(), 207 p.cfg.GetAuthConfig().GetPublicKey(), 208 p.cfg.GetAuthConfig().GetKeyPassword(), 209 ) 210 privateKey, keyErr := keyReader.GetPrivateKey() 211 if keyErr != nil { 212 return nil, keyErr 213 } 214 215 publicKey, keyErr := keyReader.GetPublicKey() 216 if keyErr != nil { 217 return nil, keyErr 218 } 219 return NewAuthClient(p.GetTokenEndpoint(), p.apiClient, 220 WithServerName(p.cfg.GetIDPName()), 221 WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()), 222 WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()), 223 WithKeyPairAuth( 224 p.cfg.GetAuthConfig().GetClientID(), 225 p.cfg.GetAuthConfig().GetClientID(), 226 p.authServerMetadata.Issuer, 227 privateKey, 228 publicKey, 229 p.cfg.GetAuthConfig().GetClientScope(), 230 p.cfg.GetAuthConfig().GetTokenSigningMethod(), 231 ), 232 ) 233 } 234 235 func (p *provider) createTLSAuthClient() (AuthClient, error) { 236 return NewAuthClient(p.GetMTLSTokenEndpoint(), p.apiClient, 237 WithServerName(p.cfg.GetIDPName()), 238 WithRequestHeaders(p.cfg.GetAuthConfig().GetRequestHeaders()), 239 WithQueryParams(p.cfg.GetAuthConfig().GetQueryParams()), 240 WithTLSClientAuth(p.cfg.GetAuthConfig().GetClientID(), p.cfg.GetAuthConfig().GetClientScope())) 241 } 242 243 // GetName - returns the name of the provider 244 func (p *provider) GetName() string { 245 return p.cfg.GetIDPName() 246 } 247 248 // GetTitle - returns the friendly name of the provider 249 func (p *provider) GetTitle() string { 250 return p.cfg.GetIDPTitle() 251 } 252 253 // GetIssuer - returns the issuer for the provider 254 func (p *provider) GetIssuer() string { 255 if p.authServerMetadata != nil { 256 return p.authServerMetadata.Issuer 257 } 258 return "" 259 } 260 261 func (p *provider) useTLSAuth() bool { 262 if p.cfg.GetAuthConfig() == nil { 263 return false 264 } 265 return p.cfg.GetAuthConfig().GetType() == corecfg.TLSClientAuth || p.cfg.GetAuthConfig().GetType() == corecfg.SelfSignedTLSClientAuth 266 } 267 268 // GetTokenEndpoint - return the token endpoint URL 269 func (p *provider) GetTokenEndpoint() string { 270 return p.authServerMetadata.TokenEndpoint 271 } 272 273 func (p *provider) GetMTLSTokenEndpoint() string { 274 if p.authServerMetadata != nil { 275 if p.authServerMetadata.MTLSEndPointAlias != nil && p.authServerMetadata.MTLSEndPointAlias.TokenEndpoint != "" { 276 return p.authServerMetadata.MTLSEndPointAlias.TokenEndpoint 277 } 278 return p.authServerMetadata.TokenEndpoint 279 } 280 return "" 281 } 282 283 // GetAuthorizationEndpoint - return authorization endpoint 284 func (p *provider) GetAuthorizationEndpoint() string { 285 if p.authServerMetadata != nil { 286 return p.authServerMetadata.AuthorizationEndpoint 287 } 288 return "" 289 } 290 291 // GetSupportedScopes - returns the global scopes supported by provider 292 func (p *provider) GetSupportedScopes() []string { 293 if p.authServerMetadata != nil { 294 return p.authServerMetadata.ScopesSupported 295 } 296 return []string{""} 297 } 298 299 // GetSupportedGrantTypes - returns the grant type supported by provider 300 func (p *provider) GetSupportedGrantTypes() []string { 301 if p.authServerMetadata != nil { 302 return p.authServerMetadata.GrantTypesSupported 303 } 304 return []string{""} 305 } 306 307 // GetSupportedTokenAuthMethods - returns the token auth method supported by provider 308 func (p *provider) GetSupportedTokenAuthMethods() []string { 309 if p.authServerMetadata != nil { 310 return p.authServerMetadata.TokenEndpointAuthMethodSupported 311 } 312 return []string{""} 313 314 } 315 316 // GetSupportedResponseMethod - returns the token response method supported by provider 317 func (p *provider) GetSupportedResponseMethod() []string { 318 if p.authServerMetadata != nil { 319 return p.authServerMetadata.ResponseTypesSupported 320 } 321 return []string{""} 322 } 323 324 func (p *provider) getClientRegistrationEndpoint() string { 325 registrationEndpoint := p.authServerMetadata.RegistrationEndpoint 326 if p.useTLSAuth() && 327 p.authServerMetadata.MTLSEndPointAlias != nil && p.authServerMetadata.MTLSEndPointAlias.RegistrationEndpoint != "" { 328 registrationEndpoint = p.authServerMetadata.MTLSEndPointAlias.RegistrationEndpoint 329 } 330 return registrationEndpoint 331 } 332 333 func (p *provider) prepareHeaders(authPrefix, token string) map[string]string { 334 headers := make(map[string]string) 335 for key, value := range p.requestHeaders { 336 headers[key] = value 337 } 338 headers[hdrAuthorization] = authPrefix + " " + token 339 headers[hdrContentType] = mimeApplicationJSON 340 return headers 341 } 342 343 // RegisterClient - register the OAuth client with IDP 344 func (p *provider) RegisterClient(clientReq ClientMetadata) (ClientMetadata, error) { 345 authPrefix := p.idpType.getAuthorizationHeaderPrefix() 346 err := p.enrichClientReq(clientReq) 347 if err != nil { 348 return nil, err 349 } 350 351 clientBuffer, err := json.Marshal(clientReq) 352 if err != nil { 353 return nil, err 354 } 355 356 token, err := p.getClientToken() 357 if err != nil { 358 return nil, err 359 } 360 361 request := coreapi.Request{ 362 Method: coreapi.POST, 363 URL: p.getClientRegistrationEndpoint(), 364 QueryParams: p.queryParameters, 365 Headers: p.prepareHeaders(authPrefix, token), 366 Body: clientBuffer, 367 } 368 369 response, err := p.apiClient.Send(request) 370 if err != nil { 371 return nil, err 372 } 373 374 if response.Code == http.StatusCreated || response.Code == http.StatusOK { 375 clientRes := &clientMetadata{} 376 err = json.Unmarshal(response.Body, clientRes) 377 if !p.cfg.GetAuthConfig().UseRegistrationAccessToken() { 378 clientRes.RegistrationAccessToken = "" 379 } 380 381 p.logger. 382 WithField("provider", p.cfg.GetIDPName()). 383 WithField("client-name", clientReq.GetClientName()). 384 WithField("client-id", clientReq.GetClientName()). 385 WithField("grant-type", clientReq.GetGrantTypes()). 386 WithField("token-auth-method", clientReq.GetTokenEndpointAuthMethod()). 387 WithField("response-type", clientReq.GetResponseTypes()). 388 WithField("redirect-url", clientReq.GetRedirectURIs()). 389 Info("registered client") 390 return clientRes, err 391 } 392 393 err = fmt.Errorf("error status code: %d, body: %s", response.Code, string(response.Body)) 394 p.logger. 395 WithField("provider", p.cfg.GetIDPName()). 396 WithField("client-name", clientReq.GetClientName()). 397 WithField("grant-type", clientReq.GetGrantTypes()). 398 WithField("token-auth-method", clientReq.GetTokenEndpointAuthMethod()). 399 WithField("response-type", clientReq.GetResponseTypes()). 400 WithField("redirect-url", clientReq.GetRedirectURIs()). 401 Error(err.Error()) 402 403 return nil, err 404 } 405 406 func (p *provider) enrichClientReq(clientReq ClientMetadata) error { 407 clientRequest, ok := clientReq.(*clientMetadata) 408 if !ok { 409 return fmt.Errorf("unrecognized client request metadata") 410 } 411 412 p.applyClientDefaults(clientRequest) 413 414 clientRequest.extraProperties = p.extraProperties 415 416 p.idpType.preProcessClientRequest(clientRequest) 417 p.preProcessResponseType(clientRequest) 418 return nil 419 } 420 421 func (p *provider) applyClientDefaults(clientRequest *clientMetadata) { 422 // Default the values from config if not set on the request 423 if len(clientRequest.GetScopes()) == 0 { 424 clientRequest.Scope = strings.Split(p.cfg.GetClientScopes(), " ") 425 } 426 427 if len(clientRequest.GetGrantTypes()) == 0 { 428 clientRequest.GrantTypes = []string{p.cfg.GetGrantType()} 429 } 430 431 if clientRequest.TokenEndpointAuthMethod == "" { 432 clientRequest.TokenEndpointAuthMethod = p.cfg.GetAuthMethod() 433 } 434 } 435 436 func (p *provider) preProcessResponseType(clientRequest *clientMetadata) { 437 for _, grantTypes := range clientRequest.GrantTypes { 438 switch grantTypes { 439 case GrantTypeAuthorizationCode: 440 if !hasResponseType(clientRequest, AuthResponseCode) { 441 addResponseType(clientRequest, AuthResponseCode) 442 } 443 case GrantTypeImplicit: 444 if !hasResponseType(clientRequest, AuthResponseToken) { 445 addResponseType(clientRequest, AuthResponseToken) 446 } 447 } 448 } 449 } 450 451 func hasResponseType(clientRequest *clientMetadata, responseType string) bool { 452 for _, clientResponseType := range clientRequest.ResponseTypes { 453 if clientResponseType == responseType { 454 return true 455 } 456 } 457 return false 458 } 459 460 func addResponseType(clientRequest *clientMetadata, responseType string) { 461 if clientRequest.ResponseTypes == nil { 462 clientRequest.ResponseTypes = make([]string, 0) 463 } 464 clientRequest.ResponseTypes = append(clientRequest.ResponseTypes, responseType) 465 } 466 467 // UnregisterClient - removes the OAuth client from IDP 468 func (p *provider) UnregisterClient(clientID, accessToken string) error { 469 authPrefix := p.idpType.getAuthorizationHeaderPrefix() 470 if accessToken == "" { 471 token, err := p.getClientToken() 472 if err != nil { 473 return err 474 } 475 accessToken = token 476 } 477 478 request := coreapi.Request{ 479 Method: coreapi.DELETE, 480 URL: p.getClientRegistrationEndpoint() + "/" + clientID, 481 QueryParams: p.queryParameters, 482 Headers: p.prepareHeaders(authPrefix, accessToken), 483 } 484 485 response, err := p.apiClient.Send(request) 486 if err != nil { 487 return err 488 } 489 490 if response.Code != http.StatusNoContent { 491 err := fmt.Errorf("error status code: %d, body: %s", response.Code, string(response.Body)) 492 p.logger. 493 WithField("provider", p.cfg.GetIDPName()). 494 WithField("client-id", clientID). 495 Error(err.Error()) 496 return err 497 } 498 499 p.logger. 500 WithField("provider", p.cfg.GetIDPName()). 501 WithField("client-id", clientID). 502 Info("unregistered client") 503 return nil 504 } 505 506 func (p *provider) getClientToken() (string, error) { 507 if p.authClient != nil { 508 useTokenCache := p.cfg.GetAuthConfig().UseTokenCache() 509 return p.authClient.FetchToken(useTokenCache) 510 } 511 return p.cfg.GetAuthConfig().GetAccessToken(), nil 512 } 513 514 func (p *provider) GetConfig() corecfg.IDPConfig { 515 return p.cfg 516 } 517 518 func (p *provider) GetMetadata() *AuthorizationServerMetadata { 519 return p.authServerMetadata 520 } 521 522 func (p *provider) Validate() error { 523 // Validate fetching token using client id/secret with oauth flow 524 // how to validate accessToken 525 // validate if the auth used has authorization? 526 _, err := p.getClientToken() 527 if err != nil { 528 return err 529 } 530 return nil 531 }