github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/oauth2/middleware.go (about) 1 package oauth2 2 3 import ( 4 "context" 5 "net/http" 6 "strings" 7 8 "github.com/hellofresh/janus/pkg/errors" 9 "github.com/hellofresh/janus/pkg/metrics" 10 obs "github.com/hellofresh/janus/pkg/observability" 11 "github.com/hellofresh/stats-go/bucket" 12 log "github.com/sirupsen/logrus" 13 "go.opencensus.io/stats" 14 ) 15 16 const ( 17 tokensSection = "tokens" 18 ) 19 20 // Enums for keys to be stored in a session context - this is how gorilla expects 21 // these to be implemented and is lifted pretty much from docs 22 var ( 23 AuthHeaderValue = ContextKey("auth_header") 24 25 // ErrAuthorizationFieldNotFound is used when the http Authorization header is missing from the request 26 ErrAuthorizationFieldNotFound = errors.New(http.StatusBadRequest, "authorization field missing") 27 // ErrBearerMalformed is used when the Bearer string in the Authorization header is not found or is malformed 28 ErrBearerMalformed = errors.New(http.StatusBadRequest, "bearer token malformed") 29 // ErrAccessTokenNotAuthorized is used when the access token is not found on the storage 30 ErrAccessTokenNotAuthorized = errors.New(http.StatusUnauthorized, "access token not authorized") 31 ) 32 33 // ContextKey is used to create context keys that are concurrent safe 34 type ContextKey string 35 36 func (c ContextKey) String() string { 37 return "janus." + string(c) 38 } 39 40 // NewKeyExistsMiddleware creates a new instance of KeyExistsMiddleware 41 func NewKeyExistsMiddleware(manager Manager) func(http.Handler) http.Handler { 42 return func(handler http.Handler) http.Handler { 43 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 44 log.Debug("Starting Oauth2KeyExists middleware") 45 statsClient := metrics.WithContext(r.Context()) 46 47 logger := log.WithFields(log.Fields{ 48 "path": r.RequestURI, 49 "origin": r.RemoteAddr, 50 }) 51 52 // We're using OAuth, start checking for access keys 53 authHeaderValue := r.Header.Get("Authorization") 54 parts := strings.Split(authHeaderValue, " ") 55 if len(parts) < 2 { 56 logger.Warn("Attempted access with malformed header, no auth header found.") 57 statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "header"}, nil, false) 58 stats.Record(r.Context(), obs.MOAuth2MissingHeader.M(1)) 59 errors.Handler(w, r, ErrAuthorizationFieldNotFound) 60 return 61 } 62 statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "header"}, nil, true) 63 64 if strings.ToLower(parts[0]) != "bearer" { 65 logger.Warn("Bearer token malformed") 66 statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "malformed"}, nil, false) 67 stats.Record(r.Context(), obs.MOAuth2MalformedHeader.M(1)) 68 errors.Handler(w, r, ErrBearerMalformed) 69 return 70 } 71 statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "malformed"}, nil, true) 72 73 accessToken := parts[1] 74 keyExists := manager.IsKeyAuthorized(r.Context(), accessToken) 75 statsClient.TrackOperation(tokensSection, bucket.MetricOperation{"key-exists", "authorized"}, nil, keyExists) 76 if keyExists { 77 stats.Record(r.Context(), obs.MOAuth2Authorized.M(1)) 78 } else { 79 stats.Record(r.Context(), obs.MOAuth2Unauthorized.M(1)) 80 } 81 82 if !keyExists { 83 log.WithFields(log.Fields{ 84 "path": r.RequestURI, 85 "origin": r.RemoteAddr, 86 "key": accessToken, 87 }).Debug("Attempted access with invalid key.") 88 errors.Handler(w, r, ErrAccessTokenNotAuthorized) 89 return 90 } 91 92 ctx := context.WithValue(r.Context(), AuthHeaderValue, accessToken) 93 handler.ServeHTTP(w, r.WithContext(ctx)) 94 }) 95 } 96 }