github.com/System-Glitch/goyave/v3@v3.6.1-0.20210226143142-ac2fe42ee80e/middleware_test.go (about) 1 package goyave 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "mime/multipart" 9 "net/http" 10 "net/http/httptest" 11 "os" 12 "path" 13 "path/filepath" 14 "runtime" 15 "strings" 16 "testing" 17 18 "github.com/System-Glitch/goyave/v3/config" 19 "github.com/System-Glitch/goyave/v3/cors" 20 "github.com/System-Glitch/goyave/v3/helper/filesystem" 21 "github.com/System-Glitch/goyave/v3/lang" 22 "github.com/System-Glitch/goyave/v3/validation" 23 ) 24 25 type MiddlewareTestSuite struct { 26 TestSuite 27 } 28 29 func (suite *MiddlewareTestSuite) SetupSuite() { 30 lang.LoadDefault() 31 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 32 } 33 34 func addFileToRequest(writer *multipart.Writer, path, name, fileName string) { 35 file, err := os.Open(path) 36 if err != nil { 37 panic(err) 38 } 39 defer file.Close() 40 part, err := writer.CreateFormFile(name, fileName) 41 if err != nil { 42 panic(err) 43 } 44 _, err = io.Copy(part, file) 45 if err != nil { 46 panic(err) 47 } 48 } 49 50 func createTestFileRequest(route string, files ...string) *http.Request { 51 _, filename, _, _ := runtime.Caller(1) 52 53 body := &bytes.Buffer{} 54 writer := multipart.NewWriter(body) 55 for _, p := range files { 56 fp := path.Dir(filename) + "/" + p 57 addFileToRequest(writer, fp, "file", filepath.Base(fp)) 58 } 59 field, err := writer.CreateFormField("field") 60 if err != nil { 61 panic(err) 62 } 63 _, err = io.Copy(field, strings.NewReader("world")) 64 if err != nil { 65 panic(err) 66 } 67 68 err = writer.Close() 69 if err != nil { 70 panic(err) 71 } 72 73 req, err := http.NewRequest("POST", route, body) 74 if err != nil { 75 panic(err) 76 } 77 req.Header.Set("Content-Type", writer.FormDataContentType()) 78 return req 79 } 80 81 func testMiddleware(middleware Middleware, rawRequest *http.Request, data map[string]interface{}, rules validation.RuleSet, corsOptions *cors.Options, handler func(*Response, *Request)) *http.Response { 82 request := &Request{ 83 httpRequest: rawRequest, 84 corsOptions: corsOptions, 85 Data: data, 86 Rules: rules.AsRules(), 87 Lang: "en-US", 88 Params: map[string]string{}, 89 } 90 response := newResponse(httptest.NewRecorder(), nil) 91 middleware(handler)(response, request) 92 93 return response.responseWriter.(*httptest.ResponseRecorder).Result() 94 } 95 96 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewarePanic() { 97 response := newResponse(httptest.NewRecorder(), nil) 98 err := fmt.Errorf("error message") 99 recoveryMiddleware(func(response *Response, r *Request) { 100 panic(err) 101 })(response, &Request{}) 102 suite.Equal(err, response.GetError()) 103 suite.Empty(response.GetStacktrace()) 104 suite.Equal(500, response.status) 105 106 prev := config.GetBool("app.debug") 107 config.Set("app.debug", true) 108 defer config.Set("app.debug", prev) 109 110 response = newResponse(httptest.NewRecorder(), nil) 111 err = fmt.Errorf("error message") 112 recoveryMiddleware(func(response *Response, r *Request) { 113 panic(err) 114 })(response, &Request{}) 115 suite.Equal(err, response.GetError()) 116 suite.NotEmpty(response.GetStacktrace()) 117 suite.Equal(500, response.status) 118 } 119 120 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNoPanic() { 121 response := newResponse(httptest.NewRecorder(), nil) 122 recoveryMiddleware(func(response *Response, r *Request) { 123 response.String(200, "message") 124 })(response, &Request{}) 125 126 resp := response.responseWriter.(*httptest.ResponseRecorder).Result() 127 suite.Nil(response.GetError()) 128 suite.Equal(200, response.status) 129 suite.Equal(200, resp.StatusCode) 130 131 body, err := ioutil.ReadAll(resp.Body) 132 resp.Body.Close() 133 suite.Nil(err) 134 suite.Equal("message", string(body)) 135 } 136 137 func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNilPanic() { 138 response := newResponse(httptest.NewRecorder(), nil) 139 recoveryMiddleware(func(response *Response, r *Request) { 140 panic(nil) 141 })(response, &Request{}) 142 suite.Nil(response.GetError()) 143 suite.Equal(500, response.status) 144 } 145 146 func (suite *MiddlewareTestSuite) TestLanguageMiddleware() { 147 defaultLanguage = config.GetString("app.defaultLanguage") 148 executed := false 149 rawRequest := httptest.NewRequest("GET", "/test-route", strings.NewReader("body")) 150 rawRequest.Header.Set("Accept-Language", "en-US") 151 res := testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 152 suite.Equal("en-US", r.Lang) 153 executed = true 154 }) 155 res.Body.Close() 156 suite.True(executed) 157 158 rawRequest = httptest.NewRequest("GET", "/test-route", strings.NewReader("body")) 159 res = testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 160 suite.Equal("en-US", r.Lang) 161 executed = true 162 }) 163 res.Body.Close() 164 165 suite.True(executed) 166 } 167 168 func (suite *MiddlewareTestSuite) TestParsePostRequestMiddleware() { 169 executed := false 170 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 171 rawRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") 172 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 173 suite.Equal("hello world", r.Data["string"]) 174 suite.Equal("42", r.Data["number"]) 175 executed = true 176 }) 177 suite.True(executed) 178 res.Body.Close() 179 } 180 181 func (suite *MiddlewareTestSuite) TestParseGetRequestMiddleware() { 182 executed := false 183 rawRequest := httptest.NewRequest("GET", "/test-route?string=hello%20world&number=42", nil) 184 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 185 suite.Equal("hello world", r.Data["string"]) 186 suite.Equal("42", r.Data["number"]) 187 executed = true 188 }) 189 suite.True(executed) 190 res.Body.Close() 191 192 executed = false 193 rawRequest = httptest.NewRequest("GET", "/test-route?%9", nil) 194 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 195 suite.Nil(r.Data) 196 executed = true 197 }) 198 suite.True(executed) 199 res.Body.Close() 200 } 201 202 func (suite *MiddlewareTestSuite) TestParseJsonRequestMiddleware() { 203 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 204 rawRequest.Header.Set("Content-Type", "application/json") 205 executed := false 206 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 207 suite.Equal("hello world", r.Data["string"]) 208 suite.Equal(42.0, r.Data["number"]) 209 slice, ok := r.Data["array"].([]interface{}) 210 suite.True(ok) 211 suite.Equal(2, len(slice)) 212 suite.Equal("val1", slice[0]) 213 suite.Equal("val2", slice[1]) 214 executed = true 215 }) 216 suite.True(executed) 217 res.Body.Close() 218 219 executed = false 220 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]")) // Missing closing braces 221 rawRequest.Header.Set("Content-Type", "application/json") 222 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 223 suite.Nil(r.Data) 224 executed = true 225 }) 226 suite.True(executed) 227 res.Body.Close() 228 229 // Test with query parameters 230 executed = false 231 rawRequest = httptest.NewRequest("POST", "/test-route?query=param", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 232 rawRequest.Header.Set("Content-Type", "application/json") 233 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 234 suite.NotNil(r.Data) 235 suite.Equal("param", r.Data["query"]) 236 executed = true 237 }) 238 suite.True(executed) 239 res.Body.Close() 240 241 // Test with charset (#101) 242 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) 243 rawRequest.Header.Set("Content-Type", "application/json; charset=utf-8") 244 executed = false 245 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 246 suite.Equal("hello world", r.Data["string"]) 247 suite.Equal(42.0, r.Data["number"]) 248 slice, ok := r.Data["array"].([]interface{}) 249 suite.True(ok) 250 suite.Equal(2, len(slice)) 251 suite.Equal("val1", slice[0]) 252 suite.Equal("val2", slice[1]) 253 executed = true 254 }) 255 res.Body.Close() 256 suite.True(executed) 257 258 } 259 260 func (suite *MiddlewareTestSuite) TestParseMultipartRequestMiddleware() { 261 executed := false 262 rawRequest := createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 263 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 264 suite.Equal(3, len(r.Data)) 265 suite.Equal("hello", r.Data["test"]) 266 suite.Equal("world", r.Data["field"]) 267 files, ok := r.Data["file"].([]filesystem.File) 268 suite.True(ok) 269 suite.Equal(1, len(files)) 270 executed = true 271 }) 272 suite.True(executed) 273 res.Body.Close() 274 275 // Test payload too large 276 prev := config.Get("server.maxUploadSize") 277 config.Set("server.maxUploadSize", -10.0) 278 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 279 rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 280 281 request := createTestRequest(rawRequest) 282 response := newResponse(httptest.NewRecorder(), nil) 283 parseRequestMiddleware(nil)(response, request) 284 suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus()) 285 config.Set("server.maxUploadSize", prev) 286 287 prev = config.Get("server.maxUploadSize") 288 config.Set("server.maxUploadSize", 0.0006) 289 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 290 rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png") 291 292 request = createTestRequest(rawRequest) 293 response = newResponse(httptest.NewRecorder(), nil) 294 parseRequestMiddleware(nil)(response, request) 295 suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus()) 296 config.Set("server.maxUploadSize", prev) 297 maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024) 298 } 299 300 func (suite *MiddlewareTestSuite) TestParseMultipartOverrideMiddleware() { 301 executed := false 302 rawRequest := createTestFileRequest("/test-route?field=hello", "resources/img/logo/goyave_16.png") 303 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 304 suite.Equal(2, len(r.Data)) 305 suite.Equal("world", r.Data["field"]) 306 files, ok := r.Data["file"].([]filesystem.File) 307 suite.True(ok) 308 suite.Equal(1, len(files)) 309 executed = true 310 }) 311 suite.True(executed) 312 res.Body.Close() 313 } 314 315 func (suite *MiddlewareTestSuite) TestParseMiddlewareWithArray() { 316 executed := false 317 rawRequest := httptest.NewRequest("GET", "/test-route?arr=hello&arr=world", nil) 318 res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 319 arr, ok := r.Data["arr"].([]string) 320 suite.True(ok) 321 if ok { 322 suite.Equal(2, len(arr)) 323 suite.Equal("hello", arr[0]) 324 suite.Equal("world", arr[1]) 325 } 326 executed = true 327 }) 328 suite.True(executed) 329 res.Body.Close() 330 331 body := &bytes.Buffer{} 332 writer := multipart.NewWriter(body) 333 field, err := writer.CreateFormField("field") 334 if err != nil { 335 panic(err) 336 } 337 _, err = io.Copy(field, strings.NewReader("hello")) 338 if err != nil { 339 panic(err) 340 } 341 342 field, err = writer.CreateFormField("field") 343 if err != nil { 344 panic(err) 345 } 346 _, err = io.Copy(field, strings.NewReader("world")) 347 if err != nil { 348 panic(err) 349 } 350 351 err = writer.Close() 352 if err != nil { 353 panic(err) 354 } 355 356 executed = false 357 rawRequest, err = http.NewRequest("POST", "/test-route", body) 358 if err != nil { 359 panic(err) 360 } 361 rawRequest.Header.Set("Content-Type", writer.FormDataContentType()) 362 res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) { 363 suite.Equal(1, len(r.Data)) 364 arr, ok := r.Data["field"].([]string) 365 suite.True(ok) 366 if ok { 367 suite.Equal(2, len(arr)) 368 suite.Equal("hello", arr[0]) 369 suite.Equal("world", arr[1]) 370 } 371 executed = true 372 }) 373 suite.True(executed) 374 res.Body.Close() 375 } 376 377 func (suite *MiddlewareTestSuite) TestValidateMiddleware() { 378 rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 379 rawRequest.Header.Set("Content-Type", "application/json") 380 data := map[string]interface{}{ 381 "string": "hello world", 382 "number": 42, 383 } 384 rules := validation.RuleSet{ 385 "string": {"required", "string"}, 386 "number": {"required", "numeric", "min:10"}, 387 } 388 request := suite.CreateTestRequest(rawRequest) 389 request.Data = data 390 request.Rules = rules.AsRules() 391 result := suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 392 suite.Equal(http.StatusNoContent, result.StatusCode) 393 result.Body.Close() 394 395 rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42")) 396 rawRequest.Header.Set("Content-Type", "application/json") 397 data = map[string]interface{}{ 398 "string": "hello world", 399 "number": 42, 400 } 401 rules = validation.RuleSet{ 402 "string": {"required", "string"}, 403 "number": {"required", "numeric", "min:50"}, 404 } 405 406 request = suite.CreateTestRequest(rawRequest) 407 request.Data = data 408 request.Rules = rules.AsRules() 409 result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 410 body, err := ioutil.ReadAll(result.Body) 411 if err != nil { 412 panic(err) 413 } 414 result.Body.Close() 415 suite.Equal(http.StatusUnprocessableEntity, result.StatusCode) 416 suite.Equal("{\"validationError\":{\"number\":[\"The number must be at least 50.\"]}}\n", string(body)) 417 418 rawRequest = httptest.NewRequest("POST", "/test-route", nil) 419 rawRequest.Header.Set("Content-Type", "application/json") 420 request = suite.CreateTestRequest(rawRequest) 421 request.Data = nil 422 request.Rules = rules.AsRules() 423 result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {}) 424 body, err = ioutil.ReadAll(result.Body) 425 if err != nil { 426 panic(err) 427 } 428 result.Body.Close() 429 suite.Equal(http.StatusBadRequest, result.StatusCode) 430 suite.Equal("{\"validationError\":{\"error\":[\"Malformed JSON\"]}}\n", string(body)) 431 } 432 433 func (suite *MiddlewareTestSuite) TestCORSMiddleware() { 434 // No CORS options 435 rawRequest := httptest.NewRequest("GET", "/test-route", nil) 436 result := testMiddleware(corsMiddleware, rawRequest, nil, nil, nil, func(response *Response, r *Request) {}) 437 suite.Equal(200, result.StatusCode) 438 result.Body.Close() 439 440 // Preflight 441 options := cors.Default() 442 rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil) 443 rawRequest.Header.Set("Origin", "https://google.com") 444 rawRequest.Header.Set("Access-Control-Request-Method", "GET") 445 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 446 response.String(200, "Hi!") 447 }) 448 body, err := ioutil.ReadAll(result.Body) 449 if err != nil { 450 panic(err) 451 } 452 result.Body.Close() 453 suite.Equal(204, result.StatusCode) 454 suite.Empty(body) 455 456 // Preflight passthrough 457 options = cors.Default() 458 options.OptionsPassthrough = true 459 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 460 response.String(200, "Passthrough") 461 }) 462 body, err = ioutil.ReadAll(result.Body) 463 if err != nil { 464 panic(err) 465 } 466 result.Body.Close() 467 suite.Equal(200, result.StatusCode) 468 suite.Equal("Passthrough", string(body)) 469 470 // Preflight without Access-Control-Request-Method 471 rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil) 472 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 473 response.String(200, "Hi!") 474 }) 475 body, err = ioutil.ReadAll(result.Body) 476 if err != nil { 477 panic(err) 478 } 479 result.Body.Close() 480 suite.Equal(200, result.StatusCode) 481 suite.Equal("Hi!", string(body)) 482 483 // Actual request 484 options = cors.Default() 485 options.AllowedOrigins = []string{"https://google.com", "https://images.google.com"} 486 rawRequest = httptest.NewRequest("GET", "/test-route", nil) 487 rawRequest.Header.Set("Origin", "https://images.google.com") 488 result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) { 489 response.String(200, "Hi!") 490 }) 491 body, err = ioutil.ReadAll(result.Body) 492 if err != nil { 493 panic(err) 494 } 495 result.Body.Close() 496 suite.Equal("Hi!", string(body)) 497 suite.Equal("https://images.google.com", result.Header.Get("Access-Control-Allow-Origin")) 498 suite.Equal("Origin", result.Header.Get("Vary")) 499 } 500 501 func TestMiddlewareTestSuite(t *testing.T) { 502 RunTest(t, new(MiddlewareTestSuite)) 503 }