github.com/System-Glitch/goyave/v3@v3.6.1-0.20210226143142-ac2fe42ee80e/router.go (about)

     1  package goyave
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"os"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/System-Glitch/goyave/v3/cors"
    11  	"github.com/System-Glitch/goyave/v3/helper/filesystem"
    12  )
    13  
    14  type routeMatcher interface {
    15  	match(req *http.Request, match *routeMatch) bool
    16  }
    17  
    18  // Router registers routes to be matched and executes a handler.
    19  type Router struct {
    20  	parent            *Router
    21  	prefix            string
    22  	corsOptions       *cors.Options
    23  	hasCORSMiddleware bool
    24  
    25  	routes         []*Route
    26  	subrouters     []*Router
    27  	statusHandlers map[int]Handler
    28  	namedRoutes    map[string]*Route
    29  	middlewareHolder
    30  	parametrizeable
    31  }
    32  
    33  // TODO openapi.go: make Router and Route implement methods for OpenAPI format conversion (native support)
    34  
    35  var _ http.Handler = (*Router)(nil) // implements http.Handler
    36  var _ routeMatcher = (*Router)(nil) // implements routeMatcher
    37  
    38  // Handler is a controller or middleware function
    39  type Handler func(*Response, *Request)
    40  
    41  type middlewareHolder struct {
    42  	middleware []Middleware
    43  }
    44  
    45  type routeMatch struct {
    46  	route       *Route
    47  	err         error
    48  	currentPath string
    49  	parameters  map[string]string
    50  }
    51  
    52  var (
    53  	errMatchMethodNotAllowed = errors.New("Method not allowed for this route")
    54  	errMatchNotFound         = errors.New("No match for this URI")
    55  
    56  	methodNotAllowedRoute = newRoute(func(response *Response, request *Request) {
    57  		response.Status(http.StatusMethodNotAllowed)
    58  	})
    59  	notFoundRoute = newRoute(func(response *Response, request *Request) {
    60  		response.Status(http.StatusNotFound)
    61  	})
    62  )
    63  
    64  // PanicStatusHandler for the HTTP 500 error.
    65  // If debugging is enabled, writes the error details to the response and
    66  // print stacktrace in the console.
    67  // If debugging is not enabled, writes `{"error": "Internal Server Error"}`
    68  // to the response.
    69  func PanicStatusHandler(response *Response, request *Request) {
    70  	response.error(response.GetError())
    71  	if response.empty {
    72  		message := map[string]string{
    73  			"error": http.StatusText(response.GetStatus()),
    74  		}
    75  		response.JSON(response.GetStatus(), message)
    76  	}
    77  }
    78  
    79  // ErrorStatusHandler a generic status handler for non-success codes.
    80  // Writes the corresponding status message to the response.
    81  func ErrorStatusHandler(response *Response, request *Request) {
    82  	message := map[string]string{
    83  		"error": http.StatusText(response.GetStatus()),
    84  	}
    85  	response.JSON(response.GetStatus(), message)
    86  }
    87  
    88  // ValidationStatusHandler for HTTP 400 and HTTP 422 errors.
    89  // Writes the validation errors to the response.
    90  func ValidationStatusHandler(response *Response, request *Request) {
    91  	message := map[string]interface{}{"validationError": response.GetError()}
    92  	response.JSON(response.GetStatus(), message)
    93  }
    94  
    95  func newRouter() *Router {
    96  	methodNotAllowedRoute.name = "method-not-allowed"
    97  	// Create a fresh regex cache
    98  	// This cache is set to nil when the server starts
    99  	regexCache = make(map[string]*regexp.Regexp, 5)
   100  
   101  	router := &Router{
   102  		parent:            nil,
   103  		prefix:            "",
   104  		hasCORSMiddleware: false,
   105  		statusHandlers:    make(map[int]Handler, 41),
   106  		namedRoutes:       make(map[string]*Route, 5),
   107  		middlewareHolder: middlewareHolder{
   108  			middleware: make([]Middleware, 0, 3),
   109  		},
   110  	}
   111  	router.StatusHandler(PanicStatusHandler, http.StatusInternalServerError)
   112  	router.StatusHandler(ValidationStatusHandler, http.StatusBadRequest, http.StatusUnprocessableEntity)
   113  	for i := 401; i <= 418; i++ {
   114  		router.StatusHandler(ErrorStatusHandler, i)
   115  	}
   116  	for i := 423; i <= 426; i++ {
   117  		router.StatusHandler(ErrorStatusHandler, i)
   118  	}
   119  	router.StatusHandler(ErrorStatusHandler, 421, 428, 429, 431, 444, 451)
   120  	router.StatusHandler(ErrorStatusHandler, 501, 502, 503, 504, 505, 506, 507, 508, 510, 511)
   121  	router.Middleware(recoveryMiddleware, parseRequestMiddleware, languageMiddleware)
   122  	return router
   123  }
   124  
   125  // ServeHTTP dispatches the handler registered in the matched route.
   126  func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   127  	if req.URL.Scheme != "" && req.URL.Scheme != protocol {
   128  		address := getAddress(protocol) + req.URL.Path
   129  		query := req.URL.Query()
   130  		if len(query) != 0 {
   131  			address += "?" + query.Encode()
   132  		}
   133  		http.Redirect(w, req, address, http.StatusPermanentRedirect)
   134  		return
   135  	}
   136  
   137  	match := routeMatch{currentPath: req.URL.Path}
   138  	r.match(req, &match)
   139  	r.requestHandler(&match, w, req)
   140  }
   141  
   142  func (r *Router) match(req *http.Request, match *routeMatch) bool {
   143  	// Check if router itself matches
   144  	var params []string
   145  	if r.parametrizeable.regex != nil {
   146  		params = r.parametrizeable.regex.FindStringSubmatch(match.currentPath)
   147  	} else {
   148  		params = []string{""}
   149  	}
   150  
   151  	if params != nil {
   152  		match.trimCurrentPath(params[0])
   153  		if len(params) > 1 {
   154  			match.mergeParams(r.makeParameters(params))
   155  		}
   156  
   157  		// Check in subrouters first
   158  		for _, router := range r.subrouters {
   159  			if router.match(req, match) {
   160  				if router.prefix == "" && match.route == methodNotAllowedRoute {
   161  					// This allows route groups with subrouters having empty prefix.
   162  					break
   163  				}
   164  				return true
   165  			}
   166  		}
   167  
   168  		// Check if any route matches
   169  		for _, route := range r.routes {
   170  			if route.match(req, match) {
   171  				return true
   172  			}
   173  		}
   174  	}
   175  
   176  	if match.err == errMatchMethodNotAllowed {
   177  		match.route = methodNotAllowedRoute
   178  		return true
   179  	}
   180  
   181  	match.route = notFoundRoute
   182  	return false
   183  }
   184  
   185  func (r *Router) makeParameters(match []string) map[string]string {
   186  	return r.parametrizeable.makeParameters(match, r.parameters)
   187  }
   188  
   189  // Subrouter create a new sub-router from this router.
   190  // Use subrouters to create route groups and to apply middleware to multiple routes.
   191  // CORS options are also inherited.
   192  func (r *Router) Subrouter(prefix string) *Router {
   193  	if prefix == "/" {
   194  		prefix = ""
   195  	}
   196  
   197  	router := &Router{
   198  		parent:            r,
   199  		prefix:            prefix,
   200  		corsOptions:       r.corsOptions,
   201  		hasCORSMiddleware: r.hasCORSMiddleware,
   202  		statusHandlers:    r.copyStatusHandlers(),
   203  		namedRoutes:       r.namedRoutes,
   204  		routes:            make([]*Route, 0, 5), // Typical CRUD has 5 routes
   205  		middlewareHolder: middlewareHolder{
   206  			middleware: nil,
   207  		},
   208  	}
   209  	router.compileParameters(router.prefix, false)
   210  	r.subrouters = append(r.subrouters, router)
   211  	return router
   212  }
   213  
   214  // Middleware apply one or more middleware to the route group.
   215  func (r *Router) Middleware(middleware ...Middleware) {
   216  	if r.middleware == nil {
   217  		r.middleware = make([]Middleware, 0, 3)
   218  	}
   219  	r.middleware = append(r.middleware, middleware...)
   220  }
   221  
   222  // Route register a new route.
   223  //
   224  // Multiple methods can be passed using a pipe-separated string.
   225  //  "PUT|PATCH"
   226  //
   227  // The validation rules set is optional. If you don't want your route
   228  // to be validated, pass "nil".
   229  //
   230  // If the route matches the "GET" method, the "HEAD" method is automatically added
   231  // to the matcher if it's missing.
   232  //
   233  // If the router has CORS options set, the "OPTIONS" method is automatically added
   234  // to the matcher if it's missing, so it allows preflight requests.
   235  //
   236  // Returns the generated route.
   237  func (r *Router) Route(methods string, uri string, handler Handler) *Route {
   238  	return r.registerRoute(methods, uri, handler)
   239  }
   240  
   241  func (r *Router) registerRoute(methods string, uri string, handler Handler) *Route {
   242  	if r.corsOptions != nil && !strings.Contains(methods, "OPTIONS") {
   243  		methods += "|OPTIONS"
   244  	}
   245  
   246  	if strings.Contains(methods, "GET") && !strings.Contains(methods, "HEAD") {
   247  		methods += "|HEAD"
   248  	}
   249  
   250  	if uri == "/" && r.parent != nil {
   251  		uri = ""
   252  	}
   253  
   254  	route := &Route{
   255  		name:    "",
   256  		uri:     uri,
   257  		methods: strings.Split(methods, "|"),
   258  		parent:  r,
   259  		handler: handler,
   260  	}
   261  	route.compileParameters(route.uri, true)
   262  	r.routes = append(r.routes, route)
   263  	return route
   264  }
   265  
   266  // Get registers a new route with the GET and HEAD methods.
   267  func (r *Router) Get(uri string, handler Handler) *Route {
   268  	return r.registerRoute(http.MethodGet, uri, handler)
   269  }
   270  
   271  // Post registers a new route with the POST method.
   272  func (r *Router) Post(uri string, handler Handler) *Route {
   273  	return r.registerRoute(http.MethodPost, uri, handler)
   274  }
   275  
   276  // Put registers a new route with the PUT method.
   277  func (r *Router) Put(uri string, handler Handler) *Route {
   278  	return r.registerRoute(http.MethodPut, uri, handler)
   279  }
   280  
   281  // Patch registers a new route with the PATCH method.
   282  func (r *Router) Patch(uri string, handler Handler) *Route {
   283  	return r.registerRoute(http.MethodPatch, uri, handler)
   284  }
   285  
   286  // Delete registers a new route with the DELETE method.
   287  func (r *Router) Delete(uri string, handler Handler) *Route {
   288  	return r.registerRoute(http.MethodDelete, uri, handler)
   289  }
   290  
   291  // Options registers a new route wit the OPTIONS method.
   292  func (r *Router) Options(uri string, handler Handler) *Route {
   293  	return r.registerRoute(http.MethodOptions, uri, handler)
   294  }
   295  
   296  // GetRoute get a named route.
   297  // Returns nil if the route doesn't exist.
   298  func (r *Router) GetRoute(name string) *Route {
   299  	return r.namedRoutes[name]
   300  }
   301  
   302  // Static serve a directory and its subdirectories of static resources.
   303  // Set the "download" parameter to true if you want the files to be sent as an attachment
   304  // instead of an inline element.
   305  //
   306  // If no file is given in the url, or if the given file is a directory, the handler will
   307  // send the "index.html" file if it exists.
   308  func (r *Router) Static(uri string, directory string, download bool, middleware ...Middleware) {
   309  	r.registerRoute(http.MethodGet, uri+"{resource:.*}", staticHandler(directory, download)).Middleware(middleware...)
   310  }
   311  
   312  // CORS set the CORS options for this route group.
   313  // If the options are not nil, the CORS middleware is automatically added.
   314  func (r *Router) CORS(options *cors.Options) {
   315  	r.corsOptions = options
   316  	if options != nil && !r.hasCORSMiddleware {
   317  		r.Middleware(corsMiddleware)
   318  		r.hasCORSMiddleware = true
   319  	}
   320  }
   321  
   322  // StatusHandler set a handler for responses with an empty body.
   323  // The handler will be automatically executed if the request's life-cycle reaches its end
   324  // and nothing has been written in the response body.
   325  //
   326  // Multiple status codes can be given. The handler will be executed if one of them matches.
   327  //
   328  // This method can be used to define custom error handlers for example.
   329  //
   330  // Status handlers are inherited as a copy in sub-routers. Modifying a child's status handler
   331  // will not modify its parent's.
   332  //
   333  // Codes in the 400 and 500 ranges have a default status handler.
   334  func (r *Router) StatusHandler(handler Handler, status int, additionalStatuses ...int) {
   335  	r.statusHandlers[status] = handler
   336  	for _, s := range additionalStatuses {
   337  		r.statusHandlers[s] = handler
   338  	}
   339  }
   340  
   341  func staticHandler(directory string, download bool) Handler {
   342  	return func(response *Response, r *Request) {
   343  		file := r.Params["resource"]
   344  		path := cleanStaticPath(directory, file)
   345  
   346  		var err error
   347  		if download {
   348  			err = response.Download(path, file[strings.LastIndex(file, "/")+1:])
   349  		} else {
   350  			err = response.File(path)
   351  		}
   352  
   353  		if _, ok := err.(*os.PathError); err != nil && !ok {
   354  			ErrLogger.Println(err)
   355  		}
   356  	}
   357  }
   358  
   359  func cleanStaticPath(directory string, file string) string {
   360  	file = strings.TrimPrefix(file, "/")
   361  	path := directory + "/" + file
   362  	if filesystem.IsDirectory(path) {
   363  		if !strings.HasSuffix(path, "/") {
   364  			path += "/"
   365  		}
   366  		path += "index.html"
   367  	}
   368  	return path
   369  }
   370  
   371  func (r *Router) copyStatusHandlers() map[int]Handler {
   372  	cpy := make(map[int]Handler, len(r.statusHandlers))
   373  	for key, value := range r.statusHandlers {
   374  		cpy[key] = value
   375  	}
   376  	return cpy
   377  }
   378  
   379  func (r *Router) requestHandler(match *routeMatch, w http.ResponseWriter, rawRequest *http.Request) {
   380  	request := &Request{
   381  		httpRequest: rawRequest,
   382  		route:       match.route,
   383  		corsOptions: r.corsOptions,
   384  		Rules:       match.route.validationRules,
   385  		Params:      match.parameters,
   386  		Extra:       map[string]interface{}{},
   387  	}
   388  	response := newResponse(w, rawRequest)
   389  	handler := match.route.handler
   390  
   391  	// Validate last.
   392  	// Allows custom middleware to be executed after core
   393  	// middleware and before validation.
   394  	handler = validateRequestMiddleware(handler)
   395  
   396  	// Route-specific middleware is executed after router middleware
   397  	handler = match.route.applyMiddleware(handler)
   398  
   399  	parent := match.route.parent
   400  	for parent != nil {
   401  		handler = parent.applyMiddleware(handler)
   402  		parent = parent.parent
   403  	}
   404  
   405  	handler(response, request)
   406  
   407  	r.finalize(response, request)
   408  }
   409  
   410  // finalize the request's life-cycle.
   411  func (r *Router) finalize(response *Response, request *Request) {
   412  	if response.empty {
   413  		if response.status == 0 {
   414  			// If the response is empty, return status 204 to
   415  			// comply with RFC 7231, 6.3.5
   416  			response.Status(http.StatusNoContent)
   417  		} else if statusHandler, ok := r.statusHandlers[response.status]; ok {
   418  			// Status has been set but body is empty.
   419  			// Execute status handler if exists.
   420  			statusHandler(response, request)
   421  		}
   422  	}
   423  
   424  	if !response.wroteHeader && !response.hijacked {
   425  		response.WriteHeader(response.status)
   426  	}
   427  
   428  	response.close()
   429  }
   430  
   431  func (h *middlewareHolder) applyMiddleware(handler Handler) Handler {
   432  	for i := len(h.middleware) - 1; i >= 0; i-- {
   433  		handler = h.middleware[i](handler)
   434  	}
   435  	return handler
   436  }
   437  
   438  func (rm *routeMatch) mergeParams(params map[string]string) {
   439  	if rm.parameters == nil {
   440  		rm.parameters = params
   441  	}
   442  	for k, v := range params {
   443  		rm.parameters[k] = v
   444  	}
   445  }
   446  
   447  func (rm *routeMatch) trimCurrentPath(fullMatch string) {
   448  	rm.currentPath = rm.currentPath[len(fullMatch):]
   449  }