github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/middlewares/csrf_test.go (about) 1 package middlewares 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "net/url" 7 "strings" 8 "testing" 9 10 "github.com/cozy/cozy-stack/pkg/assets/dynamic" 11 "github.com/cozy/cozy-stack/pkg/config/config" 12 "github.com/cozy/cozy-stack/pkg/utils" 13 "github.com/cozy/cozy-stack/tests/testutils" 14 "github.com/labstack/echo/v4" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/require" 17 ) 18 19 func TestCsrf(t *testing.T) { 20 if testing.Short() { 21 t.Skip("an instance is required for this test: test skipped due to the use of --short flag") 22 } 23 24 config.UseTestFile(t) 25 config.GetConfig().Assets = "../../assets" 26 setup := testutils.NewSetup(t, t.Name()) 27 28 setup.SetupSwiftTest() 29 require.NoError(t, dynamic.InitDynamicAssetFS(config.FsURL().String()), "Could not init dynamic FS") 30 31 t.Run("CSRF", func(t *testing.T) { 32 e := echo.New() 33 req := httptest.NewRequest(http.MethodGet, "/", nil) 34 rec := httptest.NewRecorder() 35 c := e.NewContext(req, rec) 36 csrf := CSRFWithConfig(CSRFConfig{ 37 TokenLength: 16, 38 }) 39 h := csrf(func(c echo.Context) error { 40 return c.String(http.StatusOK, "test") 41 }) 42 43 // Generate CSRF token 44 assert.NoError(t, h(c)) 45 assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") 46 47 // Without CSRF cookie 48 req = httptest.NewRequest(http.MethodPost, "/", nil) 49 rec = httptest.NewRecorder() 50 c = e.NewContext(req, rec) 51 assert.Error(t, h(c)) 52 53 // Empty/invalid CSRF token 54 req = httptest.NewRequest(http.MethodPost, "/", nil) 55 rec = httptest.NewRecorder() 56 c = e.NewContext(req, rec) 57 req.Header.Set(echo.HeaderXCSRFToken, "") 58 assert.Error(t, h(c)) 59 60 // Valid CSRF token 61 token := utils.RandomString(16) 62 req.Header.Set(echo.HeaderCookie, "_csrf="+token) 63 req.Header.Set(echo.HeaderXCSRFToken, token) 64 if assert.NoError(t, h(c)) { 65 assert.Equal(t, http.StatusOK, rec.Code) 66 } 67 }) 68 69 t.Run("CSRFTokenFromForm", func(t *testing.T) { 70 f := make(url.Values) 71 f.Set("csrf", "token") 72 e := echo.New() 73 req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) 74 req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) 75 c := e.NewContext(req, nil) 76 token, err := csrfTokenFromForm("csrf")(c) 77 if assert.NoError(t, err) { 78 assert.Equal(t, "token", token) 79 } 80 _, err = csrfTokenFromForm("invalid")(c) 81 assert.Error(t, err) 82 }) 83 84 t.Run("CSRFTokenFromQuery", func(t *testing.T) { 85 q := make(url.Values) 86 q.Set("csrf", "token") 87 e := echo.New() 88 req := httptest.NewRequest(http.MethodGet, "/", nil) 89 req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) 90 req.URL.RawQuery = q.Encode() 91 c := e.NewContext(req, nil) 92 token, err := csrfTokenFromQuery("csrf")(c) 93 if assert.NoError(t, err) { 94 assert.Equal(t, "token", token) 95 } 96 _, err = csrfTokenFromQuery("invalid")(c) 97 assert.Error(t, err) 98 csrfTokenFromQuery("csrf") 99 }) 100 }