github.com/wfusion/gofusion@v1.1.14/http/middleware/xss.go (about)

     1  package middleware
     2  
     3  import (
     4  	"html"
     5  	"io/ioutil"
     6  	"net/url"
     7  	"strings"
     8  
     9  	"github.com/gin-gonic/gin"
    10  	"github.com/gin-gonic/gin/binding"
    11  	"github.com/microcosm-cc/bluemonday"
    12  	"github.com/pkg/errors"
    13  
    14  	"github.com/wfusion/gofusion/common/utils/serialize/json"
    15  )
    16  
    17  func XSS(whitelistURLs []string) gin.HandlerFunc {
    18  	// Do this once for each unique policy, and use the policy for the life of the
    19  	// program Policy creation/editing is not safe to use in multiple goroutines.
    20  	p := bluemonday.UGCPolicy()
    21  
    22  	return func(c *gin.Context) {
    23  		for _, u := range whitelistURLs {
    24  			if strings.HasPrefix(c.Request.URL.String(), u) {
    25  				c.Next()
    26  				return
    27  			}
    28  		}
    29  
    30  		sanitizedQuery, err := xssFilterQuery(p, c.Request.URL.RawQuery)
    31  		if err != nil {
    32  			err = errors.Wrap(err, "filter query")
    33  			_ = c.Error(err)
    34  			c.Abort()
    35  			return
    36  		}
    37  		c.Request.URL.RawQuery = sanitizedQuery
    38  
    39  		var sanitizedBody string
    40  		body, err := c.GetRawData()
    41  		if err != nil {
    42  			err = errors.Wrap(err, "read body")
    43  			_ = c.Error(err)
    44  			c.Abort()
    45  			return
    46  		}
    47  
    48  		// xssFilterJSON() will return error when body is empty.
    49  		if len(body) == 0 {
    50  			c.Next()
    51  			return
    52  		}
    53  
    54  		switch binding.Default(c.Request.Method, c.ContentType()) {
    55  		case binding.JSON:
    56  			if sanitizedBody, err = xssFilterJSON(p, string(body)); err != nil {
    57  				err = errors.Wrap(err, "filter json")
    58  			}
    59  		case binding.FormMultipart:
    60  			sanitizedBody = xssFilterPlain(p, string(body))
    61  		case binding.Form:
    62  			if sanitizedBody, err = xssFilterQuery(p, string(body)); err != nil {
    63  				err = errors.Wrap(err, "filter form")
    64  			}
    65  		}
    66  		if err != nil {
    67  			_ = c.Error(err)
    68  			c.Abort()
    69  			return
    70  		}
    71  
    72  		c.Request.Body = ioutil.NopCloser(strings.NewReader(sanitizedBody))
    73  		c.Next()
    74  	}
    75  }
    76  
    77  func xssFilterQuery(p *bluemonday.Policy, s string) (string, error) {
    78  	values, err := url.ParseQuery(s)
    79  	if err != nil {
    80  		return "", err
    81  	}
    82  
    83  	for k, v := range values {
    84  		values.Del(k)
    85  		for _, vv := range v {
    86  			values.Add(k, xssFilterPlain(p, vv))
    87  		}
    88  	}
    89  
    90  	return values.Encode(), nil
    91  }
    92  
    93  func xssFilterJSON(p *bluemonday.Policy, s string) (string, error) {
    94  	var data any
    95  	if err := json.Unmarshal([]byte(s), &data); err != nil {
    96  		return "", err
    97  	}
    98  
    99  	b := strings.Builder{}
   100  	e := json.NewEncoder(&b)
   101  	e.SetEscapeHTML(false)
   102  	if err := e.Encode(xssFilterJSONData(p, data)); err != nil {
   103  		return "", err
   104  	}
   105  	// use `TrimSpace` to trim newline char add by `Encode`.
   106  	return strings.TrimSpace(b.String()), nil
   107  }
   108  
   109  func xssFilterJSONData(p *bluemonday.Policy, d any) any {
   110  	switch data := d.(type) {
   111  	case []any:
   112  		for i, v := range data {
   113  			data[i] = xssFilterJSONData(p, v)
   114  		}
   115  		return data
   116  	case map[string]any:
   117  		for k, v := range data {
   118  			data[k] = xssFilterJSONData(p, v)
   119  		}
   120  		return data
   121  	case string:
   122  		return xssFilterPlain(p, data)
   123  	default:
   124  		return data
   125  	}
   126  }
   127  
   128  func xssFilterPlain(p *bluemonday.Policy, s string) string {
   129  	sanitized := p.Sanitize(s)
   130  	return html.UnescapeString(sanitized)
   131  }