github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/auth-middleware/middleware.go (about) 1 package authmiddleware 2 3 import ( 4 "context" 5 "crypto/rsa" 6 "crypto/sha256" 7 "fmt" 8 "net/http" 9 "strings" 10 "sync" 11 12 "github.com/kyma-incubator/compass/components/director/internal/domain/tenant" 13 "github.com/kyma-incubator/compass/components/director/pkg/idtokenclaims" 14 15 authenticator_director "github.com/kyma-incubator/compass/components/director/internal/authenticator" 16 "github.com/kyma-incubator/compass/components/director/internal/domain/scenariogroups" 17 18 "github.com/kyma-incubator/compass/components/director/internal/nsadapter/httputil" 19 20 "github.com/form3tech-oss/jwt-go" 21 "github.com/kyma-incubator/compass/components/director/internal/domain/client" 22 "github.com/kyma-incubator/compass/components/director/pkg/apperrors" 23 "github.com/kyma-incubator/compass/components/director/pkg/log" 24 "github.com/kyma-incubator/compass/components/hydrator/pkg/authenticator" 25 "github.com/lestrrat-go/iter/arrayiter" 26 "github.com/lestrrat-go/jwx/jwk" 27 "github.com/pkg/errors" 28 ) 29 30 const ( 31 // AuthorizationHeaderKey missing godoc 32 AuthorizationHeaderKey = "Authorization" 33 // JwksKeyIDKey missing godoc 34 JwksKeyIDKey = "kid" 35 ) 36 37 const ( 38 logKeyConsumerType = "consumer-type" 39 logKeyTokenClientID = "token-client-id" 40 logKeyFlow = "flow" 41 ctxScenarioGroupsKey = "scenario_groups" 42 ) 43 44 // ClaimsValidator missing godoc 45 // 46 //go:generate mockery --name=ClaimsValidator --output=automock --outpkg=automock --case=underscore --disable-version-string 47 type ClaimsValidator interface { 48 Validate(context.Context, idtokenclaims.Claims) error 49 } 50 51 // Authenticator missing godoc 52 type Authenticator struct { 53 httpClient *http.Client 54 jwksEndpoint string 55 allowJWTSigningNone bool 56 cachedJWKS jwk.Set 57 clientIDHeaderKey string 58 mux sync.RWMutex 59 claimsValidator ClaimsValidator 60 } 61 62 // New missing godoc 63 func New(httpClient *http.Client, jwksEndpoint string, allowJWTSigningNone bool, clientIDHeaderKey string, claimsValidator ClaimsValidator) *Authenticator { 64 return &Authenticator{ 65 httpClient: httpClient, 66 jwksEndpoint: jwksEndpoint, 67 allowJWTSigningNone: allowJWTSigningNone, 68 clientIDHeaderKey: clientIDHeaderKey, 69 claimsValidator: claimsValidator, 70 } 71 } 72 73 // SynchronizeJWKS missing godoc 74 func (a *Authenticator) SynchronizeJWKS(ctx context.Context) error { 75 log.C(ctx).Info("Synchronizing JWKS...") 76 a.mux.Lock() 77 defer a.mux.Unlock() 78 79 jwks, err := authenticator_director.FetchJWK(ctx, a.jwksEndpoint, jwk.WithHTTPClient(a.httpClient)) 80 if err != nil { 81 return errors.Wrapf(err, "while fetching JWKS from endpoint %s", a.jwksEndpoint) 82 } 83 84 a.cachedJWKS = jwks 85 log.C(ctx).Info("Successfully synchronized JWKS") 86 87 return nil 88 } 89 90 // SetJWKSEndpoint missing godoc 91 func (a *Authenticator) SetJWKSEndpoint(url string) { 92 a.jwksEndpoint = url 93 } 94 95 // Handler missing godoc 96 func (a *Authenticator) Handler() func(next http.Handler) http.Handler { 97 return func(next http.Handler) http.Handler { 98 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 99 ctx := r.Context() 100 tokenClaims, statusCode, err := a.processToken(ctx, r) 101 if err != nil { 102 apperrors.WriteAppError(ctx, w, err, statusCode) 103 return 104 } 105 106 ctx = tokenClaims.ContextWithClaims(ctx) 107 108 ctx = a.storeHeadersDataInContext(ctx, r) 109 110 next.ServeHTTP(w, r.WithContext(ctx)) 111 }) 112 } 113 } 114 115 // KymaAdapterHandler performs authorization checks on requests to the Kyma adapter 116 func (a *Authenticator) KymaAdapterHandler() func(next http.Handler) http.Handler { 117 return func(next http.Handler) http.Handler { 118 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 119 ctx := r.Context() 120 tokenClaims, statusCode, err := a.processToken(ctx, r) 121 if err != nil { 122 apperrors.WriteAppError(ctx, w, err, statusCode) 123 return 124 } 125 126 ctx = tokenClaims.ContextWithClaimsAndProviderTenant(ctx) 127 128 ctx = a.storeHeadersDataInContext(ctx, r) 129 130 next.ServeHTTP(w, r.WithContext(ctx)) 131 }) 132 } 133 } 134 135 // NSAdapterHandler performs authorization checks on requests to the Notifications Service Adapter 136 func (a *Authenticator) NSAdapterHandler() func(next http.Handler) http.Handler { 137 return func(next http.Handler) http.Handler { 138 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 139 ctx := r.Context() 140 141 bearerToken, err := a.getBearerToken(r) 142 if err != nil { 143 log.C(ctx).WithError(err).Errorf("An error has occurred while getting token from header. Error code: %d: %v", http.StatusBadRequest, err) 144 httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{ 145 Code: http.StatusUnauthorized, 146 Message: "missing or invalid authorization token", 147 }) 148 return 149 } 150 151 tokenClaims, err := a.parseClaimsWithRetry(ctx, bearerToken) 152 if err != nil { 153 log.C(ctx).WithError(err).Errorf("An error has occurred while parsing claims: %v", err) 154 httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{ 155 Code: http.StatusUnauthorized, 156 Message: "missing or invalid authorization token", 157 }) 158 return 159 } 160 161 if err := a.claimsValidator.Validate(ctx, *tokenClaims); err != nil { 162 log.C(ctx).WithError(err).Errorf("An error has occurred while validating claims: %v", err) 163 httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{ 164 Code: http.StatusUnauthorized, 165 Message: "missing or invalid authorization token", 166 }) 167 return 168 } 169 170 next.ServeHTTP(w, r.WithContext(ctx)) 171 }) 172 } 173 } 174 175 func (a *Authenticator) parseClaimsWithRetry(ctx context.Context, bearerToken string) (*idtokenclaims.Claims, error) { 176 parsedClaims, err := a.parseClaims(ctx, bearerToken) 177 if err != nil { 178 validationErr, ok := err.(*jwt.ValidationError) 179 if !ok || (validationErr.Inner != rsa.ErrVerification && !apperrors.IsKeyDoesNotExist(validationErr.Inner)) { 180 return nil, apperrors.NewUnauthorizedError(err.Error()) 181 } 182 183 if err := a.SynchronizeJWKS(ctx); err != nil { 184 return nil, apperrors.InternalErrorFrom(err, "while synchronizing JWKS during parsing token") 185 } 186 187 parsedClaims, err = a.parseClaims(ctx, bearerToken) 188 if err != nil { 189 log.C(ctx).WithError(err).Errorf("Failed to parse claims: %v", err) 190 return nil, apperrors.NewUnauthorizedError(err.Error()) 191 } 192 } 193 194 return parsedClaims, nil 195 } 196 197 func (a *Authenticator) getBearerToken(r *http.Request) (string, error) { 198 reqToken := r.Header.Get(AuthorizationHeaderKey) 199 if reqToken == "" { 200 return "", apperrors.NewUnauthorizedError("invalid bearer token") 201 } 202 203 reqToken = strings.TrimPrefix(reqToken, "Bearer ") 204 return reqToken, nil 205 } 206 207 func (a *Authenticator) parseClaims(ctx context.Context, bearerToken string) (*idtokenclaims.Claims, error) { 208 parsed := idtokenclaims.Claims{} 209 210 if _, err := jwt.ParseWithClaims(bearerToken, &parsed, a.getKeyFunc(ctx)); err != nil { 211 return &parsed, err 212 } 213 214 return &parsed, nil 215 } 216 217 func (a *Authenticator) getKeyFunc(ctx context.Context) func(token *jwt.Token) (interface{}, error) { 218 return func(token *jwt.Token) (interface{}, error) { 219 a.mux.RLock() 220 defer a.mux.RUnlock() 221 222 unsupportedErr := fmt.Errorf("unexpected signing method: %v", token.Method.Alg()) 223 224 switch token.Method.Alg() { 225 case jwt.SigningMethodRS256.Name: 226 keyID, err := a.getKeyID(*token) 227 if err != nil { 228 log.C(ctx).WithError(err).Errorf("An error occurred while getting the token signing key ID: %v", err) 229 return nil, errors.Wrap(err, "while getting the key ID") 230 } 231 232 if a.cachedJWKS == nil { 233 log.C(ctx).Debugf("Empty JWKS cache... Signing key %s is not found", keyID) 234 return nil, apperrors.NewKeyDoesNotExistError(keyID) 235 } 236 237 keyIterator := &authenticator.JWTKeyIterator{ 238 AlgorithmCriteria: func(alg string) bool { 239 return token.Method.Alg() == alg 240 }, 241 IDCriteria: func(id string) bool { 242 return id == keyID 243 }, 244 } 245 246 if err := arrayiter.Walk(ctx, a.cachedJWKS, keyIterator); err != nil { 247 log.C(ctx).WithError(err).Errorf("An error occurred while walking through the JWKS: %v", err) 248 return nil, err 249 } 250 251 if keyIterator.ResultingKey == nil { 252 log.C(ctx).Debugf("Signing key %s is not found", keyID) 253 return nil, apperrors.NewKeyDoesNotExistError(keyID) 254 } 255 256 return keyIterator.ResultingKey, nil 257 case jwt.SigningMethodNone.Alg(): 258 if !a.allowJWTSigningNone { 259 return nil, unsupportedErr 260 } 261 return jwt.UnsafeAllowNoneSignatureType, nil 262 } 263 264 return nil, unsupportedErr 265 } 266 } 267 268 func (a *Authenticator) getKeyID(token jwt.Token) (string, error) { 269 keyID, ok := token.Header[JwksKeyIDKey] 270 if !ok { 271 return "", apperrors.NewInternalError("unable to find the key ID in the token") 272 } 273 274 keyIDStr, ok := keyID.(string) 275 if !ok { 276 return "", apperrors.NewInternalError("unable to cast the key ID to a string") 277 } 278 279 return keyIDStr, nil 280 } 281 282 func (a *Authenticator) processToken(ctx context.Context, r *http.Request) (*idtokenclaims.Claims, int, error) { 283 bearerToken, err := a.getBearerToken(r) 284 if err != nil { 285 log.C(ctx).WithError(err).Errorf("An error has occurred while getting token from header. Error code: %d: %v", http.StatusBadRequest, err) 286 return nil, http.StatusBadRequest, err 287 } 288 289 tokenClaims, err := a.parseClaimsWithRetry(ctx, bearerToken) 290 if err != nil { 291 log.C(ctx).WithError(err).Errorf("An error has occurred while parsing claims: %v", err) 292 return nil, http.StatusUnauthorized, err 293 } 294 295 if mdc := log.MdcFromContext(ctx); nil != mdc { 296 mdc.Set(logKeyConsumerType, tokenClaims.ConsumerType) 297 mdc.Set(logKeyFlow, tokenClaims.Flow) 298 mdc.SetIfNotEmpty(logKeyTokenClientID, tokenClaims.TokenClientID) 299 } 300 301 if err := a.claimsValidator.Validate(ctx, *tokenClaims); err != nil { 302 log.C(ctx).WithError(err).Errorf("An error has occurred while validating claims: %v", err) 303 switch apperrors.ErrorCode(err) { 304 case apperrors.TenantNotFound: 305 return nil, http.StatusBadRequest, err 306 default: 307 return nil, http.StatusUnauthorized, err 308 } 309 } 310 311 return tokenClaims, 0, nil 312 } 313 314 func (a *Authenticator) storeHeadersDataInContext(ctx context.Context, r *http.Request) context.Context { 315 if clientUser := r.Header.Get(a.clientIDHeaderKey); clientUser != "" { 316 log.C(ctx).Infof("Found %s header in request with value: REDACTED_%x", a.clientIDHeaderKey, sha256.Sum256([]byte(clientUser))) 317 ctx = client.SaveToContext(ctx, clientUser) 318 } 319 320 if scenarioGroupsValue := r.Header.Get(ctxScenarioGroupsKey); scenarioGroupsValue != "" { 321 log.C(ctx).Infof("Found %s header in request with value: %s", ctxScenarioGroupsKey, scenarioGroupsValue) 322 groups := strings.Split(strings.ToUpper(scenarioGroupsValue), ",") 323 324 ctx = scenariogroups.SaveToContext(ctx, groups) 325 } 326 327 return ctx 328 } 329 330 // LoadExternalTenantFromContext extracts the external tenant ID stored in the context object 331 func LoadExternalTenantFromContext(ctx context.Context) (string, error) { 332 tenantFromContext, err := tenant.LoadTenantPairFromContext(ctx) 333 if err != nil { 334 return "", err 335 } 336 return tenantFromContext.ExternalID, nil 337 } 338 339 // SaveToContext stores the tenant in the context object 340 func SaveToContext(ctx context.Context, internalID, externalID string) context.Context { 341 return tenant.SaveToContext(ctx, internalID, externalID) 342 }