github.com/gofiber/fiber/v2@v2.47.0/middleware/csrf/csrf_test.go (about) 1 package csrf 2 3 import ( 4 "net/http/httptest" 5 "strings" 6 "testing" 7 8 "github.com/gofiber/fiber/v2" 9 "github.com/gofiber/fiber/v2/utils" 10 11 "github.com/valyala/fasthttp" 12 ) 13 14 func Test_CSRF(t *testing.T) { 15 t.Parallel() 16 app := fiber.New() 17 18 app.Use(New()) 19 20 app.Post("/", func(c *fiber.Ctx) error { 21 return c.SendStatus(fiber.StatusOK) 22 }) 23 24 h := app.Handler() 25 ctx := &fasthttp.RequestCtx{} 26 27 methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} 28 29 for _, method := range methods { 30 // Generate CSRF token 31 ctx.Request.Header.SetMethod(method) 32 h(ctx) 33 34 // Without CSRF cookie 35 ctx.Request.Reset() 36 ctx.Response.Reset() 37 ctx.Request.Header.SetMethod(fiber.MethodPost) 38 h(ctx) 39 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 40 41 // Empty/invalid CSRF token 42 ctx.Request.Reset() 43 ctx.Response.Reset() 44 ctx.Request.Header.SetMethod(fiber.MethodPost) 45 ctx.Request.Header.Set(HeaderName, "johndoe") 46 h(ctx) 47 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 48 49 // Valid CSRF token 50 ctx.Request.Reset() 51 ctx.Response.Reset() 52 ctx.Request.Header.SetMethod(method) 53 h(ctx) 54 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 55 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 56 57 ctx.Request.Reset() 58 ctx.Response.Reset() 59 ctx.Request.Header.SetMethod(fiber.MethodPost) 60 ctx.Request.Header.Set(HeaderName, token) 61 h(ctx) 62 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 63 } 64 } 65 66 // go test -run Test_CSRF_Next 67 func Test_CSRF_Next(t *testing.T) { 68 t.Parallel() 69 app := fiber.New() 70 app.Use(New(Config{ 71 Next: func(_ *fiber.Ctx) bool { 72 return true 73 }, 74 })) 75 76 resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) 77 utils.AssertEqual(t, nil, err) 78 utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) 79 } 80 81 func Test_CSRF_Invalid_KeyLookup(t *testing.T) { 82 t.Parallel() 83 defer func() { 84 utils.AssertEqual(t, "[CSRF] KeyLookup must in the form of <source>:<key>", recover()) 85 }() 86 app := fiber.New() 87 88 app.Use(New(Config{KeyLookup: "I:am:invalid"})) 89 90 app.Post("/", func(c *fiber.Ctx) error { 91 return c.SendStatus(fiber.StatusOK) 92 }) 93 94 h := app.Handler() 95 ctx := &fasthttp.RequestCtx{} 96 ctx.Request.Header.SetMethod(fiber.MethodGet) 97 h(ctx) 98 } 99 100 func Test_CSRF_From_Form(t *testing.T) { 101 t.Parallel() 102 app := fiber.New() 103 104 app.Use(New(Config{KeyLookup: "form:_csrf"})) 105 106 app.Post("/", func(c *fiber.Ctx) error { 107 return c.SendStatus(fiber.StatusOK) 108 }) 109 110 h := app.Handler() 111 ctx := &fasthttp.RequestCtx{} 112 113 // Invalid CSRF token 114 ctx.Request.Header.SetMethod(fiber.MethodPost) 115 ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) 116 h(ctx) 117 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 118 119 // Generate CSRF token 120 ctx.Request.Reset() 121 ctx.Response.Reset() 122 ctx.Request.Header.SetMethod(fiber.MethodGet) 123 h(ctx) 124 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 125 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 126 127 ctx.Request.Header.SetMethod(fiber.MethodPost) 128 ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) 129 ctx.Request.SetBodyString("_csrf=" + token) 130 h(ctx) 131 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 132 } 133 134 func Test_CSRF_From_Query(t *testing.T) { 135 t.Parallel() 136 app := fiber.New() 137 138 app.Use(New(Config{KeyLookup: "query:_csrf"})) 139 140 app.Post("/", func(c *fiber.Ctx) error { 141 return c.SendStatus(fiber.StatusOK) 142 }) 143 144 h := app.Handler() 145 ctx := &fasthttp.RequestCtx{} 146 147 // Invalid CSRF token 148 ctx.Request.Header.SetMethod(fiber.MethodPost) 149 ctx.Request.SetRequestURI("/?_csrf=" + utils.UUID()) 150 h(ctx) 151 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 152 153 // Generate CSRF token 154 ctx.Request.Reset() 155 ctx.Response.Reset() 156 ctx.Request.Header.SetMethod(fiber.MethodGet) 157 ctx.Request.SetRequestURI("/") 158 h(ctx) 159 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 160 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 161 162 ctx.Request.Reset() 163 ctx.Response.Reset() 164 ctx.Request.SetRequestURI("/?_csrf=" + token) 165 ctx.Request.Header.SetMethod(fiber.MethodPost) 166 h(ctx) 167 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 168 utils.AssertEqual(t, "OK", string(ctx.Response.Body())) 169 } 170 171 func Test_CSRF_From_Param(t *testing.T) { 172 t.Parallel() 173 app := fiber.New() 174 175 csrfGroup := app.Group("/:csrf", New(Config{KeyLookup: "param:csrf"})) 176 177 csrfGroup.Post("/", func(c *fiber.Ctx) error { 178 return c.SendStatus(fiber.StatusOK) 179 }) 180 181 h := app.Handler() 182 ctx := &fasthttp.RequestCtx{} 183 184 // Invalid CSRF token 185 ctx.Request.Header.SetMethod(fiber.MethodPost) 186 ctx.Request.SetRequestURI("/" + utils.UUID()) 187 h(ctx) 188 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 189 190 // Generate CSRF token 191 ctx.Request.Reset() 192 ctx.Response.Reset() 193 ctx.Request.Header.SetMethod(fiber.MethodGet) 194 ctx.Request.SetRequestURI("/" + utils.UUID()) 195 h(ctx) 196 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 197 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 198 199 ctx.Request.Reset() 200 ctx.Response.Reset() 201 ctx.Request.SetRequestURI("/" + token) 202 ctx.Request.Header.SetMethod(fiber.MethodPost) 203 h(ctx) 204 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 205 utils.AssertEqual(t, "OK", string(ctx.Response.Body())) 206 } 207 208 func Test_CSRF_From_Cookie(t *testing.T) { 209 t.Parallel() 210 app := fiber.New() 211 212 csrfGroup := app.Group("/", New(Config{KeyLookup: "cookie:csrf"})) 213 214 csrfGroup.Post("/", func(c *fiber.Ctx) error { 215 return c.SendStatus(fiber.StatusOK) 216 }) 217 218 h := app.Handler() 219 ctx := &fasthttp.RequestCtx{} 220 221 // Invalid CSRF token 222 ctx.Request.Header.SetMethod(fiber.MethodPost) 223 ctx.Request.SetRequestURI("/") 224 ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+utils.UUID()+";") 225 h(ctx) 226 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 227 228 // Generate CSRF token 229 ctx.Request.Reset() 230 ctx.Response.Reset() 231 ctx.Request.Header.SetMethod(fiber.MethodGet) 232 ctx.Request.SetRequestURI("/") 233 h(ctx) 234 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 235 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 236 237 ctx.Request.Reset() 238 ctx.Response.Reset() 239 ctx.Request.Header.SetMethod(fiber.MethodPost) 240 ctx.Request.Header.Set(fiber.HeaderCookie, "csrf="+token+";") 241 ctx.Request.SetRequestURI("/") 242 h(ctx) 243 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 244 utils.AssertEqual(t, "OK", string(ctx.Response.Body())) 245 } 246 247 func Test_CSRF_From_Custom(t *testing.T) { 248 t.Parallel() 249 app := fiber.New() 250 251 extractor := func(c *fiber.Ctx) (string, error) { 252 body := string(c.Body()) 253 // Generate the correct extractor to get the token from the correct location 254 selectors := strings.Split(body, "=") 255 256 if len(selectors) != 2 || selectors[1] == "" { 257 return "", errMissingParam 258 } 259 return selectors[1], nil 260 } 261 262 app.Use(New(Config{Extractor: extractor})) 263 264 app.Post("/", func(c *fiber.Ctx) error { 265 return c.SendStatus(fiber.StatusOK) 266 }) 267 268 h := app.Handler() 269 ctx := &fasthttp.RequestCtx{} 270 271 // Invalid CSRF token 272 ctx.Request.Header.SetMethod(fiber.MethodPost) 273 ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) 274 h(ctx) 275 utils.AssertEqual(t, 403, ctx.Response.StatusCode()) 276 277 // Generate CSRF token 278 ctx.Request.Reset() 279 ctx.Response.Reset() 280 ctx.Request.Header.SetMethod(fiber.MethodGet) 281 h(ctx) 282 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 283 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 284 285 ctx.Request.Header.SetMethod(fiber.MethodPost) 286 ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) 287 ctx.Request.SetBodyString("_csrf=" + token) 288 h(ctx) 289 utils.AssertEqual(t, 200, ctx.Response.StatusCode()) 290 } 291 292 func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { 293 t.Parallel() 294 app := fiber.New() 295 296 errHandler := func(ctx *fiber.Ctx, err error) error { 297 utils.AssertEqual(t, errTokenNotFound, err) 298 return ctx.Status(419).Send([]byte("invalid CSRF token")) 299 } 300 301 app.Use(New(Config{ErrorHandler: errHandler})) 302 303 app.Post("/", func(c *fiber.Ctx) error { 304 return c.SendStatus(fiber.StatusOK) 305 }) 306 307 h := app.Handler() 308 ctx := &fasthttp.RequestCtx{} 309 310 // Generate CSRF token 311 ctx.Request.Header.SetMethod(fiber.MethodGet) 312 h(ctx) 313 314 // invalid CSRF token 315 ctx.Request.Reset() 316 ctx.Response.Reset() 317 ctx.Request.Header.SetMethod(fiber.MethodPost) 318 ctx.Request.Header.Set(HeaderName, "johndoe") 319 h(ctx) 320 utils.AssertEqual(t, 419, ctx.Response.StatusCode()) 321 utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body())) 322 } 323 324 func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) { 325 t.Parallel() 326 app := fiber.New() 327 328 errHandler := func(ctx *fiber.Ctx, err error) error { 329 utils.AssertEqual(t, errMissingHeader, err) 330 return ctx.Status(419).Send([]byte("empty CSRF token")) 331 } 332 333 app.Use(New(Config{ErrorHandler: errHandler})) 334 335 app.Post("/", func(c *fiber.Ctx) error { 336 return c.SendStatus(fiber.StatusOK) 337 }) 338 339 h := app.Handler() 340 ctx := &fasthttp.RequestCtx{} 341 342 // Generate CSRF token 343 ctx.Request.Header.SetMethod(fiber.MethodGet) 344 h(ctx) 345 346 // empty CSRF token 347 ctx.Request.Reset() 348 ctx.Response.Reset() 349 ctx.Request.Header.SetMethod(fiber.MethodPost) 350 h(ctx) 351 utils.AssertEqual(t, 419, ctx.Response.StatusCode()) 352 utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body())) 353 } 354 355 // TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase 356 // func Test_CSRF_UnsafeHeaderValue(t *testing.T) { 357 // t.Parallel() 358 // app := fiber.New() 359 360 // app.Use(New()) 361 // app.Get("/", func(c *fiber.Ctx) error { 362 // return c.SendStatus(fiber.StatusOK) 363 // }) 364 // app.Get("/test", func(c *fiber.Ctx) error { 365 // return c.SendStatus(fiber.StatusOK) 366 // }) 367 // app.Post("/", func(c *fiber.Ctx) error { 368 // return c.SendStatus(fiber.StatusOK) 369 // }) 370 371 // resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) 372 // utils.AssertEqual(t, nil, err) 373 // utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) 374 375 // var token string 376 // for _, c := range resp.Cookies() { 377 // if c.Name != ConfigDefault.CookieName { 378 // continue 379 // } 380 // token = c.Value 381 // break 382 // } 383 384 // fmt.Println("token", token) 385 386 // getReq := httptest.NewRequest(fiber.MethodGet, "/", nil) 387 // getReq.Header.Set(HeaderName, token) 388 // resp, err = app.Test(getReq) 389 390 // getReq = httptest.NewRequest(fiber.MethodGet, "/test", nil) 391 // getReq.Header.Set("X-Requested-With", "XMLHttpRequest") 392 // getReq.Header.Set(fiber.HeaderCacheControl, "no") 393 // getReq.Header.Set(HeaderName, token) 394 395 // resp, err = app.Test(getReq) 396 397 // getReq.Header.Set(fiber.HeaderAccept, "*/*") 398 // getReq.Header.Del(HeaderName) 399 // resp, err = app.Test(getReq) 400 401 // postReq := httptest.NewRequest(fiber.MethodPost, "/", nil) 402 // postReq.Header.Set("X-Requested-With", "XMLHttpRequest") 403 // postReq.Header.Set(HeaderName, token) 404 // resp, err = app.Test(postReq) 405 // } 406 407 // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4 408 func Benchmark_Middleware_CSRF_Check(b *testing.B) { 409 app := fiber.New() 410 411 app.Use(New()) 412 app.Get("/", func(c *fiber.Ctx) error { 413 return c.SendStatus(fiber.StatusTeapot) 414 }) 415 416 fctx := &fasthttp.RequestCtx{} 417 h := app.Handler() 418 ctx := &fasthttp.RequestCtx{} 419 420 // Generate CSRF token 421 ctx.Request.Header.SetMethod(fiber.MethodGet) 422 h(ctx) 423 token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) 424 token = strings.Split(strings.Split(token, ";")[0], "=")[1] 425 426 ctx.Request.Header.SetMethod(fiber.MethodPost) 427 ctx.Request.Header.Set(HeaderName, token) 428 429 b.ReportAllocs() 430 b.ResetTimer() 431 432 for n := 0; n < b.N; n++ { 433 h(fctx) 434 } 435 436 utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) 437 } 438 439 // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 440 func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { 441 app := fiber.New() 442 443 app.Use(New()) 444 app.Get("/", func(c *fiber.Ctx) error { 445 return c.SendStatus(fiber.StatusTeapot) 446 }) 447 448 fctx := &fasthttp.RequestCtx{} 449 h := app.Handler() 450 ctx := &fasthttp.RequestCtx{} 451 452 // Generate CSRF token 453 ctx.Request.Header.SetMethod(fiber.MethodGet) 454 b.ReportAllocs() 455 b.ResetTimer() 456 457 for n := 0; n < b.N; n++ { 458 h(fctx) 459 } 460 461 utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) 462 }