goyave.dev/goyave/v4@v4.4.11/middleware_test.go (about) 1 package goyave 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "mime/multipart" 8 "net/http" 9 "net/http/httptest" 10 "os" 11 "path" 12 "path/filepath" 13 "runtime" 14 "strings" 15 "testing" 16 17 "goyave.dev/goyave/v4/config" 18 "goyave.dev/goyave/v4/cors" 19 "goyave.dev/goyave/v4/lang" 20 "goyave.dev/goyave/v4/util/fsutil" 21 "goyave.dev/goyave/v4/validation" 22 ) 23 24 type MiddlewareTestSuite struct { 25 TestSuite 26 } 27 28 func (suite *MiddlewareTestSuite) SetupSuite() { 29 lang.LoadDefault() 30 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 31 } 32 33 func addFileToRequest(writer *multipart.Writer, path, name, fileName string) { 34 file, err := os.Open(path) 35 if err != nil { 36 panic(err) 37 } 38 defer file.Close() 39 part, err := writer.CreateFormFile(name, fileName) 40 if err != nil { 41 panic(err) 42 } 43 _, err = io.Copy(part, file) 44 if err != nil { 45 panic(err) 46 } 47 } 48 49 func createTestFileRequest(route string, files ...string) *http.Request { 50 _, filename, _, _ := runtime.Caller(1) 51 52 body := &bytes.Buffer{} 53 writer := multipart.NewWriter(body) 54 for _, p := range files { 55 fp := path.Dir(filename) + "/" + p 56 addFileToRequest(writer, fp, "file", filepath.Base(fp)) 57 } 58 field, err := writer.CreateFormField("field") 59 if err != nil { 60 panic(err) 61 } 62 _, err = io.Copy(field, strings.NewReader("world")) 63 if err != nil { 64 panic(err) 65 } 66 67 err = writer.Close() 68 if err != nil { 69 panic(err) 70 } 71 72 req, err := http.NewRequest("POST", route, body) 73 if err != nil { 74 panic(err) 75 } 76 req.Header.Set("Content-Type", writer.FormDataContentType()) 77 return req 78 } 79 80 func testMiddleware(middleware Middleware, rawRequest *http.Request, data map[string]interface{}, rules validation.RuleSet, corsOptions *cors.Options, handler func(*Response, *Request)) *http.Response { 81 request := &Request{ 82 httpRequest: rawRequest, 83 corsOptions: corsOptions, 84 Data: data, 85 Rules: rules.AsRules(), 86 Lang: "en-US", 87 Params: map[string]string{}, 88 } 89 response := newResponse(httptest.NewRecorder(), nil) 90 middleware(handler)(response, request) 91 92 return response.responseWriter.(*httptest.ResponseRecorder).Result() 93 } 94 95 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewarePanic() { 96 response := newResponse(httptest.NewRecorder(), nil) 97 err := fmt.Errorf("error message") 98 recoveryMiddleware(func(response *Response, r *Request) { 99 panic(err) 100 })(response, &Request{}) 101 suite.Equal(err, response.GetError()) 102 suite.NotEmpty(response.GetStacktrace()) 103 suite.Equal(500, response.status) 104 } 105 106 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNoPanic() { 107 response := newResponse(httptest.NewRecorder(), nil) 108 recoveryMiddleware(func(response *Response, r *Request) { 109 response.String(200, "message") 110 })(response, &Request{}) 111 112 resp := response.responseWriter.(*httptest.ResponseRecorder).Result() 113 suite.Nil(response.GetError()) 114 suite.Empty(response.GetStacktrace()) 115 suite.Equal(200, response.status) 116 suite.Equal(200, resp.StatusCode) 117 118 body, err := io.ReadAll(resp.Body) 119 resp.Body.Close() 120 suite.Nil(err) 121 suite.Equal("message", string(body)) 122 } 123 124 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNilPanic() { 125 response := newResponse(httptest.NewRecorder(), nil) 126 recoveryMiddleware(func(response *Response, r *Request) { 127 panic(nil) 128 })(response, &Request{}) 129 suite.Nil(response.GetError()) 130 suite.NotEmpty(response.GetStacktrace()) 131 suite.Equal(500, response.status) 132 } 133 134 func (suite *MiddlewareTestSuite) TestLanguageMiddleware() { 135 defaultLanguage = config.GetString("app.defaultLanguage") 136 executed := false 137 rawRequest := httptest.NewRequest("GET", "/test-route", strings.NewReader("body")) 138 rawRequest.Header.Set("Accept-Language", "en-US") 139 res := testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 140 suite.Equal("en-US", r.Lang) 141 executed = true 142 }) 143 res.Body.Close() 144 suite.True(executed) 145 146 rawRequest = httptest.NewRequest("GET", "/test-route", strings.NewReader("body")) 147 res = testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 148 suite.Equal("en-US", r.Lang) 149 executed = true 150 }) 151 res.Body.Close() 152 153 suite.True(executed) 154 } 155 156 func (suite *MiddlewareTestSuite) TestParsePostRequestMiddleware() { 157 executed := false 158 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 159 rawRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") 160 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 161 suite.Equal("hello world", r.Data["string"]) 162 suite.Equal("42", r.Data["number"]) 163 executed = true 164 }) 165 suite.True(executed) 166 res.Body.Close() 167 } 168 169 func (suite *MiddlewareTestSuite) TestParseGetRequestMiddleware() { 170 executed := false 171 rawRequest := httptest.NewRequest("GET", "/test-route?string=hello%20world&number=42", nil) 172 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 173 suite.Equal("hello world", r.Data["string"]) 174 suite.Equal("42", r.Data["number"]) 175 executed = true 176 }) 177 suite.True(executed) 178 res.Body.Close() 179 180 executed = false 181 rawRequest = httptest.NewRequest("GET", "/test-route?%9", nil) 182 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 183 suite.Nil(r.Data) 184 executed = true 185 }) 186 suite.True(executed) 187 res.Body.Close() 188 } 189 190 func (suite *MiddlewareTestSuite) TestParseJsonRequestMiddleware() { 191 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 192 rawRequest.Header.Set("Content-Type", "application/json") 193 executed := false 194 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 195 suite.Equal("hello world", r.Data["string"]) 196 suite.Equal(42.0, r.Data["number"]) 197 slice, ok := r.Data["array"].([]interface{}) 198 suite.True(ok) 199 suite.Equal(2, len(slice)) 200 suite.Equal("val1", slice[0]) 201 suite.Equal("val2", slice[1]) 202 executed = true 203 }) 204 suite.True(executed) 205 res.Body.Close() 206 207 executed = false 208 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]")) // Missing closing braces 209 rawRequest.Header.Set("Content-Type", "application/json") 210 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 211 suite.Nil(r.Data) 212 executed = true 213 }) 214 suite.True(executed) 215 res.Body.Close() 216 217 // Test with query parameters 218 executed = false 219 rawRequest = httptest.NewRequest("POST", "/test-route?query=param", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 220 rawRequest.Header.Set("Content-Type", "application/json") 221 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 222 suite.NotNil(r.Data) 223 suite.Equal("param", r.Data["query"]) 224 executed = true 225 }) 226 suite.True(executed) 227 res.Body.Close() 228 229 executed = false 230 rawRequest = httptest.NewRequest("POST", "/test-route?%9", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) // Invalid query param 231 rawRequest.Header.Set("Content-Type", "application/json") 232 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 233 suite.Nil(r.Data) 234 executed = true 235 }) 236 suite.True(executed) 237 res.Body.Close() 238 239 // Test with charset (#101) 240 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 241 rawRequest.Header.Set("Content-Type", "application/json; charset=utf-8") 242 executed = false 243 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 244 suite.Equal("hello world", r.Data["string"]) 245 suite.Equal(42.0, r.Data["number"]) 246 slice, ok := r.Data["array"].([]interface{}) 247 suite.True(ok) 248 suite.Equal(2, len(slice)) 249 suite.Equal("val1", slice[0]) 250 suite.Equal("val2", slice[1]) 251 executed = true 252 }) 253 res.Body.Close() 254 suite.True(executed) 255 256 } 257 258 func (suite *MiddlewareTestSuite) TestParseMultipartRequestMiddleware() { 259 executed := false 260 rawRequest := createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 261 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 262 suite.Equal(3, len(r.Data)) 263 suite.Equal("hello", r.Data["test"]) 264 suite.Equal("world", r.Data["field"]) 265 files, ok := r.Data["file"].([]fsutil.File) 266 suite.True(ok) 267 suite.Equal(1, len(files)) 268 executed = true 269 }) 270 suite.True(executed) 271 res.Body.Close() 272 273 // Test payload too large 274 prev := config.Get("server.maxUploadSize") 275 config.Set("server.maxUploadSize", -10.0) 276 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 277 rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 278 279 request := createTestRequest(rawRequest) 280 response := newResponse(httptest.NewRecorder(), nil) 281 parseRequestMiddleware(nil)(response, request) 282 suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus()) 283 config.Set("server.maxUploadSize", prev) 284 285 prev = config.Get("server.maxUploadSize") 286 config.Set("server.maxUploadSize", 0.0006) 287 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 288 rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 289 290 request = createTestRequest(rawRequest) 291 response = newResponse(httptest.NewRecorder(), nil) 292 parseRequestMiddleware(nil)(response, request) 293 suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus()) 294 config.Set("server.maxUploadSize", prev) 295 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 296 } 297 298 func (suite *MiddlewareTestSuite) TestParseMultipartOverrideMiddleware() { 299 executed := false 300 rawRequest := createTestFileRequest("/test-route?field=hello", "resources/img/logo/goyave_16.png") 301 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 302 suite.Equal(2, len(r.Data)) 303 suite.Equal("world", r.Data["field"]) 304 files, ok := r.Data["file"].([]fsutil.File) 305 suite.True(ok) 306 suite.Equal(1, len(files)) 307 executed = true 308 }) 309 suite.True(executed) 310 res.Body.Close() 311 } 312 313 func (suite *MiddlewareTestSuite) TestParseMiddlewareWithArray() { 314 executed := false 315 rawRequest := httptest.NewRequest("GET", "/test-route?arr=hello&arr=world", nil) 316 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 317 arr, ok := r.Data["arr"].([]string) 318 suite.True(ok) 319 if ok { 320 suite.Equal(2, len(arr)) 321 suite.Equal("hello", arr[0]) 322 suite.Equal("world", arr[1]) 323 } 324 executed = true 325 }) 326 suite.True(executed) 327 res.Body.Close() 328 329 body := &bytes.Buffer{} 330 writer := multipart.NewWriter(body) 331 field, err := writer.CreateFormField("field") 332 if err != nil { 333 panic(err) 334 } 335 _, err = io.Copy(field, strings.NewReader("hello")) 336 if err != nil { 337 panic(err) 338 } 339 340 field, err = writer.CreateFormField("field") 341 if err != nil { 342 panic(err) 343 } 344 _, err = io.Copy(field, strings.NewReader("world")) 345 if err != nil { 346 panic(err) 347 } 348 349 err = writer.Close() 350 if err != nil { 351 panic(err) 352 } 353 354 executed = false 355 rawRequest, err = http.NewRequest("POST", "/test-route", body) 356 if err != nil { 357 panic(err) 358 } 359 rawRequest.Header.Set("Content-Type", writer.FormDataContentType()) 360 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 361 suite.Equal(1, len(r.Data)) 362 arr, ok := r.Data["field"].([]string) 363 suite.True(ok) 364 if ok { 365 suite.Equal(2, len(arr)) 366 suite.Equal("hello", arr[0]) 367 suite.Equal("world", arr[1]) 368 } 369 executed = true 370 }) 371 suite.True(executed) 372 res.Body.Close() 373 } 374 375 func (suite *MiddlewareTestSuite) TestValidateMiddleware() { 376 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 377 rawRequest.Header.Set("Content-Type", "application/json") 378 data := map[string]interface{}{ 379 "string": "hello world", 380 "number": 42, 381 } 382 rules := validation.RuleSet{ 383 "string": validation.List{"required", "string"}, 384 "number": validation.List{"required", "numeric", "min:10"}, 385 } 386 request := suite.CreateTestRequest(rawRequest) 387 request.Data = data 388 request.Rules = rules.AsRules() 389 result := suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 390 suite.Equal(http.StatusNoContent, result.StatusCode) 391 result.Body.Close() 392 393 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 394 rawRequest.Header.Set("Content-Type", "application/json") 395 data = map[string]interface{}{ 396 "string": "hello world", 397 "number": 42, 398 } 399 rules = validation.RuleSet{ 400 "string": validation.List{"required", "string"}, 401 "number": validation.List{"required", "numeric", "min:50"}, 402 } 403 404 request = suite.CreateTestRequest(rawRequest) 405 request.Data = data 406 request.Rules = rules.AsRules() 407 result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 408 body, err := io.ReadAll(result.Body) 409 if err != nil { 410 panic(err) 411 } 412 result.Body.Close() 413 suite.Equal(http.StatusUnprocessableEntity, result.StatusCode) 414 suite.Equal("{\"validationError\":{\"number\":{\"errors\":[\"The number must be at least 50.\"]}}}\n", string(body)) 415 416 rawRequest = httptest.NewRequest("POST", "/test-route", nil) 417 rawRequest.Header.Set("Content-Type", "application/json") 418 request = suite.CreateTestRequest(rawRequest) 419 request.Data = nil 420 request.Rules = rules.AsRules() 421 result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 422 body, err = io.ReadAll(result.Body) 423 if err != nil { 424 panic(err) 425 } 426 result.Body.Close() 427 suite.Equal(http.StatusBadRequest, result.StatusCode) 428 suite.Equal("{\"validationError\":{\"[data]\":{\"errors\":[\"Malformed JSON\"]}}}\n", string(body)) 429 } 430 431 func (suite *MiddlewareTestSuite) TestCORSMiddleware() { 432 // No CORS options 433 rawRequest := httptest.NewRequest("GET", "/test-route", nil) 434 result := testMiddleware(corsMiddleware, rawRequest, nil, nil, nil, func(response *Response, r *Request) {}) 435 suite.Equal(200, result.StatusCode) 436 result.Body.Close() 437 438 // Preflight 439 options := cors.Default() 440 rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil) 441 rawRequest.Header.Set("Origin", "https://google.com") 442 rawRequest.Header.Set("Access-Control-Request-Method", "GET") 443 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 444 response.String(200, "Hi!") 445 }) 446 body, err := io.ReadAll(result.Body) 447 if err != nil { 448 panic(err) 449 } 450 result.Body.Close() 451 suite.Equal(204, result.StatusCode) 452 suite.Empty(body) 453 454 // Preflight passthrough 455 options = cors.Default() 456 options.OptionsPassthrough = true 457 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 458 response.String(200, "Passthrough") 459 }) 460 body, err = io.ReadAll(result.Body) 461 if err != nil { 462 panic(err) 463 } 464 result.Body.Close() 465 suite.Equal(200, result.StatusCode) 466 suite.Equal("Passthrough", string(body)) 467 468 // Preflight without Access-Control-Request-Method 469 rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil) 470 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 471 response.String(200, "Hi!") 472 }) 473 body, err = io.ReadAll(result.Body) 474 if err != nil { 475 panic(err) 476 } 477 result.Body.Close() 478 suite.Equal(200, result.StatusCode) 479 suite.Equal("Hi!", string(body)) 480 481 // Actual request 482 options = cors.Default() 483 options.AllowedOrigins = []string{"https://google.com", "https://images.google.com"} 484 rawRequest = httptest.NewRequest("GET", "/test-route", nil) 485 rawRequest.Header.Set("Origin", "https://images.google.com") 486 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 487 response.String(200, "Hi!") 488 }) 489 body, err = io.ReadAll(result.Body) 490 if err != nil { 491 panic(err) 492 } 493 result.Body.Close() 494 suite.Equal("Hi!", string(body)) 495 suite.Equal("https://images.google.com", result.Header.Get("Access-Control-Allow-Origin")) 496 suite.Equal("Origin", result.Header.Get("Vary")) 497 } 498 499 func TestMiddlewareTestSuite(t *testing.T) { 500 RunTest(t, new(MiddlewareTestSuite)) 501 }