github.com/System-Glitch/goyave/v3@v3.6.1-0.20210226143142-ac2fe42ee80e/middleware_test.go (about)

     1  package goyave
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"mime/multipart"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"runtime"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/System-Glitch/goyave/v3/config"
    19  	"github.com/System-Glitch/goyave/v3/cors"
    20  	"github.com/System-Glitch/goyave/v3/helper/filesystem"
    21  	"github.com/System-Glitch/goyave/v3/lang"
    22  	"github.com/System-Glitch/goyave/v3/validation"
    23  )
    24  
    25  type MiddlewareTestSuite struct {
    26  	TestSuite
    27  }
    28  
    29  func (suite *MiddlewareTestSuite) SetupSuite() {
    30  	lang.LoadDefault()
    31  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
    32  }
    33  
    34  func addFileToRequest(writer *multipart.Writer, path, name, fileName string) {
    35  	file, err := os.Open(path)
    36  	if err != nil {
    37  		panic(err)
    38  	}
    39  	defer file.Close()
    40  	part, err := writer.CreateFormFile(name, fileName)
    41  	if err != nil {
    42  		panic(err)
    43  	}
    44  	_, err = io.Copy(part, file)
    45  	if err != nil {
    46  		panic(err)
    47  	}
    48  }
    49  
    50  func createTestFileRequest(route string, files ...string) *http.Request {
    51  	_, filename, _, _ := runtime.Caller(1)
    52  
    53  	body := &bytes.Buffer{}
    54  	writer := multipart.NewWriter(body)
    55  	for _, p := range files {
    56  		fp := path.Dir(filename) + "/" + p
    57  		addFileToRequest(writer, fp, "file", filepath.Base(fp))
    58  	}
    59  	field, err := writer.CreateFormField("field")
    60  	if err != nil {
    61  		panic(err)
    62  	}
    63  	_, err = io.Copy(field, strings.NewReader("world"))
    64  	if err != nil {
    65  		panic(err)
    66  	}
    67  
    68  	err = writer.Close()
    69  	if err != nil {
    70  		panic(err)
    71  	}
    72  
    73  	req, err := http.NewRequest("POST", route, body)
    74  	if err != nil {
    75  		panic(err)
    76  	}
    77  	req.Header.Set("Content-Type", writer.FormDataContentType())
    78  	return req
    79  }
    80  
    81  func testMiddleware(middleware Middleware, rawRequest *http.Request, data map[string]interface{}, rules validation.RuleSet, corsOptions *cors.Options, handler func(*Response, *Request)) *http.Response {
    82  	request := &Request{
    83  		httpRequest: rawRequest,
    84  		corsOptions: corsOptions,
    85  		Data:        data,
    86  		Rules:       rules.AsRules(),
    87  		Lang:        "en-US",
    88  		Params:      map[string]string{},
    89  	}
    90  	response := newResponse(httptest.NewRecorder(), nil)
    91  	middleware(handler)(response, request)
    92  
    93  	return response.responseWriter.(*httptest.ResponseRecorder).Result()
    94  }
    95  
    96  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewarePanic() {
    97  	response := newResponse(httptest.NewRecorder(), nil)
    98  	err := fmt.Errorf("error message")
    99  	recoveryMiddleware(func(response *Response, r *Request) {
   100  		panic(err)
   101  	})(response, &Request{})
   102  	suite.Equal(err, response.GetError())
   103  	suite.Empty(response.GetStacktrace())
   104  	suite.Equal(500, response.status)
   105  
   106  	prev := config.GetBool("app.debug")
   107  	config.Set("app.debug", true)
   108  	defer config.Set("app.debug", prev)
   109  
   110  	response = newResponse(httptest.NewRecorder(), nil)
   111  	err = fmt.Errorf("error message")
   112  	recoveryMiddleware(func(response *Response, r *Request) {
   113  		panic(err)
   114  	})(response, &Request{})
   115  	suite.Equal(err, response.GetError())
   116  	suite.NotEmpty(response.GetStacktrace())
   117  	suite.Equal(500, response.status)
   118  }
   119  
   120  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNoPanic() {
   121  	response := newResponse(httptest.NewRecorder(), nil)
   122  	recoveryMiddleware(func(response *Response, r *Request) {
   123  		response.String(200, "message")
   124  	})(response, &Request{})
   125  
   126  	resp := response.responseWriter.(*httptest.ResponseRecorder).Result()
   127  	suite.Nil(response.GetError())
   128  	suite.Equal(200, response.status)
   129  	suite.Equal(200, resp.StatusCode)
   130  
   131  	body, err := ioutil.ReadAll(resp.Body)
   132  	resp.Body.Close()
   133  	suite.Nil(err)
   134  	suite.Equal("message", string(body))
   135  }
   136  
   137  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNilPanic() {
   138  	response := newResponse(httptest.NewRecorder(), nil)
   139  	recoveryMiddleware(func(response *Response, r *Request) {
   140  		panic(nil)
   141  	})(response, &Request{})
   142  	suite.Nil(response.GetError())
   143  	suite.Equal(500, response.status)
   144  }
   145  
   146  func (suite *MiddlewareTestSuite) TestLanguageMiddleware() {
   147  	defaultLanguage = config.GetString("app.defaultLanguage")
   148  	executed := false
   149  	rawRequest := httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   150  	rawRequest.Header.Set("Accept-Language", "en-US")
   151  	res := testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   152  		suite.Equal("en-US", r.Lang)
   153  		executed = true
   154  	})
   155  	res.Body.Close()
   156  	suite.True(executed)
   157  
   158  	rawRequest = httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   159  	res = testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   160  		suite.Equal("en-US", r.Lang)
   161  		executed = true
   162  	})
   163  	res.Body.Close()
   164  
   165  	suite.True(executed)
   166  }
   167  
   168  func (suite *MiddlewareTestSuite) TestParsePostRequestMiddleware() {
   169  	executed := false
   170  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   171  	rawRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
   172  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   173  		suite.Equal("hello world", r.Data["string"])
   174  		suite.Equal("42", r.Data["number"])
   175  		executed = true
   176  	})
   177  	suite.True(executed)
   178  	res.Body.Close()
   179  }
   180  
   181  func (suite *MiddlewareTestSuite) TestParseGetRequestMiddleware() {
   182  	executed := false
   183  	rawRequest := httptest.NewRequest("GET", "/test-route?string=hello%20world&number=42", nil)
   184  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   185  		suite.Equal("hello world", r.Data["string"])
   186  		suite.Equal("42", r.Data["number"])
   187  		executed = true
   188  	})
   189  	suite.True(executed)
   190  	res.Body.Close()
   191  
   192  	executed = false
   193  	rawRequest = httptest.NewRequest("GET", "/test-route?%9", nil)
   194  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   195  		suite.Nil(r.Data)
   196  		executed = true
   197  	})
   198  	suite.True(executed)
   199  	res.Body.Close()
   200  }
   201  
   202  func (suite *MiddlewareTestSuite) TestParseJsonRequestMiddleware() {
   203  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   204  	rawRequest.Header.Set("Content-Type", "application/json")
   205  	executed := false
   206  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   207  		suite.Equal("hello world", r.Data["string"])
   208  		suite.Equal(42.0, r.Data["number"])
   209  		slice, ok := r.Data["array"].([]interface{})
   210  		suite.True(ok)
   211  		suite.Equal(2, len(slice))
   212  		suite.Equal("val1", slice[0])
   213  		suite.Equal("val2", slice[1])
   214  		executed = true
   215  	})
   216  	suite.True(executed)
   217  	res.Body.Close()
   218  
   219  	executed = false
   220  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]")) // Missing closing braces
   221  	rawRequest.Header.Set("Content-Type", "application/json")
   222  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   223  		suite.Nil(r.Data)
   224  		executed = true
   225  	})
   226  	suite.True(executed)
   227  	res.Body.Close()
   228  
   229  	// Test with query parameters
   230  	executed = false
   231  	rawRequest = httptest.NewRequest("POST", "/test-route?query=param", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   232  	rawRequest.Header.Set("Content-Type", "application/json")
   233  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   234  		suite.NotNil(r.Data)
   235  		suite.Equal("param", r.Data["query"])
   236  		executed = true
   237  	})
   238  	suite.True(executed)
   239  	res.Body.Close()
   240  
   241  	// Test with charset (#101)
   242  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   243  	rawRequest.Header.Set("Content-Type", "application/json; charset=utf-8")
   244  	executed = false
   245  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   246  		suite.Equal("hello world", r.Data["string"])
   247  		suite.Equal(42.0, r.Data["number"])
   248  		slice, ok := r.Data["array"].([]interface{})
   249  		suite.True(ok)
   250  		suite.Equal(2, len(slice))
   251  		suite.Equal("val1", slice[0])
   252  		suite.Equal("val2", slice[1])
   253  		executed = true
   254  	})
   255  	res.Body.Close()
   256  	suite.True(executed)
   257  
   258  }
   259  
   260  func (suite *MiddlewareTestSuite) TestParseMultipartRequestMiddleware() {
   261  	executed := false
   262  	rawRequest := createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   263  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   264  		suite.Equal(3, len(r.Data))
   265  		suite.Equal("hello", r.Data["test"])
   266  		suite.Equal("world", r.Data["field"])
   267  		files, ok := r.Data["file"].([]filesystem.File)
   268  		suite.True(ok)
   269  		suite.Equal(1, len(files))
   270  		executed = true
   271  	})
   272  	suite.True(executed)
   273  	res.Body.Close()
   274  
   275  	// Test payload too large
   276  	prev := config.Get("server.maxUploadSize")
   277  	config.Set("server.maxUploadSize", -10.0)
   278  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   279  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   280  
   281  	request := createTestRequest(rawRequest)
   282  	response := newResponse(httptest.NewRecorder(), nil)
   283  	parseRequestMiddleware(nil)(response, request)
   284  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   285  	config.Set("server.maxUploadSize", prev)
   286  
   287  	prev = config.Get("server.maxUploadSize")
   288  	config.Set("server.maxUploadSize", 0.0006)
   289  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   290  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   291  
   292  	request = createTestRequest(rawRequest)
   293  	response = newResponse(httptest.NewRecorder(), nil)
   294  	parseRequestMiddleware(nil)(response, request)
   295  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   296  	config.Set("server.maxUploadSize", prev)
   297  	maxPayloadSize = int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
   298  }
   299  
   300  func (suite *MiddlewareTestSuite) TestParseMultipartOverrideMiddleware() {
   301  	executed := false
   302  	rawRequest := createTestFileRequest("/test-route?field=hello", "resources/img/logo/goyave_16.png")
   303  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   304  		suite.Equal(2, len(r.Data))
   305  		suite.Equal("world", r.Data["field"])
   306  		files, ok := r.Data["file"].([]filesystem.File)
   307  		suite.True(ok)
   308  		suite.Equal(1, len(files))
   309  		executed = true
   310  	})
   311  	suite.True(executed)
   312  	res.Body.Close()
   313  }
   314  
   315  func (suite *MiddlewareTestSuite) TestParseMiddlewareWithArray() {
   316  	executed := false
   317  	rawRequest := httptest.NewRequest("GET", "/test-route?arr=hello&arr=world", nil)
   318  	res := testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   319  		arr, ok := r.Data["arr"].([]string)
   320  		suite.True(ok)
   321  		if ok {
   322  			suite.Equal(2, len(arr))
   323  			suite.Equal("hello", arr[0])
   324  			suite.Equal("world", arr[1])
   325  		}
   326  		executed = true
   327  	})
   328  	suite.True(executed)
   329  	res.Body.Close()
   330  
   331  	body := &bytes.Buffer{}
   332  	writer := multipart.NewWriter(body)
   333  	field, err := writer.CreateFormField("field")
   334  	if err != nil {
   335  		panic(err)
   336  	}
   337  	_, err = io.Copy(field, strings.NewReader("hello"))
   338  	if err != nil {
   339  		panic(err)
   340  	}
   341  
   342  	field, err = writer.CreateFormField("field")
   343  	if err != nil {
   344  		panic(err)
   345  	}
   346  	_, err = io.Copy(field, strings.NewReader("world"))
   347  	if err != nil {
   348  		panic(err)
   349  	}
   350  
   351  	err = writer.Close()
   352  	if err != nil {
   353  		panic(err)
   354  	}
   355  
   356  	executed = false
   357  	rawRequest, err = http.NewRequest("POST", "/test-route", body)
   358  	if err != nil {
   359  		panic(err)
   360  	}
   361  	rawRequest.Header.Set("Content-Type", writer.FormDataContentType())
   362  	res = testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   363  		suite.Equal(1, len(r.Data))
   364  		arr, ok := r.Data["field"].([]string)
   365  		suite.True(ok)
   366  		if ok {
   367  			suite.Equal(2, len(arr))
   368  			suite.Equal("hello", arr[0])
   369  			suite.Equal("world", arr[1])
   370  		}
   371  		executed = true
   372  	})
   373  	suite.True(executed)
   374  	res.Body.Close()
   375  }
   376  
   377  func (suite *MiddlewareTestSuite) TestValidateMiddleware() {
   378  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   379  	rawRequest.Header.Set("Content-Type", "application/json")
   380  	data := map[string]interface{}{
   381  		"string": "hello world",
   382  		"number": 42,
   383  	}
   384  	rules := validation.RuleSet{
   385  		"string": {"required", "string"},
   386  		"number": {"required", "numeric", "min:10"},
   387  	}
   388  	request := suite.CreateTestRequest(rawRequest)
   389  	request.Data = data
   390  	request.Rules = rules.AsRules()
   391  	result := suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   392  	suite.Equal(http.StatusNoContent, result.StatusCode)
   393  	result.Body.Close()
   394  
   395  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   396  	rawRequest.Header.Set("Content-Type", "application/json")
   397  	data = map[string]interface{}{
   398  		"string": "hello world",
   399  		"number": 42,
   400  	}
   401  	rules = validation.RuleSet{
   402  		"string": {"required", "string"},
   403  		"number": {"required", "numeric", "min:50"},
   404  	}
   405  
   406  	request = suite.CreateTestRequest(rawRequest)
   407  	request.Data = data
   408  	request.Rules = rules.AsRules()
   409  	result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   410  	body, err := ioutil.ReadAll(result.Body)
   411  	if err != nil {
   412  		panic(err)
   413  	}
   414  	result.Body.Close()
   415  	suite.Equal(http.StatusUnprocessableEntity, result.StatusCode)
   416  	suite.Equal("{\"validationError\":{\"number\":[\"The number must be at least 50.\"]}}\n", string(body))
   417  
   418  	rawRequest = httptest.NewRequest("POST", "/test-route", nil)
   419  	rawRequest.Header.Set("Content-Type", "application/json")
   420  	request = suite.CreateTestRequest(rawRequest)
   421  	request.Data = nil
   422  	request.Rules = rules.AsRules()
   423  	result = suite.Middleware(validateRequestMiddleware, request, func(response *Response, r *Request) {})
   424  	body, err = ioutil.ReadAll(result.Body)
   425  	if err != nil {
   426  		panic(err)
   427  	}
   428  	result.Body.Close()
   429  	suite.Equal(http.StatusBadRequest, result.StatusCode)
   430  	suite.Equal("{\"validationError\":{\"error\":[\"Malformed JSON\"]}}\n", string(body))
   431  }
   432  
   433  func (suite *MiddlewareTestSuite) TestCORSMiddleware() {
   434  	// No CORS options
   435  	rawRequest := httptest.NewRequest("GET", "/test-route", nil)
   436  	result := testMiddleware(corsMiddleware, rawRequest, nil, nil, nil, func(response *Response, r *Request) {})
   437  	suite.Equal(200, result.StatusCode)
   438  	result.Body.Close()
   439  
   440  	// Preflight
   441  	options := cors.Default()
   442  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   443  	rawRequest.Header.Set("Origin", "https://google.com")
   444  	rawRequest.Header.Set("Access-Control-Request-Method", "GET")
   445  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   446  		response.String(200, "Hi!")
   447  	})
   448  	body, err := ioutil.ReadAll(result.Body)
   449  	if err != nil {
   450  		panic(err)
   451  	}
   452  	result.Body.Close()
   453  	suite.Equal(204, result.StatusCode)
   454  	suite.Empty(body)
   455  
   456  	// Preflight passthrough
   457  	options = cors.Default()
   458  	options.OptionsPassthrough = true
   459  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   460  		response.String(200, "Passthrough")
   461  	})
   462  	body, err = ioutil.ReadAll(result.Body)
   463  	if err != nil {
   464  		panic(err)
   465  	}
   466  	result.Body.Close()
   467  	suite.Equal(200, result.StatusCode)
   468  	suite.Equal("Passthrough", string(body))
   469  
   470  	// Preflight without Access-Control-Request-Method
   471  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   472  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   473  		response.String(200, "Hi!")
   474  	})
   475  	body, err = ioutil.ReadAll(result.Body)
   476  	if err != nil {
   477  		panic(err)
   478  	}
   479  	result.Body.Close()
   480  	suite.Equal(200, result.StatusCode)
   481  	suite.Equal("Hi!", string(body))
   482  
   483  	// Actual request
   484  	options = cors.Default()
   485  	options.AllowedOrigins = []string{"https://google.com", "https://images.google.com"}
   486  	rawRequest = httptest.NewRequest("GET", "/test-route", nil)
   487  	rawRequest.Header.Set("Origin", "https://images.google.com")
   488  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   489  		response.String(200, "Hi!")
   490  	})
   491  	body, err = ioutil.ReadAll(result.Body)
   492  	if err != nil {
   493  		panic(err)
   494  	}
   495  	result.Body.Close()
   496  	suite.Equal("Hi!", string(body))
   497  	suite.Equal("https://images.google.com", result.Header.Get("Access-Control-Allow-Origin"))
   498  	suite.Equal("Origin", result.Header.Get("Vary"))
   499  }
   500  
   501  func TestMiddlewareTestSuite(t *testing.T) {
   502  	RunTest(t, new(MiddlewareTestSuite))
   503  }