github.com/TeaOSLab/EdgeNode@v1.3.8/internal/waf/checkpoints/request_upload.go (about)

     1  package checkpoints
     2  
     3  import (
     4  	"bytes"
     5  	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
     6  	"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
     7  	"github.com/iwind/TeaGo/lists"
     8  	"github.com/iwind/TeaGo/maps"
     9  	"github.com/iwind/TeaGo/types"
    10  	"io"
    11  	"net/http"
    12  	"path/filepath"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  var multipartHeaderRegexp = regexp.MustCompile(`(?i)(?:^|\r\n)--+\w+\r\n((([\w-]+: .+)\r\n)+)`)
    19  var multipartHeaderContentRangeRegexp = regexp.MustCompile(`/(\d+)`)
    20  
    21  // RequestUploadCheckpoint ${requestUpload.arg}
    22  type RequestUploadCheckpoint struct {
    23  	Checkpoint
    24  }
    25  
    26  func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
    27  	if this.RequestBodyIsEmpty(req) {
    28  		value = ""
    29  		return
    30  	}
    31  
    32  	value = ""
    33  	if param == "minSize" || param == "maxSize" {
    34  		value = 0
    35  	}
    36  
    37  	if req.WAFRaw().Method != http.MethodPost {
    38  		return
    39  	}
    40  
    41  	if req.WAFRaw().Body == nil {
    42  		return
    43  	}
    44  
    45  	hasRequestBody = true
    46  
    47  	var requestContentLength = req.WAFRaw().ContentLength
    48  
    49  	var fields []string
    50  	var minSize int64
    51  	var maxSize int64
    52  	var names []string
    53  	var extensions []string
    54  
    55  	if requestContentLength <= req.WAFMaxRequestSize() { // full read
    56  		if req.WAFRaw().MultipartForm == nil {
    57  			var bodyData = req.WAFGetCacheBody()
    58  			if len(bodyData) == 0 {
    59  				data, err := req.WAFReadBody(req.WAFMaxRequestSize())
    60  				if err != nil {
    61  					sysErr = err
    62  					return
    63  				}
    64  
    65  				bodyData = data
    66  				req.WAFSetCacheBody(data)
    67  				defer req.WAFRestoreBody(data)
    68  			}
    69  			var oldBody = req.WAFRaw().Body
    70  			req.WAFRaw().Body = io.NopCloser(bytes.NewBuffer(bodyData))
    71  			err := req.WAFRaw().ParseMultipartForm(req.WAFMaxRequestSize())
    72  			if err == nil {
    73  				for field, files := range req.WAFRaw().MultipartForm.File {
    74  					if param == "field" {
    75  						fields = append(fields, field)
    76  					} else if param == "minSize" {
    77  						for _, file := range files {
    78  							if minSize == 0 || minSize > file.Size {
    79  								minSize = file.Size
    80  							}
    81  						}
    82  					} else if param == "maxSize" {
    83  						for _, file := range files {
    84  							if maxSize < file.Size {
    85  								maxSize = file.Size
    86  							}
    87  						}
    88  					} else if param == "name" {
    89  						for _, file := range files {
    90  							if !lists.ContainsString(names, file.Filename) {
    91  								names = append(names, file.Filename)
    92  							}
    93  						}
    94  					} else if param == "ext" {
    95  						for _, file := range files {
    96  							if len(file.Filename) > 0 {
    97  								exit := strings.ToLower(filepath.Ext(file.Filename))
    98  								if !lists.ContainsString(extensions, exit) {
    99  									extensions = append(extensions, exit)
   100  								}
   101  							}
   102  						}
   103  					}
   104  				}
   105  			}
   106  
   107  			// 还原
   108  			req.WAFRaw().Body = oldBody
   109  
   110  			if err != nil {
   111  				userErr = err
   112  				return
   113  			}
   114  
   115  			if req.WAFRaw().MultipartForm == nil {
   116  				return
   117  			}
   118  		}
   119  	} else { // read first part
   120  		var bodyData = req.WAFGetCacheBody()
   121  		if len(bodyData) == 0 {
   122  			data, err := req.WAFReadBody(req.WAFMaxRequestSize())
   123  			if err != nil {
   124  				sysErr = err
   125  				return
   126  			}
   127  
   128  			bodyData = data
   129  			req.WAFSetCacheBody(data)
   130  			defer req.WAFRestoreBody(data)
   131  		}
   132  
   133  		var subMatches = multipartHeaderRegexp.FindAllSubmatch(bodyData, -1)
   134  		for _, subMatch := range subMatches {
   135  			var headers = bytes.Split(subMatch[1], []byte{'\r', '\n'})
   136  			var partContentLength int64 = -1
   137  			for _, header := range headers {
   138  				if len(header) > 2 {
   139  					var kv = bytes.SplitN(header, []byte{':'}, 2)
   140  					if len(kv) == 2 {
   141  						var k = kv[0]
   142  						var v = kv[1]
   143  						switch string(bytes.ToLower(k)) {
   144  						case "content-disposition":
   145  							var props = bytes.Split(v, []byte{';', ' '})
   146  							for _, prop := range props {
   147  								var propKV = bytes.SplitN(prop, []byte{'='}, 2)
   148  								if len(propKV) == 2 {
   149  									var propValue = string(propKV[1])
   150  									switch string(propKV[0]) {
   151  									case "name":
   152  										if param == "field" {
   153  											propValue, _ = strconv.Unquote(propValue)
   154  											fields = append(fields, propValue)
   155  										}
   156  									case "filename":
   157  										if param == "name" {
   158  											propValue, _ = strconv.Unquote(propValue)
   159  											names = append(names, propValue)
   160  										} else if param == "ext" {
   161  											propValue, _ = strconv.Unquote(propValue)
   162  											extensions = append(extensions, strings.ToLower(filepath.Ext(propValue)))
   163  										}
   164  									}
   165  								}
   166  							}
   167  						case "content-range":
   168  							if partContentLength <= 0 {
   169  								var contentRange = multipartHeaderContentRangeRegexp.FindSubmatch(v)
   170  								if len(contentRange) >= 2 {
   171  									partContentLength = types.Int64(string(contentRange[1]))
   172  								}
   173  							}
   174  						case "content-length":
   175  							if partContentLength <= 0 {
   176  								partContentLength = types.Int64(string(v))
   177  							}
   178  						}
   179  					}
   180  				}
   181  			}
   182  
   183  			// minSize & maxSize
   184  			if partContentLength > 0 {
   185  				if param == "minSize" && (minSize == 0 /** not set yet **/ || partContentLength < minSize) {
   186  					minSize = partContentLength
   187  				} else if param == "maxSize" && partContentLength > maxSize {
   188  					maxSize = partContentLength
   189  				}
   190  			}
   191  		}
   192  	}
   193  
   194  	if param == "field" { // field
   195  		value = strings.Join(fields, ",")
   196  	} else if param == "minSize" { // minSize
   197  		if minSize == 0 && requestContentLength > 0 {
   198  			minSize = requestContentLength
   199  		}
   200  		value = minSize
   201  	} else if param == "maxSize" { // maxSize
   202  		if maxSize == 0 && requestContentLength > 0 {
   203  			maxSize = requestContentLength
   204  		}
   205  		value = maxSize
   206  	} else if param == "name" { // name
   207  		value = strings.Join(names, ",")
   208  	} else if param == "ext" { // ext
   209  		value = strings.Join(extensions, ",")
   210  	}
   211  
   212  	return
   213  }
   214  
   215  func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
   216  	if this.IsRequest() {
   217  		return this.RequestValue(req, param, options, ruleId)
   218  	}
   219  	return
   220  }
   221  
   222  func (this *RequestUploadCheckpoint) ParamOptions() *ParamOptions {
   223  	option := NewParamOptions()
   224  	option.AddParam("最小文件尺寸", "minSize")
   225  	option.AddParam("最大文件尺寸", "maxSize")
   226  	option.AddParam("扩展名(如.txt)", "ext")
   227  	option.AddParam("原始文件名", "name")
   228  	option.AddParam("表单字段名", "field")
   229  	return option
   230  }
   231  
   232  func (this *RequestUploadCheckpoint) CacheLife() utils.CacheLife {
   233  	return utils.CacheMiddleLife
   234  }