
     1  package session
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"math"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"time"
    15  	oidc ""
    16  	""
    17  	log ""
    18  	""
    19  	""
    21  	""
    22  	""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	httputil ""
    28  	jwtutil ""
    29  	oidcutil ""
    30  	passwordutil ""
    31  	""
    32  )
    34  // SessionManager generates and validates JWT tokens for login sessions.
    35  type SessionManager struct {
    36  	settingsMgr                   *settings.SettingsManager
    37  	projectsLister                v1alpha1.AppProjectNamespaceLister
    38  	client                        *http.Client
    39  	prov                          oidcutil.Provider
    40  	storage                       UserStateStorage
    41  	sleep                         func(d time.Duration)
    42  	verificationDelayNoiseEnabled bool
    43  }
    45  type inMemoryUserStateStorage struct {
    46  	attempts map[string]LoginAttempts
    47  }
    49  func NewInMemoryUserStateStorage() *inMemoryUserStateStorage {
    50  	return &inMemoryUserStateStorage{attempts: map[string]LoginAttempts{}}
    51  }
    53  func (storage *inMemoryUserStateStorage) GetLoginAttempts(attempts *map[string]LoginAttempts) error {
    54  	*attempts = storage.attempts
    55  	return nil
    56  }
    58  func (storage *inMemoryUserStateStorage) SetLoginAttempts(attempts map[string]LoginAttempts) error {
    59  	storage.attempts = attempts
    60  	return nil
    61  }
    63  type UserStateStorage interface {
    64  	GetLoginAttempts(attempts *map[string]LoginAttempts) error
    65  	SetLoginAttempts(attempts map[string]LoginAttempts) error
    66  }
    68  // LoginAttempts is a timestamped counter for failed login attempts
    69  type LoginAttempts struct {
    70  	// Time of the last failed login
    71  	LastFailed time.Time `json:"lastFailed"`
    72  	// Number of consecutive login failures
    73  	FailCount int `json:"failCount"`
    74  }
    76  const (
    77  	// SessionManagerClaimsIssuer fills the "iss" field of the token.
    78  	SessionManagerClaimsIssuer = "argocd"
    80  	// invalidLoginError, for security purposes, doesn't say whether the username or password was invalid.  This does not mitigate the potential for timing attacks to determine which is which.
    81  	invalidLoginError         = "Invalid username or password"
    82  	blankPasswordError        = "Blank passwords are not allowed"
    83  	accountDisabled           = "Account %s is disabled"
    84  	usernameTooLongError      = "Username is too long (%d bytes max)"
    85  	userDoesNotHaveCapability = "Account %s does not have %s capability"
    86  )
    88  const (
    89  	// Maximum length of username, too keep the cache's memory signature low
    90  	maxUsernameLength = 32
    91  	// The default maximum session cache size
    92  	defaultMaxCacheSize = 1000
    93  	// The default number of maximum login failures before delay kicks in
    94  	defaultMaxLoginFailures = 5
    95  	// The default time in seconds for the failure window
    96  	defaultFailureWindow = 300
    97  	// The password verification delay max
    98  	verificationDelayNoiseMin = 500 * time.Millisecond
    99  	// The password verification delay max
   100  	verificationDelayNoiseMax = 1000 * time.Millisecond
   102  	// environment variables to control rate limiter behaviour:
   104  	// Max number of login failures before login delay kicks in
   107  	// Number of maximum seconds the login is allowed to delay for. Default: 300 (5 minutes).
   108  	envLoginFailureWindowSeconds = "ARGOCD_SESSION_FAILURE_WINDOW_SECONDS"
   110  	// Max number of stored usernames
   111  	envLoginMaxCacheSize = "ARGOCD_SESSION_MAX_CACHE_SIZE"
   112  )
   114  var (
   115  	InvalidLoginErr = status.Errorf(codes.Unauthenticated, invalidLoginError)
   116  )
   118  // Returns the maximum cache size as number of entries
   119  func getMaximumCacheSize() int {
   120  	return env.ParseNumFromEnv(envLoginMaxCacheSize, defaultMaxCacheSize, 1, math.MaxInt32)
   121  }
   123  // Returns the maximum number of login failures before login delay kicks in
   124  func getMaxLoginFailures() int {
   125  	return env.ParseNumFromEnv(envLoginMaxFailCount, defaultMaxLoginFailures, 1, math.MaxInt32)
   126  }
   128  // Returns the number of maximum seconds the login is allowed to delay for
   129  func getLoginFailureWindow() time.Duration {
   130  	return time.Duration(env.ParseNumFromEnv(envLoginFailureWindowSeconds, defaultFailureWindow, 0, math.MaxInt32))
   131  }
   133  // NewSessionManager creates a new session manager from Argo CD settings
   134  func NewSessionManager(settingsMgr *settings.SettingsManager, projectsLister v1alpha1.AppProjectNamespaceLister, dexServerAddr string, storage UserStateStorage) *SessionManager {
   135  	s := SessionManager{
   136  		settingsMgr:                   settingsMgr,
   137  		storage:                       storage,
   138  		sleep:                         time.Sleep,
   139  		projectsLister:                projectsLister,
   140  		verificationDelayNoiseEnabled: true,
   141  	}
   142  	settings, err := settingsMgr.GetSettings()
   143  	if err != nil {
   144  		panic(err)
   145  	}
   146  	tlsConfig := settings.TLSConfig()
   147  	if tlsConfig != nil {
   148  		tlsConfig.InsecureSkipVerify = true
   149  	}
   150  	s.client = &http.Client{
   151  		Transport: &http.Transport{
   152  			TLSClientConfig: tlsConfig,
   153  			Proxy:           http.ProxyFromEnvironment,
   154  			Dial: (&net.Dialer{
   155  				Timeout:   30 * time.Second,
   156  				KeepAlive: 30 * time.Second,
   157  			}).Dial,
   158  			TLSHandshakeTimeout:   10 * time.Second,
   159  			ExpectContinueTimeout: 1 * time.Second,
   160  		},
   161  	}
   162  	if settings.DexConfig != "" {
   163  		s.client.Transport = dex.NewDexRewriteURLRoundTripper(dexServerAddr, s.client.Transport)
   164  	}
   165  	if os.Getenv(common.EnvVarSSODebug) == "1" {
   166  		s.client.Transport = httputil.DebugTransport{T: s.client.Transport}
   167  	}
   169  	return &s
   170  }
   172  // Create creates a new token for a given subject (user) and returns it as a string.
   173  // Passing a value of `0` for secondsBeforeExpiry creates a token that never expires.
   174  // The id parameter holds an optional unique JWT token identifier and stored as a standard claim "jti" in the JWT token.
   175  func (mgr *SessionManager) Create(subject string, secondsBeforeExpiry int64, id string) (string, error) {
   176  	// Create a new token object, specifying signing method and the claims
   177  	// you would like it to contain.
   178  	now := time.Now().UTC()
   179  	claims := jwt.StandardClaims{
   180  		IssuedAt:  jwt.At(now),
   181  		Issuer:    SessionManagerClaimsIssuer,
   182  		NotBefore: jwt.At(now),
   183  		Subject:   subject,
   184  		ID:        id,
   185  	}
   186  	if secondsBeforeExpiry > 0 {
   187  		expires := now.Add(time.Duration(secondsBeforeExpiry) * time.Second)
   188  		claims.ExpiresAt = jwt.At(expires)
   189  	}
   191  	return mgr.signClaims(claims)
   192  }
   194  type standardClaims struct {
   195  	Audience  jwt.ClaimStrings `json:"aud,omitempty"`
   196  	ExpiresAt int64            `json:"exp,omitempty"`
   197  	ID        string           `json:"jti,omitempty"`
   198  	IssuedAt  int64            `json:"iat,omitempty"`
   199  	Issuer    string           `json:"iss,omitempty"`
   200  	NotBefore int64            `json:"nbf,omitempty"`
   201  	Subject   string           `json:"sub,omitempty"`
   202  }
   204  func unixTimeOrZero(t *jwt.Time) int64 {
   205  	if t == nil {
   206  		return 0
   207  	}
   208  	return t.Unix()
   209  }
   211  func (mgr *SessionManager) signClaims(claims jwt.Claims) (string, error) {
   212  	// log.Infof("Issuing claims: %v", claims)
   213  	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
   214  	settings, err := mgr.settingsMgr.GetSettings()
   215  	if err != nil {
   216  		return "", err
   217  	}
   218  	// workaround for
   219  	// According to "iat" and other time fields must contain
   220  	// number of seconds from 1970-01-01T00:00:00Z UTC until the specified UTC date/time.
   221  	// The marshals time as non integer.
   222  	return token.SignedString(settings.ServerSignature, jwt.WithMarshaller(func(ctx jwt.CodingContext, v interface{}) ([]byte, error) {
   223  		if std, ok := v.(jwt.StandardClaims); ok {
   224  			return json.Marshal(standardClaims{
   225  				Audience:  std.Audience,
   226  				ExpiresAt: unixTimeOrZero(std.ExpiresAt),
   227  				ID:        std.ID,
   228  				IssuedAt:  unixTimeOrZero(std.IssuedAt),
   229  				Issuer:    std.Issuer,
   230  				NotBefore: unixTimeOrZero(std.NotBefore),
   231  				Subject:   std.Subject,
   232  			})
   233  		}
   234  		return json.Marshal(v)
   235  	}))
   236  }
   238  // Parse tries to parse the provided string and returns the token claims for local login.
   239  func (mgr *SessionManager) Parse(tokenString string) (jwt.Claims, error) {
   240  	// Parse takes the token string and a function for looking up the key. The latter is especially
   241  	// useful if you use multiple keys for your application.  The standard is to use 'kid' in the
   242  	// head of the token to identify which key to use, but the parsed token (head and claims) is provided
   243  	// to the callback, providing flexibility.
   244  	var claims jwt.MapClaims
   245  	argoCDSettings, err := mgr.settingsMgr.GetSettings()
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  	token, err := jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (interface{}, error) {
   250  		// Don't forget to validate the alg is what you expect:
   251  		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
   252  			return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
   253  		}
   254  		return argoCDSettings.ServerSignature, nil
   255  	})
   256  	if err != nil {
   257  		return nil, err
   258  	}
   260  	issuedAt, err := jwtutil.IssuedAtTime(claims)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   265  	subject := jwtutil.StringField(claims, "sub")
   266  	id := jwtutil.StringField(claims, "jti")
   268  	if projName, role, ok := rbacpolicy.GetProjectRoleFromSubject(subject); ok {
   269  		proj, err := mgr.projectsLister.Get(projName)
   270  		if err != nil {
   271  			return nil, err
   272  		}
   273  		_, _, err = proj.GetJWTToken(role, issuedAt.Unix(), id)
   274  		if err != nil {
   275  			return nil, err
   276  		}
   278  		return token.Claims, nil
   279  	}
   281  	account, err := mgr.settingsMgr.GetAccount(subject)
   282  	if err != nil {
   283  		return nil, err
   284  	}
   286  	if !account.Enabled {
   287  		return nil, fmt.Errorf("account %s is disabled", subject)
   288  	}
   290  	var capability settings.AccountCapability
   291  	if id != "" {
   292  		capability = settings.AccountCapabilityApiKey
   293  	} else {
   294  		capability = settings.AccountCapabilityLogin
   295  	}
   296  	if !account.HasCapability(capability) {
   297  		return nil, fmt.Errorf("account %s does not have '%s' capability", subject, capability)
   298  	}
   300  	if id != "" && account.TokenIndex(id) == -1 {
   301  		return nil, fmt.Errorf("account %s does not have token with id %s", subject, id)
   302  	}
   304  	if account.PasswordMtime != nil && issuedAt.Before(*account.PasswordMtime) {
   305  		return nil, fmt.Errorf("Account password has changed since token issued")
   306  	}
   307  	return token.Claims, nil
   308  }
   310  // GetLoginFailures retrieves the login failure information from the cache
   311  func (mgr *SessionManager) GetLoginFailures() map[string]LoginAttempts {
   312  	// Get failures from the cache
   313  	var failures map[string]LoginAttempts
   314  	err :=
   315  	if err != nil {
   316  		if err != appstate.ErrCacheMiss {
   317  			log.Errorf("Could not retrieve login attempts: %v", err)
   318  		}
   319  		failures = make(map[string]LoginAttempts)
   320  	}
   322  	return failures
   323  }
   325  func expireOldFailedAttempts(maxAge time.Duration, failures *map[string]LoginAttempts) int {
   326  	expiredCount := 0
   327  	for key, attempt := range *failures {
   328  		if time.Since(attempt.LastFailed) > maxAge*time.Second {
   329  			expiredCount += 1
   330  			delete(*failures, key)
   331  		}
   332  	}
   333  	return expiredCount
   334  }
   336  // Updates the failure count for a given username. If failed is true, increases the counter. Otherwise, sets counter back to 0.
   337  func (mgr *SessionManager) updateFailureCount(username string, failed bool) {
   339  	failures := mgr.GetLoginFailures()
   341  	// Expire old entries in the cache if we have a failure window defined.
   342  	if window := getLoginFailureWindow(); window > 0 {
   343  		count := expireOldFailedAttempts(window, &failures)
   344  		if count > 0 {
   345  			log.Infof("Expired %d entries from session cache due to max age reached", count)
   346  		}
   347  	}
   349  	// If we exceed a certain cache size, we need to remove random entries to
   350  	// prevent overbloating the cache with fake entries, as this could lead to
   351  	// memory exhaustion and ultimately in a DoS. We remove a single entry to
   352  	// replace it with the new one.
   353  	//
   354  	// Chances are that we remove the one that is under active attack, but this
   355  	// chance is low (1:cache_size)
   356  	if failed && len(failures) >= getMaximumCacheSize() {
   357  		log.Warnf("Session cache size exceeds %d entries, removing random entry", getMaximumCacheSize())
   358  		idx := rand.Intn(len(failures) - 1)
   359  		var rmUser string
   360  		i := 0
   361  		for key := range failures {
   362  			if i == idx {
   363  				rmUser = key
   364  				delete(failures, key)
   365  				break
   366  			}
   367  			i++
   368  		}
   369  		log.Infof("Deleted entry for user %s from cache", rmUser)
   370  	}
   372  	attempt, ok := failures[username]
   373  	if !ok {
   374  		attempt = LoginAttempts{FailCount: 0}
   375  	}
   377  	// On login failure, increase fail count and update last failed timestamp.
   378  	// On login success, remove the entry from the cache.
   379  	if failed {
   380  		attempt.FailCount += 1
   381  		attempt.LastFailed = time.Now()
   382  		failures[username] = attempt
   383  		log.Warnf("User %s failed login %d time(s)", username, attempt.FailCount)
   384  	} else {
   385  		if attempt.FailCount > 0 {
   386  			// Forget username for cache size enforcement, since entry in cache was deleted
   387  			delete(failures, username)
   388  		}
   389  	}
   391  	err :=
   392  	if err != nil {
   393  		log.Errorf("Could not update login attempts: %v", err)
   394  	}
   396  }
   398  // Get the current login failure attempts for given username
   399  func (mgr *SessionManager) getFailureCount(username string) LoginAttempts {
   400  	failures := mgr.GetLoginFailures()
   401  	attempt, ok := failures[username]
   402  	if !ok {
   403  		attempt = LoginAttempts{FailCount: 0}
   404  	}
   405  	return attempt
   406  }
   408  // Calculate a login delay for the given login attempt
   409  func (mgr *SessionManager) exceededFailedLoginAttempts(attempt LoginAttempts) bool {
   410  	maxFails := getMaxLoginFailures()
   411  	failureWindow := getLoginFailureWindow()
   413  	// Whether we are in the failure window for given attempt
   414  	inWindow := func() bool {
   415  		if failureWindow == 0 || time.Since(attempt.LastFailed).Seconds() <= float64(failureWindow) {
   416  			return true
   417  		}
   418  		return false
   419  	}
   421  	// If we reached max failed attempts within failure window, we need to calc the delay
   422  	if attempt.FailCount >= maxFails && inWindow() {
   423  		return true
   424  	}
   426  	return false
   427  }
   429  // VerifyUsernamePassword verifies if a username/password combo is correct
   430  func (mgr *SessionManager) VerifyUsernamePassword(username string, password string) error {
   431  	if password == "" {
   432  		return status.Errorf(codes.Unauthenticated, blankPasswordError)
   433  	}
   434  	// Enforce maximum length of username on local accounts
   435  	if len(username) > maxUsernameLength {
   436  		return status.Errorf(codes.InvalidArgument, usernameTooLongError, maxUsernameLength)
   437  	}
   439  	start := time.Now()
   440  	if mgr.verificationDelayNoiseEnabled {
   441  		defer func() {
   442  			// introduces random delay to protect from timing-based user enumeration attack
   443  			delayNanoseconds := verificationDelayNoiseMin.Nanoseconds() +
   444  				int64(rand.Intn(int(verificationDelayNoiseMax.Nanoseconds()-verificationDelayNoiseMin.Nanoseconds())))
   445  				// take into account amount of time spent since the request start
   446  			delayNanoseconds = delayNanoseconds - time.Since(start).Nanoseconds()
   447  			if delayNanoseconds > 0 {
   448  				mgr.sleep(time.Duration(delayNanoseconds))
   449  			}
   450  		}()
   451  	}
   453  	attempt := mgr.getFailureCount(username)
   454  	if mgr.exceededFailedLoginAttempts(attempt) {
   455  		log.Warnf("User %s had too many failed logins (%d)", username, attempt.FailCount)
   456  		return InvalidLoginErr
   457  	}
   459  	account, err := mgr.settingsMgr.GetAccount(username)
   460  	if err != nil {
   461  		if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
   462  			mgr.updateFailureCount(username, true)
   463  			err = InvalidLoginErr
   464  		}
   465  		// to prevent time-based user enumeration, we must perform a password
   466  		// hash cycle to keep response time consistent (if the function were
   467  		// to continue and not return here)
   468  		_, _ = passwordutil.HashPassword("for_consistent_response_time")
   469  		return err
   470  	}
   472  	valid, _ := passwordutil.VerifyPassword(password, account.PasswordHash)
   473  	if !valid {
   474  		mgr.updateFailureCount(username, true)
   475  		return InvalidLoginErr
   476  	}
   478  	if !account.Enabled {
   479  		return status.Errorf(codes.Unauthenticated, accountDisabled, username)
   480  	}
   482  	if !account.HasCapability(settings.AccountCapabilityLogin) {
   483  		return status.Errorf(codes.Unauthenticated, userDoesNotHaveCapability, username, settings.AccountCapabilityLogin)
   484  	}
   485  	mgr.updateFailureCount(username, false)
   486  	return nil
   487  }
   489  // VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
   490  // We choose how to verify based on the issuer.
   491  func (mgr *SessionManager) VerifyToken(tokenString string) (jwt.Claims, error) {
   492  	parser := &jwt.Parser{
   493  		ValidationHelper: jwt.NewValidationHelper(jwt.WithoutClaimsValidation(), jwt.WithoutAudienceValidation()),
   494  	}
   495  	var claims jwt.StandardClaims
   496  	_, _, err := parser.ParseUnverified(tokenString, &claims)
   497  	if err != nil {
   498  		return nil, err
   499  	}
   500  	switch claims.Issuer {
   501  	case SessionManagerClaimsIssuer:
   502  		// Argo CD signed token
   503  		return mgr.Parse(tokenString)
   504  	default:
   505  		// IDP signed token
   506  		prov, err := mgr.provider()
   507  		if err != nil {
   508  			return claims, err
   509  		}
   511  		// Token must be verified for at least one audience
   512  		// TODO(jannfis): Is this the right way? Shouldn't we know our audience and only validate for the correct one?
   513  		var idToken *oidc.IDToken
   514  		for _, aud := range claims.Audience {
   515  			idToken, err = prov.Verify(aud, tokenString)
   516  			if err == nil {
   517  				break
   518  			}
   519  		}
   520  		if err != nil {
   521  			return claims, err
   522  		}
   523  		var claims jwt.MapClaims
   524  		err = idToken.Claims(&claims)
   525  		return claims, err
   526  	}
   527  }
   529  func (mgr *SessionManager) provider() (oidcutil.Provider, error) {
   530  	if mgr.prov != nil {
   531  		return mgr.prov, nil
   532  	}
   533  	settings, err := mgr.settingsMgr.GetSettings()
   534  	if err != nil {
   535  		return nil, err
   536  	}
   537  	if !settings.IsSSOConfigured() {
   538  		return nil, fmt.Errorf("SSO is not configured")
   539  	}
   540  	mgr.prov = oidcutil.NewOIDCProvider(settings.IssuerURL(), mgr.client)
   541  	return mgr.prov, nil
   542  }
   544  func LoggedIn(ctx context.Context) bool {
   545  	return Sub(ctx) != ""
   546  }
   548  // Username is a helper to extract a human readable username from a context
   549  func Username(ctx context.Context) string {
   550  	mapClaims, ok := mapClaims(ctx)
   551  	if !ok {
   552  		return ""
   553  	}
   554  	switch jwtutil.StringField(mapClaims, "iss") {
   555  	case SessionManagerClaimsIssuer:
   556  		return jwtutil.StringField(mapClaims, "sub")
   557  	default:
   558  		return jwtutil.StringField(mapClaims, "email")
   559  	}
   560  }
   562  func Iss(ctx context.Context) string {
   563  	mapClaims, ok := mapClaims(ctx)
   564  	if !ok {
   565  		return ""
   566  	}
   567  	return jwtutil.StringField(mapClaims, "iss")
   568  }
   570  func Iat(ctx context.Context) (time.Time, error) {
   571  	mapClaims, ok := mapClaims(ctx)
   572  	if !ok {
   573  		return time.Time{}, errors.New("unable to extract token claims")
   574  	}
   575  	return jwtutil.IssuedAtTime(mapClaims)
   576  }
   578  func Sub(ctx context.Context) string {
   579  	mapClaims, ok := mapClaims(ctx)
   580  	if !ok {
   581  		return ""
   582  	}
   583  	return jwtutil.StringField(mapClaims, "sub")
   584  }
   586  func Groups(ctx context.Context, scopes []string) []string {
   587  	mapClaims, ok := mapClaims(ctx)
   588  	if !ok {
   589  		return nil
   590  	}
   591  	return jwtutil.GetGroups(mapClaims, scopes)
   592  }
   594  func mapClaims(ctx context.Context) (jwt.MapClaims, bool) {
   595  	claims, ok := ctx.Value("claims").(jwt.Claims)
   596  	if !ok {
   597  		return nil, false
   598  	}
   599  	mapClaims, err := jwtutil.MapClaims(claims)
   600  	if err != nil {
   601  		return nil, false
   602  	}
   603  	return mapClaims, true
   604  }