github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/context.go (about)

     1  package znet
     2  
     3  import (
     4  	"net/http"
     5  	"net/textproto"
     6  	"net/url"
     7  	"strconv"
     8  	"strings"
     9  	"unicode"
    10  
    11  	"github.com/sohaha/zlsgo/zdi"
    12  )
    13  
    14  // Host Get the current Host
    15  func (c *Context) Host(full ...bool) string {
    16  	scheme := c.Request.Header.Get("X-Forwarded-Proto")
    17  	if scheme == "" {
    18  		scheme = "https"
    19  		if c.Request.TLS == nil {
    20  			scheme = "http"
    21  		}
    22  	}
    23  	host := c.Request.Host
    24  	if len(full) > 0 && full[0] {
    25  		host += c.Request.URL.String()
    26  	}
    27  	return scheme + "://" + host
    28  }
    29  
    30  // CompletionLink Complete the link and add the current domain name if it is not linked
    31  func (c *Context) CompletionLink(link string) string {
    32  	if strings.HasPrefix(link, "http://") || strings.HasPrefix(link, "https://") {
    33  		return link
    34  	}
    35  	finalLink := c.Host()
    36  	if !strings.HasPrefix(link, "/") {
    37  		finalLink = finalLink + "/"
    38  	}
    39  	finalLink = finalLink + link
    40  	return finalLink
    41  }
    42  
    43  // IsWebsocket Is Websocket
    44  func (c *Context) IsWebsocket() bool {
    45  	if strings.Contains(strings.ToLower(c.GetHeader("Connection")), "upgrade") &&
    46  		strings.ToLower(c.GetHeader("Upgrade")) == "websocket" {
    47  		return true
    48  	}
    49  	return false
    50  }
    51  
    52  // IsSSE Is SSE
    53  func (c *Context) IsSSE() bool {
    54  	return strings.ToLower(c.GetHeader("Accept")) == "text/event-stream"
    55  }
    56  
    57  // IsAjax IsAjax
    58  func (c *Context) IsAjax() bool {
    59  	return c.GetHeader("X-Requested-With") == "XMLHttpRequest"
    60  }
    61  
    62  // GetClientIP Client IP
    63  func (c *Context) GetClientIP() (IP string) {
    64  	IP = ClientPublicIP(c.Request)
    65  	if IP == "" {
    66  		IP = ClientIP(c.Request)
    67  	}
    68  	return
    69  }
    70  
    71  // GetHeader Get Header
    72  func (c *Context) GetHeader(key string) string {
    73  	return c.Request.Header.Get(key)
    74  }
    75  
    76  // SetHeader Set Header
    77  func (c *Context) SetHeader(key, value string) {
    78  	key = textproto.CanonicalMIMEHeaderKey(key)
    79  	c.mu.Lock()
    80  	if value == "" {
    81  		delete(c.header, key)
    82  	} else {
    83  		c.header[key] = append(c.header[key], value)
    84  	}
    85  	c.mu.Unlock()
    86  }
    87  
    88  func (c *Context) write() {
    89  	if !c.done.CAS(false, true) {
    90  		return
    91  	}
    92  
    93  	c.Next()
    94  
    95  	data := c.PrevContent()
    96  	// data.Code.CAS(0, http.StatusInternalServerError)
    97  
    98  	for key, value := range c.header {
    99  		for i := range value {
   100  			header := value[i]
   101  			if i == 0 {
   102  				c.Writer.Header().Set(key, header)
   103  			} else {
   104  				c.Writer.Header().Add(key, header)
   105  			}
   106  		}
   107  	}
   108  
   109  	if c.Request == nil || c.Request.Context().Err() != nil {
   110  		return
   111  	}
   112  
   113  	defer func() {
   114  		if c.Engine.IsDebug() {
   115  			requestLog(c)
   116  		}
   117  	}()
   118  
   119  	code := int(data.Code.Load())
   120  	if code == 0 {
   121  		code = http.StatusOK
   122  		data.Code.Store(int32(code))
   123  	}
   124  	size := len(data.Content)
   125  	if size > 0 {
   126  		c.Writer.Header().Set("Content-Length", strconv.Itoa(size))
   127  		c.Writer.WriteHeader(code)
   128  		_, err := c.Writer.Write(data.Content)
   129  		if err != nil {
   130  			c.Log.Error(err)
   131  		}
   132  		return
   133  	}
   134  	if code != 200 {
   135  		c.Writer.WriteHeader(code)
   136  	}
   137  }
   138  
   139  // Next middleware, if current middleware has been stopped, it will return false
   140  func (c *Context) Next() bool {
   141  	for {
   142  		if c.stopHandle.Load() {
   143  			return false
   144  		}
   145  		c.mu.RLock()
   146  		n := len(c.middleware) > 0
   147  		c.mu.RUnlock()
   148  		if !n {
   149  			return true
   150  		}
   151  		c.next()
   152  	}
   153  }
   154  
   155  func (c *Context) next() {
   156  	if c.stopHandle.Load() {
   157  		return
   158  	}
   159  	c.mu.Lock()
   160  	n := c.middleware[0]
   161  	c.middleware = c.middleware[1:]
   162  	c.mu.Unlock()
   163  	err := n(c)
   164  	if err != nil {
   165  		c.renderError(c, err)
   166  		c.Abort()
   167  	}
   168  }
   169  
   170  // SetCookie Set Cookie
   171  func (c *Context) SetCookie(name, value string, maxAge ...int) {
   172  	a := 0
   173  	if len(maxAge) > 0 {
   174  		a = maxAge[0]
   175  	}
   176  	cookie := &http.Cookie{
   177  		Name:     name,
   178  		Value:    value,
   179  		Path:     "/",
   180  		HttpOnly: true,
   181  		MaxAge:   a,
   182  	}
   183  	c.Writer.Header().Add("Set-Cookie", cookie.String())
   184  }
   185  
   186  // GetCookie Get Cookie
   187  func (c *Context) GetCookie(name string) string {
   188  	cookie, err := c.Request.Cookie(name)
   189  	if err != nil {
   190  		return ""
   191  	}
   192  	v, _ := url.QueryUnescape(cookie.Value)
   193  	return v
   194  }
   195  
   196  // GetReferer request referer
   197  func (c *Context) GetReferer() string {
   198  	return c.Request.Header.Get("Referer")
   199  }
   200  
   201  // GetUserAgent http request UserAgent
   202  func (c *Context) GetUserAgent() string {
   203  	return c.Request.Header.Get("User-Agent")
   204  }
   205  
   206  // ContentType returns the Content-Type header of the request
   207  func (c *Context) ContentType(contentText ...string) string {
   208  	var content string
   209  	if len(contentText) > 0 {
   210  		content = contentText[0]
   211  	} else {
   212  		content = c.GetHeader("Content-Type")
   213  	}
   214  	for i := 0; i < len(content); i++ {
   215  		char := content[i]
   216  		if char == ' ' || char == ';' {
   217  			return content[:i]
   218  		}
   219  	}
   220  	return content
   221  }
   222  
   223  // WithValue context sharing data
   224  func (c *Context) WithValue(key string, value interface{}) *Context {
   225  	c.mu.Lock()
   226  	c.customizeData[key] = value
   227  	c.mu.Unlock()
   228  	return c
   229  }
   230  
   231  // Value get context sharing data
   232  func (c *Context) Value(key string, def ...interface{}) (value interface{}, ok bool) {
   233  	c.mu.RLock()
   234  	value, ok = c.customizeData[key]
   235  	if !ok && (len(def) > 0) {
   236  		value = def[0]
   237  	}
   238  	c.mu.RUnlock()
   239  	return
   240  }
   241  
   242  // Value get context sharing data
   243  func (c *Context) MustValue(key string, def ...interface{}) (value interface{}) {
   244  	value, _ = c.Value(key, def...)
   245  	return
   246  }
   247  
   248  func (c *Context) Injector() zdi.Injector {
   249  	return c.injector
   250  }
   251  
   252  func (c *Context) FileAttachment(filepath, filename string) {
   253  	if isASCII(filename) {
   254  		c.Writer.Header().Set("Content-Disposition", `attachment; filename="`+strings.Replace(filename, "\"", "\\\"", -1)+`"`)
   255  	} else {
   256  		c.Writer.Header().Set("Content-Disposition", `attachment; filename*=UTF-8''`+url.QueryEscape(filename))
   257  	}
   258  	http.ServeFile(c.Writer, c.Request, filepath)
   259  }
   260  
   261  // https://stackoverflow.com/questions/53069040/checking-a-string-contains-only-ascii-characters
   262  func isASCII(s string) bool {
   263  	for i := 0; i < len(s); i++ {
   264  		if s[i] > unicode.MaxASCII {
   265  			return false
   266  		}
   267  	}
   268  	return true
   269  }