github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/middlewares/csrf_test.go (about)

     1  package middlewares
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"net/url"
     7  	"strings"
     8  	"testing"
     9  
    10  	"github.com/cozy/cozy-stack/pkg/assets/dynamic"
    11  	"github.com/cozy/cozy-stack/pkg/config/config"
    12  	"github.com/cozy/cozy-stack/pkg/utils"
    13  	"github.com/cozy/cozy-stack/tests/testutils"
    14  	"github.com/labstack/echo/v4"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  func TestCsrf(t *testing.T) {
    20  	if testing.Short() {
    21  		t.Skip("an instance is required for this test: test skipped due to the use of --short flag")
    22  	}
    23  
    24  	config.UseTestFile(t)
    25  	config.GetConfig().Assets = "../../assets"
    26  	setup := testutils.NewSetup(t, t.Name())
    27  
    28  	setup.SetupSwiftTest()
    29  	require.NoError(t, dynamic.InitDynamicAssetFS(config.FsURL().String()), "Could not init dynamic FS")
    30  
    31  	t.Run("CSRF", func(t *testing.T) {
    32  		e := echo.New()
    33  		req := httptest.NewRequest(http.MethodGet, "/", nil)
    34  		rec := httptest.NewRecorder()
    35  		c := e.NewContext(req, rec)
    36  		csrf := CSRFWithConfig(CSRFConfig{
    37  			TokenLength: 16,
    38  		})
    39  		h := csrf(func(c echo.Context) error {
    40  			return c.String(http.StatusOK, "test")
    41  		})
    42  
    43  		// Generate CSRF token
    44  		assert.NoError(t, h(c))
    45  		assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
    46  
    47  		// Without CSRF cookie
    48  		req = httptest.NewRequest(http.MethodPost, "/", nil)
    49  		rec = httptest.NewRecorder()
    50  		c = e.NewContext(req, rec)
    51  		assert.Error(t, h(c))
    52  
    53  		// Empty/invalid CSRF token
    54  		req = httptest.NewRequest(http.MethodPost, "/", nil)
    55  		rec = httptest.NewRecorder()
    56  		c = e.NewContext(req, rec)
    57  		req.Header.Set(echo.HeaderXCSRFToken, "")
    58  		assert.Error(t, h(c))
    59  
    60  		// Valid CSRF token
    61  		token := utils.RandomString(16)
    62  		req.Header.Set(echo.HeaderCookie, "_csrf="+token)
    63  		req.Header.Set(echo.HeaderXCSRFToken, token)
    64  		if assert.NoError(t, h(c)) {
    65  			assert.Equal(t, http.StatusOK, rec.Code)
    66  		}
    67  	})
    68  
    69  	t.Run("CSRFTokenFromForm", func(t *testing.T) {
    70  		f := make(url.Values)
    71  		f.Set("csrf", "token")
    72  		e := echo.New()
    73  		req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
    74  		req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
    75  		c := e.NewContext(req, nil)
    76  		token, err := csrfTokenFromForm("csrf")(c)
    77  		if assert.NoError(t, err) {
    78  			assert.Equal(t, "token", token)
    79  		}
    80  		_, err = csrfTokenFromForm("invalid")(c)
    81  		assert.Error(t, err)
    82  	})
    83  
    84  	t.Run("CSRFTokenFromQuery", func(t *testing.T) {
    85  		q := make(url.Values)
    86  		q.Set("csrf", "token")
    87  		e := echo.New()
    88  		req := httptest.NewRequest(http.MethodGet, "/", nil)
    89  		req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
    90  		req.URL.RawQuery = q.Encode()
    91  		c := e.NewContext(req, nil)
    92  		token, err := csrfTokenFromQuery("csrf")(c)
    93  		if assert.NoError(t, err) {
    94  			assert.Equal(t, "token", token)
    95  		}
    96  		_, err = csrfTokenFromQuery("invalid")(c)
    97  		assert.Error(t, err)
    98  		csrfTokenFromQuery("csrf")
    99  	})
   100  }