github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/auth-middleware/middleware.go (about)

     1  package authmiddleware
     2  
     3  import (
     4  	"context"
     5  	"crypto/rsa"
     6  	"crypto/sha256"
     7  	"fmt"
     8  	"net/http"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/kyma-incubator/compass/components/director/internal/domain/tenant"
    13  	"github.com/kyma-incubator/compass/components/director/pkg/idtokenclaims"
    14  
    15  	authenticator_director "github.com/kyma-incubator/compass/components/director/internal/authenticator"
    16  	"github.com/kyma-incubator/compass/components/director/internal/domain/scenariogroups"
    17  
    18  	"github.com/kyma-incubator/compass/components/director/internal/nsadapter/httputil"
    19  
    20  	"github.com/form3tech-oss/jwt-go"
    21  	"github.com/kyma-incubator/compass/components/director/internal/domain/client"
    22  	"github.com/kyma-incubator/compass/components/director/pkg/apperrors"
    23  	"github.com/kyma-incubator/compass/components/director/pkg/log"
    24  	"github.com/kyma-incubator/compass/components/hydrator/pkg/authenticator"
    25  	"github.com/lestrrat-go/iter/arrayiter"
    26  	"github.com/lestrrat-go/jwx/jwk"
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  const (
    31  	// AuthorizationHeaderKey missing godoc
    32  	AuthorizationHeaderKey = "Authorization"
    33  	// JwksKeyIDKey missing godoc
    34  	JwksKeyIDKey = "kid"
    35  )
    36  
    37  const (
    38  	logKeyConsumerType   = "consumer-type"
    39  	logKeyTokenClientID  = "token-client-id"
    40  	logKeyFlow           = "flow"
    41  	ctxScenarioGroupsKey = "scenario_groups"
    42  )
    43  
    44  // ClaimsValidator missing godoc
    45  //
    46  //go:generate mockery --name=ClaimsValidator --output=automock --outpkg=automock --case=underscore --disable-version-string
    47  type ClaimsValidator interface {
    48  	Validate(context.Context, idtokenclaims.Claims) error
    49  }
    50  
    51  // Authenticator missing godoc
    52  type Authenticator struct {
    53  	httpClient          *http.Client
    54  	jwksEndpoint        string
    55  	allowJWTSigningNone bool
    56  	cachedJWKS          jwk.Set
    57  	clientIDHeaderKey   string
    58  	mux                 sync.RWMutex
    59  	claimsValidator     ClaimsValidator
    60  }
    61  
    62  // New missing godoc
    63  func New(httpClient *http.Client, jwksEndpoint string, allowJWTSigningNone bool, clientIDHeaderKey string, claimsValidator ClaimsValidator) *Authenticator {
    64  	return &Authenticator{
    65  		httpClient:          httpClient,
    66  		jwksEndpoint:        jwksEndpoint,
    67  		allowJWTSigningNone: allowJWTSigningNone,
    68  		clientIDHeaderKey:   clientIDHeaderKey,
    69  		claimsValidator:     claimsValidator,
    70  	}
    71  }
    72  
    73  // SynchronizeJWKS missing godoc
    74  func (a *Authenticator) SynchronizeJWKS(ctx context.Context) error {
    75  	log.C(ctx).Info("Synchronizing JWKS...")
    76  	a.mux.Lock()
    77  	defer a.mux.Unlock()
    78  
    79  	jwks, err := authenticator_director.FetchJWK(ctx, a.jwksEndpoint, jwk.WithHTTPClient(a.httpClient))
    80  	if err != nil {
    81  		return errors.Wrapf(err, "while fetching JWKS from endpoint %s", a.jwksEndpoint)
    82  	}
    83  
    84  	a.cachedJWKS = jwks
    85  	log.C(ctx).Info("Successfully synchronized JWKS")
    86  
    87  	return nil
    88  }
    89  
    90  // SetJWKSEndpoint missing godoc
    91  func (a *Authenticator) SetJWKSEndpoint(url string) {
    92  	a.jwksEndpoint = url
    93  }
    94  
    95  // Handler missing godoc
    96  func (a *Authenticator) Handler() func(next http.Handler) http.Handler {
    97  	return func(next http.Handler) http.Handler {
    98  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    99  			ctx := r.Context()
   100  			tokenClaims, statusCode, err := a.processToken(ctx, r)
   101  			if err != nil {
   102  				apperrors.WriteAppError(ctx, w, err, statusCode)
   103  				return
   104  			}
   105  
   106  			ctx = tokenClaims.ContextWithClaims(ctx)
   107  
   108  			ctx = a.storeHeadersDataInContext(ctx, r)
   109  
   110  			next.ServeHTTP(w, r.WithContext(ctx))
   111  		})
   112  	}
   113  }
   114  
   115  // KymaAdapterHandler performs authorization checks on requests to the Kyma adapter
   116  func (a *Authenticator) KymaAdapterHandler() func(next http.Handler) http.Handler {
   117  	return func(next http.Handler) http.Handler {
   118  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   119  			ctx := r.Context()
   120  			tokenClaims, statusCode, err := a.processToken(ctx, r)
   121  			if err != nil {
   122  				apperrors.WriteAppError(ctx, w, err, statusCode)
   123  				return
   124  			}
   125  
   126  			ctx = tokenClaims.ContextWithClaimsAndProviderTenant(ctx)
   127  
   128  			ctx = a.storeHeadersDataInContext(ctx, r)
   129  
   130  			next.ServeHTTP(w, r.WithContext(ctx))
   131  		})
   132  	}
   133  }
   134  
   135  // NSAdapterHandler performs authorization checks on requests to the Notifications Service Adapter
   136  func (a *Authenticator) NSAdapterHandler() func(next http.Handler) http.Handler {
   137  	return func(next http.Handler) http.Handler {
   138  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   139  			ctx := r.Context()
   140  
   141  			bearerToken, err := a.getBearerToken(r)
   142  			if err != nil {
   143  				log.C(ctx).WithError(err).Errorf("An error has occurred while getting token from header. Error code: %d: %v", http.StatusBadRequest, err)
   144  				httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{
   145  					Code:    http.StatusUnauthorized,
   146  					Message: "missing or invalid authorization token",
   147  				})
   148  				return
   149  			}
   150  
   151  			tokenClaims, err := a.parseClaimsWithRetry(ctx, bearerToken)
   152  			if err != nil {
   153  				log.C(ctx).WithError(err).Errorf("An error has occurred while parsing claims: %v", err)
   154  				httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{
   155  					Code:    http.StatusUnauthorized,
   156  					Message: "missing or invalid authorization token",
   157  				})
   158  				return
   159  			}
   160  
   161  			if err := a.claimsValidator.Validate(ctx, *tokenClaims); err != nil {
   162  				log.C(ctx).WithError(err).Errorf("An error has occurred while validating claims: %v", err)
   163  				httputil.RespondWithError(ctx, w, http.StatusUnauthorized, httputil.Error{
   164  					Code:    http.StatusUnauthorized,
   165  					Message: "missing or invalid authorization token",
   166  				})
   167  				return
   168  			}
   169  
   170  			next.ServeHTTP(w, r.WithContext(ctx))
   171  		})
   172  	}
   173  }
   174  
   175  func (a *Authenticator) parseClaimsWithRetry(ctx context.Context, bearerToken string) (*idtokenclaims.Claims, error) {
   176  	parsedClaims, err := a.parseClaims(ctx, bearerToken)
   177  	if err != nil {
   178  		validationErr, ok := err.(*jwt.ValidationError)
   179  		if !ok || (validationErr.Inner != rsa.ErrVerification && !apperrors.IsKeyDoesNotExist(validationErr.Inner)) {
   180  			return nil, apperrors.NewUnauthorizedError(err.Error())
   181  		}
   182  
   183  		if err := a.SynchronizeJWKS(ctx); err != nil {
   184  			return nil, apperrors.InternalErrorFrom(err, "while synchronizing JWKS during parsing token")
   185  		}
   186  
   187  		parsedClaims, err = a.parseClaims(ctx, bearerToken)
   188  		if err != nil {
   189  			log.C(ctx).WithError(err).Errorf("Failed to parse claims: %v", err)
   190  			return nil, apperrors.NewUnauthorizedError(err.Error())
   191  		}
   192  	}
   193  
   194  	return parsedClaims, nil
   195  }
   196  
   197  func (a *Authenticator) getBearerToken(r *http.Request) (string, error) {
   198  	reqToken := r.Header.Get(AuthorizationHeaderKey)
   199  	if reqToken == "" {
   200  		return "", apperrors.NewUnauthorizedError("invalid bearer token")
   201  	}
   202  
   203  	reqToken = strings.TrimPrefix(reqToken, "Bearer ")
   204  	return reqToken, nil
   205  }
   206  
   207  func (a *Authenticator) parseClaims(ctx context.Context, bearerToken string) (*idtokenclaims.Claims, error) {
   208  	parsed := idtokenclaims.Claims{}
   209  
   210  	if _, err := jwt.ParseWithClaims(bearerToken, &parsed, a.getKeyFunc(ctx)); err != nil {
   211  		return &parsed, err
   212  	}
   213  
   214  	return &parsed, nil
   215  }
   216  
   217  func (a *Authenticator) getKeyFunc(ctx context.Context) func(token *jwt.Token) (interface{}, error) {
   218  	return func(token *jwt.Token) (interface{}, error) {
   219  		a.mux.RLock()
   220  		defer a.mux.RUnlock()
   221  
   222  		unsupportedErr := fmt.Errorf("unexpected signing method: %v", token.Method.Alg())
   223  
   224  		switch token.Method.Alg() {
   225  		case jwt.SigningMethodRS256.Name:
   226  			keyID, err := a.getKeyID(*token)
   227  			if err != nil {
   228  				log.C(ctx).WithError(err).Errorf("An error occurred while getting the token signing key ID: %v", err)
   229  				return nil, errors.Wrap(err, "while getting the key ID")
   230  			}
   231  
   232  			if a.cachedJWKS == nil {
   233  				log.C(ctx).Debugf("Empty JWKS cache... Signing key %s is not found", keyID)
   234  				return nil, apperrors.NewKeyDoesNotExistError(keyID)
   235  			}
   236  
   237  			keyIterator := &authenticator.JWTKeyIterator{
   238  				AlgorithmCriteria: func(alg string) bool {
   239  					return token.Method.Alg() == alg
   240  				},
   241  				IDCriteria: func(id string) bool {
   242  					return id == keyID
   243  				},
   244  			}
   245  
   246  			if err := arrayiter.Walk(ctx, a.cachedJWKS, keyIterator); err != nil {
   247  				log.C(ctx).WithError(err).Errorf("An error occurred while walking through the JWKS: %v", err)
   248  				return nil, err
   249  			}
   250  
   251  			if keyIterator.ResultingKey == nil {
   252  				log.C(ctx).Debugf("Signing key %s is not found", keyID)
   253  				return nil, apperrors.NewKeyDoesNotExistError(keyID)
   254  			}
   255  
   256  			return keyIterator.ResultingKey, nil
   257  		case jwt.SigningMethodNone.Alg():
   258  			if !a.allowJWTSigningNone {
   259  				return nil, unsupportedErr
   260  			}
   261  			return jwt.UnsafeAllowNoneSignatureType, nil
   262  		}
   263  
   264  		return nil, unsupportedErr
   265  	}
   266  }
   267  
   268  func (a *Authenticator) getKeyID(token jwt.Token) (string, error) {
   269  	keyID, ok := token.Header[JwksKeyIDKey]
   270  	if !ok {
   271  		return "", apperrors.NewInternalError("unable to find the key ID in the token")
   272  	}
   273  
   274  	keyIDStr, ok := keyID.(string)
   275  	if !ok {
   276  		return "", apperrors.NewInternalError("unable to cast the key ID to a string")
   277  	}
   278  
   279  	return keyIDStr, nil
   280  }
   281  
   282  func (a *Authenticator) processToken(ctx context.Context, r *http.Request) (*idtokenclaims.Claims, int, error) {
   283  	bearerToken, err := a.getBearerToken(r)
   284  	if err != nil {
   285  		log.C(ctx).WithError(err).Errorf("An error has occurred while getting token from header. Error code: %d: %v", http.StatusBadRequest, err)
   286  		return nil, http.StatusBadRequest, err
   287  	}
   288  
   289  	tokenClaims, err := a.parseClaimsWithRetry(ctx, bearerToken)
   290  	if err != nil {
   291  		log.C(ctx).WithError(err).Errorf("An error has occurred while parsing claims: %v", err)
   292  		return nil, http.StatusUnauthorized, err
   293  	}
   294  
   295  	if mdc := log.MdcFromContext(ctx); nil != mdc {
   296  		mdc.Set(logKeyConsumerType, tokenClaims.ConsumerType)
   297  		mdc.Set(logKeyFlow, tokenClaims.Flow)
   298  		mdc.SetIfNotEmpty(logKeyTokenClientID, tokenClaims.TokenClientID)
   299  	}
   300  
   301  	if err := a.claimsValidator.Validate(ctx, *tokenClaims); err != nil {
   302  		log.C(ctx).WithError(err).Errorf("An error has occurred while validating claims: %v", err)
   303  		switch apperrors.ErrorCode(err) {
   304  		case apperrors.TenantNotFound:
   305  			return nil, http.StatusBadRequest, err
   306  		default:
   307  			return nil, http.StatusUnauthorized, err
   308  		}
   309  	}
   310  
   311  	return tokenClaims, 0, nil
   312  }
   313  
   314  func (a *Authenticator) storeHeadersDataInContext(ctx context.Context, r *http.Request) context.Context {
   315  	if clientUser := r.Header.Get(a.clientIDHeaderKey); clientUser != "" {
   316  		log.C(ctx).Infof("Found %s header in request with value: REDACTED_%x", a.clientIDHeaderKey, sha256.Sum256([]byte(clientUser)))
   317  		ctx = client.SaveToContext(ctx, clientUser)
   318  	}
   319  
   320  	if scenarioGroupsValue := r.Header.Get(ctxScenarioGroupsKey); scenarioGroupsValue != "" {
   321  		log.C(ctx).Infof("Found %s header in request with value: %s", ctxScenarioGroupsKey, scenarioGroupsValue)
   322  		groups := strings.Split(strings.ToUpper(scenarioGroupsValue), ",")
   323  
   324  		ctx = scenariogroups.SaveToContext(ctx, groups)
   325  	}
   326  
   327  	return ctx
   328  }
   329  
   330  // LoadExternalTenantFromContext extracts the external tenant ID stored in the context object
   331  func LoadExternalTenantFromContext(ctx context.Context) (string, error) {
   332  	tenantFromContext, err := tenant.LoadTenantPairFromContext(ctx)
   333  	if err != nil {
   334  		return "", err
   335  	}
   336  	return tenantFromContext.ExternalID, nil
   337  }
   338  
   339  // SaveToContext stores the tenant in the context object
   340  func SaveToContext(ctx context.Context, internalID, externalID string) context.Context {
   341  	return tenant.SaveToContext(ctx, internalID, externalID)
   342  }