github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/middleware/security/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/rsa"
     6  	"fmt"
     7  	"net/http"
     8  	"sort"
     9  	"strings"
    10  
    11  	jwt "github.com/dgrijalva/jwt-go"
    12  	"github.com/goadesign/goa"
    13  	"golang.org/x/net/context"
    14  )
    15  
    16  // New returns a middleware to be used with the JWTSecurity DSL definitions of goa.  It supports the
    17  // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated.
    18  //
    19  // The steps taken by the middleware are:
    20  //
    21  //     1. Validate the "Bearer" token present in the "Authorization" header against the key(s)
    22  //        given to New
    23  //     2. If scopes are defined in the design for the action validate them against the "scopes" JWT
    24  //        claim
    25  //
    26  // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library.
    27  //
    28  // validationKeys can be one of these:
    29  //
    30  //     * []byte
    31  //     * string
    32  //     * an *rsa.PublicKey
    33  //     * an *ecdsa.PublicKey
    34  //     * a slice of any of the above
    35  //
    36  // Keys of type string or []byte are interpreted according to the signing method defined in the JWT
    37  // token's `typ` header element: `HS`, `RS`, `ES`, etc.
    38  //
    39  // You can define an optional function to do additional validations on the token once the signature
    40  // and the claims requirements are proven to be valid.  Example:
    41  //
    42  //    validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
    43  //        token := jwt.ContextJWT(ctx)
    44  //        if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" {
    45  //            return jwt.ErrJWTError("you are not uncle ben's")
    46  //        }
    47  //    })
    48  //
    49  // Mount the middleware with the generated UseXX function where XX is the name of the scheme as
    50  // defined in the design, e.g.:
    51  //
    52  //    jwtResolver, _ := jwt.NewSimpleResolver("secret")
    53  //    app.UseJWT(jwt.New(jwtResolver, validationHandler, app.NewJWTSecurity()))
    54  //
    55  func New(resolver KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
    56  	return func(nextHandler goa.Handler) goa.Handler {
    57  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    58  			// TODO: implement the QUERY string handler too
    59  			if scheme.In != goa.LocHeader {
    60  				return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In)
    61  			}
    62  			val := req.Header.Get(scheme.Name)
    63  			if val == "" {
    64  				return ErrJWTError(fmt.Sprintf("missing header %q", scheme.Name))
    65  			}
    66  
    67  			if !strings.HasPrefix(strings.ToLower(val), "bearer ") {
    68  				return ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Bearer JWT-token...'", val))
    69  			}
    70  
    71  			incomingToken := strings.Split(val, " ")[1]
    72  
    73  			rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(resolver.SelectKeys(req))
    74  
    75  			var (
    76  				token     *jwt.Token
    77  				err       error
    78  				validated = false
    79  			)
    80  
    81  			if len(rsaKeys) > 0 {
    82  				token, err = validateRSAKeys(rsaKeys, "RS", incomingToken)
    83  				if err == nil {
    84  					validated = true
    85  				}
    86  			}
    87  
    88  			if !validated && len(ecdsaKeys) > 0 {
    89  				token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken)
    90  				if err == nil {
    91  					validated = true
    92  				}
    93  			}
    94  
    95  			if !validated && len(hmacKeys) > 0 {
    96  				token, err = validateHMACKeys(hmacKeys, "HS", incomingToken)
    97  				if err == nil {
    98  					validated = true
    99  				}
   100  			}
   101  
   102  			if !validated {
   103  				return ErrJWTError("JWT validation failed")
   104  			}
   105  
   106  			scopesInClaim, scopesInClaimList, err := parseClaimScopes(token)
   107  			if err != nil {
   108  				goa.LogError(ctx, err.Error())
   109  				return ErrJWTError(err)
   110  			}
   111  
   112  			requiredScopes := goa.ContextRequiredScopes(ctx)
   113  
   114  			for _, scope := range requiredScopes {
   115  				if !scopesInClaim[scope] {
   116  					msg := "authorization failed: required 'scopes' not present in JWT claim"
   117  					return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList)
   118  				}
   119  			}
   120  
   121  			ctx = WithJWT(ctx, token)
   122  			if validationFunc != nil {
   123  				nextHandler = validationFunc(nextHandler)
   124  			}
   125  			return nextHandler(ctx, rw, req)
   126  		}
   127  	}
   128  }
   129  
   130  // partitionKeys sorts keys by their type.
   131  func partitionKeys(keys []Key) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) {
   132  	var (
   133  		rsaKeys   []*rsa.PublicKey
   134  		ecdsaKeys []*ecdsa.PublicKey
   135  		hmacKeys  [][]byte
   136  	)
   137  
   138  	for _, key := range keys {
   139  		switch k := key.(type) {
   140  		case *rsa.PublicKey:
   141  			rsaKeys = append(rsaKeys, k)
   142  		case *ecdsa.PublicKey:
   143  			ecdsaKeys = append(ecdsaKeys, k)
   144  		case []byte:
   145  			hmacKeys = append(hmacKeys, k)
   146  		case string:
   147  			hmacKeys = append(hmacKeys, []byte(k))
   148  		}
   149  	}
   150  
   151  	return rsaKeys, ecdsaKeys, hmacKeys
   152  }
   153  
   154  // parseClaimScopes parses the "scopes" parameter in the Claims. It supports two formats:
   155  //
   156  // * a list of string
   157  //
   158  // * a single string with space-separated scopes (akin to OAuth2's "scope").
   159  func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) {
   160  	scopesInClaim := make(map[string]bool)
   161  	var scopesInClaimList []string
   162  	claims, ok := token.Claims.(jwt.MapClaims)
   163  	if !ok {
   164  		return nil, nil, fmt.Errorf("unsupported claims shape")
   165  	}
   166  	if claims["scopes"] != nil {
   167  		switch scopes := claims["scopes"].(type) {
   168  		case string:
   169  			for _, scope := range strings.Split(scopes, " ") {
   170  				scopesInClaim[scope] = true
   171  				scopesInClaimList = append(scopesInClaimList, scope)
   172  			}
   173  		case []interface{}:
   174  			for _, scope := range scopes {
   175  				if val, ok := scope.(string); ok {
   176  					scopesInClaim[val] = true
   177  					scopesInClaimList = append(scopesInClaimList, val)
   178  				}
   179  			}
   180  		default:
   181  			return nil, nil, fmt.Errorf("unsupported 'scopes' format in incoming JWT claim, was type %T", scopes)
   182  		}
   183  	}
   184  	sort.Strings(scopesInClaimList)
   185  	return scopesInClaim, scopesInClaimList, nil
   186  }
   187  
   188  func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   189  	for _, pubkey := range rsaKeys {
   190  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   191  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   192  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   193  			}
   194  			return pubkey, nil
   195  		})
   196  		if err == nil {
   197  			return
   198  		}
   199  	}
   200  	return
   201  }
   202  
   203  func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   204  	for _, pubkey := range ecdsaKeys {
   205  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   206  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   207  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   208  			}
   209  			return pubkey, nil
   210  		})
   211  		if err == nil {
   212  			return
   213  		}
   214  	}
   215  	return
   216  }
   217  
   218  func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) {
   219  	for _, key := range hmacKeys {
   220  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   221  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   222  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   223  			}
   224  			return key, nil
   225  		})
   226  		if err == nil {
   227  			return
   228  		}
   229  	}
   230  	return
   231  }