github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/router.go (about)

     1  package znet
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"path"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"github.com/sohaha/zlsgo/zfile"
    12  )
    13  
    14  var (
    15  	// ErrGenerateParameters is returned when generating a route withRequestLog wrong parameters.
    16  	ErrGenerateParameters = errors.New("params contains wrong parameters")
    17  
    18  	// ErrNotFoundRoute is returned when generating a route that can not find route in tree.
    19  	ErrNotFoundRoute = errors.New("cannot find route in tree")
    20  
    21  	// ErrNotFoundMethod is returned when generating a route that can not find method in tree.
    22  	ErrNotFoundMethod = errors.New("cannot find method in tree")
    23  
    24  	// ErrPatternGrammar is returned when generating a route that pattern grammar error.
    25  	ErrPatternGrammar = errors.New("pattern grammar error")
    26  
    27  	methods = map[string]struct{}{
    28  		http.MethodGet:     {},
    29  		http.MethodPost:    {},
    30  		http.MethodPut:     {},
    31  		http.MethodDelete:  {},
    32  		http.MethodPatch:   {},
    33  		http.MethodHead:    {},
    34  		http.MethodOptions: {},
    35  		http.MethodConnect: {},
    36  		http.MethodTrace:   {},
    37  	}
    38  )
    39  
    40  type (
    41  	// contextKeyType Private Value Structure for Each Request
    42  	contextKeyType struct{}
    43  )
    44  
    45  func temporarilyTurnOffTheLog(e *Engine, msg string) func() {
    46  	mode := e.webMode
    47  	e.webMode = prodCode
    48  	return func() {
    49  		e.webMode = mode
    50  		if e.IsDebug() {
    51  			e.Log.Debug(msg)
    52  		}
    53  	}
    54  }
    55  
    56  func (e *Engine) StaticFS(relativePath string, fs http.FileSystem, moreHandler ...Handler) {
    57  	var urlPattern string
    58  
    59  	ap := Utils.CompletionPath(relativePath, e.router.prefix)
    60  	f := fmt.Sprintf("%%s %%-40s -> %s/", zfile.SafePath(fmt.Sprintf("%s", fs)))
    61  	if e.webMode == testCode {
    62  		f = "%s %-40s"
    63  	}
    64  	log := temporarilyTurnOffTheLog(e, routeLog(e.Log, f, "FILE", ap))
    65  	fileServer := http.StripPrefix(ap, http.FileServer(fs))
    66  	handler := func(c *Context) {
    67  		for key, value := range c.header {
    68  			for i := range value {
    69  				header := value[i]
    70  				if i == 0 {
    71  					c.Writer.Header().Set(key, header)
    72  				} else {
    73  					c.Writer.Header().Add(key, header)
    74  				}
    75  			}
    76  		}
    77  		fileServer.ServeHTTP(c.Writer, c.Request)
    78  	}
    79  	if strings.HasSuffix(relativePath, "/") {
    80  		urlPattern = path.Join(relativePath, "*")
    81  		e.GET(relativePath, handler, moreHandler...)
    82  	} else {
    83  		urlPattern = path.Join(relativePath, "/*")
    84  		e.GET(relativePath, func(c *Context) {
    85  			c.Redirect(relativePath + "/")
    86  		}, moreHandler...)
    87  		e.GET(relativePath+"/", handler, moreHandler...)
    88  	}
    89  	e.GET(urlPattern, handler, moreHandler...)
    90  	e.HEAD(urlPattern, handler, moreHandler...)
    91  	e.OPTIONS(urlPattern, handler, moreHandler...)
    92  	log()
    93  }
    94  
    95  func (e *Engine) Static(relativePath, root string, moreHandler ...Handler) {
    96  	e.StaticFS(relativePath, http.Dir(root), moreHandler...)
    97  }
    98  
    99  func (e *Engine) StaticFile(relativePath, filepath string) {
   100  	handler := func(c *Context) {
   101  		c.File(filepath)
   102  	}
   103  
   104  	tip := routeLog(e.Log, "%s %-40s -> "+zfile.SafePath(filepath)+"/", "FILE", relativePath)
   105  	if e.webMode == testCode {
   106  		tip = routeLog(e.Log, "%s %-40s", "FILE", relativePath)
   107  	}
   108  	log := temporarilyTurnOffTheLog(e, tip)
   109  	e.GET(relativePath, handler)
   110  	e.HEAD(relativePath, handler)
   111  	log()
   112  }
   113  
   114  func (e *Engine) Any(path string, action Handler, moreHandler ...Handler) *Engine {
   115  	middleware, firstMiddleware := handlerFuncs(moreHandler)
   116  	_, l, ok := e.handleAny(path, Utils.ParseHandlerFunc(action), middleware, firstMiddleware)
   117  
   118  	if ok {
   119  		routeAddLog(e, "ANY", Utils.CompletionPath(path, e.router.prefix), action, l)
   120  	}
   121  
   122  	return e
   123  }
   124  
   125  func (e *Engine) Customize(method, path string, action Handler, moreHandler ...Handler) *Engine {
   126  	method = strings.ToUpper(method)
   127  	return e.Handle(method, path, action, moreHandler...)
   128  }
   129  
   130  func (e *Engine) GET(path string, action Handler, moreHandler ...Handler) *Engine {
   131  	return e.Handle(http.MethodGet, path, action, moreHandler...)
   132  }
   133  
   134  func (e *Engine) POST(path string, action Handler, moreHandler ...Handler) *Engine {
   135  	return e.Handle(http.MethodPost, path, action, moreHandler...)
   136  }
   137  
   138  func (e *Engine) DELETE(path string, action Handler, moreHandler ...Handler) *Engine {
   139  	return e.Handle(http.MethodDelete, path, action, moreHandler...)
   140  }
   141  
   142  func (e *Engine) PUT(path string, action Handler, moreHandler ...Handler) *Engine {
   143  	return e.Handle(http.MethodPut, path, action, moreHandler...)
   144  }
   145  
   146  func (e *Engine) PATCH(path string, action Handler, moreHandler ...Handler) *Engine {
   147  	return e.Handle(http.MethodPatch, path, action, moreHandler...)
   148  }
   149  
   150  func (e *Engine) HEAD(path string, action Handler, moreHandler ...Handler) *Engine {
   151  	return e.Handle(http.MethodHead, path, action, moreHandler...)
   152  }
   153  
   154  func (e *Engine) OPTIONS(path string, action Handler, moreHandler ...Handler) *Engine {
   155  	return e.Handle(http.MethodOptions, path, action, moreHandler...)
   156  }
   157  
   158  func (e *Engine) CONNECT(path string, action Handler, moreHandler ...Handler) *Engine {
   159  	return e.Handle(http.MethodConnect, path, action, moreHandler...)
   160  }
   161  
   162  func (e *Engine) TRACE(path string, action Handler, moreHandler ...Handler) *Engine {
   163  	return e.Handle(http.MethodTrace, path, action, moreHandler...)
   164  }
   165  
   166  func (e *Engine) GETAndName(path string, action Handler, routeName string) *Engine {
   167  	e.router.parameters.routeName = routeName
   168  	defer func() { e.router.parameters.routeName = "" }()
   169  	return e.GET(path, action)
   170  }
   171  
   172  func (e *Engine) POSTAndName(path string, action Handler, routeName string) *Engine {
   173  	e.router.parameters.routeName = routeName
   174  	defer func() { e.router.parameters.routeName = "" }()
   175  	return e.POST(path, action)
   176  }
   177  
   178  func (e *Engine) DELETEAndName(path string, action Handler, routeName string) *Engine {
   179  	e.router.parameters.routeName = routeName
   180  	defer func() { e.router.parameters.routeName = "" }()
   181  	return e.DELETE(path, action)
   182  }
   183  
   184  func (e *Engine) PUTAndName(path string, action Handler, routeName string) *Engine {
   185  	e.router.parameters.routeName = routeName
   186  	defer func() { e.router.parameters.routeName = "" }()
   187  	return e.PUT(path, action)
   188  }
   189  
   190  func (e *Engine) PATCHAndName(path string, action Handler, routeName string) *Engine {
   191  	e.router.parameters.routeName = routeName
   192  	defer func() { e.router.parameters.routeName = "" }()
   193  	return e.PATCH(path, action)
   194  }
   195  
   196  func (e *Engine) HEADAndName(path string, action Handler, routeName string) *Engine {
   197  	e.router.parameters.routeName = routeName
   198  	defer func() { e.router.parameters.routeName = "" }()
   199  	return e.HEAD(path, action)
   200  }
   201  
   202  func (e *Engine) OPTIONSAndName(path string, action Handler, routeName string) *Engine {
   203  	e.router.parameters.routeName = routeName
   204  	defer func() { e.router.parameters.routeName = "" }()
   205  	return e.OPTIONS(path, action)
   206  }
   207  
   208  func (e *Engine) CONNECTAndName(path string, action Handler, routeName string) *Engine {
   209  	e.router.parameters.routeName = routeName
   210  	defer func() { e.router.parameters.routeName = "" }()
   211  	return e.CONNECT(path, action)
   212  }
   213  
   214  func (e *Engine) TRACEAndName(path string, action Handler, routeName string) *Engine {
   215  	e.router.parameters.routeName = routeName
   216  	defer func() { e.router.parameters.routeName = "" }()
   217  	return e.TRACE(path, action)
   218  }
   219  
   220  func (e *Engine) Group(prefix string, groupHandle ...func(e *Engine)) (engine *Engine) {
   221  	if prefix == "" {
   222  		return e
   223  	}
   224  	rprefix := e.router.prefix
   225  	if rprefix != "" {
   226  		prefix = Utils.CompletionPath(prefix, rprefix)
   227  	}
   228  	middleware := make([]handlerFn, len(e.router.middleware))
   229  	copy(middleware, e.router.middleware)
   230  	route := &router{
   231  		prefix:     prefix,
   232  		trees:      e.router.trees,
   233  		middleware: middleware,
   234  		notFound:   e.router.notFound,
   235  	}
   236  	engine = &Engine{
   237  		router:              route,
   238  		webMode:             e.webMode,
   239  		webModeName:         e.webModeName,
   240  		MaxMultipartMemory:  e.MaxMultipartMemory,
   241  		customMethodType:    e.customMethodType,
   242  		Log:                 e.Log,
   243  		Cache:               e.Cache,
   244  		BindStructCase:      e.BindStructCase,
   245  		BindStructDelimiter: e.BindStructDelimiter,
   246  		BindStructSuffix:    e.BindStructSuffix,
   247  		templateFuncMap:     e.templateFuncMap,
   248  		template:            e.template,
   249  		injector:            e.injector,
   250  	}
   251  	engine.pool.New = func() interface{} {
   252  		return e.NewContext(nil, nil)
   253  	}
   254  	if len(groupHandle) > 0 {
   255  		groupHandle[0](engine)
   256  	}
   257  	return
   258  }
   259  
   260  func (e *Engine) GenerateURL(method string, routeName string, params map[string]string) (string, error) {
   261  	tree, ok := e.router.trees[method]
   262  	if !ok {
   263  		return "", ErrNotFoundMethod
   264  	}
   265  
   266  	route, ok := tree.routes[routeName]
   267  	if !ok {
   268  		return "", ErrNotFoundRoute
   269  	}
   270  
   271  	ps := strings.Split(route.path, "/")
   272  	l := len(ps)
   273  	segments := make([]string, 0, l)
   274  	for i := 0; i < l; i++ {
   275  		segment := ps[i]
   276  		if segment != "" {
   277  			if string(segment[0]) == ":" {
   278  				key := params[segment[1:]]
   279  				re := regexp.MustCompile(defaultPattern)
   280  				if one := re.Find([]byte(key)); one == nil {
   281  					return "", ErrGenerateParameters
   282  				}
   283  				segments = append(segments, key)
   284  				continue
   285  			}
   286  
   287  			if string(segment[0]) == "{" {
   288  				segmentLen := len(segment)
   289  				if string(segment[segmentLen-1]) == "}" {
   290  					splitRes := strings.Split(segment[1:segmentLen-1], ":")
   291  					re := regexp.MustCompile(splitRes[1])
   292  					key := params[splitRes[0]]
   293  					if one := re.Find([]byte(key)); one == nil {
   294  						return "", ErrGenerateParameters
   295  					}
   296  					segments = append(segments, key)
   297  					continue
   298  				}
   299  
   300  				return "", ErrPatternGrammar
   301  			}
   302  			if string(segment[len(segment)-1]) == "}" && string(segment[0]) != "{" {
   303  				return "", ErrPatternGrammar
   304  			}
   305  		}
   306  
   307  		segments = append(segments, segment)
   308  
   309  		continue
   310  	}
   311  
   312  	return strings.Join(segments, "/"), nil
   313  }
   314  
   315  func (e *Engine) PreHandler(preHandler Handler) {
   316  	e.preHandler = preHandler
   317  }
   318  
   319  func (e *Engine) NotFoundHandler(handler Handler) {
   320  	e.router.notFound = Utils.ParseHandlerFunc(handler)
   321  }
   322  
   323  // Deprecated: please use znet.Recovery(func(c *Context, err error) {})
   324  // PanicHandler is used for handling panics
   325  func (e *Engine) PanicHandler(handler ErrHandlerFunc) {
   326  	e.Use(Recovery(handler))
   327  }
   328  
   329  // GetTrees Load Trees
   330  func (e *Engine) GetTrees() map[string]*Tree {
   331  	return e.router.trees
   332  }
   333  
   334  // Handle registers new request handlerFn
   335  func (e *Engine) Handle(method string, path string, action Handler, moreHandler ...Handler) *Engine {
   336  	handler, firsthandle := handlerFuncs(moreHandler)
   337  	p, l, ok := e.addHandle(method, path, Utils.ParseHandlerFunc(action), firsthandle, handler)
   338  	if !ok {
   339  		return e
   340  	}
   341  
   342  	routeAddLog(e, method, p, action, l)
   343  	return e
   344  }
   345  
   346  func (e *Engine) addHandle(method string, path string, handle handlerFn, beforehandle []handlerFn, moreHandler []handlerFn) (string, int, bool) {
   347  	if _, ok := methods[method]; !ok {
   348  		e.Log.Fatal(method + " is invalid method")
   349  	}
   350  
   351  	tree, ok := e.router.trees[method]
   352  	if !ok {
   353  		tree = NewTree()
   354  		e.router.trees[method] = tree
   355  	}
   356  
   357  	path = Utils.CompletionPath(path, e.router.prefix)
   358  	if routeName := e.router.parameters.routeName; routeName != "" {
   359  		tree.parameters.routeName = routeName
   360  	}
   361  
   362  	nodes := tree.Find(path, false)
   363  	if len(nodes) > 0 {
   364  		node := nodes[0]
   365  		if e.webMode != quietCode && node.path == path && node.handle != nil {
   366  			e.Log.Track("duplicate route definition: ["+method+"]"+path, 3, 1)
   367  			return "", 0, false
   368  		}
   369  	}
   370  
   371  	middleware := make([]handlerFn, len(e.router.middleware))
   372  	{
   373  		copy(middleware, e.router.middleware)
   374  		if len(moreHandler) > 0 {
   375  			middleware = append(middleware, moreHandler...)
   376  		}
   377  	}
   378  
   379  	if len(beforehandle) > 0 {
   380  		middleware = append(beforehandle, middleware...)
   381  	}
   382  
   383  	tree.Add(path, handle, middleware...)
   384  	tree.parameters.routeName = ""
   385  	return path, len(middleware) + 1, true
   386  }
   387  
   388  func (e *Engine) handleAny(path string, handle handlerFn, beforehandle []handlerFn, moreHandler []handlerFn) (p string, l int, ok bool) {
   389  	for key := range methods {
   390  		p, l, ok = e.addHandle(key, path, handle, beforehandle, moreHandler)
   391  		if !ok {
   392  			return p, l, false
   393  		}
   394  	}
   395  	return
   396  }
   397  
   398  func (e *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   399  	p := req.URL.Path
   400  	if !e.ShowFavicon && p == "/favicon.ico" {
   401  		return
   402  	}
   403  
   404  	c := e.acquireContext()
   405  	c.clone(w, req)
   406  	defer func() {
   407  		c.write()
   408  		e.releaseContext(c)
   409  	}()
   410  
   411  	if e.AllowQuerySemicolons {
   412  		allowQuerySemicolons(c.Request)
   413  	}
   414  
   415  	// custom method type
   416  	if req.Method == "POST" && e.customMethodType != "" {
   417  		if tmpType := c.GetHeader(e.customMethodType); tmpType != "" {
   418  			req.Method = strings.ToUpper(tmpType)
   419  		}
   420  	}
   421  	if e.preHandler != nil {
   422  		if preHandler, ok := e.preHandler.(func(*Context) bool); ok {
   423  			if preHandler(c) {
   424  				return
   425  			}
   426  		} else {
   427  			err := Utils.ParseHandlerFunc(e.preHandler)(c)
   428  			if err != nil {
   429  				c.renderError(c, err)
   430  				c.Abort()
   431  				return
   432  			}
   433  		}
   434  	}
   435  	if c.stopHandle.Load() {
   436  		return
   437  	}
   438  
   439  	if _, ok := e.router.trees[req.Method]; !ok {
   440  		e.handleNotFound(c)
   441  		return
   442  	}
   443  
   444  	if e.FindHandle(c, req, p, true) {
   445  		e.handleNotFound(c)
   446  	}
   447  }
   448  
   449  func (e *Engine) FindHandle(rw *Context, req *http.Request, requestURL string, applyMiddleware bool) (not bool) {
   450  	t, ok := e.router.trees[req.Method]
   451  	if !ok {
   452  		return true
   453  	}
   454  
   455  	handler, middleware, ok := Utils.TreeFind(t, requestURL)
   456  	if !ok {
   457  		return true
   458  	}
   459  
   460  	if applyMiddleware {
   461  		handleAction(rw, handler, middleware)
   462  	} else {
   463  		handleAction(rw, handler, []handlerFn{})
   464  	}
   465  	return false
   466  }
   467  
   468  func (e *Engine) Use(middleware ...Handler) {
   469  	if len(middleware) > 0 {
   470  		middleware, firstMiddleware := handlerFuncs(middleware)
   471  		e.router.middleware = append(firstMiddleware, e.router.middleware...)
   472  		e.router.middleware = append(e.router.middleware, middleware...)
   473  	}
   474  }
   475  
   476  func (e *Engine) handleNotFound(c *Context) {
   477  	middleware := e.router.middleware
   478  	c.prevData.Code.Store(http.StatusNotFound)
   479  
   480  	if e.router.notFound != nil {
   481  		handleAction(c, e.router.notFound, middleware)
   482  		return
   483  	}
   484  
   485  	handleAction(c, func(_ *Context) error {
   486  		c.Byte(404, []byte("404 page not found"))
   487  		return nil
   488  	}, middleware)
   489  }
   490  
   491  func (e *Engine) HandleNotFound(c *Context) {
   492  	e.handleNotFound(c)
   493  	c.stopHandle.Store(true)
   494  }
   495  
   496  func handleAction(c *Context, handler handlerFn, middleware []handlerFn) {
   497  	c.middleware = append(middleware, handler)
   498  	c.Next()
   499  }
   500  
   501  // Match checks if the request matches the route pattern
   502  func (e *Engine) Match(requestURL string, path string) bool {
   503  	_, ok := Utils.URLMatchAndParse(requestURL, path)
   504  	return ok
   505  }