github.com/boomhut/fiber/v2@v2.0.0-20230603160335-b65c856e57d3/router.go (about)

     1  // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️
     2  // 🤖 Github Repository: https://github.com/gofiber/fiber
     3  // 📌 API Documentation: https://docs.gofiber.io
     4  
     5  package fiber
     6  
     7  import (
     8  	"fmt"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/boomhut/fiber/v2/utils"
    16  
    17  	"github.com/valyala/fasthttp"
    18  )
    19  
    20  // Router defines all router handle interface, including app and group router.
    21  type Router interface {
    22  	Use(args ...interface{}) Router
    23  
    24  	Get(path string, handlers ...Handler) Router
    25  	Head(path string, handlers ...Handler) Router
    26  	Post(path string, handlers ...Handler) Router
    27  	Put(path string, handlers ...Handler) Router
    28  	Delete(path string, handlers ...Handler) Router
    29  	Connect(path string, handlers ...Handler) Router
    30  	Options(path string, handlers ...Handler) Router
    31  	Trace(path string, handlers ...Handler) Router
    32  	Patch(path string, handlers ...Handler) Router
    33  
    34  	Add(method, path string, handlers ...Handler) Router
    35  	Static(prefix, root string, config ...Static) Router
    36  	All(path string, handlers ...Handler) Router
    37  
    38  	Group(prefix string, handlers ...Handler) Router
    39  
    40  	Route(prefix string, fn func(router Router), name ...string) Router
    41  
    42  	Mount(prefix string, fiber *App) Router
    43  
    44  	Name(name string) Router
    45  }
    46  
    47  // Route is a struct that holds all metadata for each registered handler.
    48  type Route struct {
    49  	// always keep in sync with the copy method "app.copyRoute"
    50  	// Data for routing
    51  	pos         uint32      // Position in stack -> important for the sort of the matched routes
    52  	use         bool        // USE matches path prefixes
    53  	mount       bool        // Indicated a mounted app on a specific route
    54  	star        bool        // Path equals '*'
    55  	root        bool        // Path equals '/'
    56  	path        string      // Prettified path
    57  	routeParser routeParser // Parameter parser
    58  	group       *Group      // Group instance. used for routes in groups
    59  
    60  	// Public fields
    61  	Method string `json:"method"` // HTTP method
    62  	Name   string `json:"name"`   // Route's name
    63  	//nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine
    64  	Path     string    `json:"path"`   // Original registered route path
    65  	Params   []string  `json:"params"` // Case sensitive param keys
    66  	Handlers []Handler `json:"-"`      // Ctx handlers
    67  }
    68  
    69  func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool {
    70  	// root detectionPath check
    71  	if r.root && detectionPath == "/" {
    72  		return true
    73  		// '*' wildcard matches any detectionPath
    74  	} else if r.star {
    75  		if len(path) > 1 {
    76  			params[0] = path[1:]
    77  		} else {
    78  			params[0] = ""
    79  		}
    80  		return true
    81  	}
    82  	// Does this route have parameters
    83  	if len(r.Params) > 0 {
    84  		// Match params
    85  		if match := r.routeParser.getMatch(detectionPath, path, params, r.use); match {
    86  			// Get params from the path detectionPath
    87  			return match
    88  		}
    89  	}
    90  	// Is this route a Middleware?
    91  	if r.use {
    92  		// Single slash will match or detectionPath prefix
    93  		if r.root || strings.HasPrefix(detectionPath, r.path) {
    94  			return true
    95  		}
    96  		// Check for a simple detectionPath match
    97  	} else if len(r.path) == len(detectionPath) && r.path == detectionPath {
    98  		return true
    99  	}
   100  	// No match
   101  	return false
   102  }
   103  
   104  func (app *App) next(c *Ctx) (bool, error) {
   105  	// Get stack length
   106  	tree, ok := app.treeStack[c.methodINT][c.treePath]
   107  	if !ok {
   108  		tree = app.treeStack[c.methodINT][""]
   109  	}
   110  	lenTree := len(tree) - 1
   111  
   112  	// Loop over the route stack starting from previous index
   113  	for c.indexRoute < lenTree {
   114  		// Increment route index
   115  		c.indexRoute++
   116  
   117  		// Get *Route
   118  		route := tree[c.indexRoute]
   119  
   120  		var match bool
   121  		var err error
   122  		// skip for mounted apps
   123  		if route.mount {
   124  			continue
   125  		}
   126  
   127  		// Check if it matches the request path
   128  		match = route.match(c.detectionPath, c.path, &c.values)
   129  		if !match {
   130  			// No match, next route
   131  			continue
   132  		}
   133  		// Pass route reference and param values
   134  		c.route = route
   135  
   136  		// Non use handler matched
   137  		if !c.matched && !route.use {
   138  			c.matched = true
   139  		}
   140  
   141  		// Execute first handler of route
   142  		c.indexHandler = 0
   143  		if len(route.Handlers) > 0 {
   144  			err = route.Handlers[0](c)
   145  		}
   146  		return match, err // Stop scanning the stack
   147  	}
   148  
   149  	// If c.Next() does not match, return 404
   150  	err := NewError(StatusNotFound, "Cannot "+c.method+" "+c.pathOriginal)
   151  	if !c.matched && app.methodExist(c) {
   152  		// If no match, scan stack again if other methods match the request
   153  		// Moved from app.handler because middleware may break the route chain
   154  		err = ErrMethodNotAllowed
   155  	}
   156  	return false, err
   157  }
   158  
   159  func (app *App) handler(rctx *fasthttp.RequestCtx) { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476
   160  	// Acquire Ctx with fasthttp request from pool
   161  	c := app.AcquireCtx(rctx)
   162  	defer app.ReleaseCtx(c)
   163  
   164  	// handle invalid http method directly
   165  	if c.methodINT == -1 {
   166  		_ = c.Status(StatusBadRequest).SendString("Invalid http method") //nolint:errcheck // It is fine to ignore the error here
   167  		return
   168  	}
   169  
   170  	// Find match in stack
   171  	match, err := app.next(c)
   172  	if err != nil {
   173  		if catch := c.app.ErrorHandler(c, err); catch != nil {
   174  			_ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here
   175  		}
   176  		// TODO: Do we need to return here?
   177  	}
   178  	// Generate ETag if enabled
   179  	if match && app.config.ETag {
   180  		setETag(c, false)
   181  	}
   182  }
   183  
   184  func (app *App) addPrefixToRoute(prefix string, route *Route) *Route {
   185  	prefixedPath := getGroupPath(prefix, route.Path)
   186  	prettyPath := prefixedPath
   187  	// Case sensitive routing, all to lowercase
   188  	if !app.config.CaseSensitive {
   189  		prettyPath = utils.ToLower(prettyPath)
   190  	}
   191  	// Strict routing, remove trailing slashes
   192  	if !app.config.StrictRouting && len(prettyPath) > 1 {
   193  		prettyPath = utils.TrimRight(prettyPath, '/')
   194  	}
   195  
   196  	route.Path = prefixedPath
   197  	route.path = RemoveEscapeChar(prettyPath)
   198  	route.routeParser = parseRoute(prettyPath)
   199  	route.root = false
   200  	route.star = false
   201  
   202  	return route
   203  }
   204  
   205  func (*App) copyRoute(route *Route) *Route {
   206  	return &Route{
   207  		// Router booleans
   208  		use:   route.use,
   209  		mount: route.mount,
   210  		star:  route.star,
   211  		root:  route.root,
   212  
   213  		// Path data
   214  		path:        route.path,
   215  		routeParser: route.routeParser,
   216  		Params:      route.Params,
   217  
   218  		// misc
   219  		pos: route.pos,
   220  
   221  		// Public data
   222  		Path:     route.Path,
   223  		Method:   route.Method,
   224  		Handlers: route.Handlers,
   225  	}
   226  }
   227  
   228  func (app *App) register(method, pathRaw string, group *Group, handlers ...Handler) Router {
   229  	// Uppercase HTTP methods
   230  	method = utils.ToUpper(method)
   231  	// Check if the HTTP method is valid unless it's USE
   232  	if method != methodUse && app.methodInt(method) == -1 {
   233  		panic(fmt.Sprintf("add: invalid http method %s\n", method))
   234  	}
   235  	// is mounted app
   236  	isMount := group != nil && group.app != app
   237  	// A route requires atleast one ctx handler
   238  	if len(handlers) == 0 && !isMount {
   239  		panic(fmt.Sprintf("missing handler in route: %s\n", pathRaw))
   240  	}
   241  	// Cannot have an empty path
   242  	if pathRaw == "" {
   243  		pathRaw = "/"
   244  	}
   245  	// Path always start with a '/'
   246  	if pathRaw[0] != '/' {
   247  		pathRaw = "/" + pathRaw
   248  	}
   249  	// Create a stripped path in-case sensitive / trailing slashes
   250  	pathPretty := pathRaw
   251  	// Case sensitive routing, all to lowercase
   252  	if !app.config.CaseSensitive {
   253  		pathPretty = utils.ToLower(pathPretty)
   254  	}
   255  	// Strict routing, remove trailing slashes
   256  	if !app.config.StrictRouting && len(pathPretty) > 1 {
   257  		pathPretty = utils.TrimRight(pathPretty, '/')
   258  	}
   259  	// Is layer a middleware?
   260  	isUse := method == methodUse
   261  	// Is path a direct wildcard?
   262  	isStar := pathPretty == "/*"
   263  	// Is path a root slash?
   264  	isRoot := pathPretty == "/"
   265  	// Parse path parameters
   266  	parsedRaw := parseRoute(pathRaw)
   267  	parsedPretty := parseRoute(pathPretty)
   268  
   269  	// Create route metadata without pointer
   270  	route := Route{
   271  		// Router booleans
   272  		use:   isUse,
   273  		mount: isMount,
   274  		star:  isStar,
   275  		root:  isRoot,
   276  
   277  		// Path data
   278  		path:        RemoveEscapeChar(pathPretty),
   279  		routeParser: parsedPretty,
   280  		Params:      parsedRaw.params,
   281  
   282  		// Group data
   283  		group: group,
   284  
   285  		// Public data
   286  		Path:     pathRaw,
   287  		Method:   method,
   288  		Handlers: handlers,
   289  	}
   290  	// Increment global handler count
   291  	atomic.AddUint32(&app.handlersCount, uint32(len(handlers)))
   292  
   293  	// Middleware route matches all HTTP methods
   294  	if isUse {
   295  		// Add route to all HTTP methods stack
   296  		for _, m := range app.config.RequestMethods {
   297  			// Create a route copy to avoid duplicates during compression
   298  			r := route
   299  			app.addRoute(m, &r, isMount)
   300  		}
   301  	} else {
   302  		// Add route to stack
   303  		app.addRoute(method, &route, isMount)
   304  	}
   305  	return app
   306  }
   307  
   308  func (app *App) registerStatic(prefix, root string, config ...Static) Router {
   309  	// For security we want to restrict to the current work directory.
   310  	if root == "" {
   311  		root = "."
   312  	}
   313  	// Cannot have an empty prefix
   314  	if prefix == "" {
   315  		prefix = "/"
   316  	}
   317  	// Prefix always start with a '/' or '*'
   318  	if prefix[0] != '/' {
   319  		prefix = "/" + prefix
   320  	}
   321  	// in case sensitive routing, all to lowercase
   322  	if !app.config.CaseSensitive {
   323  		prefix = utils.ToLower(prefix)
   324  	}
   325  	// Strip trailing slashes from the root path
   326  	if len(root) > 0 && root[len(root)-1] == '/' {
   327  		root = root[:len(root)-1]
   328  	}
   329  	// Is prefix a direct wildcard?
   330  	isStar := prefix == "/*"
   331  	// Is prefix a root slash?
   332  	isRoot := prefix == "/"
   333  	// Is prefix a partial wildcard?
   334  	if strings.Contains(prefix, "*") {
   335  		// /john* -> /john
   336  		isStar = true
   337  		prefix = strings.Split(prefix, "*")[0]
   338  		// Fix this later
   339  	}
   340  	prefixLen := len(prefix)
   341  	if prefixLen > 1 && prefix[prefixLen-1:] == "/" {
   342  		// /john/ -> /john
   343  		prefixLen--
   344  		prefix = prefix[:prefixLen]
   345  	}
   346  	const cacheDuration = 10 * time.Second
   347  	// Fileserver settings
   348  	fs := &fasthttp.FS{
   349  		Root:                 root,
   350  		AllowEmptyRoot:       true,
   351  		GenerateIndexPages:   false,
   352  		AcceptByteRange:      false,
   353  		Compress:             false,
   354  		CompressedFileSuffix: app.config.CompressedFileSuffix,
   355  		CacheDuration:        cacheDuration,
   356  		IndexNames:           []string{"index.html"},
   357  		PathRewrite: func(fctx *fasthttp.RequestCtx) []byte {
   358  			path := fctx.Path()
   359  			if len(path) >= prefixLen {
   360  				if isStar && app.getString(path[0:prefixLen]) == prefix {
   361  					path = append(path[0:0], '/')
   362  				} else {
   363  					path = path[prefixLen:]
   364  					if len(path) == 0 || path[len(path)-1] != '/' {
   365  						path = append(path, '/')
   366  					}
   367  				}
   368  			}
   369  			if len(path) > 0 && path[0] != '/' {
   370  				path = append([]byte("/"), path...)
   371  			}
   372  			return path
   373  		},
   374  		PathNotFound: func(fctx *fasthttp.RequestCtx) {
   375  			fctx.Response.SetStatusCode(StatusNotFound)
   376  		},
   377  	}
   378  
   379  	// Set config if provided
   380  	var cacheControlValue string
   381  	var modifyResponse Handler
   382  	if len(config) > 0 {
   383  		maxAge := config[0].MaxAge
   384  		if maxAge > 0 {
   385  			cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge)
   386  		}
   387  		fs.CacheDuration = config[0].CacheDuration
   388  		fs.Compress = config[0].Compress
   389  		fs.AcceptByteRange = config[0].ByteRange
   390  		fs.GenerateIndexPages = config[0].Browse
   391  		if config[0].Index != "" {
   392  			fs.IndexNames = []string{config[0].Index}
   393  		}
   394  		modifyResponse = config[0].ModifyResponse
   395  	}
   396  	fileHandler := fs.NewRequestHandler()
   397  	handler := func(c *Ctx) error {
   398  		// Don't execute middleware if Next returns true
   399  		if len(config) != 0 && config[0].Next != nil && config[0].Next(c) {
   400  			return c.Next()
   401  		}
   402  		// Serve file
   403  		fileHandler(c.fasthttp)
   404  		// Sets the response Content-Disposition header to attachment if the Download option is true
   405  		if len(config) > 0 && config[0].Download {
   406  			c.Attachment()
   407  		}
   408  		// Return request if found and not forbidden
   409  		status := c.fasthttp.Response.StatusCode()
   410  		if status != StatusNotFound && status != StatusForbidden {
   411  			if len(cacheControlValue) > 0 {
   412  				c.fasthttp.Response.Header.Set(HeaderCacheControl, cacheControlValue)
   413  			}
   414  			if modifyResponse != nil {
   415  				return modifyResponse(c)
   416  			}
   417  			return nil
   418  		}
   419  		// Reset response to default
   420  		c.fasthttp.SetContentType("") // Issue #420
   421  		c.fasthttp.Response.SetStatusCode(StatusOK)
   422  		c.fasthttp.Response.SetBodyString("")
   423  		// Next middleware
   424  		return c.Next()
   425  	}
   426  
   427  	// Create route metadata without pointer
   428  	route := Route{
   429  		// Router booleans
   430  		use:  true,
   431  		root: isRoot,
   432  		path: prefix,
   433  		// Public data
   434  		Method:   MethodGet,
   435  		Path:     prefix,
   436  		Handlers: []Handler{handler},
   437  	}
   438  	// Increment global handler count
   439  	atomic.AddUint32(&app.handlersCount, 1)
   440  	// Add route to stack
   441  	app.addRoute(MethodGet, &route)
   442  	// Add HEAD route
   443  	app.addRoute(MethodHead, &route)
   444  	return app
   445  }
   446  
   447  func (app *App) addRoute(method string, route *Route, isMounted ...bool) {
   448  	// Check mounted routes
   449  	var mounted bool
   450  	if len(isMounted) > 0 {
   451  		mounted = isMounted[0]
   452  	}
   453  
   454  	// Get unique HTTP method identifier
   455  	m := app.methodInt(method)
   456  
   457  	// prevent identically route registration
   458  	l := len(app.stack[m])
   459  	if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use && !route.mount && !app.stack[m][l-1].mount {
   460  		preRoute := app.stack[m][l-1]
   461  		preRoute.Handlers = append(preRoute.Handlers, route.Handlers...)
   462  	} else {
   463  		// Increment global route position
   464  		route.pos = atomic.AddUint32(&app.routesCount, 1)
   465  		route.Method = method
   466  		// Add route to the stack
   467  		app.stack[m] = append(app.stack[m], route)
   468  		app.routesRefreshed = true
   469  	}
   470  
   471  	// Execute onRoute hooks & change latestRoute if not adding mounted route
   472  	if !mounted {
   473  		app.mutex.Lock()
   474  		app.latestRoute = route
   475  		if err := app.hooks.executeOnRouteHooks(*route); err != nil {
   476  			panic(err)
   477  		}
   478  		app.mutex.Unlock()
   479  	}
   480  }
   481  
   482  // buildTree build the prefix tree from the previously registered routes
   483  func (app *App) buildTree() *App {
   484  	if !app.routesRefreshed {
   485  		return app
   486  	}
   487  
   488  	// loop all the methods and stacks and create the prefix tree
   489  	for m := range app.config.RequestMethods {
   490  		tsMap := make(map[string][]*Route)
   491  		for _, route := range app.stack[m] {
   492  			treePath := ""
   493  			if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= 3 {
   494  				treePath = route.routeParser.segs[0].Const[:3]
   495  			}
   496  			// create tree stack
   497  			tsMap[treePath] = append(tsMap[treePath], route)
   498  		}
   499  		app.treeStack[m] = tsMap
   500  	}
   501  
   502  	// loop the methods and tree stacks and add global stack and sort everything
   503  	for m := range app.config.RequestMethods {
   504  		tsMap := app.treeStack[m]
   505  		for treePart := range tsMap {
   506  			if treePart != "" {
   507  				// merge global tree routes in current tree stack
   508  				tsMap[treePart] = uniqueRouteStack(append(tsMap[treePart], tsMap[""]...))
   509  			}
   510  			// sort tree slices with the positions
   511  			slc := tsMap[treePart]
   512  			sort.Slice(slc, func(i, j int) bool { return slc[i].pos < slc[j].pos })
   513  		}
   514  	}
   515  	app.routesRefreshed = false
   516  
   517  	return app
   518  }