github.com/lunarobliq/gophish@v0.8.1-0.20230523153303-93511002234d/middleware/middleware_test.go (about) 1 package middleware 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 "github.com/gophish/gophish/config" 10 ctx "github.com/gophish/gophish/context" 11 "github.com/gophish/gophish/models" 12 ) 13 14 var successHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 15 w.Write([]byte("success")) 16 }) 17 18 type testContext struct { 19 apiKey string 20 } 21 22 func setupTest(t *testing.T) *testContext { 23 conf := &config.Config{ 24 DBName: "sqlite3", 25 DBPath: ":memory:", 26 MigrationsPath: "../db/db_sqlite3/migrations/", 27 } 28 err := models.Setup(conf) 29 if err != nil { 30 t.Fatalf("Failed creating database: %v", err) 31 } 32 // Get the API key to use for these tests 33 u, err := models.GetUser(1) 34 if err != nil { 35 t.Fatalf("error getting user: %v", err) 36 } 37 ctx := &testContext{} 38 ctx.apiKey = u.ApiKey 39 return ctx 40 } 41 42 // MiddlewarePermissionTest maps an expected HTTP Method to an expected HTTP 43 // status code 44 type MiddlewarePermissionTest map[string]int 45 46 // TestEnforceViewOnly ensures that only users with the ModifyObjects 47 // permission have the ability to send non-GET requests. 48 func TestEnforceViewOnly(t *testing.T) { 49 setupTest(t) 50 permissionTests := map[string]MiddlewarePermissionTest{ 51 models.RoleAdmin: MiddlewarePermissionTest{ 52 http.MethodGet: http.StatusOK, 53 http.MethodHead: http.StatusOK, 54 http.MethodOptions: http.StatusOK, 55 http.MethodPost: http.StatusOK, 56 http.MethodPut: http.StatusOK, 57 http.MethodDelete: http.StatusOK, 58 }, 59 models.RoleUser: MiddlewarePermissionTest{ 60 http.MethodGet: http.StatusOK, 61 http.MethodHead: http.StatusOK, 62 http.MethodOptions: http.StatusOK, 63 http.MethodPost: http.StatusOK, 64 http.MethodPut: http.StatusOK, 65 http.MethodDelete: http.StatusOK, 66 }, 67 } 68 for r, checks := range permissionTests { 69 role, err := models.GetRoleBySlug(r) 70 if err != nil { 71 t.Fatalf("error getting role by slug: %v", err) 72 } 73 74 for method, expected := range checks { 75 req := httptest.NewRequest(method, "/", nil) 76 response := httptest.NewRecorder() 77 78 req = ctx.Set(req, "user", models.User{ 79 Role: role, 80 RoleID: role.ID, 81 }) 82 83 EnforceViewOnly(successHandler).ServeHTTP(response, req) 84 got := response.Code 85 if got != expected { 86 t.Fatalf("incorrect status code received. expected %d got %d", expected, got) 87 } 88 } 89 } 90 } 91 92 func TestRequirePermission(t *testing.T) { 93 setupTest(t) 94 middleware := RequirePermission(models.PermissionModifySystem) 95 handler := middleware(successHandler) 96 97 permissionTests := map[string]int{ 98 models.RoleUser: http.StatusForbidden, 99 models.RoleAdmin: http.StatusOK, 100 } 101 102 for role, expected := range permissionTests { 103 req := httptest.NewRequest(http.MethodGet, "/", nil) 104 response := httptest.NewRecorder() 105 // Test that with the requested permission, the request succeeds 106 role, err := models.GetRoleBySlug(role) 107 if err != nil { 108 t.Fatalf("error getting role by slug: %v", err) 109 } 110 req = ctx.Set(req, "user", models.User{ 111 Role: role, 112 RoleID: role.ID, 113 }) 114 handler.ServeHTTP(response, req) 115 got := response.Code 116 if got != expected { 117 t.Fatalf("incorrect status code received. expected %d got %d", expected, got) 118 } 119 } 120 } 121 122 func TestRequireAPIKey(t *testing.T) { 123 setupTest(t) 124 req := httptest.NewRequest(http.MethodGet, "/", nil) 125 req.Header.Set("Content-Type", "application/json") 126 response := httptest.NewRecorder() 127 // Test that making a request without an API key is denied 128 RequireAPIKey(successHandler).ServeHTTP(response, req) 129 expected := http.StatusUnauthorized 130 got := response.Code 131 if got != expected { 132 t.Fatalf("incorrect status code received. expected %d got %d", expected, got) 133 } 134 } 135 136 func TestCORSHeaders(t *testing.T) { 137 setupTest(t) 138 req := httptest.NewRequest(http.MethodOptions, "/", nil) 139 response := httptest.NewRecorder() 140 RequireAPIKey(successHandler).ServeHTTP(response, req) 141 expected := "POST, GET, OPTIONS, PUT, DELETE" 142 got := response.Result().Header.Get("Access-Control-Allow-Methods") 143 if got != expected { 144 t.Fatalf("incorrect cors options received. expected %s got %s", expected, got) 145 } 146 } 147 148 func TestInvalidAPIKey(t *testing.T) { 149 setupTest(t) 150 req := httptest.NewRequest(http.MethodGet, "/", nil) 151 query := req.URL.Query() 152 query.Set("api_key", "bogus-api-key") 153 req.URL.RawQuery = query.Encode() 154 req.Header.Set("Content-Type", "application/json") 155 response := httptest.NewRecorder() 156 RequireAPIKey(successHandler).ServeHTTP(response, req) 157 expected := http.StatusUnauthorized 158 got := response.Code 159 if got != expected { 160 t.Fatalf("incorrect status code received. expected %d got %d", expected, got) 161 } 162 } 163 164 func TestBearerToken(t *testing.T) { 165 testCtx := setupTest(t) 166 req := httptest.NewRequest(http.MethodGet, "/", nil) 167 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", testCtx.apiKey)) 168 req.Header.Set("Content-Type", "application/json") 169 response := httptest.NewRecorder() 170 RequireAPIKey(successHandler).ServeHTTP(response, req) 171 expected := http.StatusOK 172 got := response.Code 173 if got != expected { 174 t.Fatalf("incorrect status code received. expected %d got %d", expected, got) 175 } 176 } 177 178 func TestPasswordResetRequired(t *testing.T) { 179 req := httptest.NewRequest(http.MethodGet, "/", nil) 180 req = ctx.Set(req, "user", models.User{ 181 PasswordChangeRequired: true, 182 }) 183 response := httptest.NewRecorder() 184 RequireLogin(successHandler).ServeHTTP(response, req) 185 gotStatus := response.Code 186 expectedStatus := http.StatusTemporaryRedirect 187 if gotStatus != expectedStatus { 188 t.Fatalf("incorrect status code received. expected %d got %d", expectedStatus, gotStatus) 189 } 190 expectedLocation := "/reset_password?next=%2F" 191 gotLocation := response.Header().Get("Location") 192 if gotLocation != expectedLocation { 193 t.Fatalf("incorrect location header received. expected %s got %s", expectedLocation, gotLocation) 194 } 195 } 196 197 func TestApplySecurityHeaders(t *testing.T) { 198 expected := map[string]string{ 199 "Content-Security-Policy": "frame-ancestors 'none';", 200 "X-Frame-Options": "DENY", 201 } 202 req := httptest.NewRequest(http.MethodGet, "/", nil) 203 response := httptest.NewRecorder() 204 ApplySecurityHeaders(successHandler).ServeHTTP(response, req) 205 for header, value := range expected { 206 got := response.Header().Get(header) 207 if got != value { 208 t.Fatalf("incorrect security header received for %s: expected %s got %s", header, value, got) 209 } 210 } 211 }