github.com/zak-blake/goa@v1.4.1/middleware/security/jwt/jwt.go (about) 1 package jwt 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/rsa" 6 "fmt" 7 "net/http" 8 "sort" 9 "strings" 10 11 "context" 12 13 jwt "github.com/dgrijalva/jwt-go" 14 "github.com/goadesign/goa" 15 ) 16 17 // New returns a middleware to be used with the JWTSecurity DSL definitions of goa. It supports the 18 // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated. 19 // 20 // The steps taken by the middleware are: 21 // 22 // 1. Extract the "Bearer" token from the Authorization header or query parameter 23 // 2. Validate the "Bearer" token against the key(s) 24 // given to New 25 // 3. If scopes are defined in the design for the action, validate them 26 // against the scopes presented by the JWT in the claim "scope", or if 27 // that's not defined, "scopes". 28 // 29 // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library. 30 // 31 // validationKeys can be one of these: 32 // 33 // * a string (for HMAC) 34 // * a []byte (for HMAC) 35 // * an rsa.PublicKey 36 // * an ecdsa.PublicKey 37 // * a slice of any of the above 38 // 39 // The type of the keys determine the algorithm that will be used to do the check. The goal of 40 // having lists of keys is to allow for key rotation, still check the previous keys until rotation 41 // has been completed. 42 // 43 // You can define an optional function to do additional validations on the token once the signature 44 // and the claims requirements are proven to be valid. Example: 45 // 46 // validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 47 // token := jwt.ContextJWT(ctx) 48 // if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" { 49 // return jwt.ErrJWTError("you are not uncle ben's") 50 // } 51 // }) 52 // 53 // Mount the middleware with the generated UseXX function where XX is the name of the scheme as 54 // defined in the design, e.g.: 55 // 56 // app.UseJWT(jwt.New("secret", validationHandler, app.NewJWTSecurity())) 57 // 58 func New(validationKeys interface{}, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware { 59 var rsaKeys []*rsa.PublicKey 60 var hmacKeys [][]byte 61 62 rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(validationKeys) 63 64 return func(nextHandler goa.Handler) goa.Handler { 65 return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 66 var ( 67 incomingToken string 68 err error 69 ) 70 71 if scheme.In == goa.LocHeader { 72 if incomingToken, err = extractTokenFromHeader(scheme.Name, req); err != nil { 73 return err 74 } 75 } else if scheme.In == goa.LocQuery { 76 if incomingToken, err = extractTokenFromQueryParam(scheme.Name, req); err != nil { 77 return err 78 } 79 } else { 80 return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In) 81 } 82 83 var ( 84 token *jwt.Token 85 validated = false 86 ) 87 88 if len(rsaKeys) > 0 { 89 token, err = validateRSAKeys(rsaKeys, "RS", incomingToken) 90 validated = err == nil 91 } 92 93 if !validated && len(ecdsaKeys) > 0 { 94 token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken) 95 validated = err == nil 96 } 97 98 if !validated && len(hmacKeys) > 0 { 99 token, err = validateHMACKeys(hmacKeys, "HS", incomingToken) 100 //validated = err == nil 101 } 102 103 if err != nil { 104 return ErrJWTError(fmt.Sprintf("JWT validation failed: %s", err)) 105 } 106 107 scopesInClaim, scopesInClaimList, err := parseClaimScopes(token) 108 if err != nil { 109 goa.LogError(ctx, err.Error()) 110 return ErrJWTError(err) 111 } 112 113 requiredScopes := goa.ContextRequiredScopes(ctx) 114 115 for _, scope := range requiredScopes { 116 if !scopesInClaim[scope] { 117 msg := "authorization failed: required 'scope' or 'scopes' not present in JWT claim" 118 return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList) 119 } 120 } 121 122 ctx = WithJWT(ctx, token) 123 if validationFunc != nil { 124 nextHandler = validationFunc(nextHandler) 125 } 126 return nextHandler(ctx, rw, req) 127 } 128 } 129 } 130 131 func extractTokenFromHeader(schemeName string, req *http.Request) (string, error) { 132 val := req.Header.Get(schemeName) 133 if val == "" { 134 return "", ErrJWTError(fmt.Sprintf("missing header %q", schemeName)) 135 } 136 137 if !strings.HasPrefix(strings.ToLower(val), "bearer ") { 138 return "", ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Bearer JWT-token...'", val)) 139 } 140 141 incomingToken := strings.Split(val, " ")[1] 142 143 return incomingToken, nil 144 } 145 146 func extractTokenFromQueryParam(schemeName string, req *http.Request) (string, error) { 147 incomingToken := req.URL.Query().Get(schemeName) 148 if incomingToken == "" { 149 return "", ErrJWTError(fmt.Sprintf("missing parameter %q", schemeName)) 150 } 151 152 return incomingToken, nil 153 } 154 155 // validScopeClaimKeys are the claims under which scopes may be found in a token 156 var validScopeClaimKeys = []string{"scope", "scopes"} 157 158 // parseClaimScopes parses the "scope" or "scopes" parameter in the Claims. It 159 // supports two formats: 160 // 161 // * a list of strings 162 // 163 // * a single string with space-separated scopes (akin to OAuth2's "scope"). 164 // 165 // An empty string is an explicit claim of no scopes. 166 func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) { 167 scopesInClaim := make(map[string]bool) 168 var scopesInClaimList []string 169 claims, ok := token.Claims.(jwt.MapClaims) 170 if !ok { 171 return nil, nil, fmt.Errorf("unsupport claims shape") 172 } 173 for _, k := range validScopeClaimKeys { 174 if rawscopes, ok := claims[k]; ok && rawscopes != nil { 175 switch scopes := rawscopes.(type) { 176 case string: 177 for _, scope := range strings.Split(scopes, " ") { 178 scopesInClaim[scope] = true 179 scopesInClaimList = append(scopesInClaimList, scope) 180 } 181 case []interface{}: 182 for _, scope := range scopes { 183 if val, ok := scope.(string); ok { 184 scopesInClaim[val] = true 185 scopesInClaimList = append(scopesInClaimList, val) 186 } 187 } 188 default: 189 return nil, nil, fmt.Errorf("unsupported scope format in incoming JWT claim, was type %T", scopes) 190 } 191 break 192 } 193 } 194 sort.Strings(scopesInClaimList) 195 return scopesInClaim, scopesInClaimList, nil 196 } 197 198 // ErrJWTError is the error returned by this middleware when any sort of validation or assertion 199 // fails during processing. 200 var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401) 201 202 type contextKey int 203 204 const ( 205 jwtKey contextKey = iota + 1 206 ) 207 208 // partitionKeys sorts keys by their type. 209 func partitionKeys(k interface{}) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) { 210 var ( 211 rsaKeys []*rsa.PublicKey 212 ecdsaKeys []*ecdsa.PublicKey 213 hmacKeys [][]byte 214 ) 215 216 switch typed := k.(type) { 217 case []byte: 218 hmacKeys = append(hmacKeys, typed) 219 case [][]byte: 220 hmacKeys = typed 221 case string: 222 hmacKeys = append(hmacKeys, []byte(typed)) 223 case []string: 224 for _, s := range typed { 225 hmacKeys = append(hmacKeys, []byte(s)) 226 } 227 case *rsa.PublicKey: 228 rsaKeys = append(rsaKeys, typed) 229 case []*rsa.PublicKey: 230 rsaKeys = typed 231 case *ecdsa.PublicKey: 232 ecdsaKeys = append(ecdsaKeys, typed) 233 case []*ecdsa.PublicKey: 234 ecdsaKeys = typed 235 } 236 237 return rsaKeys, ecdsaKeys, hmacKeys 238 } 239 240 func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 241 for _, pubkey := range rsaKeys { 242 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 243 if !strings.HasPrefix(token.Method.Alg(), algo) { 244 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 245 } 246 return pubkey, nil 247 }) 248 if err == nil { 249 return 250 } 251 } 252 return 253 } 254 255 func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 256 for _, pubkey := range ecdsaKeys { 257 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 258 if !strings.HasPrefix(token.Method.Alg(), algo) { 259 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 260 } 261 return pubkey, nil 262 }) 263 if err == nil { 264 return 265 } 266 } 267 return 268 } 269 270 func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) { 271 for _, key := range hmacKeys { 272 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 273 if !strings.HasPrefix(token.Method.Alg(), algo) { 274 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 275 } 276 return key, nil 277 }) 278 if err == nil { 279 return 280 } 281 } 282 return 283 }