github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/input.go (about)

     1  package znet
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"io/ioutil"
     7  	"mime"
     8  	"mime/multipart"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"strings"
    13  
    14  	"github.com/sohaha/zlsgo/zfile"
    15  	"github.com/sohaha/zlsgo/zjson"
    16  	"github.com/sohaha/zlsgo/zstring"
    17  )
    18  
    19  func (c *Context) initQuery() {
    20  	if c.cacheQuery != nil {
    21  		return
    22  	}
    23  	c.cacheQuery = c.Request.URL.Query()
    24  }
    25  
    26  func (c *Context) initPostForm() {
    27  	if c.cacheForm != nil {
    28  		return
    29  	}
    30  	form := make(url.Values)
    31  	if c.Request.PostForm == nil {
    32  		(func() {
    33  			body, err := c.GetDataRaw()
    34  			if err != nil {
    35  				return
    36  			}
    37  			values, _ := url.ParseQuery(body)
    38  			c.Request.PostForm = values
    39  			v := c.ContentType()
    40  			if v == mimeMultipartPOSTForm {
    41  				_ = c.ParseMultipartForm()
    42  			}
    43  			form = c.Request.PostForm
    44  		})()
    45  	}
    46  	c.cacheForm = form
    47  }
    48  
    49  // GetParam Get the value of the param inside the route
    50  func (c *Context) GetParam(key string) string {
    51  	return c.GetAllParam()[key]
    52  }
    53  
    54  // GetAllParam Get the value of all param in the route
    55  func (c *Context) GetAllParam() map[string]string {
    56  	if values, ok := c.Request.Context().Value(Utils.ContextKey).(map[string]string); ok {
    57  		return values
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  // GetAllQuery Get All Queryst
    64  func (c *Context) GetAllQuery() url.Values {
    65  	c.initQuery()
    66  	return c.cacheQuery
    67  }
    68  
    69  // GetAllQueryMaps Get All Queryst Maps
    70  func (c *Context) GetAllQueryMaps() map[string]string {
    71  	c.initQuery()
    72  	arr := map[string]string{}
    73  	for key, v := range c.cacheQuery {
    74  		arr[key] = v[0]
    75  	}
    76  	return arr
    77  }
    78  
    79  // GetQueryArray Get Query Array
    80  func (c *Context) GetQueryArray(key string) ([]string, bool) {
    81  	c.initQuery()
    82  	if values, ok := c.cacheQuery[key]; ok && len(values) > 0 {
    83  		return values, true
    84  	}
    85  	return []string{}, false
    86  }
    87  
    88  // GetQuery Get Query
    89  func (c *Context) GetQuery(key string) (string, bool) {
    90  	if values, ok := c.GetQueryArray(key); ok {
    91  		return values[0], ok
    92  	}
    93  	return "", false
    94  }
    95  
    96  // DefaultQuery Get Query Or Default
    97  func (c *Context) DefaultQuery(key string, def string) string {
    98  	if value, ok := c.GetQuery(key); ok {
    99  		return value
   100  	}
   101  	return def
   102  }
   103  
   104  // GetQueryMap Get Query Map
   105  func (c *Context) GetQueryMap(key string) (map[string]string, bool) {
   106  	return c.get(c.cacheQuery, key)
   107  }
   108  
   109  // QueryMap Get Query Map
   110  func (c *Context) QueryMap(key string) map[string]string {
   111  	v, _ := c.get(c.cacheQuery, key)
   112  	return v
   113  }
   114  
   115  // DefaultPostForm Get Form Or Default
   116  func (c *Context) DefaultPostForm(key, def string) string {
   117  	if value, ok := c.GetPostForm(key); ok {
   118  		return value
   119  	}
   120  	return def
   121  }
   122  
   123  // GetPostForm Get PostForm
   124  func (c *Context) GetPostForm(key string) (string, bool) {
   125  	if values, ok := c.GetPostFormArray(key); ok {
   126  		return values[0], ok
   127  	}
   128  	return "", false
   129  }
   130  
   131  // DefaultFormOrQuery  Get Form Or Query
   132  func (c *Context) DefaultFormOrQuery(key string, def string) string {
   133  	if value, ok := c.GetPostForm(key); ok {
   134  		return value
   135  	}
   136  	return c.DefaultQuery(key, def)
   137  }
   138  
   139  // GetPostFormArray Get Post FormArray
   140  func (c *Context) GetPostFormArray(key string) ([]string, bool) {
   141  	req := c.Request
   142  	postForm := c.GetPostFormAll()
   143  	if values := postForm[key]; len(values) > 0 {
   144  		return values, true
   145  	}
   146  	if req.MultipartForm != nil && req.MultipartForm.File != nil {
   147  		if values := req.MultipartForm.Value[key]; len(values) > 0 {
   148  			return values, true
   149  		}
   150  	}
   151  	return []string{}, false
   152  }
   153  
   154  // GetPostFormAll Get PostForm All
   155  func (c *Context) GetPostFormAll() (value url.Values) {
   156  	c.initPostForm()
   157  	value = c.cacheForm
   158  	return
   159  }
   160  
   161  // PostFormMap PostForm Map
   162  func (c *Context) PostFormMap(key string) map[string]string {
   163  	v, _ := c.GetPostFormMap(key)
   164  	return v
   165  }
   166  
   167  // GetPostFormMap Get PostForm Map
   168  func (c *Context) GetPostFormMap(key string) (map[string]string, bool) {
   169  	postForm := c.GetPostFormAll()
   170  	dicts, exist := c.get(postForm, key)
   171  	if !exist && c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
   172  		dicts, exist = c.get(c.Request.MultipartForm.Value, key)
   173  	}
   174  
   175  	return dicts, exist
   176  }
   177  
   178  // GetJSON Get JSON
   179  func (c *Context) GetJSON(key string) *zjson.Res {
   180  	j, _ := c.GetJSONs()
   181  
   182  	return j.Get(key)
   183  }
   184  
   185  // GetJSONs Get JSONs
   186  func (c *Context) GetJSONs() (json *zjson.Res, err error) {
   187  	if c.cacheJSON != nil {
   188  		return c.cacheJSON, nil
   189  	}
   190  	json = &zjson.Res{}
   191  	var body string
   192  	body, err = c.GetDataRaw()
   193  	if err != nil {
   194  		return
   195  	}
   196  	if !zjson.Valid(body) {
   197  		err = errors.New("illegal json format")
   198  		return
   199  	}
   200  
   201  	json = zjson.Parse(body)
   202  	c.cacheJSON = json
   203  	return
   204  }
   205  
   206  // GetDataRaw Get Raw Data
   207  func (c *Context) GetDataRaw() (string, error) {
   208  	body, err := c.GetDataRawBytes()
   209  	if err != nil {
   210  		return "", err
   211  	}
   212  	return zstring.Bytes2String(body), err
   213  }
   214  
   215  // GetDataRawBytes Get Raw Data
   216  func (c *Context) GetDataRawBytes() ([]byte, error) {
   217  	if c.rawData != nil {
   218  		return c.rawData, nil
   219  	}
   220  	var err error
   221  	if c.Request.Body == nil {
   222  		err = errors.New("request body is nil")
   223  		return nil, err
   224  	}
   225  	var body []byte
   226  	body, err = ioutil.ReadAll(c.Request.Body)
   227  	if err == nil {
   228  		c.rawData = body
   229  	}
   230  	return c.rawData, err
   231  }
   232  
   233  func (c *Context) get(m map[string][]string, key string) (map[string]string, bool) {
   234  	d := make(map[string]string)
   235  	e := false
   236  	for k, v := range m {
   237  		if i := strings.IndexByte(k, '['); i >= 1 && k[0:i] == key {
   238  			if j := strings.IndexByte(k[i+1:], ']'); j >= 1 {
   239  				e = true
   240  				d[k[i+1:][:j]] = v[0]
   241  			}
   242  		}
   243  	}
   244  	return d, e
   245  }
   246  
   247  // FormFile FormFile
   248  func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
   249  	f, err := c.FormFiles(name)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	return f[0], err
   255  }
   256  
   257  // FormFiles Multiple FormFile
   258  func (c *Context) FormFiles(name string) (files []*multipart.FileHeader, err error) {
   259  	var multipartForm *multipart.Form
   260  	multipartForm, err = c.MultipartForm()
   261  
   262  	if multipartForm == nil || multipartForm.File == nil {
   263  		err = errors.New("file is empty")
   264  		return
   265  	}
   266  
   267  	files = make([]*multipart.FileHeader, 0, 1)
   268  	if fhs := multipartForm.File[name]; len(fhs) > 0 {
   269  		for i := range fhs {
   270  			files = append(files, fhs[i])
   271  		}
   272  	}
   273  
   274  	if len(files) == 0 {
   275  		return nil, errors.New("file is empty")
   276  	}
   277  
   278  	return
   279  }
   280  
   281  // MultipartForm MultipartForm
   282  func (c *Context) MultipartForm() (*multipart.Form, error) {
   283  	err := c.ParseMultipartForm()
   284  	return c.Request.MultipartForm, err
   285  }
   286  
   287  // SaveUploadedFile Save Uploaded File
   288  func (c *Context) SaveUploadedFile(file *multipart.FileHeader, dist string) error {
   289  	src, err := file.Open()
   290  	if err != nil {
   291  		return err
   292  	}
   293  	defer src.Close()
   294  
   295  	dist = zfile.RealPath(dist)
   296  	out, err := os.Create(dist)
   297  	if err != nil {
   298  		return err
   299  	}
   300  	defer out.Close()
   301  	_, err = io.Copy(out, src)
   302  	if err != nil {
   303  		return err
   304  	}
   305  
   306  	return nil
   307  }
   308  
   309  func (c *Context) ParseMultipartForm(maxMultipartMemory ...int64) error {
   310  	if c.Request.MultipartForm != nil {
   311  		return nil
   312  	}
   313  
   314  	mr, err := c.multipartReader(false)
   315  	if err != nil {
   316  		return err
   317  	}
   318  
   319  	maxMemory := c.Engine.MaxMultipartMemory
   320  	if len(maxMultipartMemory) > 0 && maxMultipartMemory[0] > 0 {
   321  		maxMemory = maxMultipartMemory[0]
   322  	}
   323  	f, err := mr.ReadForm(maxMemory)
   324  	if err != nil {
   325  		return err
   326  	}
   327  
   328  	if c.Request.PostForm == nil {
   329  		c.Request.PostForm = make(url.Values)
   330  	}
   331  
   332  	for k, v := range f.Value {
   333  		c.Request.PostForm[k] = append(c.Request.PostForm[k], v...)
   334  	}
   335  
   336  	c.Request.MultipartForm = f
   337  	return nil
   338  }
   339  
   340  func (c *Context) multipartReader(allowMixed bool) (*multipart.Reader, error) {
   341  	v := c.Request.Header.Get("Content-Type")
   342  	if v == "" {
   343  		return nil, http.ErrNotMultipart
   344  	}
   345  	d, params, err := mime.ParseMediaType(v)
   346  	if err != nil || !(d == "multipart/form-data" || allowMixed && d == "multipart/mixed") {
   347  		return nil, http.ErrNotMultipart
   348  	}
   349  	boundary, ok := params["boundary"]
   350  	if !ok {
   351  		return nil, http.ErrMissingBoundary
   352  	}
   353  	body, err := c.GetDataRaw()
   354  	if err != nil {
   355  		return nil, err
   356  	}
   357  	return multipart.NewReader(strings.NewReader(body), boundary), nil
   358  }