github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/middleware.go (about)

     1  package middleware
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	ctx "github.com/gophish/gophish/context"
    10  	"github.com/gophish/gophish/models"
    11  	"github.com/gorilla/csrf"
    12  )
    13  
    14  // CSRFExemptPrefixes are a list of routes that are exempt from CSRF protection
    15  var CSRFExemptPrefixes = []string{
    16  	"/api",
    17  }
    18  
    19  // CSRFExceptions is a middleware that prevents CSRF checks on routes listed in
    20  // CSRFExemptPrefixes.
    21  func CSRFExceptions(handler http.Handler) http.HandlerFunc {
    22  	return func(w http.ResponseWriter, r *http.Request) {
    23  		for _, prefix := range CSRFExemptPrefixes {
    24  			if strings.HasPrefix(r.URL.Path, prefix) {
    25  				r = csrf.UnsafeSkipCheck(r)
    26  				break
    27  			}
    28  		}
    29  		handler.ServeHTTP(w, r)
    30  	}
    31  }
    32  
    33  // Use allows us to stack middleware to process the request
    34  // Example taken from https://github.com/gorilla/mux/pull/36#issuecomment-25849172
    35  func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) http.HandlerFunc {
    36  	for _, m := range mid {
    37  		handler = m(handler)
    38  	}
    39  	return handler
    40  }
    41  
    42  // GetContext wraps each request in a function which fills in the context for a given request.
    43  // This includes setting the User and Session keys and values as necessary for use in later functions.
    44  func GetContext(handler http.Handler) http.HandlerFunc {
    45  	// Set the context here
    46  	return func(w http.ResponseWriter, r *http.Request) {
    47  		// Parse the request form
    48  		err := r.ParseForm()
    49  		if err != nil {
    50  			http.Error(w, "Error parsing request", http.StatusInternalServerError)
    51  		}
    52  		// Set the context appropriately here.
    53  		// Set the session
    54  		session, _ := Store.Get(r, "gophish")
    55  		// Put the session in the context so that we can
    56  		// reuse the values in different handlers
    57  		r = ctx.Set(r, "session", session)
    58  		if id, ok := session.Values["id"]; ok {
    59  			u, err := models.GetUser(id.(int64))
    60  			if err != nil {
    61  				r = ctx.Set(r, "user", nil)
    62  			} else {
    63  				r = ctx.Set(r, "user", u)
    64  			}
    65  		} else {
    66  			r = ctx.Set(r, "user", nil)
    67  		}
    68  		handler.ServeHTTP(w, r)
    69  		// Remove context contents
    70  		ctx.Clear(r)
    71  	}
    72  }
    73  
    74  // RequireAPIKey ensures that a valid API key is set as either the api_key GET
    75  // parameter, or a Bearer token.
    76  func RequireAPIKey(handler http.Handler) http.Handler {
    77  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  		w.Header().Set("Access-Control-Allow-Origin", "*")
    79  		if r.Method == "OPTIONS" {
    80  			w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
    81  			w.Header().Set("Access-Control-Max-Age", "1000")
    82  			w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
    83  			return
    84  		}
    85  		r.ParseForm()
    86  		ak := r.Form.Get("api_key")
    87  		// If we can't get the API key, we'll also check for the
    88  		// Authorization Bearer token
    89  		if ak == "" {
    90  			tokens, ok := r.Header["Authorization"]
    91  			if ok && len(tokens) >= 1 {
    92  				ak = tokens[0]
    93  				ak = strings.TrimPrefix(ak, "Bearer ")
    94  			}
    95  		}
    96  		if ak == "" {
    97  			JSONError(w, http.StatusUnauthorized, "API Key not set")
    98  			return
    99  		}
   100  		u, err := models.GetUserByAPIKey(ak)
   101  		if err != nil {
   102  			JSONError(w, http.StatusUnauthorized, "Invalid API Key")
   103  			return
   104  		}
   105  		r = ctx.Set(r, "user", u)
   106  		r = ctx.Set(r, "user_id", u.Id)
   107  		r = ctx.Set(r, "api_key", ak)
   108  		handler.ServeHTTP(w, r)
   109  	})
   110  }
   111  
   112  // RequireLogin checks to see if the user is currently logged in.
   113  // If not, the function returns a 302 redirect to the login page.
   114  func RequireLogin(handler http.Handler) http.HandlerFunc {
   115  	return func(w http.ResponseWriter, r *http.Request) {
   116  		if u := ctx.Get(r, "user"); u != nil {
   117  			// If a password change is required for the user, then redirect them
   118  			// to the login page
   119  			currentUser := u.(models.User)
   120  			if currentUser.PasswordChangeRequired && r.URL.Path != "/reset_password" {
   121  				q := r.URL.Query()
   122  				q.Set("next", r.URL.Path)
   123  				http.Redirect(w, r, fmt.Sprintf("/reset_password?%s", q.Encode()), http.StatusTemporaryRedirect)
   124  				return
   125  			}
   126  			handler.ServeHTTP(w, r)
   127  			return
   128  		}
   129  		q := r.URL.Query()
   130  		q.Set("next", r.URL.Path)
   131  		http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
   132  	}
   133  }
   134  
   135  // EnforceViewOnly is a global middleware that limits the ability to edit
   136  // objects to accounts with the PermissionModifyObjects permission.
   137  func EnforceViewOnly(next http.Handler) http.Handler {
   138  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   139  		// If the request is for any non-GET HTTP method, e.g. POST, PUT,
   140  		// or DELETE, we need to ensure the user has the appropriate
   141  		// permission.
   142  		if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
   143  			user := ctx.Get(r, "user").(models.User)
   144  			access, err := user.HasPermission(models.PermissionModifyObjects)
   145  			if err != nil {
   146  				http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   147  				return
   148  			}
   149  			if !access {
   150  				http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
   151  				return
   152  			}
   153  		}
   154  		next.ServeHTTP(w, r)
   155  	})
   156  }
   157  
   158  // RequirePermission checks to see if the user has the requested permission
   159  // before executing the handler. If the request is unauthorized, a JSONError
   160  // is returned.
   161  func RequirePermission(perm string) func(http.Handler) http.HandlerFunc {
   162  	return func(next http.Handler) http.HandlerFunc {
   163  		return func(w http.ResponseWriter, r *http.Request) {
   164  			user := ctx.Get(r, "user").(models.User)
   165  			access, err := user.HasPermission(perm)
   166  			if err != nil {
   167  				JSONError(w, http.StatusInternalServerError, err.Error())
   168  				return
   169  			}
   170  			if !access {
   171  				JSONError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
   172  				return
   173  			}
   174  			next.ServeHTTP(w, r)
   175  		}
   176  	}
   177  }
   178  
   179  // ApplySecurityHeaders applies various security headers according to best-
   180  // practices.
   181  func ApplySecurityHeaders(next http.Handler) http.HandlerFunc {
   182  	return func(w http.ResponseWriter, r *http.Request) {
   183  		csp := "frame-ancestors 'none';"
   184  		w.Header().Set("Content-Security-Policy", csp)
   185  		w.Header().Set("X-Frame-Options", "DENY")
   186  		next.ServeHTTP(w, r)
   187  	}
   188  }
   189  
   190  // JSONError returns an error in JSON format with the given
   191  // status code and message
   192  func JSONError(w http.ResponseWriter, c int, m string) {
   193  	cj, _ := json.MarshalIndent(models.Response{Success: false, Message: m}, "", "  ")
   194  	w.Header().Set("Content-Type", "application/json")
   195  	w.WriteHeader(c)
   196  	fmt.Fprintf(w, "%s", cj)
   197  }