github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/utils.go (about) 1 package znet 2 3 import ( 4 "context" 5 "errors" 6 "html/template" 7 "net/http" 8 "path" 9 "regexp" 10 "strconv" 11 "strings" 12 "time" 13 14 "github.com/sohaha/zlsgo/zcache" 15 "github.com/sohaha/zlsgo/zdi" 16 "github.com/sohaha/zlsgo/zfile" 17 "github.com/sohaha/zlsgo/zstring" 18 "github.com/sohaha/zlsgo/zutil" 19 ) 20 21 type utils struct { 22 ContextKey contextKeyType 23 } 24 25 var Utils = utils{ 26 ContextKey: contextKeyType{}, 27 } 28 29 const ( 30 defaultPattern = `[^\/]+` 31 idPattern = `[\d]+` 32 idKey = `id` 33 allPattern = `.*` 34 allKey = `*` 35 ) 36 37 var matchCache = zcache.NewFast(func(o *zcache.Options) { 38 o.LRU2Cap = 100 39 }) 40 41 // URLMatchAndParse checks if the request matches the route path and returns a map of the parsed 42 func (_ utils) URLMatchAndParse(requestURL string, path string) (matchParams map[string]string, ok bool) { 43 var ( 44 pattern string 45 matchName []string 46 ) 47 matchParams, ok = make(map[string]string), true 48 if v, ok := matchCache.Get(path); ok { 49 m := v.([]string) 50 pattern = m[0] 51 matchName = m[1:] 52 } else { 53 res := strings.Split(path, "/") 54 pattern, matchName = parsePattern(res, "/") 55 matchCache.Set(path, append([]string{pattern}, matchName...)) 56 } 57 58 if pattern == "" { 59 return nil, false 60 } 61 62 rr, err := zstring.RegexExtract(pattern, requestURL) 63 if err != nil || len(rr) == 0 { 64 return nil, false 65 } 66 67 if rr[0] == requestURL { 68 rr = rr[1:] 69 if len(matchName) != 0 { 70 for k, v := range rr { 71 if key := matchName[k]; key != "" { 72 matchParams[key] = v 73 } 74 } 75 } 76 return 77 } 78 79 return nil, false 80 } 81 82 func parsePattern(res []string, prefix string) (string, []string) { 83 var ( 84 matchName []string 85 pattern string 86 ) 87 l := len(res) 88 for i := 0; i < l; i++ { 89 str := res[i] 90 if str == "" { 91 continue 92 } 93 if strings.HasSuffix(str, "\\") && i < l-1 { 94 res[i+1] = str[:len(str)-1] + "/" + res[i+1] 95 continue 96 } 97 pattern = pattern + prefix 98 l := len(str) - 1 99 i := strings.IndexRune(str, ')') 100 i2 := strings.IndexRune(str, '(') 101 firstChar := str[0] 102 // TODO Need to optimize 103 if i2 != -1 && i != -1 { 104 r, err := regexp.Compile(str) 105 if err != nil { 106 return "", nil 107 } 108 names := r.SubexpNames() 109 matchName = append(matchName, names[1:]...) 110 pattern = pattern + str 111 } else if firstChar == ':' { 112 matchStr := str 113 res := strings.Split(matchStr, ":") 114 key := res[1] 115 if key == "full" { 116 key = allKey 117 } 118 matchName = append(matchName, key) 119 if key == idKey { 120 pattern = pattern + "(" + idPattern + ")" 121 } else if key == allKey { 122 pattern = pattern + "(" + allPattern + ")" 123 } else { 124 pattern = pattern + "(" + defaultPattern + ")" 125 } 126 } else if firstChar == '*' { 127 pattern = pattern + "(" + allPattern + ")" 128 matchName = append(matchName, allKey) 129 } else { 130 i := strings.IndexRune(str, '}') 131 i2 := strings.IndexRune(str, '{') 132 if i2 != -1 && i != -1 { 133 if i == l && i2 == 0 { 134 matchStr := str[1:l] 135 res := strings.Split(matchStr, ":") 136 matchName = append(matchName, res[0]) 137 pattern = pattern + "(" + res[1] + ")" 138 } else { 139 if i2 != 0 { 140 p, m := parsePattern([]string{str[:i2]}, "") 141 if p != "" { 142 pattern = pattern + p 143 matchName = append(matchName, m...) 144 } 145 str = str[i2:] 146 } 147 if i >= 0 { 148 ni := i - i2 149 if ni < 0 { 150 return "", nil 151 } 152 matchStr := str[1:ni] 153 res := strings.Split(matchStr, ":") 154 matchName = append(matchName, res[0]) 155 pattern = pattern + "(" + res[1] + ")" 156 p, m := parsePattern([]string{str[ni+1:]}, "") 157 if p != "" { 158 pattern = pattern + p 159 matchName = append(matchName, m...) 160 } 161 } else { 162 pattern = pattern + str 163 } 164 } 165 } else { 166 pattern = pattern + str 167 } 168 } 169 } 170 171 return pattern, matchName 172 } 173 174 func getAddr(addr string) string { 175 var port int 176 if strings.Contains(addr, ":") { 177 port, _ = strconv.Atoi(strings.Split(addr, ":")[1]) 178 } else { 179 port, _ = strconv.Atoi(addr) 180 addr = ":" + addr 181 } 182 if port != 0 { 183 return addr 184 } 185 port, _ = Port(port, true) 186 return ":" + strconv.Itoa(port) 187 } 188 189 func getHostname(addr string, isTls bool) string { 190 hostname := "http://" 191 if isTls { 192 hostname = "https://" 193 } 194 return hostname + resolveHostname(addr) 195 } 196 197 func (u utils) TreeFind(t *Tree, path string) (handlerFn, []handlerFn, bool) { 198 nodes := t.Find(path, false) 199 for i := range nodes { 200 node := nodes[i] 201 if node.handle != nil { 202 if node.path == path { 203 return node.handle, node.middleware, true 204 } 205 } 206 } 207 208 if len(nodes) == 0 { 209 res := strings.Split(path, "/") 210 p := "" 211 if len(res) == 1 { 212 p = res[0] 213 } else { 214 p = res[1] 215 } 216 nodes := t.Find(p, true) 217 for _, node := range nodes { 218 if handler := node.handle; handler != nil && node.path != path { 219 if matchParamsMap, ok := u.URLMatchAndParse(path, node.path); ok { 220 return func(c *Context) error { 221 req := c.Request 222 ctx := context.WithValue(req.Context(), u.ContextKey, matchParamsMap) 223 c.Request = req.WithContext(ctx) 224 return node.Handle()(c) 225 }, node.middleware, true 226 } 227 } 228 } 229 } 230 return nil, nil, false 231 } 232 233 func (_ utils) CompletionPath(p, prefix string) string { 234 suffix := strings.HasSuffix(p, "/") 235 p = strings.TrimLeft(p, "/") 236 prefix = strings.TrimRight(prefix, "/") 237 path := zstring.TrimSpace(path.Join("/", prefix, p)) 238 239 if path == "" { 240 path = "/" 241 } else if suffix && path != "/" { 242 path = path + "/" 243 } 244 245 return path 246 } 247 248 // func (utils) IsAbort(c *Context) bool { 249 // return c.stopHandle.Load() 250 // } 251 252 // AppendHandler append handler to context, Use caution 253 func (utils) AppendHandler(c *Context, handlers ...Handler) { 254 hl := len(handlers) 255 if hl == 0 { 256 return 257 } 258 259 for i := range handlers { 260 c.middleware = append(c.middleware, Utils.ParseHandlerFunc(handlers[i])) 261 } 262 } 263 264 func resolveAddr(addrString string, tlsConfig ...TlsCfg) addrSt { 265 cfg := addrSt{ 266 addr: addrString, 267 } 268 if len(tlsConfig) > 0 { 269 cfg.Cert = tlsConfig[0].Cert 270 cfg.HTTPAddr = tlsConfig[0].HTTPAddr 271 cfg.HTTPProcessing = tlsConfig[0].HTTPProcessing 272 cfg.Key = tlsConfig[0].Key 273 cfg.Config = tlsConfig[0].Config 274 } 275 return cfg 276 } 277 278 func resolveHostname(addrString string) string { 279 if strings.Index(addrString, ":") == 0 { 280 return "127.0.0.1" + addrString 281 } 282 return addrString 283 } 284 285 func templateParse(templateFile []string, funcMap template.FuncMap) (t *template.Template, err error) { 286 if len(templateFile) == 0 { 287 return nil, errors.New("template file cannot be empty") 288 } 289 file := templateFile[0] 290 if len(file) <= 255 && zfile.FileExist(file) { 291 for i := range templateFile { 292 templateFile[i] = zfile.RealPath(templateFile[i]) 293 } 294 t, err = template.ParseFiles(templateFile...) 295 if err == nil && funcMap != nil { 296 t.Funcs(funcMap) 297 } 298 } else { 299 t = template.New("") 300 if funcMap != nil { 301 t.Funcs(funcMap) 302 } 303 t, err = t.Parse(file) 304 } 305 return 306 } 307 308 type tlsRedirectHandler struct { 309 Domain string 310 } 311 312 func (t *tlsRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 313 http.Redirect(w, r, t.Domain+r.URL.String(), http.StatusMovedPermanently) 314 } 315 316 func (e *Engine) NewContext(w http.ResponseWriter, req *http.Request) *Context { 317 return &Context{ 318 Writer: w, 319 Request: req, 320 Engine: e, 321 Log: e.Log, 322 Cache: Cache, 323 startTime: time.Time{}, 324 header: map[string][]string{}, 325 customizeData: map[string]interface{}{}, 326 stopHandle: zutil.NewBool(false), 327 done: zutil.NewBool(false), 328 prevData: &PrevData{ 329 Code: zutil.NewInt32(0), 330 Type: ContentTypePlain, 331 }, 332 } 333 } 334 335 func (c *Context) clone(w http.ResponseWriter, r *http.Request) { 336 c.Request = r 337 c.Writer = w 338 c.injector = zdi.New(c.Engine.injector) 339 c.injector.Maps(c) 340 c.startTime = time.Now() 341 c.renderError = defErrorHandler() 342 c.stopHandle.Store(false) 343 c.done.Store(false) 344 } 345 346 func (e *Engine) acquireContext() *Context { 347 return e.pool.Get().(*Context) 348 } 349 350 func (e *Engine) releaseContext(c *Context) { 351 c.prevData.Code.Store(0) 352 c.mu.Lock() 353 c.middleware = c.middleware[0:0] 354 c.customizeData = map[string]interface{}{} 355 c.header = map[string][]string{} 356 c.render = nil 357 c.renderError = nil 358 c.cacheJSON = nil 359 c.cacheQuery = nil 360 c.cacheForm = nil 361 c.injector = nil 362 c.rawData = nil 363 c.prevData.Content = c.prevData.Content[0:0] 364 c.prevData.Type = ContentTypePlain 365 c.mu.Unlock() 366 e.pool.Put(c) 367 } 368 369 func (s *serverMap) GetAddr() string { 370 return s.srv.Addr 371 }