github.com/greenpau/go-authcrunch@v1.1.4/pkg/idp/oauth/provider.go (about)

     1  // Copyright 2022 Paul Greenberg greenpau@outlook.com
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth
    16  
    17  import (
    18  	"encoding/json"
    19  	"github.com/greenpau/go-authcrunch/pkg/authn/enums/operator"
    20  	"github.com/greenpau/go-authcrunch/pkg/authn/icons"
    21  	"github.com/greenpau/go-authcrunch/pkg/errors"
    22  	"github.com/greenpau/go-authcrunch/pkg/requests"
    23  	"go.uber.org/zap"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"regexp"
    27  	"strings"
    28  	"time"
    29  )
    30  
    31  const (
    32  	providerKind = "oauth"
    33  )
    34  
    35  // IdentityProvider represents OAuth-based identity provider.
    36  type IdentityProvider struct {
    37  	config           *Config `json:"config,omitempty" xml:"config,omitempty" yaml:"config,omitempty"`
    38  	metadata         map[string]interface{}
    39  	keys             map[string]*JwksKey
    40  	authorizationURL string
    41  	tokenURL         string
    42  	keysURL          string
    43  	logoutURL        string
    44  	// The UserInfo API endpoint URL. Please
    45  	// see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
    46  	// for details.
    47  	userInfoURL string
    48  	// The regex filters for user groups extracted via the UserInfo API. If
    49  	// a group matches the filter, the group will be include into user
    50  	// roles issued by the portal.
    51  	userGroupFilters []*regexp.Regexp
    52  	// The regex filters for user orgs extracted from an identity provider.
    53  	userOrgFilters []*regexp.Regexp
    54  	// The name of the server hosting OAuth 2.0 IDP. For example, with public
    55  	// Gitlab the server would be gitlab.com. However, if it is a hosted
    56  	// instance, the name could be gitlab.mydomain.com. It is derived from
    57  	// base url config entry.
    58  	serverName             string
    59  	lastKeyFetch           time.Time
    60  	keyFetchAttempts       int
    61  	disableKeyVerification bool
    62  	disablePassGrantType   bool
    63  	disableResponseType    bool
    64  	disableNonce           bool
    65  	disableScope           bool
    66  	enableAcceptHeader     bool
    67  	enableBodyDecoder      bool
    68  	requiredTokenFields    map[string]interface{}
    69  	scopeMap               map[string]interface{}
    70  	userInfoFields         map[string]interface{}
    71  	userInfoRolesFieldName string
    72  	// state stores cached state IDs
    73  	state         *stateManager
    74  	logger        *zap.Logger
    75  	browserConfig *browserConfig
    76  	configured    bool
    77  	// Disabled the check for the presence of email field in a token.
    78  	disableEmailClaimCheck bool
    79  }
    80  
    81  // NewIdentityProvider returns an instance of IdentityProvider.
    82  func NewIdentityProvider(cfg *Config, logger *zap.Logger) (*IdentityProvider, error) {
    83  	if logger == nil {
    84  		return nil, errors.ErrIdentityProviderConfigureLoggerNotFound
    85  	}
    86  
    87  	b := &IdentityProvider{
    88  		config: cfg,
    89  		state:  newStateManager(),
    90  		keys:   make(map[string]*JwksKey),
    91  		logger: logger,
    92  	}
    93  
    94  	if err := b.config.Validate(); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	go manageStateManager(b.state)
    99  
   100  	return b, nil
   101  }
   102  
   103  // GetRealm return authentication realm.
   104  func (b *IdentityProvider) GetRealm() string {
   105  	return b.config.Realm
   106  }
   107  
   108  // GetName return the name associated with this identity provider.
   109  func (b *IdentityProvider) GetName() string {
   110  	return b.config.Name
   111  }
   112  
   113  // GetKind returns the authentication method associated with this identity provider.
   114  func (b *IdentityProvider) GetKind() string {
   115  	return providerKind
   116  }
   117  
   118  // Configured returns true if the identity provider was configured.
   119  func (b *IdentityProvider) Configured() bool {
   120  	return b.configured
   121  }
   122  
   123  // GetConfig returns IdentityProvider configuration.
   124  func (b *IdentityProvider) GetConfig() map[string]interface{} {
   125  	var m map[string]interface{}
   126  	j, _ := json.Marshal(b.config)
   127  	json.Unmarshal(j, &m)
   128  	return m
   129  }
   130  
   131  // ScopeExists returns true if any of the provided scopes exist.
   132  func (b *IdentityProvider) ScopeExists(scopes ...string) bool {
   133  	for _, scope := range scopes {
   134  		if _, exists := b.scopeMap[scope]; exists {
   135  			return true
   136  		}
   137  	}
   138  	return false
   139  }
   140  
   141  // Request performs the requested identity provider operation.
   142  func (b *IdentityProvider) Request(op operator.Type, r *requests.Request) error {
   143  	switch op {
   144  	case operator.Authenticate:
   145  		return b.Authenticate(r)
   146  	}
   147  	return errors.ErrOperatorNotSupported.WithArgs(op)
   148  }
   149  
   150  // Configure configures IdentityProvider.
   151  func (b *IdentityProvider) Configure() error {
   152  	if b.config.EmailClaimCheckDisabled {
   153  		b.disableEmailClaimCheck = true
   154  	}
   155  	if b.config.KeyVerificationDisabled {
   156  		b.disableKeyVerification = true
   157  	}
   158  	if b.config.PassGrantTypeDisabled {
   159  		b.disablePassGrantType = true
   160  	}
   161  	if b.config.ResponseTypeDisabled {
   162  		b.disableResponseType = true
   163  	}
   164  	if b.config.NonceDisabled {
   165  		b.disableNonce = true
   166  	}
   167  	if b.config.ScopeDisabled {
   168  		b.disableScope = true
   169  	}
   170  
   171  	if b.config.AcceptHeaderEnabled {
   172  		b.enableAcceptHeader = true
   173  	}
   174  
   175  	if b.config.AuthorizationURL != "" {
   176  		b.authorizationURL = b.config.AuthorizationURL
   177  	}
   178  	if b.config.TokenURL != "" {
   179  		b.tokenURL = b.config.TokenURL
   180  	}
   181  
   182  	if b.config.TLSInsecureSkipVerify {
   183  		b.browserConfig = &browserConfig{
   184  			TLSInsecureSkipVerify: true,
   185  		}
   186  	}
   187  
   188  	b.scopeMap = make(map[string]interface{})
   189  	for _, scope := range b.config.Scopes {
   190  		b.scopeMap[scope] = true
   191  	}
   192  
   193  	switch b.config.Driver {
   194  	case "generic":
   195  	case "okta":
   196  	case "google":
   197  	case "gitlab":
   198  	case "azure":
   199  	case "github":
   200  		b.disableKeyVerification = true
   201  		b.disablePassGrantType = true
   202  		b.disableResponseType = true
   203  		b.disableNonce = true
   204  		b.enableAcceptHeader = true
   205  	case "facebook":
   206  		b.disableKeyVerification = true
   207  		b.disablePassGrantType = true
   208  		b.disableResponseType = true
   209  		b.disableNonce = true
   210  		b.enableAcceptHeader = true
   211  	case "discord":
   212  		b.disableKeyVerification = true
   213  		b.disableNonce = true
   214  		b.enableAcceptHeader = true
   215  	case "linkedin":
   216  		b.disableNonce = true
   217  	case "nextcloud":
   218  		b.disableKeyVerification = true
   219  	}
   220  
   221  	b.serverName = b.config.ServerName
   222  
   223  	b.requiredTokenFields = make(map[string]interface{})
   224  	for _, fieldName := range b.config.RequiredTokenFields {
   225  		b.requiredTokenFields[fieldName] = true
   226  	}
   227  
   228  	b.userInfoFields = make(map[string]interface{})
   229  	for _, fieldName := range b.config.UserInfoFields {
   230  		b.userInfoFields[fieldName] = true
   231  	}
   232  
   233  	if b.config.UserInfoRolesFieldName != "" {
   234  		b.userInfoRolesFieldName = b.config.UserInfoRolesFieldName
   235  	} else {
   236  		b.userInfoRolesFieldName = "roles"
   237  	}
   238  
   239  	// Configure user group filters, if any.
   240  	for _, pattern := range b.config.UserGroupFilters {
   241  		b.userGroupFilters = append(b.userGroupFilters, regexp.MustCompile(pattern))
   242  	}
   243  
   244  	// Configure user org filters, if any.
   245  	for _, pattern := range b.config.UserOrgFilters {
   246  		b.userOrgFilters = append(b.userOrgFilters, regexp.MustCompile(pattern))
   247  	}
   248  
   249  	if b.config.DelayStart > 0 {
   250  		go b.fetchConfig()
   251  	} else {
   252  		if err := b.fetchConfig(); err != nil {
   253  			return err
   254  		}
   255  	}
   256  
   257  	b.logger.Info(
   258  		"successfully configured OAuth 2.0 identity provider",
   259  		zap.String("provider", b.config.Driver),
   260  		zap.String("client_id", b.config.ClientID),
   261  		zap.String("server_id", b.config.ServerID),
   262  		zap.String("domain_name", b.config.DomainName),
   263  		zap.Any("metadata", b.metadata),
   264  		zap.Any("jwks_keys", b.keys),
   265  		zap.Strings("required_token_fields", b.config.RequiredTokenFields),
   266  		zap.Int("delayed_by", b.config.DelayStart),
   267  		zap.Int("retry_attempts", b.config.RetryAttempts),
   268  		zap.Int("retry_interval", b.config.RetryInterval),
   269  		zap.Strings("scopes", b.config.Scopes),
   270  		zap.Any("login_icon", b.config.LoginIcon),
   271  	)
   272  
   273  	b.configured = true
   274  	return nil
   275  }
   276  
   277  func (b *IdentityProvider) fetchConfig() error {
   278  	if b.config.DelayStart > 0 {
   279  		b.logger.Debug(
   280  			"Delaying identity provider configuration",
   281  			zap.String("identity_provider_name", b.config.Name),
   282  			zap.Int("delayed_by", b.config.DelayStart),
   283  		)
   284  		time.Sleep(time.Duration(b.config.DelayStart) * time.Second)
   285  	}
   286  
   287  	if b.authorizationURL == "" {
   288  		if b.config.RetryAttempts > 0 {
   289  			for i := 0; i < b.config.RetryAttempts; i++ {
   290  				err := b.fetchMetadataURL()
   291  				if err == nil {
   292  					break
   293  				}
   294  				if i >= (b.config.RetryAttempts - 1) {
   295  					return errors.ErrIdentityProviderOauthMetadataFetchFailed.WithArgs(err)
   296  				}
   297  				b.logger.Debug(
   298  					"fetchMetadataURL failed",
   299  					zap.String("identity_provider_name", b.config.Name),
   300  					zap.Int("attempt_id", i),
   301  					zap.Error(errors.ErrIdentityProviderOauthMetadataFetchFailed.WithArgs(err)),
   302  				)
   303  				time.Sleep(time.Duration(b.config.RetryInterval) * time.Second)
   304  			}
   305  		} else {
   306  			if err := b.fetchMetadataURL(); err != nil {
   307  				b.logger.Debug(
   308  					"fetchMetadataURL failed",
   309  					zap.String("identity_provider_name", b.config.Name),
   310  					zap.Error(errors.ErrIdentityProviderOauthMetadataFetchFailed.WithArgs(err)),
   311  				)
   312  				return errors.ErrIdentityProviderOauthMetadataFetchFailed.WithArgs(err)
   313  			}
   314  		}
   315  		b.logger.Debug(
   316  			"fetchMetadataURL succeeded",
   317  			zap.String("identity_provider_name", b.config.Name),
   318  			zap.Any("metadata", b.metadata),
   319  			zap.Any("userinfo_endpoint", b.userInfoURL),
   320  		)
   321  	}
   322  
   323  	if !b.disableKeyVerification {
   324  		if b.config.RetryAttempts > 0 {
   325  			for i := 0; i < b.config.RetryAttempts; i++ {
   326  				err := b.fetchKeysURL()
   327  				if err == nil {
   328  					break
   329  				}
   330  				if i >= (b.config.RetryAttempts - 1) {
   331  					return errors.ErrIdentityProviderOauthKeyFetchFailed.WithArgs(err)
   332  				}
   333  				b.logger.Debug(
   334  					"fetchKeysURL failed",
   335  					zap.String("identity_provider_name", b.config.Name),
   336  					zap.Int("attempt_id", i),
   337  					zap.Error(errors.ErrIdentityProviderOauthKeyFetchFailed.WithArgs(err)),
   338  				)
   339  				time.Sleep(time.Duration(b.config.RetryInterval) * time.Second)
   340  			}
   341  		} else {
   342  			if err := b.fetchKeysURL(); err != nil {
   343  				return errors.ErrIdentityProviderOauthKeyFetchFailed.WithArgs(err)
   344  			}
   345  		}
   346  	}
   347  	return nil
   348  }
   349  
   350  func (b *IdentityProvider) fetchMetadataURL() error {
   351  	cli, err := b.newBrowser()
   352  	if err != nil {
   353  		return err
   354  	}
   355  	req, err := http.NewRequest("GET", b.config.MetadataURL, nil)
   356  	resp, err := cli.Do(req)
   357  	if err != nil {
   358  		return err
   359  	}
   360  	respBody, err := ioutil.ReadAll(resp.Body)
   361  	resp.Body.Close()
   362  	if err != nil {
   363  		return err
   364  	}
   365  	if err := json.Unmarshal(respBody, &b.metadata); err != nil {
   366  		return err
   367  	}
   368  	for _, k := range []string{"authorization_endpoint", "token_endpoint", "jwks_uri"} {
   369  		if _, exists := b.metadata[k]; !exists {
   370  			return errors.ErrIdentityProviderOauthMetadataFieldNotFound.WithArgs(k, b.config.Driver)
   371  		}
   372  	}
   373  	b.authorizationURL = b.metadata["authorization_endpoint"].(string)
   374  	b.tokenURL = b.metadata["token_endpoint"].(string)
   375  	b.keysURL = b.metadata["jwks_uri"].(string)
   376  	if _, exists := b.metadata["userinfo_endpoint"]; exists {
   377  		b.userInfoURL = b.metadata["userinfo_endpoint"].(string)
   378  	}
   379  	if _, exists := b.metadata["end_session_endpoint"]; exists {
   380  		b.logoutURL = b.metadata["end_session_endpoint"].(string)
   381  	}
   382  
   383  	switch b.config.Driver {
   384  	case "cognito":
   385  		b.logoutURL = strings.ReplaceAll(b.authorizationURL, "oauth2/authorize", "logout")
   386  	}
   387  	return nil
   388  }
   389  
   390  func (b *IdentityProvider) countFetchKeysAttempt() {
   391  	b.lastKeyFetch = time.Now().UTC()
   392  	b.keyFetchAttempts++
   393  	return
   394  }
   395  
   396  func (b *IdentityProvider) fetchKeysURL() error {
   397  	if b.keyFetchAttempts > 3 {
   398  		timeDiff := time.Now().UTC().Sub(b.lastKeyFetch).Minutes()
   399  		if timeDiff < 5 {
   400  			return errors.ErrIdentityProviderOauthJwksKeysTooManyAttempts
   401  		}
   402  		b.lastKeyFetch = time.Now().UTC()
   403  		b.keyFetchAttempts = 0
   404  	}
   405  	b.countFetchKeysAttempt()
   406  
   407  	//  Create new http client instance.
   408  	cli, err := b.newBrowser()
   409  	if err != nil {
   410  		return err
   411  	}
   412  	req, err := http.NewRequest("GET", b.keysURL, nil)
   413  	if err != nil {
   414  		return err
   415  	}
   416  
   417  	// Fetch data from the URL.
   418  	resp, err := cli.Do(req)
   419  	if err != nil {
   420  		return err
   421  	}
   422  
   423  	respBody, err := ioutil.ReadAll(resp.Body)
   424  	resp.Body.Close()
   425  	if err != nil {
   426  		return err
   427  	}
   428  	data := make(map[string]interface{})
   429  
   430  	if err := json.Unmarshal(respBody, &data); err != nil {
   431  		return err
   432  	}
   433  
   434  	if _, exists := data["keys"]; !exists {
   435  		return errors.ErrIdentityProviderOauthJwksResponseKeysNotFound
   436  	}
   437  
   438  	jwksJSON, err := json.Marshal(data["keys"])
   439  	if err != nil {
   440  		return errors.ErrIdentityProviderOauthJwksKeysParseFailed.WithArgs(err)
   441  	}
   442  
   443  	keys := []*JwksKey{}
   444  	if err := json.Unmarshal(jwksJSON, &keys); err != nil {
   445  		return err
   446  	}
   447  
   448  	if len(keys) < 1 {
   449  		return errors.ErrIdentityProviderOauthJwksKeysNotFound
   450  	}
   451  
   452  	for _, k := range keys {
   453  		if err := k.Validate(); err != nil {
   454  			return errors.ErrIdentityProviderOauthJwksInvalidKey.WithArgs(err)
   455  		}
   456  		b.keys[k.KeyID] = k
   457  	}
   458  
   459  	return nil
   460  }
   461  
   462  // GetLoginIcon returns the instance of the icon associated with the provider.
   463  func (b *IdentityProvider) GetLoginIcon() *icons.LoginIcon {
   464  	return b.config.LoginIcon
   465  }
   466  
   467  // GetLogoutURL returns the logout URL associated with the provider.
   468  func (b *IdentityProvider) GetLogoutURL() string {
   469  	switch b.config.Driver {
   470  	case "cognito":
   471  		return b.logoutURL + "?client_id=" + b.config.ClientID
   472  	}
   473  	return b.logoutURL
   474  }
   475  
   476  // GetDriver returns the name of the driver associated with the provider.
   477  func (b *IdentityProvider) GetDriver() string {
   478  	return b.config.Driver
   479  }
   480  
   481  // GetIdentityTokenCookieName returns the name of the identity token cookie associated with the provider.
   482  func (b *IdentityProvider) GetIdentityTokenCookieName() string {
   483  	if b.config.IdentityTokenCookieEnabled {
   484  		return b.config.IdentityTokenCookieName
   485  	}
   486  	return ""
   487  }