github.com/brycereitano/goa@v0.0.0-20170315073847-8ffa6c85e265/middleware/security/jwt/jwt.go (about) 1 package jwt 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/rsa" 6 "fmt" 7 "net/http" 8 "sort" 9 "strings" 10 11 jwt "github.com/dgrijalva/jwt-go" 12 "github.com/goadesign/goa" 13 "golang.org/x/net/context" 14 ) 15 16 // New returns a middleware to be used with the JWTSecurity DSL definitions of goa. It supports the 17 // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated. 18 // 19 // The steps taken by the middleware are: 20 // 21 // 1. Validate the "Bearer" token present in the "Authorization" header against the key(s) 22 // given to New 23 // 2. If scopes are defined in the design for the action validate them against the "scopes" JWT 24 // claim 25 // 26 // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library. 27 // 28 // validationKeys can be one of these: 29 // 30 // * []byte 31 // * string 32 // * an *rsa.PublicKey 33 // * an *ecdsa.PublicKey 34 // * a slice of any of the above 35 // 36 // Keys of type string or []byte are interpreted according to the signing method defined in the JWT 37 // token's `typ` header element: `HS`, `RS`, `ES`, etc. 38 // 39 // You can define an optional function to do additional validations on the token once the signature 40 // and the claims requirements are proven to be valid. Example: 41 // 42 // validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { 43 // token := jwt.ContextJWT(ctx) 44 // if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" { 45 // return jwt.ErrJWTError("you are not uncle ben's") 46 // } 47 // }) 48 // 49 // Mount the middleware with the generated UseXX function where XX is the name of the scheme as 50 // defined in the design, e.g.: 51 // 52 // jwtResolver, _ := jwt.NewSimpleResolver("secret") 53 // app.UseJWT(jwt.New(jwtResolver, validationHandler, app.NewJWTSecurity())) 54 // 55 func New(resolver KeyResolver, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware { 56 return func(nextHandler goa.Handler) goa.Handler { 57 return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 58 // TODO: implement the QUERY string handler too 59 if scheme.In != goa.LocHeader { 60 return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In) 61 } 62 val := req.Header.Get(scheme.Name) 63 if val == "" { 64 return ErrJWTError(fmt.Sprintf("missing header %q", scheme.Name)) 65 } 66 67 if !strings.HasPrefix(strings.ToLower(val), "bearer ") { 68 return ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Bearer JWT-token...'", val)) 69 } 70 71 incomingToken := strings.Split(val, " ")[1] 72 73 rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(resolver.SelectKeys(req)) 74 75 var ( 76 token *jwt.Token 77 err error 78 validated = false 79 ) 80 81 if len(rsaKeys) > 0 { 82 token, err = validateRSAKeys(rsaKeys, "RS", incomingToken) 83 if err == nil { 84 validated = true 85 } 86 } 87 88 if !validated && len(ecdsaKeys) > 0 { 89 token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken) 90 if err == nil { 91 validated = true 92 } 93 } 94 95 if !validated && len(hmacKeys) > 0 { 96 token, err = validateHMACKeys(hmacKeys, "HS", incomingToken) 97 if err == nil { 98 validated = true 99 } 100 } 101 102 if !validated { 103 return ErrJWTError("JWT validation failed") 104 } 105 106 scopesInClaim, scopesInClaimList, err := parseClaimScopes(token) 107 if err != nil { 108 goa.LogError(ctx, err.Error()) 109 return ErrJWTError(err) 110 } 111 112 requiredScopes := goa.ContextRequiredScopes(ctx) 113 114 for _, scope := range requiredScopes { 115 if !scopesInClaim[scope] { 116 msg := "authorization failed: required 'scopes' not present in JWT claim" 117 return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList) 118 } 119 } 120 121 ctx = WithJWT(ctx, token) 122 if validationFunc != nil { 123 nextHandler = validationFunc(nextHandler) 124 } 125 return nextHandler(ctx, rw, req) 126 } 127 } 128 } 129 130 // partitionKeys sorts keys by their type. 131 func partitionKeys(keys []Key) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) { 132 var ( 133 rsaKeys []*rsa.PublicKey 134 ecdsaKeys []*ecdsa.PublicKey 135 hmacKeys [][]byte 136 ) 137 138 for _, key := range keys { 139 switch k := key.(type) { 140 case *rsa.PublicKey: 141 rsaKeys = append(rsaKeys, k) 142 case *ecdsa.PublicKey: 143 ecdsaKeys = append(ecdsaKeys, k) 144 case []byte: 145 hmacKeys = append(hmacKeys, k) 146 case string: 147 hmacKeys = append(hmacKeys, []byte(k)) 148 } 149 } 150 151 return rsaKeys, ecdsaKeys, hmacKeys 152 } 153 154 // parseClaimScopes parses the "scopes" parameter in the Claims. It supports two formats: 155 // 156 // * a list of string 157 // 158 // * a single string with space-separated scopes (akin to OAuth2's "scope"). 159 func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) { 160 scopesInClaim := make(map[string]bool) 161 var scopesInClaimList []string 162 claims, ok := token.Claims.(jwt.MapClaims) 163 if !ok { 164 return nil, nil, fmt.Errorf("unsupported claims shape") 165 } 166 if claims["scopes"] != nil { 167 switch scopes := claims["scopes"].(type) { 168 case string: 169 for _, scope := range strings.Split(scopes, " ") { 170 scopesInClaim[scope] = true 171 scopesInClaimList = append(scopesInClaimList, scope) 172 } 173 case []interface{}: 174 for _, scope := range scopes { 175 if val, ok := scope.(string); ok { 176 scopesInClaim[val] = true 177 scopesInClaimList = append(scopesInClaimList, val) 178 } 179 } 180 default: 181 return nil, nil, fmt.Errorf("unsupported 'scopes' format in incoming JWT claim, was type %T", scopes) 182 } 183 } 184 sort.Strings(scopesInClaimList) 185 return scopesInClaim, scopesInClaimList, nil 186 } 187 188 func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 189 for _, pubkey := range rsaKeys { 190 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 191 if !strings.HasPrefix(token.Method.Alg(), algo) { 192 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 193 } 194 return pubkey, nil 195 }) 196 if err == nil { 197 return 198 } 199 } 200 return 201 } 202 203 func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) { 204 for _, pubkey := range ecdsaKeys { 205 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 206 if !strings.HasPrefix(token.Method.Alg(), algo) { 207 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 208 } 209 return pubkey, nil 210 }) 211 if err == nil { 212 return 213 } 214 } 215 return 216 } 217 218 func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) { 219 for _, key := range hmacKeys { 220 token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) { 221 if !strings.HasPrefix(token.Method.Alg(), algo) { 222 return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"])) 223 } 224 return key, nil 225 }) 226 if err == nil { 227 return 228 } 229 } 230 return 231 }