github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/jwt/parser.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/x509"
     6  	"encoding/pem"
     7  	"errors"
     8  	"net/http"
     9  	"strings"
    10  
    11  	"github.com/dgrijalva/jwt-go"
    12  )
    13  
    14  var (
    15  	// ErrSigningMethodMismatch is the error returned when token is signed with the method other than verified
    16  	ErrSigningMethodMismatch = errors.New("signing method mismatch")
    17  	// ErrFailedToParseToken is the error returned when token is failed to parse and validate against secret and expiration date
    18  	ErrFailedToParseToken = errors.New("failed to parse token")
    19  	// ErrUnsupportedSigningMethod is the error returned when token is signed with unsupported by the library method
    20  	ErrUnsupportedSigningMethod = errors.New("unsupported signing method")
    21  	// ErrInvalidPEMBlock is the error returned for keys expected to be PEM-encoded
    22  	ErrInvalidPEMBlock = errors.New("invalid RSA: not PEM-encoded")
    23  	// ErrNotRSAPublicKey is the error returned for invalid RSA public key
    24  	ErrNotRSAPublicKey = errors.New("invalid RSA: expected PUBLIC KEY block type")
    25  	// ErrBadPublicKey is the error returned for invalid RSA public key
    26  	ErrBadPublicKey = errors.New("invalid RSA: failed to assert public key")
    27  )
    28  
    29  // SigningMethod defines signing method algorithm and key
    30  type SigningMethod struct {
    31  	// Alg defines JWT signing algorithm. Possible values are: HS256, HS384, HS512, RS256, RS384, RS512
    32  	Alg string `json:"alg"`
    33  	Key string `json:"key"`
    34  }
    35  
    36  // ParserConfig configures the way JWT Parser gets and validates token
    37  type ParserConfig struct {
    38  	// SigningMethods defines chain of token signature verification algorithm/key pairs.
    39  	SigningMethods []SigningMethod
    40  
    41  	// TokenLookup is a string in the form of "<source>:<name>" that is used
    42  	// to extract token from the request.
    43  	// Optional. Default value "header:Authorization".
    44  	// Possible values:
    45  	// - "header:<name>"
    46  	// - "query:<name>"
    47  	// - "cookie:<name>"
    48  	TokenLookup string
    49  
    50  	// Leeway is the time in seconds to account for clock skew when checking nbf, iat or expiration times
    51  	Leeway int64
    52  }
    53  
    54  // NewParserConfig creates a new instance of ParserConfig
    55  func NewParserConfig(leeway int64, signingMethod ...SigningMethod) ParserConfig {
    56  	return ParserConfig{
    57  		SigningMethods: signingMethod,
    58  		TokenLookup:    "header:Authorization",
    59  		Leeway:         leeway,
    60  	}
    61  }
    62  
    63  // Parser struct
    64  type Parser struct {
    65  	Config ParserConfig
    66  }
    67  
    68  // NewParser creates a new instance of Parser
    69  func NewParser(config ParserConfig) *Parser {
    70  	return &Parser{config}
    71  }
    72  
    73  // ParseFromRequest tries to extract and validate token from request.
    74  // See "Guard.TokenLookup" for possible ways to pass token in request.
    75  func (jp *Parser) ParseFromRequest(r *http.Request) (*jwt.Token, error) {
    76  	var token string
    77  	var err error
    78  
    79  	parts := strings.Split(jp.Config.TokenLookup, ":")
    80  	switch parts[0] {
    81  	case "header":
    82  		token, err = jp.jwtFromHeader(r, parts[1])
    83  	case "query":
    84  		token, err = jp.jwtFromQuery(r, parts[1])
    85  	case "cookie":
    86  		token, err = jp.jwtFromCookie(r, parts[1])
    87  	}
    88  
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	return jp.Parse(token)
    94  }
    95  
    96  // Parse a JWT token and validates it
    97  func (jp *Parser) Parse(tokenString string) (*jwt.Token, error) {
    98  	for _, method := range jp.Config.SigningMethods {
    99  		token, err := jwt.ParseWithClaims(tokenString, NewJanusClaims(jp.Config.Leeway), func(token *jwt.Token) (interface{}, error) {
   100  			if token.Method.Alg() != method.Alg {
   101  				return nil, ErrSigningMethodMismatch
   102  			}
   103  
   104  			switch token.Method.(type) {
   105  			case *jwt.SigningMethodHMAC:
   106  				return []byte(method.Key), nil
   107  			case *jwt.SigningMethodRSA:
   108  				block, _ := pem.Decode([]byte(method.Key))
   109  				if block == nil {
   110  					return nil, ErrInvalidPEMBlock
   111  				}
   112  				if got, want := block.Type, "PUBLIC KEY"; got != want {
   113  					return nil, ErrNotRSAPublicKey
   114  				}
   115  				pub, err := x509.ParsePKIXPublicKey(block.Bytes)
   116  				if nil != err {
   117  					return nil, err
   118  				}
   119  
   120  				if _, ok := pub.(*rsa.PublicKey); !ok {
   121  					return nil, ErrBadPublicKey
   122  				}
   123  
   124  				return pub, nil
   125  			default:
   126  				return nil, ErrUnsupportedSigningMethod
   127  			}
   128  		})
   129  
   130  		if err != nil {
   131  			if err == ErrSigningMethodMismatch {
   132  				continue
   133  			}
   134  
   135  			if validationErr, ok := err.(*jwt.ValidationError); ok && (validationErr.Errors&jwt.ValidationErrorUnverifiable > 0 || validationErr.Errors&jwt.ValidationErrorSignatureInvalid > 0) {
   136  				continue
   137  			}
   138  		}
   139  
   140  		return token, err
   141  	}
   142  
   143  	return nil, ErrFailedToParseToken
   144  }
   145  
   146  // GetMapClaims returns a map version of Claims Section
   147  func (jp *Parser) GetMapClaims(token *jwt.Token) (jwt.MapClaims, bool) {
   148  	claims, ok := token.Claims.(*JanusClaims)
   149  	if !ok {
   150  		return jwt.MapClaims{}, ok
   151  	}
   152  	return claims.MapClaims, ok
   153  }
   154  
   155  func (jp *Parser) jwtFromHeader(r *http.Request, key string) (string, error) {
   156  	authHeader := r.Header.Get(key)
   157  
   158  	if authHeader == "" {
   159  		return "", errors.New("auth header empty")
   160  	}
   161  
   162  	parts := strings.SplitN(authHeader, " ", 2)
   163  	if !(len(parts) == 2 && parts[0] == "Bearer") {
   164  		return "", errors.New("invalid auth header")
   165  	}
   166  
   167  	return parts[1], nil
   168  }
   169  
   170  func (jp *Parser) jwtFromQuery(r *http.Request, key string) (string, error) {
   171  	token := r.URL.Query().Get(key)
   172  
   173  	if token == "" {
   174  		return "", errors.New("query token empty")
   175  	}
   176  
   177  	return token, nil
   178  }
   179  
   180  func (jp *Parser) jwtFromCookie(r *http.Request, key string) (string, error) {
   181  	cookie, _ := r.Cookie(key)
   182  
   183  	if nil == cookie {
   184  		return "", errors.New("cookie token empty")
   185  	}
   186  
   187  	return cookie.Value, nil
   188  }