github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/framework/http/middleware.go (about)

     1  package ddhttp
     2  
     3  import (
     4  	"context"
     5  	"crypto/subtle"
     6  	"fmt"
     7  	"github.com/apolloconfig/agollo/v4/storage"
     8  	"github.com/ascarter/requestid"
     9  	"github.com/felixge/httpsnoop"
    10  	"github.com/opentracing-contrib/go-stdlib/nethttp"
    11  	"github.com/opentracing/opentracing-go"
    12  	"github.com/pkg/errors"
    13  	"github.com/slok/goresilience"
    14  	"github.com/slok/goresilience/bulkhead"
    15  	"github.com/uber/jaeger-client-go"
    16  	"github.com/unionj-cloud/go-doudou/framework/configmgr"
    17  	"github.com/unionj-cloud/go-doudou/framework/http/model"
    18  	"github.com/unionj-cloud/go-doudou/framework/internal/config"
    19  	"github.com/unionj-cloud/go-doudou/toolkit/stringutils"
    20  	logger "github.com/unionj-cloud/go-doudou/toolkit/zlogger"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"os"
    26  	"runtime/debug"
    27  	"strings"
    28  	"time"
    29  )
    30  
    31  type httpConfigListener struct {
    32  	configmgr.BaseApolloListener
    33  }
    34  
    35  func NewHttpConfigListener() *httpConfigListener {
    36  	return &httpConfigListener{}
    37  }
    38  
    39  func (c *httpConfigListener) OnChange(event *storage.ChangeEvent) {
    40  	c.Lock.Lock()
    41  	defer c.Lock.Unlock()
    42  	if !c.SkippedFirstEvent {
    43  		c.SkippedFirstEvent = true
    44  		return
    45  	}
    46  	for key, value := range event.Changes {
    47  		upperKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_"))
    48  		if strings.HasPrefix(upperKey, "GDD_MANAGE_") {
    49  			_ = os.Setenv(upperKey, fmt.Sprint(value.NewValue))
    50  		}
    51  	}
    52  }
    53  
    54  func CallbackOnChange(listener *httpConfigListener) func(event *configmgr.NacosChangeEvent) {
    55  	return func(event *configmgr.NacosChangeEvent) {
    56  		changes := make(map[string]*storage.ConfigChange)
    57  		for k, v := range event.Changes {
    58  			changes[k] = &storage.ConfigChange{
    59  				OldValue:   v.OldValue,
    60  				NewValue:   v.NewValue,
    61  				ChangeType: storage.ConfigChangeType(v.ChangeType),
    62  			}
    63  		}
    64  		changeEvent := &storage.ChangeEvent{
    65  			Changes: changes,
    66  		}
    67  		listener.OnChange(changeEvent)
    68  	}
    69  }
    70  
    71  func InitialiseRemoteConfigListener() {
    72  	listener := &httpConfigListener{}
    73  	configType := config.GddConfigRemoteType.LoadOrDefault(config.DefaultGddConfigRemoteType)
    74  	switch configType {
    75  	case "":
    76  		return
    77  	case config.NacosConfigType:
    78  		dataIdStr := config.GddNacosConfigDataid.LoadOrDefault(config.DefaultGddNacosConfigDataid)
    79  		dataIds := strings.Split(dataIdStr, ",")
    80  		listener.SkippedFirstEvent = true
    81  		for _, dataId := range dataIds {
    82  			configmgr.NacosClient.AddChangeListener(configmgr.NacosConfigListenerParam{
    83  				DataId:   "__" + dataId + "__" + "ddhttp",
    84  				OnChange: CallbackOnChange(listener),
    85  			})
    86  		}
    87  	case config.ApolloConfigType:
    88  		configmgr.ApolloClient.AddChangeListener(listener)
    89  	default:
    90  		logger.Warn().Msgf("[go-doudou] from ddhttp pkg: unknown config type: %s\n", configType)
    91  	}
    92  }
    93  
    94  func init() {
    95  	InitialiseRemoteConfigListener()
    96  }
    97  
    98  // metrics logs some metrics for http request
    99  func metrics(inner http.Handler) http.Handler {
   100  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   101  		m := httpsnoop.CaptureMetrics(inner, w, r)
   102  		logger.Info().
   103  			Msgf("%s\t%s\t%s\t%d\t%d\t%s", r.RemoteAddr,
   104  				r.Method,
   105  				r.URL,
   106  				m.Code,
   107  				m.Written,
   108  				m.Duration.String())
   109  
   110  	})
   111  }
   112  
   113  // log logs http request body and response body for debugging
   114  func log(inner http.Handler) http.Handler {
   115  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   116  		var (
   117  			reqBodyCopy io.ReadCloser
   118  			err         error
   119  			traceId     string
   120  		)
   121  		if reqBodyCopy, r.Body, err = model.CopyReqBody(r.Body); err != nil {
   122  			logger.Error().Err(err).Msg("call copyReqBody(r.Body) error")
   123  		}
   124  
   125  		rec := httptest.NewRecorder()
   126  		start := time.Now()
   127  		inner.ServeHTTP(rec, r)
   128  		elapsed := time.Since(start)
   129  		reqBody := model.GetReqBody(reqBodyCopy, r)
   130  		rid, _ := requestid.FromContext(r.Context())
   131  		span := opentracing.SpanFromContext(r.Context())
   132  		if jspan, ok := span.(*jaeger.Span); ok {
   133  			traceId = jspan.SpanContext().TraceID().String()
   134  		}
   135  		respBody := model.GetRespBody(rec)
   136  		reqQuery := r.URL.RawQuery
   137  		if unescape, err := url.QueryUnescape(reqQuery); err == nil {
   138  			reqQuery = unescape
   139  		}
   140  		fields := map[string]interface{}{
   141  			"remoteAddr":        r.RemoteAddr,
   142  			"httpMethod":        r.Method,
   143  			"requestUrl":        r.URL.String(),
   144  			"proto":             r.Proto,
   145  			"host":              r.Host,
   146  			"reqContentLength":  r.ContentLength,
   147  			"reqHeader":         r.Header,
   148  			"requestId":         rid,
   149  			"reqQuery":          reqQuery,
   150  			"reqBody":           reqBody,
   151  			"respBody":          respBody,
   152  			"statusCode":        rec.Result().StatusCode,
   153  			"respHeader":        rec.Result().Header,
   154  			"respContentLength": rec.Body.Len(),
   155  			"elapsedTime":       elapsed.String(),
   156  			"elapsed":           elapsed.Milliseconds(),
   157  			"span":              span,
   158  			"traceId":           traceId,
   159  		}
   160  		var reqLog string
   161  		if reqLog, err = model.JsonMarshalIndent(fields, "", "    ", true); err != nil {
   162  			reqLog = fmt.Sprintf("call jsonMarshalIndent(fields, \"\", \"    \", true) error: %s", err)
   163  		}
   164  		logger.Info().Fields(fields).Msg(reqLog)
   165  		header := rec.Result().Header
   166  		for k, v := range header {
   167  			w.Header()[k] = v
   168  		}
   169  		w.WriteHeader(rec.Result().StatusCode)
   170  		rec.Body.WriteTo(w)
   171  	})
   172  }
   173  
   174  // rest set Content-Type to application/json
   175  func rest(inner http.Handler) http.Handler {
   176  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   177  		if stringutils.IsEmpty(w.Header().Get("Content-Type")) {
   178  			w.Header().Set("Content-Type", "application/json; charset=UTF-8")
   179  		}
   180  		inner.ServeHTTP(w, r)
   181  	})
   182  }
   183  
   184  // fallbackContentType set fallback response Content-Type to contentType
   185  func fallbackContentType(contentType string) func(inner http.Handler) http.Handler {
   186  	return func(inner http.Handler) http.Handler {
   187  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   188  			if stringutils.IsEmpty(w.Header().Get("Content-Type")) {
   189  				w.Header().Set("Content-Type", contentType)
   190  			}
   191  			inner.ServeHTTP(w, r)
   192  		})
   193  	}
   194  }
   195  
   196  // basicAuth adds http basic auth validation
   197  func basicAuth() func(inner http.Handler) http.Handler {
   198  	username := config.DefaultGddManageUser
   199  	password := config.DefaultGddManagePass
   200  	return func(inner http.Handler) http.Handler {
   201  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   202  			if stringutils.IsNotEmpty(config.GddManageUser.Load()) {
   203  				username = config.GddManageUser.Load()
   204  			}
   205  			if stringutils.IsNotEmpty(config.GddManagePass.Load()) {
   206  				password = config.GddManagePass.Load()
   207  			}
   208  			user, pass, ok := r.BasicAuth()
   209  			if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(username)) != 1 || subtle.ConstantTimeCompare([]byte(pass), []byte(password)) != 1 {
   210  				w.Header().Set("WWW-Authenticate", `Basic realm="Provide user name and password"`)
   211  				w.WriteHeader(401)
   212  				w.Write([]byte("Unauthorised.\n"))
   213  				return
   214  			}
   215  			inner.ServeHTTP(w, r)
   216  		})
   217  	}
   218  }
   219  
   220  // recovery handles panic from processing incoming http request
   221  func recovery(inner http.Handler) http.Handler {
   222  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   223  		defer func() {
   224  			if e := recover(); e != nil {
   225  				statusCode := http.StatusInternalServerError
   226  				respErr := fmt.Sprintf("%v", e)
   227  				if err, ok := e.(error); ok {
   228  					if errors.Is(err, context.Canceled) {
   229  						statusCode = http.StatusBadRequest
   230  					} else {
   231  						var bizErr model.BizError
   232  						if errors.As(err, &bizErr) {
   233  							statusCode = bizErr.StatusCode
   234  							if bizErr.Cause != nil {
   235  								e = bizErr.Cause
   236  							}
   237  							respErr = bizErr.Error()
   238  						}
   239  					}
   240  				}
   241  				logger.Error().Msgf("panic: %+v\n\nstacktrace from panic: %s\n", e, string(debug.Stack()))
   242  				http.Error(w, respErr, statusCode)
   243  			}
   244  		}()
   245  		inner.ServeHTTP(w, r)
   246  	})
   247  }
   248  
   249  // tracing add jaeger tracing middleware
   250  func tracing(inner http.Handler) http.Handler {
   251  	return nethttp.Middleware(
   252  		opentracing.GlobalTracer(),
   253  		inner,
   254  		nethttp.OperationNameFunc(func(r *http.Request) string {
   255  			return fmt.Sprintf("HTTP %s: %s", r.Method, r.URL.Path)
   256  		}))
   257  }
   258  
   259  var RunnerChain = goresilience.RunnerChain
   260  
   261  // BulkHead add bulk head pattern middleware based on https://github.com/slok/goresilience
   262  // workers is the number of workers in the execution pool.
   263  // maxWaitTime is the max time an incoming request will wait to execute before being dropped its execution and return 429 response.
   264  func BulkHead(workers int, maxWaitTime time.Duration) func(inner http.Handler) http.Handler {
   265  	runner := RunnerChain(
   266  		bulkhead.NewMiddleware(bulkhead.Config{
   267  			Workers:     workers,
   268  			MaxWaitTime: maxWaitTime,
   269  		}),
   270  	)
   271  	return func(inner http.Handler) http.Handler {
   272  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   273  			err := runner.Run(r.Context(), func(_ context.Context) error {
   274  				inner.ServeHTTP(w, r)
   275  				return nil
   276  			})
   277  			if err != nil {
   278  				http.Error(w, "too many requests", http.StatusTooManyRequests)
   279  			}
   280  		})
   281  	}
   282  }
   283  
   284  func BodyMaxBytes(n int64) func(inner http.Handler) http.Handler {
   285  	return func(inner http.Handler) http.Handler {
   286  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   287  			r2 := *r
   288  			r2.Body = http.MaxBytesReader(w, r.Body, n)
   289  			inner.ServeHTTP(w, &r2)
   290  		})
   291  	}
   292  }