github.com/argoproj/argo-cd/v3@v3.2.1/util/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"strings"
     7  	"time"
     8  
     9  	jwtgo "github.com/golang-jwt/jwt/v5"
    10  )
    11  
    12  // MapClaims converts a jwt.Claims to a MapClaims
    13  func MapClaims(claims jwtgo.Claims) (jwtgo.MapClaims, error) {
    14  	if mapClaims, ok := claims.(*jwtgo.MapClaims); ok {
    15  		return *mapClaims, nil
    16  	}
    17  	claimsBytes, err := json.Marshal(claims)
    18  	if err != nil {
    19  		return nil, err
    20  	}
    21  	var mapClaims jwtgo.MapClaims
    22  	err = json.Unmarshal(claimsBytes, &mapClaims)
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	return mapClaims, nil
    27  }
    28  
    29  // StringField extracts a field from the claims as a string
    30  func StringField(claims jwtgo.MapClaims, fieldName string) string {
    31  	if fieldIf, ok := claims[fieldName]; ok {
    32  		if field, ok := fieldIf.(string); ok {
    33  			return field
    34  		}
    35  	}
    36  	return ""
    37  }
    38  
    39  // Float64Field extracts a field from the claims as a float64
    40  func Float64Field(claims jwtgo.MapClaims, fieldName string) float64 {
    41  	if fieldIf, ok := claims[fieldName]; ok {
    42  		if field, ok := fieldIf.(float64); ok {
    43  			return field
    44  		}
    45  	}
    46  	return 0
    47  }
    48  
    49  // GetScopeValues extracts the values of specified scopes from the claims
    50  func GetScopeValues(claims jwtgo.MapClaims, scopes []string) []string {
    51  	groups := make([]string, 0)
    52  	for i := range scopes {
    53  		scopeIf, ok := claims[scopes[i]]
    54  		if !ok {
    55  			continue
    56  		}
    57  
    58  		switch val := scopeIf.(type) {
    59  		case []any:
    60  			for _, groupIf := range val {
    61  				group, ok := groupIf.(string)
    62  				if ok {
    63  					groups = append(groups, group)
    64  				}
    65  			}
    66  		case []string:
    67  			groups = append(groups, val...)
    68  		case string:
    69  			groups = append(groups, val)
    70  		}
    71  	}
    72  
    73  	return groups
    74  }
    75  
    76  func numField(m jwtgo.MapClaims, key string) (int64, error) {
    77  	field, ok := m[key]
    78  	if !ok {
    79  		return 0, fmt.Errorf("token does not have %s claim", key)
    80  	}
    81  	switch val := field.(type) {
    82  	case float64:
    83  		return int64(val), nil
    84  	case json.Number:
    85  		return val.Int64()
    86  	case int64:
    87  		return val, nil
    88  	default:
    89  		return 0, fmt.Errorf("%s '%v' is not a number", key, val)
    90  	}
    91  }
    92  
    93  // IssuedAt returns the issued at as an int64
    94  func IssuedAt(m jwtgo.MapClaims) (int64, error) {
    95  	return numField(m, "iat")
    96  }
    97  
    98  // IssuedAtTime returns the issued at as a time.Time
    99  func IssuedAtTime(m jwtgo.MapClaims) (time.Time, error) {
   100  	iat, err := IssuedAt(m)
   101  	return time.Unix(iat, 0), err
   102  }
   103  
   104  // ExpirationTime returns the expiration as a time.Time
   105  func ExpirationTime(m jwtgo.MapClaims) (time.Time, error) {
   106  	exp, err := numField(m, "exp")
   107  	return time.Unix(exp, 0), err
   108  }
   109  
   110  func Claims(in any) jwtgo.Claims {
   111  	claims, ok := in.(jwtgo.Claims)
   112  	if ok {
   113  		return claims
   114  	}
   115  	return nil
   116  }
   117  
   118  // IsMember returns whether or not the user's claims is a member of any of the groups
   119  func IsMember(claims jwtgo.Claims, groups []string, scopes []string) bool {
   120  	mapClaims, err := MapClaims(claims)
   121  	if err != nil {
   122  		return false
   123  	}
   124  	// O(n^2) loop
   125  	for _, userGroup := range GetGroups(mapClaims, scopes) {
   126  		for _, group := range groups {
   127  			if userGroup == group {
   128  				return true
   129  			}
   130  		}
   131  	}
   132  	return false
   133  }
   134  
   135  func GetGroups(mapClaims jwtgo.MapClaims, scopes []string) []string {
   136  	return GetScopeValues(mapClaims, scopes)
   137  }
   138  
   139  func IsValid(token string) bool {
   140  	return len(strings.SplitN(token, ".", 3)) == 3
   141  }
   142  
   143  // GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use
   144  func GetUserIdentifier(c jwtgo.MapClaims) string {
   145  	if c == nil {
   146  		return ""
   147  	}
   148  
   149  	// Fallback to sub if federated_claims.user_id is not set.
   150  	fallback := StringField(c, "sub")
   151  
   152  	f := c["federated_claims"]
   153  	if f == nil {
   154  		return fallback
   155  	}
   156  	federatedClaims, ok := f.(map[string]any)
   157  	if !ok {
   158  		return fallback
   159  	}
   160  
   161  	userId, ok := federatedClaims["user_id"].(string)
   162  	if !ok || userId == "" {
   163  		return fallback
   164  	}
   165  
   166  	return userId
   167  }