github.com/goldeneggg/goa@v1.3.1/middleware/security/jwt/jwt.go (about) 1 package jwt 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/rsa" 6 "fmt" 7 "net/http" 8 "sort" 9 "strings" 10 11 "context" 12 13 jwt "github.com/dgrijalva/jwt-go" 14 "github.com/goadesign/goa" 15 ) 16 17 // New returns a middleware to be used with the JWTSecurity DSL definitions of goa. It supports the 18 // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated. 19 // 20 // The steps taken by the middleware are: 21 // 22 // 1. Validate the "Bearer" token present in the "Authorization" header against the key(s) 23 // given to New 24 // 2. If scopes are defined in the design for the action, validate them 25 // against the scopes presented by the JWT in the claim "scope", or if 26 // that's not defined, "scopes". 27 // 28 // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library. 29 // 30 // validationKeys can be one of these: 31 // 32 // * a string (for HMAC) 33 // * a []byte (for HMAC) 34 // * an rsa.PublicKey 35 // * an ecdsa.PublicKey 36 // * a slice of any of the above 37 // 38 // The type of the keys determine the algorithm that will be used to do the check. The goal of 39 // having lists of keys is to allow for key rotation, still check the previous keys until rotation 40 // has been completed. 41 // 42 // You can define an optional function to do additional validations on the token once the signature 43 // and the claims requirements are proven to be valid. Example: 44 // 45 // validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 46 // token := jwt.ContextJWT(ctx) 47 // if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" { 48 // return jwt.ErrJWTError("you are not uncle ben's") 49 // } 50 // }) 51 // 52 // Mount the middleware with the generated UseXX function where XX is the name of the scheme as 53 // defined in the design, e.g.: 54 // 55 // app.UseJWT(jwt.New("secret", validationHandler, app.NewJWTSecurity())) 56 // 57 func New(validationKeys interface{}, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware { 58 var rsaKeys []*rsa.PublicKey 59 var hmacKeys [][]byte 60 61 rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(validationKeys) 62 63 return func(nextHandler goa.Handler) goa.Handler { 64 return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 65 // TODO: implement the QUERY string handler too 66 if scheme.In != goa.LocHeader { 67 return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In) 68 } 69 val := req.Header.Get(scheme.Name) 70 if val == "" { 71 return ErrJWTError(fmt.Sprintf("missing header %q", scheme.Name)) 72 } 73 74 if !strings.HasPrefix(strings.ToLower(val), "bearer ") { 75 return ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Authorization: Bearer JWT-token...'", val)) 76 } 77 78 incomingToken := strings.Split(val, " ")[1] 79 80 var ( 81 token *jwt.Token 82 err error 83 validated = false 84 ) 85 86 if len(rsaKeys) > 0 { 87 token, err = validateRSAKeys(rsaKeys, "RS", incomingToken) 88 validated = err == nil 89 } 90 91 if !validated && len(ecdsaKeys) > 0 { 92 token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken) 93 validated = err == nil 94 } 95 96 if !validated && len(hmacKeys) > 0 { 97 token, err = validateHMACKeys(hmacKeys, "HS", incomingToken) 98 //validated = err == nil 99 } 100 101 if err != nil { 102 return ErrJWTError(fmt.Sprintf("JWT validation failed: %s", err)) 103 } 104 105 scopesInClaim, scopesInClaimList, err := parseClaimScopes(token) 106 if err != nil { 107 goa.LogError(ctx, err.Error()) 108 return ErrJWTError(err) 109 } 110 111 requiredScopes := goa.ContextRequiredScopes(ctx) 112 113 for _, scope := range requiredScopes { 114 if !scopesInClaim[scope] { 115 msg := "authorization failed: required 'scope' or 'scopes' not present in JWT claim" 116 return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList) 117 } 118 } 119 120 ctx = WithJWT(ctx, token) 121 if validationFunc != nil { 122 nextHandler = validationFunc(nextHandler) 123 } 124 return nextHandler(ctx, rw, req) 125 } 126 } 127 } 128 129 // validScopeClaimKeys are the claims under which scopes may be found in a token 130 var validScopeClaimKeys = []string{"scope", "scopes"} 131 132 // parseClaimScopes parses the "scope" or "scopes" parameter in the Claims. It 133 // supports two formats: 134 // 135 // * a list of strings 136 // 137 // * a single string with space-separated scopes (akin to OAuth2's "scope"). 138 // 139 // An empty string is an explicit claim of no scopes. 140 func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) { 141 scopesInClaim := make(map[string]bool) 142 var scopesInClaimList []string 143 claims, ok := token.Claims.(jwt.MapClaims) 144 if !ok { 145 return nil, nil, fmt.Errorf("unsupport claims shape") 146 } 147 for _, k := range validScopeClaimKeys { 148 if rawscopes, ok := claims[k]; ok && rawscopes != nil { 149 switch scopes := rawscopes.(type) { 150 case string: 151 for _, scope := range strings.Split(scopes, " ") { 152 scopesInClaim[scope] = true 153 scopesInClaimList = append(scopesInClaimList, scope) 154 } 155 case []interface{}: 156 for _, scope := range scopes { 157 if val, ok := scope.(string); ok { 158 scopesInClaim[val] = true 159 scopesInClaimList = append(scopesInClaimList, val) 160 } 161 } 162 default: 163 return nil, nil, fmt.Errorf("unsupported scope format in incoming JWT claim, was type %T", scopes) 164 } 165 break 166 } 167 } 168 sort.Strings(scopesInClaimList) 169 return scopesInClaim, scopesInClaimList, nil 170 } 171 172 // ErrJWTError is the error returned by this middleware when any sort of validation or assertion 173 // fails during processing. 174 var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401) 175 176 type contextKey int 177 178 const ( 179 jwtKey contextKey = iota + 1 180 ) 181 182 // partitionKeys sorts keys by their type. 183 func partitionKeys(k interface{}) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) { 184 var ( 185 rsaKeys []*rsa.PublicKey 186 ecdsaKeys []*ecdsa.PublicKey 187 hmacKeys [][]byte 188 ) 189 190 switch typed := k.(type) { 191 case []byte: 192 hmacKeys = append(hmacKeys, typed) 193 case [][]byte: 194 hmacKeys = typed 195 case string: 196 hmacKeys = append(hmacKeys, []byte(typed)) 197 case []string: 198 for _, s := range typed { 199 hmacKeys = append(hmacKeys, []byte(s)) 200 } 201 case *rsa.PublicKey: 202 rsaKeys = append(rsaKeys, typed) 203 case []*rsa.PublicKey: 204 rsaKeys = typed 205 case *ecdsa.PublicKey: 206 ecdsaKeys = append(ecdsaKeys, typed) 207 case []*ecdsa.PublicKey: 208 ecdsaKeys = typed 209 } 210 211 return rsaKeys, ecdsaKeys, hmacKeys 212 } 213 214 func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 215 for _, pubkey := range rsaKeys { 216 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 217 if !strings.HasPrefix(token.Method.Alg(), algo) { 218 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 219 } 220 return pubkey, nil 221 }) 222 if err == nil { 223 return 224 } 225 } 226 return 227 } 228 229 func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 230 for _, pubkey := range ecdsaKeys { 231 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 232 if !strings.HasPrefix(token.Method.Alg(), algo) { 233 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 234 } 235 return pubkey, nil 236 }) 237 if err == nil { 238 return 239 } 240 } 241 return 242 } 243 244 func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) { 245 for _, key := range hmacKeys { 246 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 247 if !strings.HasPrefix(token.Method.Alg(), algo) { 248 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 249 } 250 return key, nil 251 }) 252 if err == nil { 253 return 254 } 255 } 256 return 257 }