github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/middleware.go (about) 1 package middleware 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "net/http" 7 "strings" 8 9 ctx "github.com/gophish/gophish/context" 10 "github.com/gophish/gophish/models" 11 "github.com/gorilla/csrf" 12 ) 13 14 // CSRFExemptPrefixes are a list of routes that are exempt from CSRF protection 15 var CSRFExemptPrefixes = []string{ 16 "/api", 17 } 18 19 // CSRFExceptions is a middleware that prevents CSRF checks on routes listed in 20 // CSRFExemptPrefixes. 21 func CSRFExceptions(handler http.Handler) http.HandlerFunc { 22 return func(w http.ResponseWriter, r *http.Request) { 23 for _, prefix := range CSRFExemptPrefixes { 24 if strings.HasPrefix(r.URL.Path, prefix) { 25 r = csrf.UnsafeSkipCheck(r) 26 break 27 } 28 } 29 handler.ServeHTTP(w, r) 30 } 31 } 32 33 // Use allows us to stack middleware to process the request 34 // Example taken from https://github.com/gorilla/mux/pull/36#issuecomment-25849172 35 func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) http.HandlerFunc { 36 for _, m := range mid { 37 handler = m(handler) 38 } 39 return handler 40 } 41 42 // GetContext wraps each request in a function which fills in the context for a given request. 43 // This includes setting the User and Session keys and values as necessary for use in later functions. 44 func GetContext(handler http.Handler) http.HandlerFunc { 45 // Set the context here 46 return func(w http.ResponseWriter, r *http.Request) { 47 // Parse the request form 48 err := r.ParseForm() 49 if err != nil { 50 http.Error(w, "Error parsing request", http.StatusInternalServerError) 51 } 52 // Set the context appropriately here. 53 // Set the session 54 session, _ := Store.Get(r, "gophish") 55 // Put the session in the context so that we can 56 // reuse the values in different handlers 57 r = ctx.Set(r, "session", session) 58 if id, ok := session.Values["id"]; ok { 59 u, err := models.GetUser(id.(int64)) 60 if err != nil { 61 r = ctx.Set(r, "user", nil) 62 } else { 63 r = ctx.Set(r, "user", u) 64 } 65 } else { 66 r = ctx.Set(r, "user", nil) 67 } 68 handler.ServeHTTP(w, r) 69 // Remove context contents 70 ctx.Clear(r) 71 } 72 } 73 74 // RequireAPIKey ensures that a valid API key is set as either the api_key GET 75 // parameter, or a Bearer token. 76 func RequireAPIKey(handler http.Handler) http.Handler { 77 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 w.Header().Set("Access-Control-Allow-Origin", "*") 79 if r.Method == "OPTIONS" { 80 w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") 81 w.Header().Set("Access-Control-Max-Age", "1000") 82 w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept") 83 return 84 } 85 r.ParseForm() 86 ak := r.Form.Get("api_key") 87 // If we can't get the API key, we'll also check for the 88 // Authorization Bearer token 89 if ak == "" { 90 tokens, ok := r.Header["Authorization"] 91 if ok && len(tokens) >= 1 { 92 ak = tokens[0] 93 ak = strings.TrimPrefix(ak, "Bearer ") 94 } 95 } 96 if ak == "" { 97 JSONError(w, http.StatusUnauthorized, "API Key not set") 98 return 99 } 100 u, err := models.GetUserByAPIKey(ak) 101 if err != nil { 102 JSONError(w, http.StatusUnauthorized, "Invalid API Key") 103 return 104 } 105 r = ctx.Set(r, "user", u) 106 r = ctx.Set(r, "user_id", u.Id) 107 r = ctx.Set(r, "api_key", ak) 108 handler.ServeHTTP(w, r) 109 }) 110 } 111 112 // RequireLogin checks to see if the user is currently logged in. 113 // If not, the function returns a 302 redirect to the login page. 114 func RequireLogin(handler http.Handler) http.HandlerFunc { 115 return func(w http.ResponseWriter, r *http.Request) { 116 if u := ctx.Get(r, "user"); u != nil { 117 // If a password change is required for the user, then redirect them 118 // to the login page 119 currentUser := u.(models.User) 120 if currentUser.PasswordChangeRequired && r.URL.Path != "/reset_password" { 121 q := r.URL.Query() 122 q.Set("next", r.URL.Path) 123 http.Redirect(w, r, fmt.Sprintf("/reset_password?%s", q.Encode()), http.StatusTemporaryRedirect) 124 return 125 } 126 handler.ServeHTTP(w, r) 127 return 128 } 129 q := r.URL.Query() 130 q.Set("next", r.URL.Path) 131 http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect) 132 } 133 } 134 135 // EnforceViewOnly is a global middleware that limits the ability to edit 136 // objects to accounts with the PermissionModifyObjects permission. 137 func EnforceViewOnly(next http.Handler) http.Handler { 138 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 139 // If the request is for any non-GET HTTP method, e.g. POST, PUT, 140 // or DELETE, we need to ensure the user has the appropriate 141 // permission. 142 if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { 143 user := ctx.Get(r, "user").(models.User) 144 access, err := user.HasPermission(models.PermissionModifyObjects) 145 if err != nil { 146 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 147 return 148 } 149 if !access { 150 http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) 151 return 152 } 153 } 154 next.ServeHTTP(w, r) 155 }) 156 } 157 158 // RequirePermission checks to see if the user has the requested permission 159 // before executing the handler. If the request is unauthorized, a JSONError 160 // is returned. 161 func RequirePermission(perm string) func(http.Handler) http.HandlerFunc { 162 return func(next http.Handler) http.HandlerFunc { 163 return func(w http.ResponseWriter, r *http.Request) { 164 user := ctx.Get(r, "user").(models.User) 165 access, err := user.HasPermission(perm) 166 if err != nil { 167 JSONError(w, http.StatusInternalServerError, err.Error()) 168 return 169 } 170 if !access { 171 JSONError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden)) 172 return 173 } 174 next.ServeHTTP(w, r) 175 } 176 } 177 } 178 179 // ApplySecurityHeaders applies various security headers according to best- 180 // practices. 181 func ApplySecurityHeaders(next http.Handler) http.HandlerFunc { 182 return func(w http.ResponseWriter, r *http.Request) { 183 csp := "frame-ancestors 'none';" 184 w.Header().Set("Content-Security-Policy", csp) 185 w.Header().Set("X-Frame-Options", "DENY") 186 next.ServeHTTP(w, r) 187 } 188 } 189 190 // JSONError returns an error in JSON format with the given 191 // status code and message 192 func JSONError(w http.ResponseWriter, c int, m string) { 193 cj, _ := json.MarshalIndent(models.Response{Success: false, Message: m}, "", " ") 194 w.Header().Set("Content-Type", "application/json") 195 w.WriteHeader(c) 196 fmt.Fprintf(w, "%s", cj) 197 }