github.com/gofiber/fiber/v2@v2.47.0/middleware/csrf/csrf_test.go (about)

     1  package csrf
     2  
     3  import (
     4  	"net/http/httptest"
     5  	"strings"
     6  	"testing"
     7  
     8  	"github.com/gofiber/fiber/v2"
     9  	"github.com/gofiber/fiber/v2/utils"
    10  
    11  	"github.com/valyala/fasthttp"
    12  )
    13  
    14  func Test_CSRF(t *testing.T) {
    15  	t.Parallel()
    16  	app := fiber.New()
    17  
    18  	app.Use(New())
    19  
    20  	app.Post("/", func(c *fiber.Ctx) error {
    21  		return c.SendStatus(fiber.StatusOK)
    22  	})
    23  
    24  	h := app.Handler()
    25  	ctx := &fasthttp.RequestCtx{}
    26  
    27  	methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace}
    28  
    29  	for _, method := range methods {
    30  		// Generate CSRF token
    31  		ctx.Request.Header.SetMethod(method)
    32  		h(ctx)
    33  
    34  		// Without CSRF cookie
    35  		ctx.Request.Reset()
    36  		ctx.Response.Reset()
    37  		ctx.Request.Header.SetMethod(fiber.MethodPost)
    38  		h(ctx)
    39  		utils.AssertEqual(t, 403, ctx.Response.StatusCode())
    40  
    41  		// Empty/invalid CSRF token
    42  		ctx.Request.Reset()
    43  		ctx.Response.Reset()
    44  		ctx.Request.Header.SetMethod(fiber.MethodPost)
    45  		ctx.Request.Header.Set(HeaderName, "johndoe")
    46  		h(ctx)
    47  		utils.AssertEqual(t, 403, ctx.Response.StatusCode())
    48  
    49  		// Valid CSRF token
    50  		ctx.Request.Reset()
    51  		ctx.Response.Reset()
    52  		ctx.Request.Header.SetMethod(method)
    53  		h(ctx)
    54  		token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
    55  		token = strings.Split(strings.Split(token, ";")[0], "=")[1]
    56  
    57  		ctx.Request.Reset()
    58  		ctx.Response.Reset()
    59  		ctx.Request.Header.SetMethod(fiber.MethodPost)
    60  		ctx.Request.Header.Set(HeaderName, token)
    61  		h(ctx)
    62  		utils.AssertEqual(t, 200, ctx.Response.StatusCode())
    63  	}
    64  }
    65  
    66  // go test -run Test_CSRF_Next
    67  func Test_CSRF_Next(t *testing.T) {
    68  	t.Parallel()
    69  	app := fiber.New()
    70  	app.Use(New(Config{
    71  		Next: func(_ *fiber.Ctx) bool {
    72  			return true
    73  		},
    74  	}))
    75  
    76  	resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
    77  	utils.AssertEqual(t, nil, err)
    78  	utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
    79  }
    80  
    81  func Test_CSRF_Invalid_KeyLookup(t *testing.T) {
    82  	t.Parallel()
    83  	defer func() {
    84  		utils.AssertEqual(t, "[CSRF] KeyLookup must in the form of <source>:<key>", recover())
    85  	}()
    86  	app := fiber.New()
    87  
    88  	app.Use(New(Config{KeyLookup: "I:am:invalid"}))
    89  
    90  	app.Post("/", func(c *fiber.Ctx) error {
    91  		return c.SendStatus(fiber.StatusOK)
    92  	})
    93  
    94  	h := app.Handler()
    95  	ctx := &fasthttp.RequestCtx{}
    96  	ctx.Request.Header.SetMethod(fiber.MethodGet)
    97  	h(ctx)
    98  }
    99  
   100  func Test_CSRF_From_Form(t *testing.T) {
   101  	t.Parallel()
   102  	app := fiber.New()
   103  
   104  	app.Use(New(Config{KeyLookup: "form:_csrf"}))
   105  
   106  	app.Post("/", func(c *fiber.Ctx) error {
   107  		return c.SendStatus(fiber.StatusOK)
   108  	})
   109  
   110  	h := app.Handler()
   111  	ctx := &fasthttp.RequestCtx{}
   112  
   113  	// Invalid CSRF token
   114  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   115  	ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
   116  	h(ctx)
   117  	utils.AssertEqual(t, 403, ctx.Response.StatusCode())
   118  
   119  	// Generate CSRF token
   120  	ctx.Request.Reset()
   121  	ctx.Response.Reset()
   122  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   123  	h(ctx)
   124  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   125  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   126  
   127  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   128  	ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm)
   129  	ctx.Request.SetBodyString("_csrf=" + token)
   130  	h(ctx)
   131  	utils.AssertEqual(t, 200, ctx.Response.StatusCode())
   132  }
   133  
   134  func Test_CSRF_From_Query(t *testing.T) {
   135  	t.Parallel()
   136  	app := fiber.New()
   137  
   138  	app.Use(New(Config{KeyLookup: "query:_csrf"}))
   139  
   140  	app.Post("/", func(c *fiber.Ctx) error {
   141  		return c.SendStatus(fiber.StatusOK)
   142  	})
   143  
   144  	h := app.Handler()
   145  	ctx := &fasthttp.RequestCtx{}
   146  
   147  	// Invalid CSRF token
   148  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   149  	ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID())
   150  	h(ctx)
   151  	utils.AssertEqual(t, 403, ctx.Response.StatusCode())
   152  
   153  	// Generate CSRF token
   154  	ctx.Request.Reset()
   155  	ctx.Response.Reset()
   156  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   157  	ctx.Request.SetRequestURI("/")
   158  	h(ctx)
   159  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   160  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   161  
   162  	ctx.Request.Reset()
   163  	ctx.Response.Reset()
   164  	ctx.Request.SetRequestURI("/?_csrf=" + token)
   165  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   166  	h(ctx)
   167  	utils.AssertEqual(t, 200, ctx.Response.StatusCode())
   168  	utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
   169  }
   170  
   171  func Test_CSRF_From_Param(t *testing.T) {
   172  	t.Parallel()
   173  	app := fiber.New()
   174  
   175  	csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"}))
   176  
   177  	csrfGroup.Post("/", func(c *fiber.Ctx) error {
   178  		return c.SendStatus(fiber.StatusOK)
   179  	})
   180  
   181  	h := app.Handler()
   182  	ctx := &fasthttp.RequestCtx{}
   183  
   184  	// Invalid CSRF token
   185  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   186  	ctx.Request.SetRequestURI("/" + utils.UUID())
   187  	h(ctx)
   188  	utils.AssertEqual(t, 403, ctx.Response.StatusCode())
   189  
   190  	// Generate CSRF token
   191  	ctx.Request.Reset()
   192  	ctx.Response.Reset()
   193  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   194  	ctx.Request.SetRequestURI("/" + utils.UUID())
   195  	h(ctx)
   196  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   197  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   198  
   199  	ctx.Request.Reset()
   200  	ctx.Response.Reset()
   201  	ctx.Request.SetRequestURI("/" + token)
   202  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   203  	h(ctx)
   204  	utils.AssertEqual(t, 200, ctx.Response.StatusCode())
   205  	utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
   206  }
   207  
   208  func Test_CSRF_From_Cookie(t *testing.T) {
   209  	t.Parallel()
   210  	app := fiber.New()
   211  
   212  	csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"}))
   213  
   214  	csrfGroup.Post("/", func(c *fiber.Ctx) error {
   215  		return c.SendStatus(fiber.StatusOK)
   216  	})
   217  
   218  	h := app.Handler()
   219  	ctx := &fasthttp.RequestCtx{}
   220  
   221  	// Invalid CSRF token
   222  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   223  	ctx.Request.SetRequestURI("/")
   224  	ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";")
   225  	h(ctx)
   226  	utils.AssertEqual(t, 403, ctx.Response.StatusCode())
   227  
   228  	// Generate CSRF token
   229  	ctx.Request.Reset()
   230  	ctx.Response.Reset()
   231  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   232  	ctx.Request.SetRequestURI("/")
   233  	h(ctx)
   234  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   235  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   236  
   237  	ctx.Request.Reset()
   238  	ctx.Response.Reset()
   239  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   240  	ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";")
   241  	ctx.Request.SetRequestURI("/")
   242  	h(ctx)
   243  	utils.AssertEqual(t, 200, ctx.Response.StatusCode())
   244  	utils.AssertEqual(t, "OK", string(ctx.Response.Body()))
   245  }
   246  
   247  func Test_CSRF_From_Custom(t *testing.T) {
   248  	t.Parallel()
   249  	app := fiber.New()
   250  
   251  	extractor := func(c *fiber.Ctx) (string, error) {
   252  		body := string(c.Body())
   253  		// Generate the correct extractor to get the token from the correct location
   254  		selectors := strings.Split(body, "=")
   255  
   256  		if len(selectors) != 2 || selectors[1] == "" {
   257  			return "", errMissingParam
   258  		}
   259  		return selectors[1], nil
   260  	}
   261  
   262  	app.Use(New(Config{Extractor: extractor}))
   263  
   264  	app.Post("/", func(c *fiber.Ctx) error {
   265  		return c.SendStatus(fiber.StatusOK)
   266  	})
   267  
   268  	h := app.Handler()
   269  	ctx := &fasthttp.RequestCtx{}
   270  
   271  	// Invalid CSRF token
   272  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   273  	ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
   274  	h(ctx)
   275  	utils.AssertEqual(t, 403, ctx.Response.StatusCode())
   276  
   277  	// Generate CSRF token
   278  	ctx.Request.Reset()
   279  	ctx.Response.Reset()
   280  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   281  	h(ctx)
   282  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   283  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   284  
   285  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   286  	ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain)
   287  	ctx.Request.SetBodyString("_csrf=" + token)
   288  	h(ctx)
   289  	utils.AssertEqual(t, 200, ctx.Response.StatusCode())
   290  }
   291  
   292  func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
   293  	t.Parallel()
   294  	app := fiber.New()
   295  
   296  	errHandler := func(ctx *fiber.Ctx, err error) error {
   297  		utils.AssertEqual(t, errTokenNotFound, err)
   298  		return ctx.Status(419).Send([]byte("invalid CSRF token"))
   299  	}
   300  
   301  	app.Use(New(Config{ErrorHandler: errHandler}))
   302  
   303  	app.Post("/", func(c *fiber.Ctx) error {
   304  		return c.SendStatus(fiber.StatusOK)
   305  	})
   306  
   307  	h := app.Handler()
   308  	ctx := &fasthttp.RequestCtx{}
   309  
   310  	// Generate CSRF token
   311  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   312  	h(ctx)
   313  
   314  	// invalid CSRF token
   315  	ctx.Request.Reset()
   316  	ctx.Response.Reset()
   317  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   318  	ctx.Request.Header.Set(HeaderName, "johndoe")
   319  	h(ctx)
   320  	utils.AssertEqual(t, 419, ctx.Response.StatusCode())
   321  	utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body()))
   322  }
   323  
   324  func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
   325  	t.Parallel()
   326  	app := fiber.New()
   327  
   328  	errHandler := func(ctx *fiber.Ctx, err error) error {
   329  		utils.AssertEqual(t, errMissingHeader, err)
   330  		return ctx.Status(419).Send([]byte("empty CSRF token"))
   331  	}
   332  
   333  	app.Use(New(Config{ErrorHandler: errHandler}))
   334  
   335  	app.Post("/", func(c *fiber.Ctx) error {
   336  		return c.SendStatus(fiber.StatusOK)
   337  	})
   338  
   339  	h := app.Handler()
   340  	ctx := &fasthttp.RequestCtx{}
   341  
   342  	// Generate CSRF token
   343  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   344  	h(ctx)
   345  
   346  	// empty CSRF token
   347  	ctx.Request.Reset()
   348  	ctx.Response.Reset()
   349  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   350  	h(ctx)
   351  	utils.AssertEqual(t, 419, ctx.Response.StatusCode())
   352  	utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
   353  }
   354  
   355  // TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
   356  // func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
   357  //  t.Parallel()
   358  // 	app := fiber.New()
   359  
   360  // 	app.Use(New())
   361  // 	app.Get("/", func(c *fiber.Ctx) error {
   362  // 		return c.SendStatus(fiber.StatusOK)
   363  // 	})
   364  // 	app.Get("/test", func(c *fiber.Ctx) error {
   365  // 		return c.SendStatus(fiber.StatusOK)
   366  // 	})
   367  // 	app.Post("/", func(c *fiber.Ctx) error {
   368  // 		return c.SendStatus(fiber.StatusOK)
   369  // 	})
   370  
   371  // 	resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
   372  // 	utils.AssertEqual(t, nil, err)
   373  // 	utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
   374  
   375  // 	var token string
   376  // 	for _, c := range resp.Cookies() {
   377  // 		if c.Name != ConfigDefault.CookieName {
   378  // 			continue
   379  // 		}
   380  // 		token = c.Value
   381  // 		break
   382  // 	}
   383  
   384  // 	fmt.Println("token", token)
   385  
   386  // 	getReq := httptest.NewRequest(fiber.MethodGet, "/", nil)
   387  // 	getReq.Header.Set(HeaderName, token)
   388  // 	resp, err = app.Test(getReq)
   389  
   390  // 	getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil)
   391  // 	getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
   392  // 	getReq.Header.Set(fiber.HeaderCacheControl, "no")
   393  // 	getReq.Header.Set(HeaderName, token)
   394  
   395  // 	resp, err = app.Test(getReq)
   396  
   397  // 	getReq.Header.Set(fiber.HeaderAccept, "*/*")
   398  // 	getReq.Header.Del(HeaderName)
   399  // 	resp, err = app.Test(getReq)
   400  
   401  // 	postReq := httptest.NewRequest(fiber.MethodPost, "/", nil)
   402  // 	postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
   403  // 	postReq.Header.Set(HeaderName, token)
   404  // 	resp, err = app.Test(postReq)
   405  // }
   406  
   407  // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
   408  func Benchmark_Middleware_CSRF_Check(b *testing.B) {
   409  	app := fiber.New()
   410  
   411  	app.Use(New())
   412  	app.Get("/", func(c *fiber.Ctx) error {
   413  		return c.SendStatus(fiber.StatusTeapot)
   414  	})
   415  
   416  	fctx := &fasthttp.RequestCtx{}
   417  	h := app.Handler()
   418  	ctx := &fasthttp.RequestCtx{}
   419  
   420  	// Generate CSRF token
   421  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   422  	h(ctx)
   423  	token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
   424  	token = strings.Split(strings.Split(token, ";")[0], "=")[1]
   425  
   426  	ctx.Request.Header.SetMethod(fiber.MethodPost)
   427  	ctx.Request.Header.Set(HeaderName, token)
   428  
   429  	b.ReportAllocs()
   430  	b.ResetTimer()
   431  
   432  	for n := 0; n < b.N; n++ {
   433  		h(fctx)
   434  	}
   435  
   436  	utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
   437  }
   438  
   439  // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
   440  func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
   441  	app := fiber.New()
   442  
   443  	app.Use(New())
   444  	app.Get("/", func(c *fiber.Ctx) error {
   445  		return c.SendStatus(fiber.StatusTeapot)
   446  	})
   447  
   448  	fctx := &fasthttp.RequestCtx{}
   449  	h := app.Handler()
   450  	ctx := &fasthttp.RequestCtx{}
   451  
   452  	// Generate CSRF token
   453  	ctx.Request.Header.SetMethod(fiber.MethodGet)
   454  	b.ReportAllocs()
   455  	b.ResetTimer()
   456  
   457  	for n := 0; n < b.N; n++ {
   458  		h(fctx)
   459  	}
   460  
   461  	utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
   462  }