github.com/xmidt-org/webpa-common@v1.11.9/secure/validator.go (about)

     1  package secure
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"regexp"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/SermoDigital/jose/jws"
    12  	"github.com/SermoDigital/jose/jwt"
    13  	"github.com/xmidt-org/webpa-common/secure/key"
    14  )
    15  
    16  var (
    17  	ErrorNoProtectedHeader = errors.New("Missing protected header")
    18  	ErrorNoSigningMethod   = errors.New("Signing method (alg) is missing or unrecognized")
    19  )
    20  
    21  // Validator describes the behavior of a type which can validate tokens
    22  type Validator interface {
    23  	// Validate asserts that the given token is valid, most often verifying
    24  	// the credentials in the token.  A separate error is returned to indicate
    25  	// any problems during validation, such as the inability to access a network resource.
    26  	// In general, the contract of this method is that a Token passes validation
    27  	// if and only if it returns BOTH true and a nil error.
    28  	Validate(context.Context, *Token) (bool, error)
    29  }
    30  
    31  // ValidatorFunc is a function type that implements Validator
    32  type ValidatorFunc func(context.Context, *Token) (bool, error)
    33  
    34  func (v ValidatorFunc) Validate(ctx context.Context, token *Token) (bool, error) {
    35  	return v(ctx, token)
    36  }
    37  
    38  // Validators is an aggregate Validator.  A Validators instance considers a token
    39  // valid if any of its validators considers it valid.  An empty Validators rejects
    40  // all tokens.
    41  type Validators []Validator
    42  
    43  func (v Validators) Validate(ctx context.Context, token *Token) (valid bool, err error) {
    44  	for _, validator := range v {
    45  		if valid, err = validator.Validate(ctx, token); valid && err == nil {
    46  			return
    47  		}
    48  	}
    49  
    50  	return
    51  }
    52  
    53  // ExactMatchValidator simply matches a token's value (exluding the prefix, such as "Basic"),
    54  // to a string.
    55  type ExactMatchValidator string
    56  
    57  func (v ExactMatchValidator) Validate(ctx context.Context, token *Token) (bool, error) {
    58  	for _, value := range strings.Split(string(v), ",") {
    59  		if value == token.value {
    60  			return true, nil
    61  		}
    62  	}
    63  
    64  	return false, nil
    65  }
    66  
    67  // JWSValidator provides validation for JWT tokens encoded as JWS.
    68  type JWSValidator struct {
    69  	DefaultKeyId  string
    70  	Resolver      key.Resolver
    71  	Parser        JWSParser
    72  	JWTValidators []*jwt.Validator
    73  	measures      *JWTValidationMeasures
    74  }
    75  
    76  // capabilityValidation determines if a claim's capability is valid
    77  func capabilityValidation(ctx context.Context, capability string) (valid_capabilities bool) {
    78  	pieces := strings.Split(capability, ":")
    79  
    80  	if len(pieces) == 5 &&
    81  		pieces[0] == "x1" &&
    82  		pieces[1] == "webpa" {
    83  
    84  		method_value, ok := ctx.Value("method").(string)
    85  		if ok && (pieces[4] == "all" || strings.EqualFold(pieces[4], method_value)) {
    86  			claimPath := fmt.Sprintf("/%s/[^/]+/%s", pieces[2], pieces[3])
    87  			valid_capabilities, _ = regexp.MatchString(claimPath, ctx.Value("path").(string))
    88  		}
    89  	}
    90  
    91  	return
    92  }
    93  
    94  func (v JWSValidator) Validate(ctx context.Context, token *Token) (valid bool, err error) {
    95  	if token.Type() != Bearer {
    96  		return
    97  	}
    98  
    99  	parser := v.Parser
   100  	if parser == nil {
   101  		parser = DefaultJWSParser
   102  	}
   103  
   104  	jwsToken, err := parser.ParseJWS(token)
   105  	if err != nil {
   106  		return
   107  	}
   108  
   109  	protected := jwsToken.Protected()
   110  	if len(protected) == 0 {
   111  		err = ErrorNoProtectedHeader
   112  		return
   113  	}
   114  
   115  	alg, _ := protected.Get("alg").(string)
   116  	signingMethod := jws.GetSigningMethod(alg)
   117  	if signingMethod == nil {
   118  		err = ErrorNoSigningMethod
   119  		return
   120  	}
   121  
   122  	keyId, _ := protected.Get("kid").(string)
   123  	if len(keyId) == 0 {
   124  		keyId = v.DefaultKeyId
   125  	}
   126  
   127  	pair, err := v.Resolver.ResolveKey(keyId)
   128  	if err != nil {
   129  		return
   130  	}
   131  
   132  	// validate the signature
   133  	if len(v.JWTValidators) > 0 {
   134  		// all JWS implementations also implement jwt.JWT
   135  		err = jwsToken.(jwt.JWT).Validate(pair.Public(), signingMethod, v.JWTValidators...)
   136  	} else {
   137  		err = jwsToken.Verify(pair.Public(), signingMethod)
   138  	}
   139  
   140  	if nil != err {
   141  		if v.measures != nil {
   142  
   143  			//capture specific cases of interest, default to global (invalid_signature) reason
   144  			switch err {
   145  			case jwt.ErrTokenIsExpired:
   146  				v.measures.ValidationReason.With("reason", "expired_token").Add(1)
   147  				break
   148  			case jwt.ErrTokenNotYetValid:
   149  				v.measures.ValidationReason.With("reason", "premature_token").Add(1)
   150  				break
   151  
   152  			default:
   153  				v.measures.ValidationReason.With("reason", "invalid_signature").Add(1)
   154  			}
   155  		}
   156  		return
   157  	}
   158  
   159  	// validate jwt token claims capabilities
   160  	if caps, capOkay := jwsToken.Payload().(jws.Claims).Get("capabilities").([]interface{}); capOkay && len(caps) > 0 {
   161  
   162  		/*  commenting out for now
   163  		    1. remove code in use below
   164  		    2. make sure to bring a back tests for this as well.
   165  		        - TestJWSValidatorCapabilities()
   166  
   167  				for c := 0; c < len(caps); c++ {
   168  					if cap_value, ok := caps[c].(string); ok {
   169  						if valid = capabilityValidation(ctx, cap_value); valid {
   170  							return
   171  						}
   172  					}
   173  				}
   174  		*/
   175  		// *****  REMOVE THIS CODE AFTER BRING BACK THE COMMENTED CODE ABOVE *****
   176  		// ***** vvvvvvvvvvvvvvv *****
   177  
   178  		// successful validation
   179  		if v.measures != nil {
   180  			v.measures.ValidationReason.With("reason", "ok").Add(1)
   181  		}
   182  
   183  		return true, nil
   184  		// ***** ^^^^^^^^^^^^^^^ *****
   185  
   186  	}
   187  
   188  	// This fail
   189  	return
   190  }
   191  
   192  //DefineMeasures defines the metrics tool used by JWSValidator
   193  func (v *JWSValidator) DefineMeasures(m *JWTValidationMeasures) {
   194  	v.measures = m
   195  }
   196  
   197  // JWTValidatorFactory is a configurable factory for *jwt.Validator instances
   198  type JWTValidatorFactory struct {
   199  	Expected  jwt.Claims `json:"expected"`
   200  	ExpLeeway int        `json:"expLeeway"`
   201  	NbfLeeway int        `json:"nbfLeeway"`
   202  	measures  *JWTValidationMeasures
   203  }
   204  
   205  func (f *JWTValidatorFactory) expLeeway() time.Duration {
   206  	if f.ExpLeeway > 0 {
   207  		return time.Duration(f.ExpLeeway) * time.Second
   208  	}
   209  
   210  	return 0
   211  }
   212  
   213  func (f *JWTValidatorFactory) nbfLeeway() time.Duration {
   214  	if f.NbfLeeway > 0 {
   215  		return time.Duration(f.NbfLeeway) * time.Second
   216  	}
   217  
   218  	return 0
   219  }
   220  
   221  //DefineMeasures helps establish the metrics tools
   222  func (f *JWTValidatorFactory) DefineMeasures(m *JWTValidationMeasures) {
   223  	f.measures = m
   224  }
   225  
   226  // New returns a jwt.Validator using the configuration expected claims (if any)
   227  // and a validator function that checks the exp and nbf claims.
   228  //
   229  // The SermoDigital library doesn't appear to do anything with the EXP and NBF
   230  // members of jwt.Validator, but this Factory Method populates them anyway.
   231  func (f *JWTValidatorFactory) New(custom ...jwt.ValidateFunc) *jwt.Validator {
   232  	expLeeway := f.expLeeway()
   233  	nbfLeeway := f.nbfLeeway()
   234  
   235  	var validateFunc jwt.ValidateFunc
   236  	customCount := len(custom)
   237  	if customCount > 0 {
   238  		validateFunc = func(claims jwt.Claims) (err error) {
   239  			now := time.Now()
   240  			err = claims.Validate(now, expLeeway, nbfLeeway)
   241  			for index := 0; index < customCount && err == nil; index++ {
   242  				err = custom[index](claims)
   243  			}
   244  
   245  			f.observeMeasures(claims, now, expLeeway, nbfLeeway, err)
   246  
   247  			return
   248  		}
   249  	} else {
   250  		// if no custom validate functions were passed, use a simpler function
   251  		validateFunc = func(claims jwt.Claims) (err error) {
   252  			now := time.Now()
   253  			err = claims.Validate(now, expLeeway, nbfLeeway)
   254  
   255  			f.observeMeasures(claims, now, expLeeway, nbfLeeway, err)
   256  
   257  			return
   258  		}
   259  	}
   260  
   261  	return &jwt.Validator{
   262  		Expected: f.Expected,
   263  		EXP:      expLeeway,
   264  		NBF:      nbfLeeway,
   265  		Fn:       validateFunc,
   266  	}
   267  }
   268  
   269  func (f *JWTValidatorFactory) observeMeasures(claims jwt.Claims, now time.Time, expLeeway, nbfLeeway time.Duration, err error) {
   270  	if f.measures == nil {
   271  		return // measure tools are not defined, skip
   272  	}
   273  
   274  	//how far did we land from the NBF (in seconds): ie. -1 means 1 sec before, 1 means 1 sec after
   275  	if nbf, nbfPresent := claims.NotBefore(); nbfPresent {
   276  		nbf = nbf.Add(-nbfLeeway)
   277  		offsetToNBF := now.Sub(nbf).Seconds()
   278  		f.measures.NBFHistogram.Observe(offsetToNBF)
   279  	}
   280  
   281  	//how far did we land from the EXP (in seconds): ie. -1 means 1 sec before, 1 means 1 sec after
   282  	if exp, expPresent := claims.Expiration(); expPresent {
   283  		exp = exp.Add(expLeeway)
   284  		offsetToEXP := now.Sub(exp).Seconds()
   285  		f.measures.ExpHistogram.Observe(offsetToEXP)
   286  	}
   287  }