goyave.dev/goyave/v4@v4.4.11/middleware_test.go (about)

     1  package goyave
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"mime/multipart"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"path"
    12  	"path/filepath"
    13  	"runtime"
    14  	"strings"
    15  	"testing"
    16  
    17  	"goyave.dev/goyave/v4/config"
    18  	"goyave.dev/goyave/v4/cors"
    19  	"goyave.dev/goyave/v4/lang"
    20  	"goyave.dev/goyave/v4/util/fsutil"
    21  	"goyave.dev/goyave/v4/validation"
    22  )
    23  
    24  type MiddlewareTestSuite struct {
    25  	TestSuite
    26  }
    27  
    28  func (suite *MiddlewareTestSuite) SetupSuite() {
    29  	lang.LoadDefault()
    30  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
    31  }
    32  
    33  func addFileToRequest(writer *multipart.Writer, path, name, fileName string) {
    34  	file, err := os.Open(path)
    35  	if err != nil {
    36  		panic(err)
    37  	}
    38  	defer file.Close()
    39  	part, err := writer.CreateFormFile(name, fileName)
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  	_, err = io.Copy(part, file)
    44  	if err != nil {
    45  		panic(err)
    46  	}
    47  }
    48  
    49  func createTestFileRequest(route string, files ...string) *http.Request {
    50  	_, filename, _, _ := runtime.Caller(1)
    51  
    52  	body := &bytes.Buffer{}
    53  	writer := multipart.NewWriter(body)
    54  	for _, p := range files {
    55  		fp := path.Dir(filename) + "/" + p
    56  		addFileToRequest(writer, fp, "file", filepath.Base(fp))
    57  	}
    58  	field, err := writer.CreateFormField("field")
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	_, err = io.Copy(field, strings.NewReader("world"))
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  
    67  	err = writer.Close()
    68  	if err != nil {
    69  		panic(err)
    70  	}
    71  
    72  	req, err := http.NewRequest("POST", route, body)
    73  	if err != nil {
    74  		panic(err)
    75  	}
    76  	req.Header.Set("Content-Type", writer.FormDataContentType())
    77  	return req
    78  }
    79  
    80  func testMiddleware(middleware Middleware, rawRequest *http.Request, data map[string]interface{}, rules validation.RuleSet, corsOptions *cors.Options, handler func(*Response, *Request)) *http.Response {
    81  	request := &Request{
    82  		httpRequest: rawRequest,
    83  		corsOptions: corsOptions,
    84  		Data:        data,
    85  		Rules:       rules.AsRules(),
    86  		Lang:        "en-US",
    87  		Params:      map[string]string{},
    88  	}
    89  	response := newResponse(httptest.NewRecorder(), nil)
    90  	middleware(handler)(response, request)
    91  
    92  	return response.responseWriter.(*httptest.ResponseRecorder).Result()
    93  }
    94  
    95  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewarePanic() {
    96  	response := newResponse(httptest.NewRecorder(), nil)
    97  	err := fmt.Errorf("error message")
    98  	recoveryMiddleware(func(response *Response, r *Request) {
    99  		panic(err)
   100  	})(response, &Request{})
   101  	suite.Equal(err, response.GetError())
   102  	suite.NotEmpty(response.GetStacktrace())
   103  	suite.Equal(500, response.status)
   104  }
   105  
   106  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNoPanic() {
   107  	response := newResponse(httptest.NewRecorder(), nil)
   108  	recoveryMiddleware(func(response *Response, r *Request) {
   109  		response.String(200, "message")
   110  	})(response, &Request{})
   111  
   112  	resp := response.responseWriter.(*httptest.ResponseRecorder).Result()
   113  	suite.Nil(response.GetError())
   114  	suite.Empty(response.GetStacktrace())
   115  	suite.Equal(200, response.status)
   116  	suite.Equal(200, resp.StatusCode)
   117  
   118  	body, err := io.ReadAll(resp.Body)
   119  	resp.Body.Close()
   120  	suite.Nil(err)
   121  	suite.Equal("message", string(body))
   122  }
   123  
   124  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNilPanic() {
   125  	response := newResponse(httptest.NewRecorder(), nil)
   126  	recoveryMiddleware(func(response *Response, r *Request) {
   127  		panic(nil)
   128  	})(response, &Request{})
   129  	suite.Nil(response.GetError())
   130  	suite.NotEmpty(response.GetStacktrace())
   131  	suite.Equal(500, response.status)
   132  }
   133  
   134  func (suite *MiddlewareTestSuite) TestLanguageMiddleware() {
   135  	defaultLanguage = config.GetString("app.defaultLanguage")
   136  	executed := false
   137  	rawRequest := httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   138  	rawRequest.Header.Set("Accept-Language", "en-US")
   139  	res := testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   140  		suite.Equal("en-US", r.Lang)
   141  		executed = true
   142  	})
   143  	res.Body.Close()
   144  	suite.True(executed)
   145  
   146  	rawRequest = httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   147  	res = testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   148  		suite.Equal("en-US", r.Lang)
   149  		executed = true
   150  	})
   151  	res.Body.Close()
   152  
   153  	suite.True(executed)
   154  }
   155  
   156  func (suite *MiddlewareTestSuite) TestParsePostRequestMiddleware() {
   157  	executed := false
   158  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   159  	rawRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
   160  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   161  		suite.Equal("hello world", r.Data["string"])
   162  		suite.Equal("42", r.Data["number"])
   163  		executed = true
   164  	})
   165  	suite.True(executed)
   166  	res.Body.Close()
   167  }
   168  
   169  func (suite *MiddlewareTestSuite) TestParseGetRequestMiddleware() {
   170  	executed := false
   171  	rawRequest := httptest.NewRequest("GET", "/test-route?string=hello%20world&number=42", nil)
   172  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   173  		suite.Equal("hello world", r.Data["string"])
   174  		suite.Equal("42", r.Data["number"])
   175  		executed = true
   176  	})
   177  	suite.True(executed)
   178  	res.Body.Close()
   179  
   180  	executed = false
   181  	rawRequest = httptest.NewRequest("GET", "/test-route?%9", nil)
   182  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   183  		suite.Nil(r.Data)
   184  		executed = true
   185  	})
   186  	suite.True(executed)
   187  	res.Body.Close()
   188  }
   189  
   190  func (suite *MiddlewareTestSuite) TestParseJsonRequestMiddleware() {
   191  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   192  	rawRequest.Header.Set("Content-Type", "application/json")
   193  	executed := false
   194  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   195  		suite.Equal("hello world", r.Data["string"])
   196  		suite.Equal(42.0, r.Data["number"])
   197  		slice, ok := r.Data["array"].([]interface{})
   198  		suite.True(ok)
   199  		suite.Equal(2, len(slice))
   200  		suite.Equal("val1", slice[0])
   201  		suite.Equal("val2", slice[1])
   202  		executed = true
   203  	})
   204  	suite.True(executed)
   205  	res.Body.Close()
   206  
   207  	executed = false
   208  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]")) // Missing closing braces
   209  	rawRequest.Header.Set("Content-Type", "application/json")
   210  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   211  		suite.Nil(r.Data)
   212  		executed = true
   213  	})
   214  	suite.True(executed)
   215  	res.Body.Close()
   216  
   217  	// Test with query parameters
   218  	executed = false
   219  	rawRequest = httptest.NewRequest("POST", "/test-route?query=param", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   220  	rawRequest.Header.Set("Content-Type", "application/json")
   221  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   222  		suite.NotNil(r.Data)
   223  		suite.Equal("param", r.Data["query"])
   224  		executed = true
   225  	})
   226  	suite.True(executed)
   227  	res.Body.Close()
   228  
   229  	executed = false
   230  	rawRequest = httptest.NewRequest("POST", "/test-route?%9", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}")) // Invalid query param
   231  	rawRequest.Header.Set("Content-Type", "application/json")
   232  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   233  		suite.Nil(r.Data)
   234  		executed = true
   235  	})
   236  	suite.True(executed)
   237  	res.Body.Close()
   238  
   239  	// Test with charset (#101)
   240  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   241  	rawRequest.Header.Set("Content-Type", "application/json; charset=utf-8")
   242  	executed = false
   243  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   244  		suite.Equal("hello world", r.Data["string"])
   245  		suite.Equal(42.0, r.Data["number"])
   246  		slice, ok := r.Data["array"].([]interface{})
   247  		suite.True(ok)
   248  		suite.Equal(2, len(slice))
   249  		suite.Equal("val1", slice[0])
   250  		suite.Equal("val2", slice[1])
   251  		executed = true
   252  	})
   253  	res.Body.Close()
   254  	suite.True(executed)
   255  
   256  }
   257  
   258  func (suite *MiddlewareTestSuite) TestParseMultipartRequestMiddleware() {
   259  	executed := false
   260  	rawRequest := createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   261  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   262  		suite.Equal(3, len(r.Data))
   263  		suite.Equal("hello", r.Data["test"])
   264  		suite.Equal("world", r.Data["field"])
   265  		files, ok := r.Data["file"].([]fsutil.File)
   266  		suite.True(ok)
   267  		suite.Equal(1, len(files))
   268  		executed = true
   269  	})
   270  	suite.True(executed)
   271  	res.Body.Close()
   272  
   273  	// Test payload too large
   274  	prev := config.Get("server.maxUploadSize")
   275  	config.Set("server.maxUploadSize", -10.0)
   276  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   277  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   278  
   279  	request := createTestRequest(rawRequest)
   280  	response := newResponse(httptest.NewRecorder(), nil)
   281  	parseRequestMiddleware(nil)(response, request)
   282  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   283  	config.Set("server.maxUploadSize", prev)
   284  
   285  	prev = config.Get("server.maxUploadSize")
   286  	config.Set("server.maxUploadSize", 0.0006)
   287  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   288  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   289  
   290  	request = createTestRequest(rawRequest)
   291  	response = newResponse(httptest.NewRecorder(), nil)
   292  	parseRequestMiddleware(nil)(response, request)
   293  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   294  	config.Set("server.maxUploadSize", prev)
   295  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   296  }
   297  
   298  func (suite *MiddlewareTestSuite) TestParseMultipartOverrideMiddleware() {
   299  	executed := false
   300  	rawRequest := createTestFileRequest("/test-route?field=hello", "resources/img/logo/goyave_16.png")
   301  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   302  		suite.Equal(2, len(r.Data))
   303  		suite.Equal("world", r.Data["field"])
   304  		files, ok := r.Data["file"].([]fsutil.File)
   305  		suite.True(ok)
   306  		suite.Equal(1, len(files))
   307  		executed = true
   308  	})
   309  	suite.True(executed)
   310  	res.Body.Close()
   311  }
   312  
   313  func (suite *MiddlewareTestSuite) TestParseMiddlewareWithArray() {
   314  	executed := false
   315  	rawRequest := httptest.NewRequest("GET", "/test-route?arr=hello&arr=world", nil)
   316  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   317  		arr, ok := r.Data["arr"].([]string)
   318  		suite.True(ok)
   319  		if ok {
   320  			suite.Equal(2, len(arr))
   321  			suite.Equal("hello", arr[0])
   322  			suite.Equal("world", arr[1])
   323  		}
   324  		executed = true
   325  	})
   326  	suite.True(executed)
   327  	res.Body.Close()
   328  
   329  	body := &bytes.Buffer{}
   330  	writer := multipart.NewWriter(body)
   331  	field, err := writer.CreateFormField("field")
   332  	if err != nil {
   333  		panic(err)
   334  	}
   335  	_, err = io.Copy(field, strings.NewReader("hello"))
   336  	if err != nil {
   337  		panic(err)
   338  	}
   339  
   340  	field, err = writer.CreateFormField("field")
   341  	if err != nil {
   342  		panic(err)
   343  	}
   344  	_, err = io.Copy(field, strings.NewReader("world"))
   345  	if err != nil {
   346  		panic(err)
   347  	}
   348  
   349  	err = writer.Close()
   350  	if err != nil {
   351  		panic(err)
   352  	}
   353  
   354  	executed = false
   355  	rawRequest, err = http.NewRequest("POST", "/test-route", body)
   356  	if err != nil {
   357  		panic(err)
   358  	}
   359  	rawRequest.Header.Set("Content-Type", writer.FormDataContentType())
   360  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   361  		suite.Equal(1, len(r.Data))
   362  		arr, ok := r.Data["field"].([]string)
   363  		suite.True(ok)
   364  		if ok {
   365  			suite.Equal(2, len(arr))
   366  			suite.Equal("hello", arr[0])
   367  			suite.Equal("world", arr[1])
   368  		}
   369  		executed = true
   370  	})
   371  	suite.True(executed)
   372  	res.Body.Close()
   373  }
   374  
   375  func (suite *MiddlewareTestSuite) TestValidateMiddleware() {
   376  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   377  	rawRequest.Header.Set("Content-Type", "application/json")
   378  	data := map[string]interface{}{
   379  		"string": "hello world",
   380  		"number": 42,
   381  	}
   382  	rules := validation.RuleSet{
   383  		"string": validation.List{"required", "string"},
   384  		"number": validation.List{"required", "numeric", "min:10"},
   385  	}
   386  	request := suite.CreateTestRequest(rawRequest)
   387  	request.Data = data
   388  	request.Rules = rules.AsRules()
   389  	result := suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   390  	suite.Equal(http.StatusNoContent, result.StatusCode)
   391  	result.Body.Close()
   392  
   393  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   394  	rawRequest.Header.Set("Content-Type", "application/json")
   395  	data = map[string]interface{}{
   396  		"string": "hello world",
   397  		"number": 42,
   398  	}
   399  	rules = validation.RuleSet{
   400  		"string": validation.List{"required", "string"},
   401  		"number": validation.List{"required", "numeric", "min:50"},
   402  	}
   403  
   404  	request = suite.CreateTestRequest(rawRequest)
   405  	request.Data = data
   406  	request.Rules = rules.AsRules()
   407  	result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   408  	body, err := io.ReadAll(result.Body)
   409  	if err != nil {
   410  		panic(err)
   411  	}
   412  	result.Body.Close()
   413  	suite.Equal(http.StatusUnprocessableEntity, result.StatusCode)
   414  	suite.Equal("{\"validationError\":{\"number\":{\"errors\":[\"The number must be at least 50.\"]}}}\n", string(body))
   415  
   416  	rawRequest = httptest.NewRequest("POST", "/test-route", nil)
   417  	rawRequest.Header.Set("Content-Type", "application/json")
   418  	request = suite.CreateTestRequest(rawRequest)
   419  	request.Data = nil
   420  	request.Rules = rules.AsRules()
   421  	result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   422  	body, err = io.ReadAll(result.Body)
   423  	if err != nil {
   424  		panic(err)
   425  	}
   426  	result.Body.Close()
   427  	suite.Equal(http.StatusBadRequest, result.StatusCode)
   428  	suite.Equal("{\"validationError\":{\"[data]\":{\"errors\":[\"Malformed JSON\"]}}}\n", string(body))
   429  }
   430  
   431  func (suite *MiddlewareTestSuite) TestCORSMiddleware() {
   432  	// No CORS options
   433  	rawRequest := httptest.NewRequest("GET", "/test-route", nil)
   434  	result := testMiddleware(corsMiddleware, rawRequest, nil, nil, nil, func(response *Response, r *Request) {})
   435  	suite.Equal(200, result.StatusCode)
   436  	result.Body.Close()
   437  
   438  	// Preflight
   439  	options := cors.Default()
   440  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   441  	rawRequest.Header.Set("Origin", "https://google.com")
   442  	rawRequest.Header.Set("Access-Control-Request-Method", "GET")
   443  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   444  		response.String(200, "Hi!")
   445  	})
   446  	body, err := io.ReadAll(result.Body)
   447  	if err != nil {
   448  		panic(err)
   449  	}
   450  	result.Body.Close()
   451  	suite.Equal(204, result.StatusCode)
   452  	suite.Empty(body)
   453  
   454  	// Preflight passthrough
   455  	options = cors.Default()
   456  	options.OptionsPassthrough = true
   457  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   458  		response.String(200, "Passthrough")
   459  	})
   460  	body, err = io.ReadAll(result.Body)
   461  	if err != nil {
   462  		panic(err)
   463  	}
   464  	result.Body.Close()
   465  	suite.Equal(200, result.StatusCode)
   466  	suite.Equal("Passthrough", string(body))
   467  
   468  	// Preflight without Access-Control-Request-Method
   469  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   470  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   471  		response.String(200, "Hi!")
   472  	})
   473  	body, err = io.ReadAll(result.Body)
   474  	if err != nil {
   475  		panic(err)
   476  	}
   477  	result.Body.Close()
   478  	suite.Equal(200, result.StatusCode)
   479  	suite.Equal("Hi!", string(body))
   480  
   481  	// Actual request
   482  	options = cors.Default()
   483  	options.AllowedOrigins = []string{"https://google.com", "https://images.google.com"}
   484  	rawRequest = httptest.NewRequest("GET", "/test-route", nil)
   485  	rawRequest.Header.Set("Origin", "https://images.google.com")
   486  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   487  		response.String(200, "Hi!")
   488  	})
   489  	body, err = io.ReadAll(result.Body)
   490  	if err != nil {
   491  		panic(err)
   492  	}
   493  	result.Body.Close()
   494  	suite.Equal("Hi!", string(body))
   495  	suite.Equal("https://images.google.com", result.Header.Get("Access-Control-Allow-Origin"))
   496  	suite.Equal("Origin", result.Header.Get("Vary"))
   497  }
   498  
   499  func TestMiddlewareTestSuite(t *testing.T) {
   500  	RunTest(t, new(MiddlewareTestSuite))
   501  }