istio.io/istio@v0.0.0-20240520182934-d79c90f27776/security/pkg/util/jwtutil.go (about) 1 // Copyright Istio Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package util 16 17 import ( 18 "bytes" 19 "encoding/base64" 20 "encoding/json" 21 "fmt" 22 "strings" 23 "time" 24 ) 25 26 // GetExp returns token expiration time, or error on failures. 27 func GetExp(token string) (time.Time, error) { 28 claims, err := parseJwtClaims(token) 29 if err != nil { 30 return time.Time{}, err 31 } 32 33 if claims["exp"] == nil { 34 // The JWT doesn't have "exp", so it's always valid. E.g., the K8s first party JWT. 35 return time.Time{}, nil 36 } 37 38 var expiration time.Time 39 switch exp := claims["exp"].(type) { 40 case float64: 41 expiration = time.Unix(int64(exp), 0) 42 case json.Number: 43 v, _ := exp.Int64() 44 expiration = time.Unix(v, 0) 45 } 46 return expiration, nil 47 } 48 49 // GetAud returns the claim `aud` from the token. Returns nil if not found. 50 func GetAud(token string) ([]string, error) { 51 claims, err := parseJwtClaims(token) 52 if err != nil { 53 return nil, err 54 } 55 56 rawAud := claims["aud"] 57 if rawAud == nil { 58 return nil, fmt.Errorf("no aud in the token claims") 59 } 60 61 data, err := json.Marshal(rawAud) 62 if err != nil { 63 return nil, err 64 } 65 66 var singleAud string 67 if err = json.Unmarshal(data, &singleAud); err == nil { 68 return []string{singleAud}, nil 69 } 70 71 var listAud []string 72 if err = json.Unmarshal(data, &listAud); err == nil { 73 return listAud, nil 74 } 75 76 return nil, err 77 } 78 79 type jwtPayload struct { 80 // Aud is JWT token audience - used to identify 3p tokens. 81 // It is empty for the default K8S tokens. 82 Aud []string `json:"aud"` 83 } 84 85 // IsK8SUnbound detects if the token is a K8S unbound token. 86 // It is a regular JWT with no audience and expiration, which can 87 // be exchanged with bound tokens with audience. 88 // 89 // This is used to determine if we check audience in the token. 90 // Clients should not use unbound tokens except in cases where 91 // bound tokens are not possible. 92 func IsK8SUnbound(jwt string) bool { 93 aud, f := ExtractJwtAud(jwt) 94 if !f { 95 return false // unbound tokens are valid JWT 96 } 97 98 return len(aud) == 0 99 } 100 101 // ExtractJwtAud extracts the audiences from a JWT token. If aud cannot be parse, the bool will be set 102 // to false. This distinguishes aud=[] from not parsed. 103 func ExtractJwtAud(jwt string) ([]string, bool) { 104 jwtSplit := strings.Split(jwt, ".") 105 if len(jwtSplit) != 3 { 106 return nil, false 107 } 108 payload := jwtSplit[1] 109 110 payloadBytes, err := DecodeJwtPart(payload) 111 if err != nil { 112 return nil, false 113 } 114 115 structuredPayload := jwtPayload{} 116 err = json.Unmarshal(payloadBytes, &structuredPayload) 117 if err != nil { 118 return nil, false 119 } 120 121 return structuredPayload.Aud, true 122 } 123 124 func parseJwtClaims(token string) (map[string]any, error) { 125 parts := strings.Split(token, ".") 126 if len(parts) != 3 { 127 return nil, fmt.Errorf("token contains an invalid number of segments: %d, expected: 3", len(parts)) 128 } 129 130 // Decode the second part. 131 claimBytes, err := DecodeJwtPart(parts[1]) 132 if err != nil { 133 return nil, err 134 } 135 dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) 136 137 claims := make(map[string]any) 138 if err := dec.Decode(&claims); err != nil { 139 return nil, fmt.Errorf("failed to decode the JWT claims") 140 } 141 return claims, nil 142 } 143 144 func DecodeJwtPart(seg string) ([]byte, error) { 145 if l := len(seg) % 4; l > 0 { 146 seg += strings.Repeat("=", 4-l) 147 } 148 149 return base64.URLEncoding.DecodeString(seg) 150 }