github.com/System-Glitch/goyave/v2@v2.10.3-0.20200819142921-51011e75d504/middleware_test.go (about)

     1  package goyave
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"mime/multipart"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"runtime"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/System-Glitch/goyave/v2/config"
    19  	"github.com/System-Glitch/goyave/v2/cors"
    20  	"github.com/System-Glitch/goyave/v2/helper/filesystem"
    21  	"github.com/System-Glitch/goyave/v2/lang"
    22  	"github.com/System-Glitch/goyave/v2/validation"
    23  )
    24  
    25  type MiddlewareTestSuite struct {
    26  	TestSuite
    27  }
    28  
    29  func (suite *MiddlewareTestSuite) SetupSuite() {
    30  	lang.LoadDefault()
    31  }
    32  
    33  func addFileToRequest(writer *multipart.Writer, path, name, fileName string) {
    34  	file, err := os.Open(path)
    35  	if err != nil {
    36  		panic(err)
    37  	}
    38  	defer file.Close()
    39  	part, err := writer.CreateFormFile(name, fileName)
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  	_, err = io.Copy(part, file)
    44  	if err != nil {
    45  		panic(err)
    46  	}
    47  }
    48  
    49  func createTestFileRequest(route string, files ...string) *http.Request {
    50  	_, filename, _, _ := runtime.Caller(1)
    51  
    52  	body := &bytes.Buffer{}
    53  	writer := multipart.NewWriter(body)
    54  	for _, p := range files {
    55  		fp := path.Dir(filename) + "/" + p
    56  		addFileToRequest(writer, fp, "file", filepath.Base(fp))
    57  	}
    58  	field, err := writer.CreateFormField("field")
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	_, err = io.Copy(field, strings.NewReader("world"))
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  
    67  	err = writer.Close()
    68  	if err != nil {
    69  		panic(err)
    70  	}
    71  
    72  	req, err := http.NewRequest("POST", route, body)
    73  	if err != nil {
    74  		panic(err)
    75  	}
    76  	req.Header.Set("Content-Type", writer.FormDataContentType())
    77  	return req
    78  }
    79  
    80  func testMiddleware(middleware Middleware, rawRequest *http.Request, data map[string]interface{}, rules validation.RuleSet, corsOptions *cors.Options, handler func(*Response, *Request)) *http.Response {
    81  	request := &Request{
    82  		httpRequest: rawRequest,
    83  		corsOptions: corsOptions,
    84  		Data:        data,
    85  		Rules:       rules.AsRules(),
    86  		Lang:        "en-US",
    87  		Params:      map[string]string{},
    88  	}
    89  	response := newResponse(httptest.NewRecorder(), nil)
    90  	middleware(handler)(response, request)
    91  
    92  	return response.responseWriter.(*httptest.ResponseRecorder).Result()
    93  }
    94  
    95  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewarePanic() {
    96  	response := newResponse(httptest.NewRecorder(), nil)
    97  	err := fmt.Errorf("error message")
    98  	recoveryMiddleware(func(response *Response, r *Request) {
    99  		panic(err)
   100  	})(response, &Request{})
   101  	suite.Equal(err, response.GetError())
   102  	suite.Equal(500, response.status)
   103  }
   104  
   105  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNoPanic() {
   106  	response := newResponse(httptest.NewRecorder(), nil)
   107  	recoveryMiddleware(func(response *Response, r *Request) {
   108  		response.String(200, "message")
   109  	})(response, &Request{})
   110  
   111  	resp := response.responseWriter.(*httptest.ResponseRecorder).Result()
   112  	suite.Nil(response.GetError())
   113  	suite.Equal(200, response.status)
   114  	suite.Equal(200, resp.StatusCode)
   115  
   116  	body, err := ioutil.ReadAll(resp.Body)
   117  	resp.Body.Close()
   118  	suite.Nil(err)
   119  	suite.Equal("message", string(body))
   120  }
   121  
   122  func (suite *MiddlewareTestSuite) TestRecoveryMiddlewareNilPanic() {
   123  	response := newResponse(httptest.NewRecorder(), nil)
   124  	recoveryMiddleware(func(response *Response, r *Request) {
   125  		panic(nil)
   126  	})(response, &Request{})
   127  	suite.Nil(response.GetError())
   128  	suite.Equal(500, response.status)
   129  }
   130  
   131  func (suite *MiddlewareTestSuite) TestLanguageMiddleware() {
   132  	executed := false
   133  	rawRequest := httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   134  	rawRequest.Header.Set("Accept-Language", "en-US")
   135  	testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   136  		suite.Equal("en-US", r.Lang)
   137  		executed = true
   138  	})
   139  	suite.True(executed)
   140  
   141  	rawRequest = httptest.NewRequest("GET", "/test-route", strings.NewReader("body"))
   142  	testMiddleware(languageMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   143  		suite.Equal("en-US", r.Lang)
   144  		executed = true
   145  	})
   146  	suite.True(executed)
   147  }
   148  
   149  func (suite *MiddlewareTestSuite) TestParsePostRequestMiddleware() {
   150  	executed := false
   151  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   152  	rawRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
   153  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   154  		suite.Equal("hello world", r.Data["string"])
   155  		suite.Equal("42", r.Data["number"])
   156  		executed = true
   157  	})
   158  	suite.True(executed)
   159  }
   160  
   161  func (suite *MiddlewareTestSuite) TestParseGetRequestMiddleware() {
   162  	executed := false
   163  	rawRequest := httptest.NewRequest("GET", "/test-route?string=hello%20world&number=42", nil)
   164  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   165  		suite.Equal("hello world", r.Data["string"])
   166  		suite.Equal("42", r.Data["number"])
   167  		executed = true
   168  	})
   169  	suite.True(executed)
   170  
   171  	executed = false
   172  	rawRequest = httptest.NewRequest("GET", "/test-route?%9", nil)
   173  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   174  		suite.Nil(r.Data)
   175  		executed = true
   176  	})
   177  	suite.True(executed)
   178  }
   179  
   180  func (suite *MiddlewareTestSuite) TestParseJsonRequestMiddleware() {
   181  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   182  	rawRequest.Header.Set("Content-Type", "application/json")
   183  	executed := false
   184  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   185  		suite.Equal("hello world", r.Data["string"])
   186  		suite.Equal(42.0, r.Data["number"])
   187  		slice, ok := r.Data["array"].([]interface{})
   188  		suite.True(ok)
   189  		suite.Equal(2, len(slice))
   190  		suite.Equal("val1", slice[0])
   191  		suite.Equal("val2", slice[1])
   192  		executed = true
   193  	})
   194  	suite.True(executed)
   195  
   196  	executed = false
   197  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]")) // Missing closing braces
   198  	rawRequest.Header.Set("Content-Type", "application/json")
   199  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   200  		suite.Nil(r.Data)
   201  		executed = true
   202  	})
   203  	suite.True(executed)
   204  
   205  	// Test with query parameters
   206  	executed = false
   207  	rawRequest = httptest.NewRequest("POST", "/test-route?query=param", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   208  	rawRequest.Header.Set("Content-Type", "application/json")
   209  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   210  		suite.NotNil(r.Data)
   211  		suite.Equal("param", r.Data["query"])
   212  		executed = true
   213  	})
   214  	suite.True(executed)
   215  
   216  	// Test with charset (#101)
   217  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("{\"string\":\"hello world\", \"number\":42, \"array\":[\"val1\",\"val2\"]}"))
   218  	rawRequest.Header.Set("Content-Type", "application/json; charset=utf-8")
   219  	executed = false
   220  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   221  		suite.Equal("hello world", r.Data["string"])
   222  		suite.Equal(42.0, r.Data["number"])
   223  		slice, ok := r.Data["array"].([]interface{})
   224  		suite.True(ok)
   225  		suite.Equal(2, len(slice))
   226  		suite.Equal("val1", slice[0])
   227  		suite.Equal("val2", slice[1])
   228  		executed = true
   229  	})
   230  	suite.True(executed)
   231  
   232  }
   233  
   234  func (suite *MiddlewareTestSuite) TestParseMultipartRequestMiddleware() {
   235  	executed := false
   236  	rawRequest := createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   237  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   238  		suite.Equal(3, len(r.Data))
   239  		suite.Equal("hello", r.Data["test"])
   240  		suite.Equal("world", r.Data["field"])
   241  		files, ok := r.Data["file"].([]filesystem.File)
   242  		suite.True(ok)
   243  		suite.Equal(1, len(files))
   244  		executed = true
   245  	})
   246  	suite.True(executed)
   247  
   248  	// Test payload too large
   249  	prev := config.Get("server.maxUploadSize")
   250  	config.Set("server.maxUploadSize", -10.0)
   251  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   252  
   253  	request := createTestRequest(rawRequest)
   254  	response := newResponse(httptest.NewRecorder(), nil)
   255  	parseRequestMiddleware(nil)(response, request)
   256  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   257  	config.Set("server.maxUploadSize", prev)
   258  
   259  	prev = config.Get("server.maxUploadSize")
   260  	config.Set("server.maxUploadSize", 0.0006)
   261  	rawRequest = createTestFileRequest("/test-route?test=hello", "resources/img/logo/goyave_16.png")
   262  
   263  	request = createTestRequest(rawRequest)
   264  	response = newResponse(httptest.NewRecorder(), nil)
   265  	parseRequestMiddleware(nil)(response, request)
   266  	suite.Equal(http.StatusRequestEntityTooLarge, response.GetStatus())
   267  	config.Set("server.maxUploadSize", prev)
   268  }
   269  
   270  func (suite *MiddlewareTestSuite) TestParseMultipartOverrideMiddleware() {
   271  	executed := false
   272  	rawRequest := createTestFileRequest("/test-route?field=hello", "resources/img/logo/goyave_16.png")
   273  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   274  		suite.Equal(2, len(r.Data))
   275  		suite.Equal("world", r.Data["field"])
   276  		files, ok := r.Data["file"].([]filesystem.File)
   277  		suite.True(ok)
   278  		suite.Equal(1, len(files))
   279  		executed = true
   280  	})
   281  	suite.True(executed)
   282  }
   283  
   284  func (suite *MiddlewareTestSuite) TestParseMiddlewareWithArray() {
   285  	executed := false
   286  	rawRequest := httptest.NewRequest("GET", "/test-route?arr=hello&arr=world", nil)
   287  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   288  		arr, ok := r.Data["arr"].([]string)
   289  		suite.True(ok)
   290  		if ok {
   291  			suite.Equal(2, len(arr))
   292  			suite.Equal("hello", arr[0])
   293  			suite.Equal("world", arr[1])
   294  		}
   295  		executed = true
   296  	})
   297  	suite.True(executed)
   298  
   299  	body := &bytes.Buffer{}
   300  	writer := multipart.NewWriter(body)
   301  	field, err := writer.CreateFormField("field")
   302  	if err != nil {
   303  		panic(err)
   304  	}
   305  	_, err = io.Copy(field, strings.NewReader("hello"))
   306  	if err != nil {
   307  		panic(err)
   308  	}
   309  
   310  	field, err = writer.CreateFormField("field")
   311  	if err != nil {
   312  		panic(err)
   313  	}
   314  	_, err = io.Copy(field, strings.NewReader("world"))
   315  	if err != nil {
   316  		panic(err)
   317  	}
   318  
   319  	err = writer.Close()
   320  	if err != nil {
   321  		panic(err)
   322  	}
   323  
   324  	executed = false
   325  	rawRequest, err = http.NewRequest("POST", "/test-route", body)
   326  	if err != nil {
   327  		panic(err)
   328  	}
   329  	rawRequest.Header.Set("Content-Type", writer.FormDataContentType())
   330  	testMiddleware(parseRequestMiddleware, rawRequest, nil, validation.RuleSet{}, nil, func(response *Response, r *Request) {
   331  		suite.Equal(1, len(r.Data))
   332  		arr, ok := r.Data["field"].([]string)
   333  		suite.True(ok)
   334  		if ok {
   335  			suite.Equal(2, len(arr))
   336  			suite.Equal("hello", arr[0])
   337  			suite.Equal("world", arr[1])
   338  		}
   339  		executed = true
   340  	})
   341  	suite.True(executed)
   342  }
   343  
   344  func (suite *MiddlewareTestSuite) TestValidateMiddleware() {
   345  	rawRequest := httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   346  	rawRequest.Header.Set("Content-Type", "application/json")
   347  	data := map[string]interface{}{
   348  		"string": "hello world",
   349  		"number": 42,
   350  	}
   351  	rules := validation.RuleSet{
   352  		"string": {"required", "string"},
   353  		"number": {"required", "numeric", "min:10"},
   354  	}
   355  	result := testMiddleware(validateRequestMiddleware, rawRequest, data, rules, nil, func(response *Response, r *Request) {})
   356  	suite.Equal(200, result.StatusCode)
   357  
   358  	rawRequest = httptest.NewRequest("POST", "/test-route", strings.NewReader("string=hello%20world&number=42"))
   359  	rawRequest.Header.Set("Content-Type", "application/json")
   360  	data = map[string]interface{}{
   361  		"string": "hello world",
   362  		"number": 42,
   363  	}
   364  	rules = validation.RuleSet{
   365  		"string": {"required", "string"},
   366  		"number": {"required", "numeric", "min:50"},
   367  	}
   368  	result = testMiddleware(validateRequestMiddleware, rawRequest, data, rules, nil, func(response *Response, r *Request) {})
   369  	body, err := ioutil.ReadAll(result.Body)
   370  	if err != nil {
   371  		panic(err)
   372  	}
   373  	suite.Equal(422, result.StatusCode)
   374  	suite.Equal("{\"validationError\":{\"number\":[\"The number must be at least 50.\"]}}\n", string(body))
   375  
   376  	rawRequest = httptest.NewRequest("POST", "/test-route", nil)
   377  	rawRequest.Header.Set("Content-Type", "application/json")
   378  	result = testMiddleware(validateRequestMiddleware, rawRequest, nil, rules, nil, func(response *Response, r *Request) {})
   379  	body, err = ioutil.ReadAll(result.Body)
   380  	if err != nil {
   381  		panic(err)
   382  	}
   383  	suite.Equal(400, result.StatusCode)
   384  	suite.Equal("{\"validationError\":{\"error\":[\"Malformed JSON\"]}}\n", string(body))
   385  }
   386  
   387  func (suite *MiddlewareTestSuite) TestCORSMiddleware() {
   388  	// No CORS options
   389  	rawRequest := httptest.NewRequest("GET", "/test-route", nil)
   390  	result := testMiddleware(corsMiddleware, rawRequest, nil, nil, nil, func(response *Response, r *Request) {})
   391  	suite.Equal(200, result.StatusCode)
   392  
   393  	// Preflight
   394  	options := cors.Default()
   395  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   396  	rawRequest.Header.Set("Origin", "https://google.com")
   397  	rawRequest.Header.Set("Access-Control-Request-Method", "GET")
   398  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   399  		response.String(200, "Hi!")
   400  	})
   401  	body, err := ioutil.ReadAll(result.Body)
   402  	if err != nil {
   403  		panic(err)
   404  	}
   405  	suite.Equal(204, result.StatusCode)
   406  	suite.Empty(body)
   407  
   408  	// Preflight passthrough
   409  	options = cors.Default()
   410  	options.OptionsPassthrough = true
   411  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   412  		response.String(200, "Passthrough")
   413  	})
   414  	body, err = ioutil.ReadAll(result.Body)
   415  	if err != nil {
   416  		panic(err)
   417  	}
   418  	suite.Equal(200, result.StatusCode)
   419  	suite.Equal("Passthrough", string(body))
   420  
   421  	// Preflight without Access-Control-Request-Method
   422  	rawRequest = httptest.NewRequest("OPTIONS", "/test-route", nil)
   423  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   424  		response.String(200, "Hi!")
   425  	})
   426  	body, err = ioutil.ReadAll(result.Body)
   427  	if err != nil {
   428  		panic(err)
   429  	}
   430  	suite.Equal(200, result.StatusCode)
   431  	suite.Equal("Hi!", string(body))
   432  
   433  	// Actual request
   434  	options = cors.Default()
   435  	options.AllowedOrigins = []string{"https://google.com", "https://images.google.com"}
   436  	rawRequest = httptest.NewRequest("GET", "/test-route", nil)
   437  	rawRequest.Header.Set("Origin", "https://images.google.com")
   438  	result = testMiddleware(corsMiddleware, rawRequest, nil, nil, options, func(response *Response, r *Request) {
   439  		response.String(200, "Hi!")
   440  	})
   441  	body, err = ioutil.ReadAll(result.Body)
   442  	if err != nil {
   443  		panic(err)
   444  	}
   445  	suite.Equal("Hi!", string(body))
   446  	suite.Equal("https://images.google.com", result.Header.Get("Access-Control-Allow-Origin"))
   447  	suite.Equal("Origin", result.Header.Get("Vary"))
   448  }
   449  
   450  func TestMiddlewareTestSuite(t *testing.T) {
   451  	RunTest(t, new(MiddlewareTestSuite))
   452  }