goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/middleware.go (about)

     1  package goyave
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  
     7  	"gorm.io/gorm"
     8  	"goyave.dev/goyave/v5/cors"
     9  	"goyave.dev/goyave/v5/util/errors"
    10  	"goyave.dev/goyave/v5/validation"
    11  )
    12  
    13  // Middleware are special handlers executed in a stack above the controller handler.
    14  // They allow to inspect and filter requests, transform responses or provide additional
    15  // information to the next handlers in the stack.
    16  // Example uses are authentication, authorization, logging, panic recovery, CORS,
    17  // validation, gzip compression.
    18  type Middleware interface {
    19  	Composable
    20  	Handle(Handler) Handler
    21  }
    22  
    23  type middlewareHolder struct {
    24  	middleware []Middleware
    25  }
    26  
    27  func (h *middlewareHolder) applyMiddleware(handler Handler) Handler {
    28  	for i := len(h.middleware) - 1; i >= 0; i-- {
    29  		handler = h.middleware[i].Handle(handler)
    30  	}
    31  	return handler
    32  }
    33  
    34  // GetMiddleware returns a copy of the middleware applied on this holder.
    35  func (h *middlewareHolder) GetMiddleware() []Middleware {
    36  	return append(make([]Middleware, 0, len(h.middleware)), h.middleware...)
    37  }
    38  
    39  func findMiddleware[T Middleware](m []Middleware) T {
    40  	for _, middleware := range m {
    41  		if m, ok := middleware.(T); ok {
    42  			return m
    43  		}
    44  	}
    45  	var zero T
    46  	return zero
    47  }
    48  
    49  func hasMiddleware[T Middleware](m []Middleware) bool {
    50  	for _, middleware := range m {
    51  		if _, ok := middleware.(T); ok {
    52  			return true
    53  		}
    54  	}
    55  	return false
    56  }
    57  
    58  // routeHasMiddleware returns true if the given route or any of its
    59  // parents has a middleware of the T type.
    60  func routeHasMiddleware[T Middleware](route *Route) bool {
    61  	return hasMiddleware[T](route.middleware)
    62  }
    63  
    64  // routerHasMiddleware returns true if the given route or any of its
    65  // parents has a middleware of the T type. Also returns true if the middleware
    66  // is present as global middleware.
    67  func routerHasMiddleware[T Middleware](router *Router) bool {
    68  	return hasMiddleware[T](router.globalMiddleware.middleware) || hasMiddleware[T](router.middleware) || (router.parent != nil && routerHasMiddleware[T](router.parent))
    69  }
    70  
    71  // recoveryMiddleware is a middleware that recovers from panic and sends a 500 error code.
    72  // If debugging is enabled in the config and the default status handler for the 500 status code
    73  // had not been changed, the error is also written in the response.
    74  type recoveryMiddleware struct {
    75  	Component
    76  }
    77  
    78  func (m *recoveryMiddleware) Handle(next Handler) Handler {
    79  	return func(response *Response, request *Request) {
    80  		panicked := true
    81  		defer func() {
    82  			if err := recover(); err != nil || panicked {
    83  				e := errors.NewSkip(err, 4).(*errors.Error) // Skipped: runtime.Callers, NewSkip, this func, runtime.panic
    84  				m.Logger().Error(e)
    85  				response.err = e
    86  				response.status = http.StatusInternalServerError // Force status override
    87  			}
    88  		}()
    89  
    90  		next(response, request)
    91  		panicked = false
    92  	}
    93  }
    94  
    95  // languageMiddleware is a middleware that sets the language of a request.
    96  //
    97  // Uses the "Accept-Language" header to determine which language to use. If
    98  // the header is not set or the language is not available, uses the default
    99  // language as fallback.
   100  //
   101  // If "*" is provided, the default language will be used.
   102  // If multiple languages are given, the first available language will be used,
   103  // and if none are available, the default language will be used.
   104  // If no variant is given (for example "en"), the first available variant will be used.
   105  // For example, if "en-US" and "en-UK" are available and the request accepts "en",
   106  // "en-US" will be used.
   107  type languageMiddleware struct {
   108  	Component
   109  }
   110  
   111  func (m *languageMiddleware) Handle(next Handler) Handler {
   112  	return func(response *Response, request *Request) {
   113  		if header := request.Header().Get("Accept-Language"); len(header) > 0 {
   114  			request.Lang = m.Lang().DetectLanguage(header)
   115  		} else {
   116  			request.Lang = m.Lang().GetDefault()
   117  		}
   118  		next(response, request)
   119  	}
   120  }
   121  
   122  // validateRequestMiddleware is a middleware that validates the request.
   123  // If validation is not rules are not met, sets the response status to 422 Unprocessable Entity
   124  // or 400 Bad Request and the response error (which can be retrieved with `GetError()`) to the
   125  // `validation.Errors` returned by the validator.
   126  // This data can then be used in a status handler.
   127  // This middleware requires the parse middleware.
   128  type validateRequestMiddleware struct {
   129  	Component
   130  	BodyRules  RuleSetFunc
   131  	QueryRules RuleSetFunc
   132  }
   133  
   134  func (m *validateRequestMiddleware) Handle(next Handler) Handler {
   135  	return func(response *Response, r *Request) {
   136  		extra := map[any]any{
   137  			validation.ExtraRequest{}: r,
   138  		}
   139  		contentType := r.Header().Get("Content-Type")
   140  
   141  		var db *gorm.DB
   142  		if m.Config().GetString("database.connection") != "none" {
   143  			db = m.DB().WithContext(r.Context())
   144  		}
   145  		var errsBag *validation.Errors
   146  		var queryErrsBag *validation.Errors
   147  		var errors []error
   148  		if m.QueryRules != nil {
   149  			opt := &validation.Options{
   150  				Data:                     r.Query,
   151  				Rules:                    m.QueryRules(r).AsRules(),
   152  				ConvertSingleValueArrays: true,
   153  				Language:                 r.Lang,
   154  				DB:                       db,
   155  				Config:                   m.Config(),
   156  				Logger:                   m.Logger(),
   157  				Extra:                    extra,
   158  			}
   159  			r.Extra[ExtraQueryValidationRules{}] = opt.Rules
   160  			var err []error
   161  			queryErrsBag, err = validation.Validate(opt)
   162  			if queryErrsBag != nil {
   163  				r.Extra[ExtraQueryValidationError{}] = queryErrsBag
   164  			}
   165  			if err != nil {
   166  				errors = append(errors, err...)
   167  			}
   168  		}
   169  		if m.BodyRules != nil {
   170  			opt := &validation.Options{
   171  				Data:                     r.Data,
   172  				Rules:                    m.BodyRules(r).AsRules(),
   173  				ConvertSingleValueArrays: !strings.HasPrefix(contentType, "application/json"),
   174  				Language:                 r.Lang,
   175  				DB:                       db,
   176  				Config:                   m.Config(),
   177  				Logger:                   m.Logger(),
   178  				Extra:                    extra,
   179  			}
   180  			r.Extra[ExtraBodyValidationRules{}] = opt.Rules
   181  			var err []error
   182  			errsBag, err = validation.Validate(opt)
   183  			if errsBag != nil {
   184  				r.Extra[ExtraValidationError{}] = errsBag
   185  			}
   186  			if err != nil {
   187  				errors = append(errors, err...)
   188  			}
   189  			r.Data = opt.Data
   190  		}
   191  
   192  		if len(errors) != 0 {
   193  			response.Error(errors)
   194  			return
   195  		}
   196  
   197  		if errsBag != nil || queryErrsBag != nil {
   198  			response.Status(http.StatusUnprocessableEntity)
   199  			return
   200  		}
   201  
   202  		next(response, r)
   203  	}
   204  }
   205  
   206  type corsMiddleware struct {
   207  	Component
   208  }
   209  
   210  func (m *corsMiddleware) Handle(next Handler) Handler {
   211  	return func(response *Response, request *Request) {
   212  		o, ok := request.Route.LookupMeta(MetaCORS)
   213  		if !ok || o == nil || o == (*cors.Options)(nil) {
   214  			next(response, request)
   215  			return
   216  		}
   217  
   218  		options := o.(*cors.Options)
   219  		headers := response.Header()
   220  		requestHeaders := request.Header()
   221  
   222  		if request.Method() == http.MethodOptions && requestHeaders.Get("Access-Control-Request-Method") == "" {
   223  			response.Status(http.StatusBadRequest)
   224  			return
   225  		}
   226  
   227  		options.ConfigureCommon(headers, requestHeaders)
   228  
   229  		if request.Method() == http.MethodOptions {
   230  			options.HandlePreflight(headers, requestHeaders)
   231  			if options.OptionsPassthrough {
   232  				next(response, request)
   233  			} else {
   234  				response.WriteHeader(http.StatusNoContent)
   235  			}
   236  		} else {
   237  			next(response, request)
   238  		}
   239  	}
   240  }