github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/handler.go (about)

     1  package znet
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/sohaha/zlsgo/zdi"
    14  	"github.com/sohaha/zlsgo/zlog"
    15  	"github.com/sohaha/zlsgo/zreflect"
    16  	"github.com/sohaha/zlsgo/zutil"
    17  )
    18  
    19  type (
    20  	invokerCodeText func() (int, string)
    21  )
    22  
    23  var (
    24  	_ zdi.PreInvoker = (*invokerCodeText)(nil)
    25  )
    26  
    27  func (h invokerCodeText) Invoke(_ []interface{}) ([]reflect.Value, error) {
    28  	code, text := h()
    29  	return []reflect.Value{zreflect.ValueOf(code), reflect.ValueOf(text)}, nil
    30  }
    31  
    32  func defErrorHandler() ErrHandlerFunc {
    33  	return func(c *Context, err error) {
    34  		c.String(500, err.Error())
    35  	}
    36  }
    37  
    38  // RewriteErrorHandler rewrite error handler
    39  func RewriteErrorHandler(handler ErrHandlerFunc) Handler {
    40  	return func(c *Context) {
    41  		c.renderError = handler
    42  		c.Next()
    43  	}
    44  }
    45  
    46  // Recovery is a middleware that recovers from panics anywhere in the chain
    47  func Recovery(handler ErrHandlerFunc) Handler {
    48  	return func(c *Context) {
    49  		defer func() {
    50  			if err := recover(); err != nil {
    51  				errMsg, ok := err.(error)
    52  				if !ok {
    53  					errMsg = errors.New(fmt.Sprint(err))
    54  				}
    55  				handler(c, errMsg)
    56  			}
    57  		}()
    58  		c.Next()
    59  	}
    60  }
    61  
    62  func requestLog(c *Context) {
    63  	if c.Engine.IsDebug() {
    64  		var status string
    65  		end := time.Now()
    66  		statusCode := zutil.GetBuff()
    67  		latency := end.Sub(c.startTime)
    68  		code := c.prevData.Code.Load()
    69  		statusCode.WriteString(" ")
    70  		statusCode.WriteString(strconv.FormatInt(int64(code), 10))
    71  		statusCode.WriteString(" ")
    72  		s := statusCode.String()
    73  		zutil.PutBuff(statusCode)
    74  		switch {
    75  		case code >= 200 && code <= 299:
    76  			status = c.Log.ColorBackgroundWrap(zlog.ColorBlack, zlog.ColorGreen, s)
    77  		case code >= 300 && code <= 399:
    78  			status = c.Log.ColorBackgroundWrap(zlog.ColorBlack, zlog.ColorYellow, s)
    79  		default:
    80  			status = c.Log.ColorBackgroundWrap(zlog.ColorBlack, zlog.ColorRed, s)
    81  		}
    82  		clientIP := c.GetClientIP()
    83  		if clientIP == "" {
    84  			clientIP = "unknown"
    85  		}
    86  		ft := fmt.Sprintf("%s %15s %15v %%s %%s", status, clientIP, latency)
    87  		c.Log.Success(routeLog(c.Log, ft, c.Request.Method, c.Request.RequestURI))
    88  	}
    89  }
    90  
    91  const errURLQuerySemicolon = "http: URL query contains semicolon, which is no longer a supported separator; parts of the query may be stripped when parsed; see golang.org/issue/25192\n"
    92  
    93  func allowQuerySemicolons(r *http.Request) {
    94  	// clopy of net/http.AllowQuerySemicolons.
    95  	if s := r.URL.RawQuery; strings.Contains(s, ";") {
    96  		r2 := new(http.Request)
    97  		*r2 = *r
    98  		r2.URL = new(url.URL)
    99  		*r2.URL = *r.URL
   100  		r2.URL.RawQuery = strings.Replace(s, ";", "&", -1)
   101  		*r = *r2
   102  	}
   103  }