github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/utils.go (about)

     1  package znet
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"html/template"
     7  	"net/http"
     8  	"path"
     9  	"regexp"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/sohaha/zlsgo/zcache"
    15  	"github.com/sohaha/zlsgo/zdi"
    16  	"github.com/sohaha/zlsgo/zfile"
    17  	"github.com/sohaha/zlsgo/zstring"
    18  	"github.com/sohaha/zlsgo/zutil"
    19  )
    20  
    21  type utils struct {
    22  	ContextKey contextKeyType
    23  }
    24  
    25  var Utils = utils{
    26  	ContextKey: contextKeyType{},
    27  }
    28  
    29  const (
    30  	defaultPattern = `[^\/]+`
    31  	idPattern      = `[\d]+`
    32  	idKey          = `id`
    33  	allPattern     = `.*`
    34  	allKey         = `*`
    35  )
    36  
    37  var matchCache = zcache.NewFast(func(o *zcache.Options) {
    38  	o.LRU2Cap = 100
    39  })
    40  
    41  // URLMatchAndParse checks if the request matches the route path and returns a map of the parsed
    42  func (_ utils) URLMatchAndParse(requestURL string, path string) (matchParams map[string]string, ok bool) {
    43  	var (
    44  		pattern   string
    45  		matchName []string
    46  	)
    47  	matchParams, ok = make(map[string]string), true
    48  	if v, ok := matchCache.Get(path); ok {
    49  		m := v.([]string)
    50  		pattern = m[0]
    51  		matchName = m[1:]
    52  	} else {
    53  		res := strings.Split(path, "/")
    54  		pattern, matchName = parsePattern(res, "/")
    55  		matchCache.Set(path, append([]string{pattern}, matchName...))
    56  	}
    57  
    58  	if pattern == "" {
    59  		return nil, false
    60  	}
    61  
    62  	rr, err := zstring.RegexExtract(pattern, requestURL)
    63  	if err != nil || len(rr) == 0 {
    64  		return nil, false
    65  	}
    66  
    67  	if rr[0] == requestURL {
    68  		rr = rr[1:]
    69  		if len(matchName) != 0 {
    70  			for k, v := range rr {
    71  				if key := matchName[k]; key != "" {
    72  					matchParams[key] = v
    73  				}
    74  			}
    75  		}
    76  		return
    77  	}
    78  
    79  	return nil, false
    80  }
    81  
    82  func parsePattern(res []string, prefix string) (string, []string) {
    83  	var (
    84  		matchName []string
    85  		pattern   string
    86  	)
    87  	l := len(res)
    88  	for i := 0; i < l; i++ {
    89  		str := res[i]
    90  		if str == "" {
    91  			continue
    92  		}
    93  		if strings.HasSuffix(str, "\\") && i < l-1 {
    94  			res[i+1] = str[:len(str)-1] + "/" + res[i+1]
    95  			continue
    96  		}
    97  		pattern = pattern + prefix
    98  		l := len(str) - 1
    99  		i := strings.IndexRune(str, ')')
   100  		i2 := strings.IndexRune(str, '(')
   101  		firstChar := str[0]
   102  		// TODO Need to optimize
   103  		if i2 != -1 && i != -1 {
   104  			r, err := regexp.Compile(str)
   105  			if err != nil {
   106  				return "", nil
   107  			}
   108  			names := r.SubexpNames()
   109  			matchName = append(matchName, names[1:]...)
   110  			pattern = pattern + str
   111  		} else if firstChar == ':' {
   112  			matchStr := str
   113  			res := strings.Split(matchStr, ":")
   114  			key := res[1]
   115  			if key == "full" {
   116  				key = allKey
   117  			}
   118  			matchName = append(matchName, key)
   119  			if key == idKey {
   120  				pattern = pattern + "(" + idPattern + ")"
   121  			} else if key == allKey {
   122  				pattern = pattern + "(" + allPattern + ")"
   123  			} else {
   124  				pattern = pattern + "(" + defaultPattern + ")"
   125  			}
   126  		} else if firstChar == '*' {
   127  			pattern = pattern + "(" + allPattern + ")"
   128  			matchName = append(matchName, allKey)
   129  		} else {
   130  			i := strings.IndexRune(str, '}')
   131  			i2 := strings.IndexRune(str, '{')
   132  			if i2 != -1 && i != -1 {
   133  				if i == l && i2 == 0 {
   134  					matchStr := str[1:l]
   135  					res := strings.Split(matchStr, ":")
   136  					matchName = append(matchName, res[0])
   137  					pattern = pattern + "(" + res[1] + ")"
   138  				} else {
   139  					if i2 != 0 {
   140  						p, m := parsePattern([]string{str[:i2]}, "")
   141  						if p != "" {
   142  							pattern = pattern + p
   143  							matchName = append(matchName, m...)
   144  						}
   145  						str = str[i2:]
   146  					}
   147  					if i >= 0 {
   148  						ni := i - i2
   149  						if ni < 0 {
   150  							return "", nil
   151  						}
   152  						matchStr := str[1:ni]
   153  						res := strings.Split(matchStr, ":")
   154  						matchName = append(matchName, res[0])
   155  						pattern = pattern + "(" + res[1] + ")"
   156  						p, m := parsePattern([]string{str[ni+1:]}, "")
   157  						if p != "" {
   158  							pattern = pattern + p
   159  							matchName = append(matchName, m...)
   160  						}
   161  					} else {
   162  						pattern = pattern + str
   163  					}
   164  				}
   165  			} else {
   166  				pattern = pattern + str
   167  			}
   168  		}
   169  	}
   170  
   171  	return pattern, matchName
   172  }
   173  
   174  func getAddr(addr string) string {
   175  	var port int
   176  	if strings.Contains(addr, ":") {
   177  		port, _ = strconv.Atoi(strings.Split(addr, ":")[1])
   178  	} else {
   179  		port, _ = strconv.Atoi(addr)
   180  		addr = ":" + addr
   181  	}
   182  	if port != 0 {
   183  		return addr
   184  	}
   185  	port, _ = Port(port, true)
   186  	return ":" + strconv.Itoa(port)
   187  }
   188  
   189  func getHostname(addr string, isTls bool) string {
   190  	hostname := "http://"
   191  	if isTls {
   192  		hostname = "https://"
   193  	}
   194  	return hostname + resolveHostname(addr)
   195  }
   196  
   197  func (u utils) TreeFind(t *Tree, path string) (handlerFn, []handlerFn, bool) {
   198  	nodes := t.Find(path, false)
   199  	for i := range nodes {
   200  		node := nodes[i]
   201  		if node.handle != nil {
   202  			if node.path == path {
   203  				return node.handle, node.middleware, true
   204  			}
   205  		}
   206  	}
   207  
   208  	if len(nodes) == 0 {
   209  		res := strings.Split(path, "/")
   210  		p := ""
   211  		if len(res) == 1 {
   212  			p = res[0]
   213  		} else {
   214  			p = res[1]
   215  		}
   216  		nodes := t.Find(p, true)
   217  		for _, node := range nodes {
   218  			if handler := node.handle; handler != nil && node.path != path {
   219  				if matchParamsMap, ok := u.URLMatchAndParse(path, node.path); ok {
   220  					return func(c *Context) error {
   221  						req := c.Request
   222  						ctx := context.WithValue(req.Context(), u.ContextKey, matchParamsMap)
   223  						c.Request = req.WithContext(ctx)
   224  						return node.Handle()(c)
   225  					}, node.middleware, true
   226  				}
   227  			}
   228  		}
   229  	}
   230  	return nil, nil, false
   231  }
   232  
   233  func (_ utils) CompletionPath(p, prefix string) string {
   234  	suffix := strings.HasSuffix(p, "/")
   235  	p = strings.TrimLeft(p, "/")
   236  	prefix = strings.TrimRight(prefix, "/")
   237  	path := zstring.TrimSpace(path.Join("/", prefix, p))
   238  
   239  	if path == "" {
   240  		path = "/"
   241  	} else if suffix && path != "/" {
   242  		path = path + "/"
   243  	}
   244  
   245  	return path
   246  }
   247  
   248  // func (utils) IsAbort(c *Context) bool {
   249  // 	return c.stopHandle.Load()
   250  // }
   251  
   252  // AppendHandler append handler to context, Use caution
   253  func (utils) AppendHandler(c *Context, handlers ...Handler) {
   254  	hl := len(handlers)
   255  	if hl == 0 {
   256  		return
   257  	}
   258  
   259  	for i := range handlers {
   260  		c.middleware = append(c.middleware, Utils.ParseHandlerFunc(handlers[i]))
   261  	}
   262  }
   263  
   264  func resolveAddr(addrString string, tlsConfig ...TlsCfg) addrSt {
   265  	cfg := addrSt{
   266  		addr: addrString,
   267  	}
   268  	if len(tlsConfig) > 0 {
   269  		cfg.Cert = tlsConfig[0].Cert
   270  		cfg.HTTPAddr = tlsConfig[0].HTTPAddr
   271  		cfg.HTTPProcessing = tlsConfig[0].HTTPProcessing
   272  		cfg.Key = tlsConfig[0].Key
   273  		cfg.Config = tlsConfig[0].Config
   274  	}
   275  	return cfg
   276  }
   277  
   278  func resolveHostname(addrString string) string {
   279  	if strings.Index(addrString, ":") == 0 {
   280  		return "127.0.0.1" + addrString
   281  	}
   282  	return addrString
   283  }
   284  
   285  func templateParse(templateFile []string, funcMap template.FuncMap) (t *template.Template, err error) {
   286  	if len(templateFile) == 0 {
   287  		return nil, errors.New("template file cannot be empty")
   288  	}
   289  	file := templateFile[0]
   290  	if len(file) <= 255 && zfile.FileExist(file) {
   291  		for i := range templateFile {
   292  			templateFile[i] = zfile.RealPath(templateFile[i])
   293  		}
   294  		t, err = template.ParseFiles(templateFile...)
   295  		if err == nil && funcMap != nil {
   296  			t.Funcs(funcMap)
   297  		}
   298  	} else {
   299  		t = template.New("")
   300  		if funcMap != nil {
   301  			t.Funcs(funcMap)
   302  		}
   303  		t, err = t.Parse(file)
   304  	}
   305  	return
   306  }
   307  
   308  type tlsRedirectHandler struct {
   309  	Domain string
   310  }
   311  
   312  func (t *tlsRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   313  	http.Redirect(w, r, t.Domain+r.URL.String(), http.StatusMovedPermanently)
   314  }
   315  
   316  func (e *Engine) NewContext(w http.ResponseWriter, req *http.Request) *Context {
   317  	return &Context{
   318  		Writer:        w,
   319  		Request:       req,
   320  		Engine:        e,
   321  		Log:           e.Log,
   322  		Cache:         Cache,
   323  		startTime:     time.Time{},
   324  		header:        map[string][]string{},
   325  		customizeData: map[string]interface{}{},
   326  		stopHandle:    zutil.NewBool(false),
   327  		done:          zutil.NewBool(false),
   328  		prevData: &PrevData{
   329  			Code: zutil.NewInt32(0),
   330  			Type: ContentTypePlain,
   331  		},
   332  	}
   333  }
   334  
   335  func (c *Context) clone(w http.ResponseWriter, r *http.Request) {
   336  	c.Request = r
   337  	c.Writer = w
   338  	c.injector = zdi.New(c.Engine.injector)
   339  	c.injector.Maps(c)
   340  	c.startTime = time.Now()
   341  	c.renderError = defErrorHandler()
   342  	c.stopHandle.Store(false)
   343  	c.done.Store(false)
   344  }
   345  
   346  func (e *Engine) acquireContext() *Context {
   347  	return e.pool.Get().(*Context)
   348  }
   349  
   350  func (e *Engine) releaseContext(c *Context) {
   351  	c.prevData.Code.Store(0)
   352  	c.mu.Lock()
   353  	c.middleware = c.middleware[0:0]
   354  	c.customizeData = map[string]interface{}{}
   355  	c.header = map[string][]string{}
   356  	c.render = nil
   357  	c.renderError = nil
   358  	c.cacheJSON = nil
   359  	c.cacheQuery = nil
   360  	c.cacheForm = nil
   361  	c.injector = nil
   362  	c.rawData = nil
   363  	c.prevData.Content = c.prevData.Content[0:0]
   364  	c.prevData.Type = ContentTypePlain
   365  	c.mu.Unlock()
   366  	e.pool.Put(c)
   367  }
   368  
   369  func (s *serverMap) GetAddr() string {
   370  	return s.srv.Addr
   371  }