github.com/goldeneggg/goa@v1.3.1/middleware/security/jwt/jwt.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/rsa"
     6  	"fmt"
     7  	"net/http"
     8  	"sort"
     9  	"strings"
    10  
    11  	"context"
    12  
    13  	jwt "github.com/dgrijalva/jwt-go"
    14  	"github.com/goadesign/goa"
    15  )
    16  
    17  // New returns a middleware to be used with the JWTSecurity DSL definitions of goa.  It supports the
    18  // scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated.
    19  //
    20  // The steps taken by the middleware are:
    21  //
    22  //     1. Validate the "Bearer" token present in the "Authorization" header against the key(s)
    23  //        given to New
    24  //     2. If scopes are defined in the design for the action, validate them
    25  //        against the scopes presented by the JWT in the claim "scope", or if
    26  //        that's not defined, "scopes".
    27  //
    28  // The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library.
    29  //
    30  // validationKeys can be one of these:
    31  //
    32  //     * a string (for HMAC)
    33  //     * a []byte (for HMAC)
    34  //     * an rsa.PublicKey
    35  //     * an ecdsa.PublicKey
    36  //     * a slice of any of the above
    37  //
    38  // The type of the keys determine the algorithm that will be used to do the check.  The goal of
    39  // having lists of keys is to allow for key rotation, still check the previous keys until rotation
    40  // has been completed.
    41  //
    42  // You can define an optional function to do additional validations on the token once the signature
    43  // and the claims requirements are proven to be valid.  Example:
    44  //
    45  //    validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
    46  //        token := jwt.ContextJWT(ctx)
    47  //        if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" {
    48  //            return jwt.ErrJWTError("you are not uncle ben's")
    49  //        }
    50  //    })
    51  //
    52  // Mount the middleware with the generated UseXX function where XX is the name of the scheme as
    53  // defined in the design, e.g.:
    54  //
    55  //    app.UseJWT(jwt.New("secret", validationHandler, app.NewJWTSecurity()))
    56  //
    57  func New(validationKeys interface{}, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
    58  	var rsaKeys []*rsa.PublicKey
    59  	var hmacKeys [][]byte
    60  
    61  	rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(validationKeys)
    62  
    63  	return func(nextHandler goa.Handler) goa.Handler {
    64  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    65  			// TODO: implement the QUERY string handler too
    66  			if scheme.In != goa.LocHeader {
    67  				return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In)
    68  			}
    69  			val := req.Header.Get(scheme.Name)
    70  			if val == "" {
    71  				return ErrJWTError(fmt.Sprintf("missing header %q", scheme.Name))
    72  			}
    73  
    74  			if !strings.HasPrefix(strings.ToLower(val), "bearer ") {
    75  				return ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Authorization: Bearer JWT-token...'", val))
    76  			}
    77  
    78  			incomingToken := strings.Split(val, " ")[1]
    79  
    80  			var (
    81  				token     *jwt.Token
    82  				err       error
    83  				validated = false
    84  			)
    85  
    86  			if len(rsaKeys) > 0 {
    87  				token, err = validateRSAKeys(rsaKeys, "RS", incomingToken)
    88  				validated = err == nil
    89  			}
    90  
    91  			if !validated && len(ecdsaKeys) > 0 {
    92  				token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken)
    93  				validated = err == nil
    94  			}
    95  
    96  			if !validated && len(hmacKeys) > 0 {
    97  				token, err = validateHMACKeys(hmacKeys, "HS", incomingToken)
    98  				//validated = err == nil
    99  			}
   100  
   101  			if err != nil {
   102  				return ErrJWTError(fmt.Sprintf("JWT validation failed: %s", err))
   103  			}
   104  
   105  			scopesInClaim, scopesInClaimList, err := parseClaimScopes(token)
   106  			if err != nil {
   107  				goa.LogError(ctx, err.Error())
   108  				return ErrJWTError(err)
   109  			}
   110  
   111  			requiredScopes := goa.ContextRequiredScopes(ctx)
   112  
   113  			for _, scope := range requiredScopes {
   114  				if !scopesInClaim[scope] {
   115  					msg := "authorization failed: required 'scope' or 'scopes' not present in JWT claim"
   116  					return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList)
   117  				}
   118  			}
   119  
   120  			ctx = WithJWT(ctx, token)
   121  			if validationFunc != nil {
   122  				nextHandler = validationFunc(nextHandler)
   123  			}
   124  			return nextHandler(ctx, rw, req)
   125  		}
   126  	}
   127  }
   128  
   129  // validScopeClaimKeys are the claims under which scopes may be found in a token
   130  var validScopeClaimKeys = []string{"scope", "scopes"}
   131  
   132  // parseClaimScopes parses the "scope" or "scopes" parameter in the Claims. It
   133  // supports two formats:
   134  //
   135  // * a list of strings
   136  //
   137  // * a single string with space-separated scopes (akin to OAuth2's "scope").
   138  //
   139  // An empty string is an explicit claim of no scopes.
   140  func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) {
   141  	scopesInClaim := make(map[string]bool)
   142  	var scopesInClaimList []string
   143  	claims, ok := token.Claims.(jwt.MapClaims)
   144  	if !ok {
   145  		return nil, nil, fmt.Errorf("unsupport claims shape")
   146  	}
   147  	for _, k := range validScopeClaimKeys {
   148  		if rawscopes, ok := claims[k]; ok && rawscopes != nil {
   149  			switch scopes := rawscopes.(type) {
   150  			case string:
   151  				for _, scope := range strings.Split(scopes, " ") {
   152  					scopesInClaim[scope] = true
   153  					scopesInClaimList = append(scopesInClaimList, scope)
   154  				}
   155  			case []interface{}:
   156  				for _, scope := range scopes {
   157  					if val, ok := scope.(string); ok {
   158  						scopesInClaim[val] = true
   159  						scopesInClaimList = append(scopesInClaimList, val)
   160  					}
   161  				}
   162  			default:
   163  				return nil, nil, fmt.Errorf("unsupported scope format in incoming JWT claim, was type %T", scopes)
   164  			}
   165  			break
   166  		}
   167  	}
   168  	sort.Strings(scopesInClaimList)
   169  	return scopesInClaim, scopesInClaimList, nil
   170  }
   171  
   172  // ErrJWTError is the error returned by this middleware when any sort of validation or assertion
   173  // fails during processing.
   174  var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401)
   175  
   176  type contextKey int
   177  
   178  const (
   179  	jwtKey contextKey = iota + 1
   180  )
   181  
   182  // partitionKeys sorts keys by their type.
   183  func partitionKeys(k interface{}) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) {
   184  	var (
   185  		rsaKeys   []*rsa.PublicKey
   186  		ecdsaKeys []*ecdsa.PublicKey
   187  		hmacKeys  [][]byte
   188  	)
   189  
   190  	switch typed := k.(type) {
   191  	case []byte:
   192  		hmacKeys = append(hmacKeys, typed)
   193  	case [][]byte:
   194  		hmacKeys = typed
   195  	case string:
   196  		hmacKeys = append(hmacKeys, []byte(typed))
   197  	case []string:
   198  		for _, s := range typed {
   199  			hmacKeys = append(hmacKeys, []byte(s))
   200  		}
   201  	case *rsa.PublicKey:
   202  		rsaKeys = append(rsaKeys, typed)
   203  	case []*rsa.PublicKey:
   204  		rsaKeys = typed
   205  	case *ecdsa.PublicKey:
   206  		ecdsaKeys = append(ecdsaKeys, typed)
   207  	case []*ecdsa.PublicKey:
   208  		ecdsaKeys = typed
   209  	}
   210  
   211  	return rsaKeys, ecdsaKeys, hmacKeys
   212  }
   213  
   214  func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   215  	for _, pubkey := range rsaKeys {
   216  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   217  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   218  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   219  			}
   220  			return pubkey, nil
   221  		})
   222  		if err == nil {
   223  			return
   224  		}
   225  	}
   226  	return
   227  }
   228  
   229  func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
   230  	for _, pubkey := range ecdsaKeys {
   231  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   232  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   233  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   234  			}
   235  			return pubkey, nil
   236  		})
   237  		if err == nil {
   238  			return
   239  		}
   240  	}
   241  	return
   242  }
   243  
   244  func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) {
   245  	for _, key := range hmacKeys {
   246  		token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
   247  			if !strings.HasPrefix(token.Method.Alg(), algo) {
   248  				return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
   249  			}
   250  			return key, nil
   251  		})
   252  		if err == nil {
   253  			return
   254  		}
   255  	}
   256  	return
   257  }