github.com/GuanceCloud/cliutils@v1.1.21/network/http/gin.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the MIT License.
     3  // This product includes software developed at Guance Cloud (https://www.guance.com/).
     4  // Copyright 2021-present Guance, Inc.
     5  
     6  package http
     7  
     8  import (
     9  	"bytes"
    10  	"compress/gzip"
    11  	"crypto/md5" //nolint:gosec
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/http"
    16  	"net/textproto"
    17  	"os"
    18  	"sort"
    19  	"strconv"
    20  	"strings"
    21  
    22  	"github.com/GuanceCloud/cliutils"
    23  	"github.com/GuanceCloud/cliutils/logger"
    24  	"github.com/gin-gonic/gin"
    25  )
    26  
    27  const (
    28  	XAgentIP       = "X-Agent-Ip"
    29  	XAgentUID      = "X-Agent-Uid"
    30  	XCQRP          = "X-CQ-RP"
    31  	XDatakitInfo   = "X-Datakit-Info"
    32  	XDatakitUUID   = "X-Datakit-UUID" // deprecated
    33  	XDBUUID        = "X-DB-UUID"
    34  	XDomainName    = "X-Domain-Name"
    35  	XLua           = "X-Lua"
    36  	XPrecision     = "X-Precision"
    37  	XRP            = "X-RP"
    38  	XSource        = "X-Source"
    39  	XTableName     = "X-Table-Name"
    40  	XToken         = "X-Token"
    41  	XTraceID       = "X-Trace-Id"
    42  	XVersion       = "X-Version"
    43  	XWorkspaceUUID = "X-Workspace-UUID"
    44  )
    45  
    46  const (
    47  	HeaderWildcard = "*"
    48  	HeaderGlue     = ", "
    49  )
    50  
    51  var (
    52  	// Although CORS-safelisted request headers(Accept/Accept-Language/Content-Language/Content-Type) are always allowed
    53  	// and don't usually need to be listed in Access-Control-Allow-Headers,
    54  	// listing them anyway will circumvent the additional restrictions that apply.
    55  	// see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers#bypassing_additional_restrictions
    56  	defaultCORSHeader = newCORSHeaders([]string{
    57  		"Content-Type",
    58  		"Content-Length",
    59  		"Accept-Encoding",
    60  		"X-CSRF-Token",
    61  		"Authorization",
    62  		"Accept",
    63  		"Accept-Language",
    64  		"Content-Language",
    65  		"Origin",
    66  		"Cache-Control",
    67  		"X-Requested-With",
    68  
    69  		// dataflux headers
    70  		XToken,
    71  		XDatakitUUID,
    72  		XRP,
    73  		XPrecision,
    74  		XLua,
    75  		"*",
    76  	})
    77  	allowHeaders      = defaultCORSHeader.String()
    78  	realIPHeader      = []string{"X-Forwarded-For", "X-Real-IP", "RemoteAddr"}
    79  	MaxRequestBodyLen = 128
    80  
    81  	l = logger.DefaultSLogger("gin")
    82  )
    83  
    84  func Init() {
    85  	l = logger.SLogger("gin")
    86  
    87  	if v, ok := os.LookupEnv("MAX_REQUEST_BODY_LEN"); ok {
    88  		if i, err := strconv.ParseInt(v, 10, 64); err != nil {
    89  			l.Warnf("invalid MAX_REQUEST_BODY_LEN, expect int, got %s, ignored", v)
    90  		} else {
    91  			MaxRequestBodyLen = int(i)
    92  		}
    93  	}
    94  }
    95  
    96  type CORSHeaders map[string]struct{}
    97  
    98  func newCORSHeaders(headers []string) CORSHeaders {
    99  	ch := make(CORSHeaders, len(headers))
   100  	for _, header := range headers {
   101  		header = strings.TrimSpace(header)
   102  		if header == "" {
   103  			continue
   104  		}
   105  		ch[textproto.CanonicalMIMEHeaderKey(header)] = struct{}{}
   106  	}
   107  	return ch
   108  }
   109  
   110  func (c CORSHeaders) String() string {
   111  	headers := make([]string, 0, len(c))
   112  	hasWildcard := false
   113  	for k := range c {
   114  		if k == HeaderWildcard {
   115  			hasWildcard = true
   116  			continue
   117  		}
   118  		headers = append(headers, k)
   119  	}
   120  
   121  	sort.Strings(headers)
   122  
   123  	if hasWildcard {
   124  		headers = append(headers, "*")
   125  	}
   126  
   127  	return strings.Join(headers, HeaderGlue)
   128  }
   129  
   130  func (c CORSHeaders) Add(requestHeaders string) string {
   131  	if requestHeaders == "" {
   132  		return allowHeaders
   133  	}
   134  	headers := make([]string, 0)
   135  	for _, key := range strings.Split(requestHeaders, ",") {
   136  		key = strings.TrimSpace(key)
   137  		if key == "" {
   138  			continue
   139  		}
   140  		key = textproto.CanonicalMIMEHeaderKey(key)
   141  		if _, ok := c[key]; !ok {
   142  			headers = append(headers, key)
   143  		}
   144  	}
   145  	if len(headers) == 0 {
   146  		return allowHeaders
   147  	}
   148  	return strings.Join(headers, HeaderGlue) + HeaderGlue + allowHeaders
   149  }
   150  
   151  func GinLogFormatter(param gin.LogFormatterParams) string {
   152  	realIP := param.ClientIP
   153  	for _, h := range realIPHeader {
   154  		if v := param.Request.Header.Get(h); v != "" {
   155  			realIP = v
   156  		}
   157  	}
   158  
   159  	if param.ErrorMessage != "" {
   160  		return fmt.Sprintf("[GIN] %v | %3d | %8v | %15s | %-7s %#v -> %s\n",
   161  			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
   162  			param.StatusCode,
   163  			param.Latency,
   164  			net.ParseIP(realIP),
   165  			param.Method,
   166  			param.Path,
   167  			param.ErrorMessage)
   168  	} else {
   169  		return fmt.Sprintf("[GIN] %v | %3d | %8v | %15s | %-7s %#v\n",
   170  			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
   171  			param.StatusCode,
   172  			param.Latency,
   173  			net.ParseIP(realIP),
   174  			param.Method,
   175  			param.Path)
   176  	}
   177  }
   178  
   179  func CORSMiddleware(c *gin.Context) {
   180  	allowOrigin := c.GetHeader("origin")
   181  	requestHeaders := c.GetHeader("Access-Control-Request-Headers")
   182  	if allowOrigin == "" {
   183  		allowOrigin = "*"
   184  	}
   185  	c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
   186  	c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
   187  	if requestHeaders != "" {
   188  		c.Writer.Header().Set("Access-Control-Allow-Headers", defaultCORSHeader.Add(requestHeaders))
   189  	} else {
   190  		c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeaders)
   191  	}
   192  	c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
   193  
   194  	// The default value is only 5 seconds, so we explicitly set it to reduce the count of OPTIONS requests.
   195  	// see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age#directives
   196  	c.Writer.Header().Set("Access-Control-Max-Age", "7200")
   197  	if c.Request.Method == "OPTIONS" {
   198  		c.AbortWithStatus(http.StatusNoContent)
   199  		return
   200  	}
   201  	c.Next()
   202  }
   203  
   204  func CORSMiddlewareV2(allowedOrigins []string) gin.HandlerFunc {
   205  	return func(c *gin.Context) {
   206  		allowOrigin := c.GetHeader("origin")
   207  		requestHeaders := c.GetHeader("Access-Control-Request-Headers")
   208  		if allowOrigin == "" {
   209  			allowOrigin = "*"
   210  		}
   211  		if originIsAllowed(allowOrigin, allowedOrigins) {
   212  			c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
   213  			c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
   214  			if requestHeaders != "" {
   215  				c.Writer.Header().Set("Access-Control-Allow-Headers", defaultCORSHeader.Add(requestHeaders))
   216  			} else {
   217  				c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeaders)
   218  			}
   219  			c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
   220  
   221  			// The default value is only 5 seconds, so we explicitly set it to reduce the count of OPTIONS requests.
   222  			// see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age#directives
   223  			c.Writer.Header().Set("Access-Control-Max-Age", "7200")
   224  		}
   225  		if c.Request.Method == "OPTIONS" {
   226  			c.AbortWithStatus(http.StatusNoContent)
   227  			return
   228  		}
   229  		c.Next()
   230  	}
   231  }
   232  
   233  func originIsAllowed(origin string, allowedOrigins []string) bool {
   234  	if len(allowedOrigins) == 0 {
   235  		return true
   236  	}
   237  	for _, allowedOrigin := range allowedOrigins {
   238  		if origin == allowedOrigin {
   239  			return true
   240  		}
   241  	}
   242  	return false
   243  }
   244  
   245  func TraceIDMiddleware(c *gin.Context) {
   246  	if c.Request.Method == `OPTIONS` {
   247  		c.Next()
   248  	} else {
   249  		tid := c.Request.Header.Get(XTraceID)
   250  		if tid == "" {
   251  			tid = cliutils.XID(`trace_`)
   252  			c.Request.Header.Set(XTraceID, tid)
   253  		}
   254  
   255  		c.Writer.Header().Set(XTraceID, tid)
   256  		c.Next()
   257  	}
   258  }
   259  
   260  func FormatRequest(r *http.Request) string {
   261  	// Add the request string
   262  	url := fmt.Sprintf("%v %v %v", r.Method, r.URL, r.Proto)
   263  	request := []string{url}
   264  
   265  	// Add the host
   266  	request = append(request, fmt.Sprintf("Host: %v", r.Host))
   267  	// Loop through headers
   268  
   269  	for name, headers := range r.Header {
   270  		for _, h := range headers {
   271  			request = append(request, fmt.Sprintf("%v: %v", name, h))
   272  		}
   273  	}
   274  
   275  	// Return the request as a string
   276  	return strings.Join(request, "|")
   277  }
   278  
   279  type bodyLoggerWriter struct {
   280  	gin.ResponseWriter
   281  	body *bytes.Buffer
   282  }
   283  
   284  func (w bodyLoggerWriter) Write(b []byte) (int, error) {
   285  	w.body.Write(b)
   286  	return w.ResponseWriter.Write(b)
   287  }
   288  
   289  func RequestLoggerMiddleware(c *gin.Context) {
   290  	w := &bodyLoggerWriter{
   291  		ResponseWriter: c.Writer,
   292  		body:           bytes.NewBufferString(``),
   293  	}
   294  
   295  	c.Writer = w
   296  	c.Next()
   297  
   298  	body := w.body.String()
   299  
   300  	l.Infof("%s %s %d, RemoteAddr: %s, Request: [%s], Body: %s",
   301  		c.Request.Method,
   302  		c.Request.URL,
   303  		c.Writer.Status(),
   304  		c.Request.RemoteAddr,
   305  		FormatRequest(c.Request),
   306  		body[:len(body)%MaxRequestBodyLen]+"...")
   307  }
   308  
   309  func GinReadWithMD5(c *gin.Context) (buf []byte, md5str string, err error) {
   310  	buf, err = readBody(c)
   311  	if err != nil {
   312  		return
   313  	}
   314  
   315  	md5str = fmt.Sprintf("%x", md5.Sum(buf)) //nolint:gosec
   316  
   317  	if c.Request.Header.Get("Content-Encoding") == "gzip" {
   318  		buf, err = Unzip(buf)
   319  	}
   320  
   321  	return
   322  }
   323  
   324  func GinRead(c *gin.Context) (buf []byte, err error) {
   325  	buf, err = readBody(c)
   326  	if err != nil {
   327  		return
   328  	}
   329  
   330  	if c.Request.Header.Get("Content-Encoding") == "gzip" {
   331  		buf, err = Unzip(buf)
   332  	}
   333  
   334  	return
   335  }
   336  
   337  func GinGetArg(c *gin.Context, hdr, param string) (v string, err error) {
   338  	v = c.Request.Header.Get(hdr)
   339  	if v == "" {
   340  		v = c.Query(param)
   341  		if v == "" {
   342  			err = fmt.Errorf("HTTP header %s and query param %s missing", hdr, param)
   343  		}
   344  	}
   345  	return
   346  }
   347  
   348  func Unzip(in []byte) (out []byte, err error) {
   349  	gzr, err := gzip.NewReader(bytes.NewBuffer(in))
   350  	if err != nil {
   351  		return
   352  	}
   353  
   354  	out, err = io.ReadAll(gzr)
   355  	if err != nil {
   356  		return
   357  	}
   358  
   359  	if err := gzr.Close(); err != nil {
   360  		_ = err // pass
   361  	}
   362  	return
   363  }
   364  
   365  func readBody(c *gin.Context) ([]byte, error) {
   366  	body, err := io.ReadAll(c.Request.Body)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	defer func() {
   372  		if err := c.Request.Body.Close(); err != nil {
   373  			_ = err // pass
   374  		}
   375  	}()
   376  	return body, nil
   377  }