istio.io/istio@v0.0.0-20240520182934-d79c90f27776/security/pkg/util/jwtutil.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package util
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/base64"
    20  	"encoding/json"
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  )
    25  
    26  // GetExp returns token expiration time, or error on failures.
    27  func GetExp(token string) (time.Time, error) {
    28  	claims, err := parseJwtClaims(token)
    29  	if err != nil {
    30  		return time.Time{}, err
    31  	}
    32  
    33  	if claims["exp"] == nil {
    34  		// The JWT doesn't have "exp", so it's always valid. E.g., the K8s first party JWT.
    35  		return time.Time{}, nil
    36  	}
    37  
    38  	var expiration time.Time
    39  	switch exp := claims["exp"].(type) {
    40  	case float64:
    41  		expiration = time.Unix(int64(exp), 0)
    42  	case json.Number:
    43  		v, _ := exp.Int64()
    44  		expiration = time.Unix(v, 0)
    45  	}
    46  	return expiration, nil
    47  }
    48  
    49  // GetAud returns the claim `aud` from the token. Returns nil if not found.
    50  func GetAud(token string) ([]string, error) {
    51  	claims, err := parseJwtClaims(token)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	rawAud := claims["aud"]
    57  	if rawAud == nil {
    58  		return nil, fmt.Errorf("no aud in the token claims")
    59  	}
    60  
    61  	data, err := json.Marshal(rawAud)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	var singleAud string
    67  	if err = json.Unmarshal(data, &singleAud); err == nil {
    68  		return []string{singleAud}, nil
    69  	}
    70  
    71  	var listAud []string
    72  	if err = json.Unmarshal(data, &listAud); err == nil {
    73  		return listAud, nil
    74  	}
    75  
    76  	return nil, err
    77  }
    78  
    79  type jwtPayload struct {
    80  	// Aud is JWT token audience - used to identify 3p tokens.
    81  	// It is empty for the default K8S tokens.
    82  	Aud []string `json:"aud"`
    83  }
    84  
    85  // IsK8SUnbound detects if the token is a K8S unbound token.
    86  // It is a regular JWT with no audience and expiration, which can
    87  // be exchanged with bound tokens with audience.
    88  //
    89  // This is used to determine if we check audience in the token.
    90  // Clients should not use unbound tokens except in cases where
    91  // bound tokens are not possible.
    92  func IsK8SUnbound(jwt string) bool {
    93  	aud, f := ExtractJwtAud(jwt)
    94  	if !f {
    95  		return false // unbound tokens are valid JWT
    96  	}
    97  
    98  	return len(aud) == 0
    99  }
   100  
   101  // ExtractJwtAud extracts the audiences from a JWT token. If aud cannot be parse, the bool will be set
   102  // to false. This distinguishes aud=[] from not parsed.
   103  func ExtractJwtAud(jwt string) ([]string, bool) {
   104  	jwtSplit := strings.Split(jwt, ".")
   105  	if len(jwtSplit) != 3 {
   106  		return nil, false
   107  	}
   108  	payload := jwtSplit[1]
   109  
   110  	payloadBytes, err := DecodeJwtPart(payload)
   111  	if err != nil {
   112  		return nil, false
   113  	}
   114  
   115  	structuredPayload := jwtPayload{}
   116  	err = json.Unmarshal(payloadBytes, &structuredPayload)
   117  	if err != nil {
   118  		return nil, false
   119  	}
   120  
   121  	return structuredPayload.Aud, true
   122  }
   123  
   124  func parseJwtClaims(token string) (map[string]any, error) {
   125  	parts := strings.Split(token, ".")
   126  	if len(parts) != 3 {
   127  		return nil, fmt.Errorf("token contains an invalid number of segments: %d, expected: 3", len(parts))
   128  	}
   129  
   130  	// Decode the second part.
   131  	claimBytes, err := DecodeJwtPart(parts[1])
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	dec := json.NewDecoder(bytes.NewBuffer(claimBytes))
   136  
   137  	claims := make(map[string]any)
   138  	if err := dec.Decode(&claims); err != nil {
   139  		return nil, fmt.Errorf("failed to decode the JWT claims")
   140  	}
   141  	return claims, nil
   142  }
   143  
   144  func DecodeJwtPart(seg string) ([]byte, error) {
   145  	if l := len(seg) % 4; l > 0 {
   146  		seg += strings.Repeat("=", 4-l)
   147  	}
   148  
   149  	return base64.URLEncoding.DecodeString(seg)
   150  }