github.com/System-Glitch/goyave/v2@v2.10.3-0.20200819142921-51011e75d504/middleware.go (about)

     1  package goyave
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/url"
    10  	"runtime/debug"
    11  	"strings"
    12  
    13  	"github.com/System-Glitch/goyave/v2/config"
    14  	"github.com/System-Glitch/goyave/v2/helper/filesystem"
    15  	"github.com/System-Glitch/goyave/v2/lang"
    16  )
    17  
    18  // Middleware function generating middleware handler function.
    19  //
    20  // Request data is available to middleware, but bear in mind that
    21  // it had not been validated yet. That means that you can modify or
    22  // filter data. (Trim strings for example)
    23  type Middleware func(Handler) Handler
    24  
    25  // recoveryMiddleware is a middleware that recovers from panic and sends a 500 error code.
    26  // If debugging is enabled in the config and the default status handler for the 500 status code
    27  // had not been changed, the error is also written in the response.
    28  func recoveryMiddleware(next Handler) Handler {
    29  	return func(response *Response, r *Request) {
    30  		panicked := true
    31  		defer func() {
    32  			if err := recover(); err != nil || panicked {
    33  				ErrLogger.Println(err)
    34  				response.err = err
    35  				if config.GetBool("app.debug") {
    36  					response.stacktrace = string(debug.Stack())
    37  				}
    38  				response.Status(http.StatusInternalServerError)
    39  			}
    40  		}()
    41  
    42  		next(response, r)
    43  		panicked = false
    44  	}
    45  }
    46  
    47  // parseRequestMiddleware is a middleware that parses the request data.
    48  //
    49  // If the parsing fails, the request's data is set to nil. If it succeeds
    50  // and there is no data, the request's data is set to an empty map.
    51  //
    52  // If the "Content-Type: application/json" header is set, the middleware
    53  // will attempt to unmarshal the request's body.
    54  //
    55  // This middleware doesn't drain the request body to maximize compatibility
    56  // with native handlers.
    57  //
    58  // The maximum length of the data is limited by the "maxUploadSize" config entry.
    59  // If a request exceeds the maximum size, the middleware doesn't call "next()" and
    60  // sets the response status code to "413 Payload Too Large".
    61  func parseRequestMiddleware(next Handler) Handler {
    62  	return func(response *Response, request *Request) {
    63  
    64  		request.Data = nil
    65  		maxSize := int64(config.GetFloat("server.maxUploadSize") * 1024 * 1024)
    66  		maxValueBytes := maxSize
    67  		var bodyBuf bytes.Buffer
    68  		n, err := io.CopyN(&bodyBuf, request.httpRequest.Body, maxValueBytes+1)
    69  		request.httpRequest.Body.Close()
    70  		if err == nil || err == io.EOF {
    71  			maxValueBytes -= n
    72  			if maxValueBytes < 0 {
    73  				response.Status(http.StatusRequestEntityTooLarge)
    74  				return
    75  			}
    76  
    77  			bodyBytes := bodyBuf.Bytes()
    78  			contentType := request.httpRequest.Header.Get("Content-Type")
    79  			if strings.HasPrefix(contentType, "application/json") {
    80  				request.Data = make(map[string]interface{}, 10)
    81  				parseQuery(request)
    82  				if err := json.Unmarshal(bodyBytes, &request.Data); err != nil {
    83  					request.Data = nil
    84  				}
    85  				resetRequestBody(request, bodyBytes)
    86  			} else {
    87  				resetRequestBody(request, bodyBytes)
    88  				request.Data = generateFlatMap(request.httpRequest, maxSize)
    89  				resetRequestBody(request, bodyBytes)
    90  			}
    91  		}
    92  
    93  		next(response, request)
    94  	}
    95  }
    96  
    97  func generateFlatMap(request *http.Request, maxSize int64) map[string]interface{} {
    98  	var flatMap map[string]interface{} = make(map[string]interface{})
    99  	err := request.ParseMultipartForm(maxSize)
   100  
   101  	if err != nil {
   102  		if err == http.ErrNotMultipart {
   103  			if err := request.ParseForm(); err != nil {
   104  				return nil
   105  			}
   106  		} else {
   107  			return nil
   108  		}
   109  	}
   110  
   111  	if request.Form != nil {
   112  		flatten(flatMap, request.Form)
   113  	}
   114  	if request.MultipartForm != nil {
   115  		flatten(flatMap, request.MultipartForm.Value)
   116  
   117  		for field := range request.MultipartForm.File {
   118  			flatMap[field] = filesystem.ParseMultipartFiles(request, field)
   119  		}
   120  	}
   121  
   122  	// Source form is not needed anymore, clear it.
   123  	request.Form = nil
   124  	request.PostForm = nil
   125  	request.MultipartForm = nil
   126  
   127  	return flatMap
   128  }
   129  
   130  func flatten(dst map[string]interface{}, values url.Values) {
   131  	for field, value := range values {
   132  		if len(value) > 1 {
   133  			dst[field] = value
   134  		} else {
   135  			dst[field] = value[0]
   136  		}
   137  	}
   138  }
   139  
   140  func parseQuery(request *Request) {
   141  	if uri := request.URI(); uri != nil {
   142  		queryParams, err := url.ParseQuery(uri.RawQuery)
   143  		if err == nil {
   144  			flatten(request.Data, queryParams)
   145  		}
   146  	}
   147  }
   148  
   149  func resetRequestBody(request *Request, bodyBytes []byte) {
   150  	request.httpRequest.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
   151  }
   152  
   153  // corsMiddleware is the middleware handling CORS, using the options set in the router.
   154  // This middleware is automatically inserted first to the router's list of middleware
   155  // if the latter has defined CORS Options.
   156  func corsMiddleware(next Handler) Handler {
   157  	return func(response *Response, request *Request) {
   158  		if request.corsOptions == nil {
   159  			next(response, request)
   160  			return
   161  		}
   162  
   163  		options := request.corsOptions
   164  		headers := response.Header()
   165  		requestHeaders := request.Header()
   166  
   167  		options.ConfigureCommon(headers, requestHeaders)
   168  
   169  		if request.Method() == http.MethodOptions && requestHeaders.Get("Access-Control-Request-Method") != "" {
   170  			options.HandlePreflight(headers, requestHeaders)
   171  			if options.OptionsPassthrough {
   172  				next(response, request)
   173  			} else {
   174  				response.WriteHeader(http.StatusNoContent)
   175  			}
   176  		} else {
   177  			next(response, request)
   178  		}
   179  	}
   180  }
   181  
   182  // validateRequestMiddleware is a middleware that validates the request and sends a 422 error code
   183  // if the validation rules are not met.
   184  func validateRequestMiddleware(next Handler) Handler {
   185  	return func(response *Response, r *Request) {
   186  		errsBag := r.validate()
   187  		if errsBag == nil {
   188  			next(response, r)
   189  			return
   190  		}
   191  
   192  		var code int
   193  		if r.Data == nil {
   194  			code = http.StatusBadRequest
   195  		} else {
   196  			code = http.StatusUnprocessableEntity
   197  		}
   198  		response.JSON(code, errsBag)
   199  	}
   200  }
   201  
   202  // languageMiddleware is a middleware that sets the language of a request.
   203  //
   204  // Uses the "Accept-Language" header to determine which language to use. If
   205  // the header is not set or the language is not available, uses the default
   206  // language as fallback.
   207  //
   208  // If "*" is provided, the default language will be used.
   209  // If multiple languages are given, the first available language will be used,
   210  // and if none are available, the default language will be used.
   211  // If no variant is given (for example "en"), the first available variant will be used.
   212  // For example, if "en-US" and "en-UK" are available and the request accepts "en",
   213  // "en-US" will be used.
   214  func languageMiddleware(next Handler) Handler {
   215  	return func(response *Response, request *Request) {
   216  		if header := request.Header().Get("Accept-Language"); len(header) > 0 {
   217  			request.Lang = lang.DetectLanguage(header)
   218  		} else {
   219  			request.Lang = config.GetString("app.defaultLanguage")
   220  		}
   221  		next(response, request)
   222  	}
   223  }