github.com/olivierlemoal/gophish@v0.9.0/middleware/middleware.go (about)

     1  package middleware
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	ctx "github.com/gophish/gophish/context"
    10  	"github.com/gophish/gophish/models"
    11  	"github.com/gorilla/csrf"
    12  )
    13  
    14  // CSRFExemptPrefixes are a list of routes that are exempt from CSRF protection
    15  var CSRFExemptPrefixes = []string{
    16  	"/api",
    17  }
    18  
    19  // CSRFExceptions is a middleware that prevents CSRF checks on routes listed in
    20  // CSRFExemptPrefixes.
    21  func CSRFExceptions(handler http.Handler) http.HandlerFunc {
    22  	return func(w http.ResponseWriter, r *http.Request) {
    23  		for _, prefix := range CSRFExemptPrefixes {
    24  			if strings.HasPrefix(r.URL.Path, prefix) {
    25  				r = csrf.UnsafeSkipCheck(r)
    26  				break
    27  			}
    28  		}
    29  		handler.ServeHTTP(w, r)
    30  	}
    31  }
    32  
    33  // Use allows us to stack middleware to process the request
    34  // Example taken from https://github.com/gorilla/mux/pull/36#issuecomment-25849172
    35  func Use(handler http.HandlerFunc, mid ...func(http.Handler) http.HandlerFunc) http.HandlerFunc {
    36  	for _, m := range mid {
    37  		handler = m(handler)
    38  	}
    39  	return handler
    40  }
    41  
    42  // GetContext wraps each request in a function which fills in the context for a given request.
    43  // This includes setting the User and Session keys and values as necessary for use in later functions.
    44  func GetContext(handler http.Handler) http.HandlerFunc {
    45  	// Set the context here
    46  	return func(w http.ResponseWriter, r *http.Request) {
    47  		// Parse the request form
    48  		err := r.ParseForm()
    49  		if err != nil {
    50  			http.Error(w, "Error parsing request", http.StatusInternalServerError)
    51  		}
    52  		// Set the context appropriately here.
    53  		// Set the session
    54  		session, _ := Store.Get(r, "gophish")
    55  		// Put the session in the context so that we can
    56  		// reuse the values in different handlers
    57  		r = ctx.Set(r, "session", session)
    58  		if id, ok := session.Values["id"]; ok {
    59  			u, err := models.GetUser(id.(int64))
    60  			if err != nil {
    61  				r = ctx.Set(r, "user", nil)
    62  			} else {
    63  				r = ctx.Set(r, "user", u)
    64  			}
    65  		} else {
    66  			r = ctx.Set(r, "user", nil)
    67  		}
    68  		handler.ServeHTTP(w, r)
    69  		// Remove context contents
    70  		ctx.Clear(r)
    71  	}
    72  }
    73  
    74  // RequireAPIKey ensures that a valid API key is set as either the api_key GET
    75  // parameter, or a Bearer token.
    76  func RequireAPIKey(handler http.Handler) http.Handler {
    77  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  		w.Header().Set("Access-Control-Allow-Origin", "*")
    79  		if r.Method == "OPTIONS" {
    80  			w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
    81  			w.Header().Set("Access-Control-Max-Age", "1000")
    82  			w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
    83  			return
    84  		}
    85  		r.ParseForm()
    86  		ak := r.Form.Get("api_key")
    87  		// If we can't get the API key, we'll also check for the
    88  		// Authorization Bearer token
    89  		if ak == "" {
    90  			tokens, ok := r.Header["Authorization"]
    91  			if ok && len(tokens) >= 1 {
    92  				ak = tokens[0]
    93  				ak = strings.TrimPrefix(ak, "Bearer ")
    94  			}
    95  		}
    96  		if ak == "" {
    97  			JSONError(w, http.StatusUnauthorized, "API Key not set")
    98  			return
    99  		}
   100  		u, err := models.GetUserByAPIKey(ak)
   101  		if err != nil {
   102  			JSONError(w, http.StatusUnauthorized, "Invalid API Key")
   103  			return
   104  		}
   105  		r = ctx.Set(r, "user", u)
   106  		r = ctx.Set(r, "user_id", u.Id)
   107  		r = ctx.Set(r, "api_key", ak)
   108  		handler.ServeHTTP(w, r)
   109  	})
   110  }
   111  
   112  // RequireLogin checks to see if the user is currently logged in.
   113  // If not, the function returns a 302 redirect to the login page.
   114  func RequireLogin(handler http.Handler) http.HandlerFunc {
   115  	return func(w http.ResponseWriter, r *http.Request) {
   116  		if u := ctx.Get(r, "user"); u != nil {
   117  			handler.ServeHTTP(w, r)
   118  			return
   119  		}
   120  		q := r.URL.Query()
   121  		q.Set("next", r.URL.Path)
   122  		http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
   123  		return
   124  	}
   125  }
   126  
   127  // EnforceViewOnly is a global middleware that limits the ability to edit
   128  // objects to accounts with the PermissionModifyObjects permission.
   129  func EnforceViewOnly(next http.Handler) http.Handler {
   130  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   131  		// If the request is for any non-GET HTTP method, e.g. POST, PUT,
   132  		// or DELETE, we need to ensure the user has the appropriate
   133  		// permission.
   134  		if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
   135  			user := ctx.Get(r, "user").(models.User)
   136  			access, err := user.HasPermission(models.PermissionModifyObjects)
   137  			if err != nil {
   138  				http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   139  				return
   140  			}
   141  			if !access {
   142  				http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
   143  				return
   144  			}
   145  		}
   146  		next.ServeHTTP(w, r)
   147  	})
   148  }
   149  
   150  // RequirePermission checks to see if the user has the requested permission
   151  // before executing the handler. If the request is unauthorized, a JSONError
   152  // is returned.
   153  func RequirePermission(perm string) func(http.Handler) http.HandlerFunc {
   154  	return func(next http.Handler) http.HandlerFunc {
   155  		return func(w http.ResponseWriter, r *http.Request) {
   156  			user := ctx.Get(r, "user").(models.User)
   157  			access, err := user.HasPermission(perm)
   158  			if err != nil {
   159  				JSONError(w, http.StatusInternalServerError, err.Error())
   160  				return
   161  			}
   162  			if !access {
   163  				JSONError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
   164  				return
   165  			}
   166  			next.ServeHTTP(w, r)
   167  		}
   168  	}
   169  }
   170  
   171  // JSONError returns an error in JSON format with the given
   172  // status code and message
   173  func JSONError(w http.ResponseWriter, c int, m string) {
   174  	cj, _ := json.MarshalIndent(models.Response{Success: false, Message: m}, "", "  ")
   175  	w.Header().Set("Content-Type", "application/json")
   176  	w.WriteHeader(c)
   177  	fmt.Fprintf(w, "%s", cj)
   178  }