github.com/xmidt-org/webpa-common@v1.11.9/secure/validator.go (about) 1 package secure 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "regexp" 8 "strings" 9 "time" 10 11 "github.com/SermoDigital/jose/jws" 12 "github.com/SermoDigital/jose/jwt" 13 "github.com/xmidt-org/webpa-common/secure/key" 14 ) 15 16 var ( 17 ErrorNoProtectedHeader = errors.New("Missing protected header") 18 ErrorNoSigningMethod = errors.New("Signing method (alg) is missing or unrecognized") 19 ) 20 21 // Validator describes the behavior of a type which can validate tokens 22 type Validator interface { 23 // Validate asserts that the given token is valid, most often verifying 24 // the credentials in the token. A separate error is returned to indicate 25 // any problems during validation, such as the inability to access a network resource. 26 // In general, the contract of this method is that a Token passes validation 27 // if and only if it returns BOTH true and a nil error. 28 Validate(context.Context, *Token) (bool, error) 29 } 30 31 // ValidatorFunc is a function type that implements Validator 32 type ValidatorFunc func(context.Context, *Token) (bool, error) 33 34 func (v ValidatorFunc) Validate(ctx context.Context, token *Token) (bool, error) { 35 return v(ctx, token) 36 } 37 38 // Validators is an aggregate Validator. A Validators instance considers a token 39 // valid if any of its validators considers it valid. An empty Validators rejects 40 // all tokens. 41 type Validators []Validator 42 43 func (v Validators) Validate(ctx context.Context, token *Token) (valid bool, err error) { 44 for _, validator := range v { 45 if valid, err = validator.Validate(ctx, token); valid && err == nil { 46 return 47 } 48 } 49 50 return 51 } 52 53 // ExactMatchValidator simply matches a token's value (exluding the prefix, such as "Basic"), 54 // to a string. 55 type ExactMatchValidator string 56 57 func (v ExactMatchValidator) Validate(ctx context.Context, token *Token) (bool, error) { 58 for _, value := range strings.Split(string(v), ",") { 59 if value == token.value { 60 return true, nil 61 } 62 } 63 64 return false, nil 65 } 66 67 // JWSValidator provides validation for JWT tokens encoded as JWS. 68 type JWSValidator struct { 69 DefaultKeyId string 70 Resolver key.Resolver 71 Parser JWSParser 72 JWTValidators []*jwt.Validator 73 measures *JWTValidationMeasures 74 } 75 76 // capabilityValidation determines if a claim's capability is valid 77 func capabilityValidation(ctx context.Context, capability string) (valid_capabilities bool) { 78 pieces := strings.Split(capability, ":") 79 80 if len(pieces) == 5 && 81 pieces[0] == "x1" && 82 pieces[1] == "webpa" { 83 84 method_value, ok := ctx.Value("method").(string) 85 if ok && (pieces[4] == "all" || strings.EqualFold(pieces[4], method_value)) { 86 claimPath := fmt.Sprintf("/%s/[^/]+/%s", pieces[2], pieces[3]) 87 valid_capabilities, _ = regexp.MatchString(claimPath, ctx.Value("path").(string)) 88 } 89 } 90 91 return 92 } 93 94 func (v JWSValidator) Validate(ctx context.Context, token *Token) (valid bool, err error) { 95 if token.Type() != Bearer { 96 return 97 } 98 99 parser := v.Parser 100 if parser == nil { 101 parser = DefaultJWSParser 102 } 103 104 jwsToken, err := parser.ParseJWS(token) 105 if err != nil { 106 return 107 } 108 109 protected := jwsToken.Protected() 110 if len(protected) == 0 { 111 err = ErrorNoProtectedHeader 112 return 113 } 114 115 alg, _ := protected.Get("alg").(string) 116 signingMethod := jws.GetSigningMethod(alg) 117 if signingMethod == nil { 118 err = ErrorNoSigningMethod 119 return 120 } 121 122 keyId, _ := protected.Get("kid").(string) 123 if len(keyId) == 0 { 124 keyId = v.DefaultKeyId 125 } 126 127 pair, err := v.Resolver.ResolveKey(keyId) 128 if err != nil { 129 return 130 } 131 132 // validate the signature 133 if len(v.JWTValidators) > 0 { 134 // all JWS implementations also implement jwt.JWT 135 err = jwsToken.(jwt.JWT).Validate(pair.Public(), signingMethod, v.JWTValidators...) 136 } else { 137 err = jwsToken.Verify(pair.Public(), signingMethod) 138 } 139 140 if nil != err { 141 if v.measures != nil { 142 143 //capture specific cases of interest, default to global (invalid_signature) reason 144 switch err { 145 case jwt.ErrTokenIsExpired: 146 v.measures.ValidationReason.With("reason", "expired_token").Add(1) 147 break 148 case jwt.ErrTokenNotYetValid: 149 v.measures.ValidationReason.With("reason", "premature_token").Add(1) 150 break 151 152 default: 153 v.measures.ValidationReason.With("reason", "invalid_signature").Add(1) 154 } 155 } 156 return 157 } 158 159 // validate jwt token claims capabilities 160 if caps, capOkay := jwsToken.Payload().(jws.Claims).Get("capabilities").([]interface{}); capOkay && len(caps) > 0 { 161 162 /* commenting out for now 163 1. remove code in use below 164 2. make sure to bring a back tests for this as well. 165 - TestJWSValidatorCapabilities() 166 167 for c := 0; c < len(caps); c++ { 168 if cap_value, ok := caps[c].(string); ok { 169 if valid = capabilityValidation(ctx, cap_value); valid { 170 return 171 } 172 } 173 } 174 */ 175 // ***** REMOVE THIS CODE AFTER BRING BACK THE COMMENTED CODE ABOVE ***** 176 // ***** vvvvvvvvvvvvvvv ***** 177 178 // successful validation 179 if v.measures != nil { 180 v.measures.ValidationReason.With("reason", "ok").Add(1) 181 } 182 183 return true, nil 184 // ***** ^^^^^^^^^^^^^^^ ***** 185 186 } 187 188 // This fail 189 return 190 } 191 192 //DefineMeasures defines the metrics tool used by JWSValidator 193 func (v *JWSValidator) DefineMeasures(m *JWTValidationMeasures) { 194 v.measures = m 195 } 196 197 // JWTValidatorFactory is a configurable factory for *jwt.Validator instances 198 type JWTValidatorFactory struct { 199 Expected jwt.Claims `json:"expected"` 200 ExpLeeway int `json:"expLeeway"` 201 NbfLeeway int `json:"nbfLeeway"` 202 measures *JWTValidationMeasures 203 } 204 205 func (f *JWTValidatorFactory) expLeeway() time.Duration { 206 if f.ExpLeeway > 0 { 207 return time.Duration(f.ExpLeeway) * time.Second 208 } 209 210 return 0 211 } 212 213 func (f *JWTValidatorFactory) nbfLeeway() time.Duration { 214 if f.NbfLeeway > 0 { 215 return time.Duration(f.NbfLeeway) * time.Second 216 } 217 218 return 0 219 } 220 221 //DefineMeasures helps establish the metrics tools 222 func (f *JWTValidatorFactory) DefineMeasures(m *JWTValidationMeasures) { 223 f.measures = m 224 } 225 226 // New returns a jwt.Validator using the configuration expected claims (if any) 227 // and a validator function that checks the exp and nbf claims. 228 // 229 // The SermoDigital library doesn't appear to do anything with the EXP and NBF 230 // members of jwt.Validator, but this Factory Method populates them anyway. 231 func (f *JWTValidatorFactory) New(custom ...jwt.ValidateFunc) *jwt.Validator { 232 expLeeway := f.expLeeway() 233 nbfLeeway := f.nbfLeeway() 234 235 var validateFunc jwt.ValidateFunc 236 customCount := len(custom) 237 if customCount > 0 { 238 validateFunc = func(claims jwt.Claims) (err error) { 239 now := time.Now() 240 err = claims.Validate(now, expLeeway, nbfLeeway) 241 for index := 0; index < customCount && err == nil; index++ { 242 err = custom[index](claims) 243 } 244 245 f.observeMeasures(claims, now, expLeeway, nbfLeeway, err) 246 247 return 248 } 249 } else { 250 // if no custom validate functions were passed, use a simpler function 251 validateFunc = func(claims jwt.Claims) (err error) { 252 now := time.Now() 253 err = claims.Validate(now, expLeeway, nbfLeeway) 254 255 f.observeMeasures(claims, now, expLeeway, nbfLeeway, err) 256 257 return 258 } 259 } 260 261 return &jwt.Validator{ 262 Expected: f.Expected, 263 EXP: expLeeway, 264 NBF: nbfLeeway, 265 Fn: validateFunc, 266 } 267 } 268 269 func (f *JWTValidatorFactory) observeMeasures(claims jwt.Claims, now time.Time, expLeeway, nbfLeeway time.Duration, err error) { 270 if f.measures == nil { 271 return // measure tools are not defined, skip 272 } 273 274 //how far did we land from the NBF (in seconds): ie. -1 means 1 sec before, 1 means 1 sec after 275 if nbf, nbfPresent := claims.NotBefore(); nbfPresent { 276 nbf = nbf.Add(-nbfLeeway) 277 offsetToNBF := now.Sub(nbf).Seconds() 278 f.measures.NBFHistogram.Observe(offsetToNBF) 279 } 280 281 //how far did we land from the EXP (in seconds): ie. -1 means 1 sec before, 1 means 1 sec after 282 if exp, expPresent := claims.Expiration(); expPresent { 283 exp = exp.Add(expLeeway) 284 offsetToEXP := now.Sub(exp).Seconds() 285 f.measures.ExpHistogram.Observe(offsetToEXP) 286 } 287 }