github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/middleware_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/gophish/gophish/config"
    10  	ctx "github.com/gophish/gophish/context"
    11  	"github.com/gophish/gophish/models"
    12  )
    13  
    14  var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    15  	w.Write([]byte("success"))
    16  })
    17  
    18  type testContext struct {
    19  	apiKey string
    20  }
    21  
    22  func setupTest(t *testing.T) *testContext {
    23  	conf := &config.Config{
    24  		DBName:         "sqlite3",
    25  		DBPath:         ":memory:",
    26  		MigrationsPath: "../db/db_sqlite3/migrations/",
    27  	}
    28  	err := models.Setup(conf)
    29  	if err != nil {
    30  		t.Fatalf("Failed creating database: %v", err)
    31  	}
    32  	// Get the API key to use for these tests
    33  	u, err := models.GetUser(1)
    34  	if err != nil {
    35  		t.Fatalf("error getting user: %v", err)
    36  	}
    37  	ctx := &testContext{}
    38  	ctx.apiKey = u.ApiKey
    39  	return ctx
    40  }
    41  
    42  // MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP
    43  // status code
    44  type MiddlewarePermissionTest map[string]int
    45  
    46  // TestEnforceViewOnly ensures that only users with the ModifyObjects
    47  // permission have the ability to send non-GET requests.
    48  func TestEnforceViewOnly(t *testing.T) {
    49  	setupTest(t)
    50  	permissionTests := map[string]MiddlewarePermissionTest{
    51  		models.RoleAdmin: MiddlewarePermissionTest{
    52  			http.MethodGet:     http.StatusOK,
    53  			http.MethodHead:    http.StatusOK,
    54  			http.MethodOptions: http.StatusOK,
    55  			http.MethodPost:    http.StatusOK,
    56  			http.MethodPut:     http.StatusOK,
    57  			http.MethodDelete:  http.StatusOK,
    58  		},
    59  		models.RoleUser: MiddlewarePermissionTest{
    60  			http.MethodGet:     http.StatusOK,
    61  			http.MethodHead:    http.StatusOK,
    62  			http.MethodOptions: http.StatusOK,
    63  			http.MethodPost:    http.StatusOK,
    64  			http.MethodPut:     http.StatusOK,
    65  			http.MethodDelete:  http.StatusOK,
    66  		},
    67  	}
    68  	for r, checks := range permissionTests {
    69  		role, err := models.GetRoleBySlug(r)
    70  		if err != nil {
    71  			t.Fatalf("error getting role by slug: %v", err)
    72  		}
    73  
    74  		for method, expected := range checks {
    75  			req := httptest.NewRequest(method, "/", nil)
    76  			response := httptest.NewRecorder()
    77  
    78  			req = ctx.Set(req, "user", models.User{
    79  				Role:   role,
    80  				RoleID: role.ID,
    81  			})
    82  
    83  			EnforceViewOnly(successHandler).ServeHTTP(response, req)
    84  			got := response.Code
    85  			if got != expected {
    86  				t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
    87  			}
    88  		}
    89  	}
    90  }
    91  
    92  func TestRequirePermission(t *testing.T) {
    93  	setupTest(t)
    94  	middleware := RequirePermission(models.PermissionModifySystem)
    95  	handler := middleware(successHandler)
    96  
    97  	permissionTests := map[string]int{
    98  		models.RoleUser:  http.StatusForbidden,
    99  		models.RoleAdmin: http.StatusOK,
   100  	}
   101  
   102  	for role, expected := range permissionTests {
   103  		req := httptest.NewRequest(http.MethodGet, "/", nil)
   104  		response := httptest.NewRecorder()
   105  		// Test that with the requested permission, the request succeeds
   106  		role, err := models.GetRoleBySlug(role)
   107  		if err != nil {
   108  			t.Fatalf("error getting role by slug: %v", err)
   109  		}
   110  		req = ctx.Set(req, "user", models.User{
   111  			Role:   role,
   112  			RoleID: role.ID,
   113  		})
   114  		handler.ServeHTTP(response, req)
   115  		got := response.Code
   116  		if got != expected {
   117  			t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
   118  		}
   119  	}
   120  }
   121  
   122  func TestRequireAPIKey(t *testing.T) {
   123  	setupTest(t)
   124  	req := httptest.NewRequest(http.MethodGet, "/", nil)
   125  	req.Header.Set("Content-Type", "application/json")
   126  	response := httptest.NewRecorder()
   127  	// Test that making a request without an API key is denied
   128  	RequireAPIKey(successHandler).ServeHTTP(response, req)
   129  	expected := http.StatusUnauthorized
   130  	got := response.Code
   131  	if got != expected {
   132  		t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
   133  	}
   134  }
   135  
   136  func TestCORSHeaders(t *testing.T) {
   137  	setupTest(t)
   138  	req := httptest.NewRequest(http.MethodOptions, "/", nil)
   139  	response := httptest.NewRecorder()
   140  	RequireAPIKey(successHandler).ServeHTTP(response, req)
   141  	expected := "POST, GET, OPTIONS, PUT, DELETE"
   142  	got := response.Result().Header.Get("Access-Control-Allow-Methods")
   143  	if got != expected {
   144  		t.Fatalf("incorrect cors options received. expected %s got %s", expected, got)
   145  	}
   146  }
   147  
   148  func TestInvalidAPIKey(t *testing.T) {
   149  	setupTest(t)
   150  	req := httptest.NewRequest(http.MethodGet, "/", nil)
   151  	query := req.URL.Query()
   152  	query.Set("api_key", "bogus-api-key")
   153  	req.URL.RawQuery = query.Encode()
   154  	req.Header.Set("Content-Type", "application/json")
   155  	response := httptest.NewRecorder()
   156  	RequireAPIKey(successHandler).ServeHTTP(response, req)
   157  	expected := http.StatusUnauthorized
   158  	got := response.Code
   159  	if got != expected {
   160  		t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
   161  	}
   162  }
   163  
   164  func TestBearerToken(t *testing.T) {
   165  	testCtx := setupTest(t)
   166  	req := httptest.NewRequest(http.MethodGet, "/", nil)
   167  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey))
   168  	req.Header.Set("Content-Type", "application/json")
   169  	response := httptest.NewRecorder()
   170  	RequireAPIKey(successHandler).ServeHTTP(response, req)
   171  	expected := http.StatusOK
   172  	got := response.Code
   173  	if got != expected {
   174  		t.Fatalf("incorrect status code received. expected %d got %d", expected, got)
   175  	}
   176  }
   177  
   178  func TestPasswordResetRequired(t *testing.T) {
   179  	req := httptest.NewRequest(http.MethodGet, "/", nil)
   180  	req = ctx.Set(req, "user", models.User{
   181  		PasswordChangeRequired: true,
   182  	})
   183  	response := httptest.NewRecorder()
   184  	RequireLogin(successHandler).ServeHTTP(response, req)
   185  	gotStatus := response.Code
   186  	expectedStatus := http.StatusTemporaryRedirect
   187  	if gotStatus != expectedStatus {
   188  		t.Fatalf("incorrect status code received. expected %d got %d", expectedStatus, gotStatus)
   189  	}
   190  	expectedLocation := "/reset_password?next=%2F"
   191  	gotLocation := response.Header().Get("Location")
   192  	if gotLocation != expectedLocation {
   193  		t.Fatalf("incorrect location header received. expected %s got %s", expectedLocation, gotLocation)
   194  	}
   195  }
   196  
   197  func TestApplySecurityHeaders(t *testing.T) {
   198  	expected := map[string]string{
   199  		"Content-Security-Policy": "frame-ancestors 'none';",
   200  		"X-Frame-Options":         "DENY",
   201  	}
   202  	req := httptest.NewRequest(http.MethodGet, "/", nil)
   203  	response := httptest.NewRecorder()
   204  	ApplySecurityHeaders(successHandler).ServeHTTP(response, req)
   205  	for header, value := range expected {
   206  		got := response.Header().Get(header)
   207  		if got != value {
   208  			t.Fatalf("incorrect security header received for %s: expected %s got %s", header, value, got)
   209  		}
   210  	}
   211  }