github.com/argoproj/argo-cd/v3@v3.2.1/util/jwt/jwt.go (about) 1 package jwt 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "strings" 7 "time" 8 9 jwtgo "github.com/golang-jwt/jwt/v5" 10 ) 11 12 // MapClaims converts a jwt.Claims to a MapClaims 13 func MapClaims(claims jwtgo.Claims) (jwtgo.MapClaims, error) { 14 if mapClaims, ok := claims.(*jwtgo.MapClaims); ok { 15 return *mapClaims, nil 16 } 17 claimsBytes, err := json.Marshal(claims) 18 if err != nil { 19 return nil, err 20 } 21 var mapClaims jwtgo.MapClaims 22 err = json.Unmarshal(claimsBytes, &mapClaims) 23 if err != nil { 24 return nil, err 25 } 26 return mapClaims, nil 27 } 28 29 // StringField extracts a field from the claims as a string 30 func StringField(claims jwtgo.MapClaims, fieldName string) string { 31 if fieldIf, ok := claims[fieldName]; ok { 32 if field, ok := fieldIf.(string); ok { 33 return field 34 } 35 } 36 return "" 37 } 38 39 // Float64Field extracts a field from the claims as a float64 40 func Float64Field(claims jwtgo.MapClaims, fieldName string) float64 { 41 if fieldIf, ok := claims[fieldName]; ok { 42 if field, ok := fieldIf.(float64); ok { 43 return field 44 } 45 } 46 return 0 47 } 48 49 // GetScopeValues extracts the values of specified scopes from the claims 50 func GetScopeValues(claims jwtgo.MapClaims, scopes []string) []string { 51 groups := make([]string, 0) 52 for i := range scopes { 53 scopeIf, ok := claims[scopes[i]] 54 if !ok { 55 continue 56 } 57 58 switch val := scopeIf.(type) { 59 case []any: 60 for _, groupIf := range val { 61 group, ok := groupIf.(string) 62 if ok { 63 groups = append(groups, group) 64 } 65 } 66 case []string: 67 groups = append(groups, val...) 68 case string: 69 groups = append(groups, val) 70 } 71 } 72 73 return groups 74 } 75 76 func numField(m jwtgo.MapClaims, key string) (int64, error) { 77 field, ok := m[key] 78 if !ok { 79 return 0, fmt.Errorf("token does not have %s claim", key) 80 } 81 switch val := field.(type) { 82 case float64: 83 return int64(val), nil 84 case json.Number: 85 return val.Int64() 86 case int64: 87 return val, nil 88 default: 89 return 0, fmt.Errorf("%s '%v' is not a number", key, val) 90 } 91 } 92 93 // IssuedAt returns the issued at as an int64 94 func IssuedAt(m jwtgo.MapClaims) (int64, error) { 95 return numField(m, "iat") 96 } 97 98 // IssuedAtTime returns the issued at as a time.Time 99 func IssuedAtTime(m jwtgo.MapClaims) (time.Time, error) { 100 iat, err := IssuedAt(m) 101 return time.Unix(iat, 0), err 102 } 103 104 // ExpirationTime returns the expiration as a time.Time 105 func ExpirationTime(m jwtgo.MapClaims) (time.Time, error) { 106 exp, err := numField(m, "exp") 107 return time.Unix(exp, 0), err 108 } 109 110 func Claims(in any) jwtgo.Claims { 111 claims, ok := in.(jwtgo.Claims) 112 if ok { 113 return claims 114 } 115 return nil 116 } 117 118 // IsMember returns whether or not the user's claims is a member of any of the groups 119 func IsMember(claims jwtgo.Claims, groups []string, scopes []string) bool { 120 mapClaims, err := MapClaims(claims) 121 if err != nil { 122 return false 123 } 124 // O(n^2) loop 125 for _, userGroup := range GetGroups(mapClaims, scopes) { 126 for _, group := range groups { 127 if userGroup == group { 128 return true 129 } 130 } 131 } 132 return false 133 } 134 135 func GetGroups(mapClaims jwtgo.MapClaims, scopes []string) []string { 136 return GetScopeValues(mapClaims, scopes) 137 } 138 139 func IsValid(token string) bool { 140 return len(strings.SplitN(token, ".", 3)) == 3 141 } 142 143 // GetUserIdentifier returns a consistent user identifier, checking federated_claims.user_id when Dex is in use 144 func GetUserIdentifier(c jwtgo.MapClaims) string { 145 if c == nil { 146 return "" 147 } 148 149 // Fallback to sub if federated_claims.user_id is not set. 150 fallback := StringField(c, "sub") 151 152 f := c["federated_claims"] 153 if f == nil { 154 return fallback 155 } 156 federatedClaims, ok := f.(map[string]any) 157 if !ok { 158 return fallback 159 } 160 161 userId, ok := federatedClaims["user_id"].(string) 162 if !ok || userId == "" { 163 return fallback 164 } 165 166 return userId 167 }