github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/middleware.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"strings"
     7  
     8  	"github.com/hellofresh/janus/pkg/errors"
     9  	"github.com/hellofresh/janus/pkg/metrics"
    10  	obs "github.com/hellofresh/janus/pkg/observability"
    11  	"github.com/hellofresh/stats-go/bucket"
    12  	log "github.com/sirupsen/logrus"
    13  	"go.opencensus.io/stats"
    14  )
    15  
    16  const (
    17  	tokensSection = "tokens"
    18  )
    19  
    20  // Enums for keys to be stored in a session context - this is how gorilla expects
    21  // these to be implemented and is lifted pretty much from docs
    22  var (
    23  	AuthHeaderValue = ContextKey("auth_header")
    24  
    25  	// ErrAuthorizationFieldNotFound is used when the http Authorization header is missing from the request
    26  	ErrAuthorizationFieldNotFound = errors.New(http.StatusBadRequest, "authorization field missing")
    27  	// ErrBearerMalformed is used when the Bearer string in the Authorization header is not found or is malformed
    28  	ErrBearerMalformed = errors.New(http.StatusBadRequest, "bearer token malformed")
    29  	// ErrAccessTokenNotAuthorized is used when the access token is not found on the storage
    30  	ErrAccessTokenNotAuthorized = errors.New(http.StatusUnauthorized, "access token not authorized")
    31  )
    32  
    33  // ContextKey is used to create context keys that are concurrent safe
    34  type ContextKey string
    35  
    36  func (c ContextKey) String() string {
    37  	return "janus." + string(c)
    38  }
    39  
    40  // NewKeyExistsMiddleware creates a new instance of KeyExistsMiddleware
    41  func NewKeyExistsMiddleware(manager Manager) func(http.Handler) http.Handler {
    42  	return func(handler http.Handler) http.Handler {
    43  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  			log.Debug("Starting Oauth2KeyExists middleware")
    45  			statsClient := metrics.WithContext(r.Context())
    46  
    47  			logger := log.WithFields(log.Fields{
    48  				"path":   r.RequestURI,
    49  				"origin": r.RemoteAddr,
    50  			})
    51  
    52  			// We're using OAuth, start checking for access keys
    53  			authHeaderValue := r.Header.Get("Authorization")
    54  			parts := strings.Split(authHeaderValue, " ")
    55  			if len(parts) < 2 {
    56  				logger.Warn("Attempted access with malformed header, no auth header found.")
    57  				statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "header"}, nil, false)
    58  				stats.Record(r.Context(), obs.MOAuth2MissingHeader.M(1))
    59  				errors.Handler(w, r, ErrAuthorizationFieldNotFound)
    60  				return
    61  			}
    62  			statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "header"}, nil, true)
    63  
    64  			if strings.ToLower(parts[0]) != "bearer" {
    65  				logger.Warn("Bearer token malformed")
    66  				statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "malformed"}, nil, false)
    67  				stats.Record(r.Context(), obs.MOAuth2MalformedHeader.M(1))
    68  				errors.Handler(w, r, ErrBearerMalformed)
    69  				return
    70  			}
    71  			statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "malformed"}, nil, true)
    72  
    73  			accessToken := parts[1]
    74  			keyExists := manager.IsKeyAuthorized(r.Context(), accessToken)
    75  			statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "authorized"}, nil, keyExists)
    76  			if keyExists {
    77  				stats.Record(r.Context(), obs.MOAuth2Authorized.M(1))
    78  			} else {
    79  				stats.Record(r.Context(), obs.MOAuth2Unauthorized.M(1))
    80  			}
    81  
    82  			if !keyExists {
    83  				log.WithFields(log.Fields{
    84  					"path":   r.RequestURI,
    85  					"origin": r.RemoteAddr,
    86  					"key":    accessToken,
    87  				}).Debug("Attempted access with invalid key.")
    88  				errors.Handler(w, r, ErrAccessTokenNotAuthorized)
    89  				return
    90  			}
    91  
    92  			ctx := context.WithValue(r.Context(), AuthHeaderValue, accessToken)
    93  			handler.ServeHTTP(w, r.WithContext(ctx))
    94  		})
    95  	}
    96  }