github.com/Ne0nd0g/gophish@v0.7.1-0.20190220040016-11493024a07d/middleware/middleware.go (about)

     1  package middleware
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/gophish/gophish/auth"
    10  	ctx "github.com/gophish/gophish/context"
    11  	"github.com/gophish/gophish/models"
    12  	"github.com/gorilla/csrf"
    13  )
    14  
    15  // CSRFExemptPrefixes are a list of routes that are exempt from CSRF protection
    16  var CSRFExemptPrefixes = []string{
    17  	"/api",
    18  }
    19  
    20  // CSRFExceptions is a middleware that prevents CSRF checks on routes listed in
    21  // CSRFExemptPrefixes.
    22  func CSRFExceptions(handler http.Handler) http.HandlerFunc {
    23  	return func(w http.ResponseWriter, r *http.Request) {
    24  		for _, prefix := range CSRFExemptPrefixes {
    25  			if strings.HasPrefix(r.URL.Path, prefix) {
    26  				r = csrf.UnsafeSkipCheck(r)
    27  				break
    28  			}
    29  		}
    30  		handler.ServeHTTP(w, r)
    31  	}
    32  }
    33  
    34  // GetContext wraps each request in a function which fills in the context for a given request.
    35  // This includes setting the User and Session keys and values as necessary for use in later functions.
    36  func GetContext(handler http.Handler) http.HandlerFunc {
    37  	// Set the context here
    38  	return func(w http.ResponseWriter, r *http.Request) {
    39  		// Parse the request form
    40  		err := r.ParseForm()
    41  		if err != nil {
    42  			http.Error(w, "Error parsing request", http.StatusInternalServerError)
    43  		}
    44  		// Set the context appropriately here.
    45  		// Set the session
    46  		session, _ := auth.Store.Get(r, "gophish")
    47  		// Put the session in the context so that we can
    48  		// reuse the values in different handlers
    49  		r = ctx.Set(r, "session", session)
    50  		if id, ok := session.Values["id"]; ok {
    51  			u, err := models.GetUser(id.(int64))
    52  			if err != nil {
    53  				r = ctx.Set(r, "user", nil)
    54  			} else {
    55  				r = ctx.Set(r, "user", u)
    56  			}
    57  		} else {
    58  			r = ctx.Set(r, "user", nil)
    59  		}
    60  		handler.ServeHTTP(w, r)
    61  		// Remove context contents
    62  		ctx.Clear(r)
    63  	}
    64  }
    65  
    66  // RequireAPIKey ensures that a valid API key is set as either the api_key GET
    67  // parameter, or a Bearer token.
    68  func RequireAPIKey(handler http.Handler) http.Handler {
    69  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    70  		w.Header().Set("Access-Control-Allow-Origin", "*")
    71  		if r.Method == "OPTIONS" {
    72  			w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
    73  			w.Header().Set("Access-Control-Max-Age", "1000")
    74  			w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept")
    75  			return
    76  		}
    77  		r.ParseForm()
    78  		ak := r.Form.Get("api_key")
    79  		// If we can't get the API key, we'll also check for the
    80  		// Authorization Bearer token
    81  		if ak == "" {
    82  			tokens, ok := r.Header["Authorization"]
    83  			if ok && len(tokens) >= 1 {
    84  				ak = tokens[0]
    85  				ak = strings.TrimPrefix(ak, "Bearer ")
    86  			}
    87  		}
    88  		if ak == "" {
    89  			JSONError(w, http.StatusUnauthorized, "API Key not set")
    90  			return
    91  		}
    92  		u, err := models.GetUserByAPIKey(ak)
    93  		if err != nil {
    94  			JSONError(w, http.StatusUnauthorized, "Invalid API Key")
    95  			return
    96  		}
    97  		r = ctx.Set(r, "user", u)
    98  		r = ctx.Set(r, "user_id", u.Id)
    99  		r = ctx.Set(r, "api_key", ak)
   100  		handler.ServeHTTP(w, r)
   101  	})
   102  }
   103  
   104  // RequireLogin checks to see if the user is currently logged in.
   105  // If not, the function returns a 302 redirect to the login page.
   106  func RequireLogin(handler http.Handler) http.HandlerFunc {
   107  	return func(w http.ResponseWriter, r *http.Request) {
   108  		if u := ctx.Get(r, "user"); u != nil {
   109  			handler.ServeHTTP(w, r)
   110  		} else {
   111  			q := r.URL.Query()
   112  			q.Set("next", r.URL.Path)
   113  			http.Redirect(w, r, fmt.Sprintf("/login?%s", q.Encode()), http.StatusTemporaryRedirect)
   114  		}
   115  	}
   116  }
   117  
   118  // EnforceViewOnly is a global middleware that limits the ability to edit
   119  // objects to accounts with the PermissionModifyObjects permission.
   120  func EnforceViewOnly(next http.Handler) http.Handler {
   121  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   122  		// If the request is for any non-GET HTTP method, e.g. POST, PUT,
   123  		// or DELETE, we need to ensure the user has the appropriate
   124  		// permission.
   125  		if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions {
   126  			user := ctx.Get(r, "user").(models.User)
   127  			access, err := user.HasPermission(models.PermissionModifyObjects)
   128  			if err != nil {
   129  				http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   130  				return
   131  			}
   132  			if !access {
   133  				http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
   134  				return
   135  			}
   136  		}
   137  		next.ServeHTTP(w, r)
   138  	})
   139  }
   140  
   141  // RequirePermission checks to see if the user has the requested permission
   142  // before executing the handler. If the request is unauthorized, a JSONError
   143  // is returned.
   144  func RequirePermission(perm string) func(http.Handler) http.HandlerFunc {
   145  	return func(next http.Handler) http.HandlerFunc {
   146  		return func(w http.ResponseWriter, r *http.Request) {
   147  			user := ctx.Get(r, "user").(models.User)
   148  			access, err := user.HasPermission(perm)
   149  			if err != nil {
   150  				JSONError(w, http.StatusInternalServerError, err.Error())
   151  				return
   152  			}
   153  			if !access {
   154  				JSONError(w, http.StatusForbidden, http.StatusText(http.StatusForbidden))
   155  				return
   156  			}
   157  			next.ServeHTTP(w, r)
   158  		}
   159  	}
   160  }
   161  
   162  // JSONError returns an error in JSON format with the given
   163  // status code and message
   164  func JSONError(w http.ResponseWriter, c int, m string) {
   165  	cj, _ := json.MarshalIndent(models.Response{Success: false, Message: m}, "", "  ")
   166  	w.Header().Set("Content-Type", "application/json")
   167  	w.WriteHeader(c)
   168  	fmt.Fprintf(w, "%s", cj)
   169  }