github.com/zak-blake/goa@v1.4.1/middleware/security/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/rsa"
     6  	"fmt"
     7  	"net/http"
     8  	"sort"
     9  	"strings"
    10  
    11  	"context"
    12  
    13  	jwt "github.com/dgrijalva/jwt-go"
    14  	"github.com/goadesign/goa"
    15  )
    16  
    17  // New returns a middleware to be used with the JWTSecurity DSL definitions of goa.  It supports the
    18  // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated.
    19  //
    20  // The steps taken by the middleware are:
    21  //
    22  //     1. Extract the "Bearer" token from the Authorization header or query parameter
    23  //     2. Validate the "Bearer" token against the key(s)
    24  //        given to New
    25  //     3. If scopes are defined in the design for the action, validate them
    26  //        against the scopes presented by the JWT in the claim "scope", or if
    27  //        that's not defined, "scopes".
    28  //
    29  // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library.
    30  //
    31  // validationKeys can be one of these:
    32  //
    33  //     * a string (for HMAC)
    34  //     * a []byte (for HMAC)
    35  //     * an rsa.PublicKey
    36  //     * an ecdsa.PublicKey
    37  //     * a slice of any of the above
    38  //
    39  // The type of the keys determine the algorithm that will be used to do the check.  The goal of
    40  // having lists of keys is to allow for key rotation, still check the previous keys until rotation
    41  // has been completed.
    42  //
    43  // You can define an optional function to do additional validations on the token once the signature
    44  // and the claims requirements are proven to be valid.  Example:
    45  //
    46  //    validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
    47  //        token := jwt.ContextJWT(ctx)
    48  //        if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" {
    49  //            return jwt.ErrJWTError("you are not uncle ben's")
    50  //        }
    51  //    })
    52  //
    53  // Mount the middleware with the generated UseXX function where XX is the name of the scheme as
    54  // defined in the design, e.g.:
    55  //
    56  //    app.UseJWT(jwt.New("secret", validationHandler, app.NewJWTSecurity()))
    57  //
    58  func New(validationKeys interface{}, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
    59  	var rsaKeys []*rsa.PublicKey
    60  	var hmacKeys [][]byte
    61  
    62  	rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(validationKeys)
    63  
    64  	return func(nextHandler goa.Handler) goa.Handler {
    65  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    66  			var (
    67  				incomingToken string
    68  				err           error
    69  			)
    70  
    71  			if scheme.In == goa.LocHeader {
    72  				if incomingToken, err = extractTokenFromHeader(scheme.Name, req); err != nil {
    73  					return err
    74  				}
    75  			} else if scheme.In == goa.LocQuery {
    76  				if incomingToken, err = extractTokenFromQueryParam(scheme.Name, req); err != nil {
    77  					return err
    78  				}
    79  			} else {
    80  				return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In)
    81  			}
    82  
    83  			var (
    84  				token     *jwt.Token
    85  				validated = false
    86  			)
    87  
    88  			if len(rsaKeys) > 0 {
    89  				token, err = validateRSAKeys(rsaKeys, "RS", incomingToken)
    90  				validated = err == nil
    91  			}
    92  
    93  			if !validated && len(ecdsaKeys) > 0 {
    94  				token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken)
    95  				validated = err == nil
    96  			}
    97  
    98  			if !validated && len(hmacKeys) > 0 {
    99  				token, err = validateHMACKeys(hmacKeys, "HS", incomingToken)
   100  				//validated = err == nil
   101  			}
   102  
   103  			if err != nil {
   104  				return ErrJWTError(fmt.Sprintf("JWT validation failed: %s", err))
   105  			}
   106  
   107  			scopesInClaim, scopesInClaimList, err := parseClaimScopes(token)
   108  			if err != nil {
   109  				goa.LogError(ctx, err.Error())
   110  				return ErrJWTError(err)
   111  			}
   112  
   113  			requiredScopes := goa.ContextRequiredScopes(ctx)
   114  
   115  			for _, scope := range requiredScopes {
   116  				if !scopesInClaim[scope] {
   117  					msg := "authorization failed: required 'scope' or 'scopes' not present in JWT claim"
   118  					return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList)
   119  				}
   120  			}
   121  
   122  			ctx = WithJWT(ctx, token)
   123  			if validationFunc != nil {
   124  				nextHandler = validationFunc(nextHandler)
   125  			}
   126  			return nextHandler(ctx, rw, req)
   127  		}
   128  	}
   129  }
   130  
   131  func extractTokenFromHeader(schemeName string, req *http.Request) (string, error) {
   132  	val := req.Header.Get(schemeName)
   133  	if val == "" {
   134  		return "", ErrJWTError(fmt.Sprintf("missing header %q", schemeName))
   135  	}
   136  
   137  	if !strings.HasPrefix(strings.ToLower(val), "bearer ") {
   138  		return "", ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Bearer JWT-token...'", val))
   139  	}
   140  
   141  	incomingToken := strings.Split(val, " ")[1]
   142  
   143  	return incomingToken, nil
   144  }
   145  
   146  func extractTokenFromQueryParam(schemeName string, req *http.Request) (string, error) {
   147  	incomingToken := req.URL.Query().Get(schemeName)
   148  	if incomingToken == "" {
   149  		return "", ErrJWTError(fmt.Sprintf("missing parameter %q", schemeName))
   150  	}
   151  
   152  	return incomingToken, nil
   153  }
   154  
   155  // validScopeClaimKeys are the claims under which scopes may be found in a token
   156  var validScopeClaimKeys = []string{"scope", "scopes"}
   157  
   158  // parseClaimScopes parses the "scope" or "scopes" parameter in the Claims. It
   159  // supports two formats:
   160  //
   161  // * a list of strings
   162  //
   163  // * a single string with space-separated scopes (akin to OAuth2's "scope").
   164  //
   165  // An empty string is an explicit claim of no scopes.
   166  func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) {
   167  	scopesInClaim := make(map[string]bool)
   168  	var scopesInClaimList []string
   169  	claims, ok := token.Claims.(jwt.MapClaims)
   170  	if !ok {
   171  		return nil, nil, fmt.Errorf("unsupport claims shape")
   172  	}
   173  	for _, k := range validScopeClaimKeys {
   174  		if rawscopes, ok := claims[k]; ok && rawscopes != nil {
   175  			switch scopes := rawscopes.(type) {
   176  			case string:
   177  				for _, scope := range strings.Split(scopes, " ") {
   178  					scopesInClaim[scope] = true
   179  					scopesInClaimList = append(scopesInClaimList, scope)
   180  				}
   181  			case []interface{}:
   182  				for _, scope := range scopes {
   183  					if val, ok := scope.(string); ok {
   184  						scopesInClaim[val] = true
   185  						scopesInClaimList = append(scopesInClaimList, val)
   186  					}
   187  				}
   188  			default:
   189  				return nil, nil, fmt.Errorf("unsupported scope format in incoming JWT claim, was type %T", scopes)
   190  			}
   191  			break
   192  		}
   193  	}
   194  	sort.Strings(scopesInClaimList)
   195  	return scopesInClaim, scopesInClaimList, nil
   196  }
   197  
   198  // ErrJWTError is the error returned by this middleware when any sort of validation or assertion
   199  // fails during processing.
   200  var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401)
   201  
   202  type contextKey int
   203  
   204  const (
   205  	jwtKey contextKey = iota + 1
   206  )
   207  
   208  // partitionKeys sorts keys by their type.
   209  func partitionKeys(k interface{}) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) {
   210  	var (
   211  		rsaKeys   []*rsa.PublicKey
   212  		ecdsaKeys []*ecdsa.PublicKey
   213  		hmacKeys  [][]byte
   214  	)
   215  
   216  	switch typed := k.(type) {
   217  	case []byte:
   218  		hmacKeys = append(hmacKeys, typed)
   219  	case [][]byte:
   220  		hmacKeys = typed
   221  	case string:
   222  		hmacKeys = append(hmacKeys, []byte(typed))
   223  	case []string:
   224  		for _, s := range typed {
   225  			hmacKeys = append(hmacKeys, []byte(s))
   226  		}
   227  	case *rsa.PublicKey:
   228  		rsaKeys = append(rsaKeys, typed)
   229  	case []*rsa.PublicKey:
   230  		rsaKeys = typed
   231  	case *ecdsa.PublicKey:
   232  		ecdsaKeys = append(ecdsaKeys, typed)
   233  	case []*ecdsa.PublicKey:
   234  		ecdsaKeys = typed
   235  	}
   236  
   237  	return rsaKeys, ecdsaKeys, hmacKeys
   238  }
   239  
   240  func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   241  	for _, pubkey := range rsaKeys {
   242  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   243  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   244  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   245  			}
   246  			return pubkey, nil
   247  		})
   248  		if err == nil {
   249  			return
   250  		}
   251  	}
   252  	return
   253  }
   254  
   255  func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   256  	for _, pubkey := range ecdsaKeys {
   257  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   258  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   259  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   260  			}
   261  			return pubkey, nil
   262  		})
   263  		if err == nil {
   264  			return
   265  		}
   266  	}
   267  	return
   268  }
   269  
   270  func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) {
   271  	for _, key := range hmacKeys {
   272  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   273  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   274  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   275  			}
   276  			return key, nil
   277  		})
   278  		if err == nil {
   279  			return
   280  		}
   281  	}
   282  	return
   283  }