github.com/unionj-cloud/go-doudou@v1.3.8-0.20221011095552-0088008e5b31/framework/http/middleware.go (about) 1 package ddhttp 2 3 import ( 4 "context" 5 "crypto/subtle" 6 "fmt" 7 "github.com/apolloconfig/agollo/v4/storage" 8 "github.com/ascarter/requestid" 9 "github.com/felixge/httpsnoop" 10 "github.com/opentracing-contrib/go-stdlib/nethttp" 11 "github.com/opentracing/opentracing-go" 12 "github.com/pkg/errors" 13 "github.com/slok/goresilience" 14 "github.com/slok/goresilience/bulkhead" 15 "github.com/uber/jaeger-client-go" 16 "github.com/unionj-cloud/go-doudou/framework/configmgr" 17 "github.com/unionj-cloud/go-doudou/framework/http/model" 18 "github.com/unionj-cloud/go-doudou/framework/internal/config" 19 "github.com/unionj-cloud/go-doudou/toolkit/stringutils" 20 logger "github.com/unionj-cloud/go-doudou/toolkit/zlogger" 21 "io" 22 "net/http" 23 "net/http/httptest" 24 "net/url" 25 "os" 26 "runtime/debug" 27 "strings" 28 "time" 29 ) 30 31 type httpConfigListener struct { 32 configmgr.BaseApolloListener 33 } 34 35 func NewHttpConfigListener() *httpConfigListener { 36 return &httpConfigListener{} 37 } 38 39 func (c *httpConfigListener) OnChange(event *storage.ChangeEvent) { 40 c.Lock.Lock() 41 defer c.Lock.Unlock() 42 if !c.SkippedFirstEvent { 43 c.SkippedFirstEvent = true 44 return 45 } 46 for key, value := range event.Changes { 47 upperKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_")) 48 if strings.HasPrefix(upperKey, "GDD_MANAGE_") { 49 _ = os.Setenv(upperKey, fmt.Sprint(value.NewValue)) 50 } 51 } 52 } 53 54 func CallbackOnChange(listener *httpConfigListener) func(event *configmgr.NacosChangeEvent) { 55 return func(event *configmgr.NacosChangeEvent) { 56 changes := make(map[string]*storage.ConfigChange) 57 for k, v := range event.Changes { 58 changes[k] = &storage.ConfigChange{ 59 OldValue: v.OldValue, 60 NewValue: v.NewValue, 61 ChangeType: storage.ConfigChangeType(v.ChangeType), 62 } 63 } 64 changeEvent := &storage.ChangeEvent{ 65 Changes: changes, 66 } 67 listener.OnChange(changeEvent) 68 } 69 } 70 71 func InitialiseRemoteConfigListener() { 72 listener := &httpConfigListener{} 73 configType := config.GddConfigRemoteType.LoadOrDefault(config.DefaultGddConfigRemoteType) 74 switch configType { 75 case "": 76 return 77 case config.NacosConfigType: 78 dataIdStr := config.GddNacosConfigDataid.LoadOrDefault(config.DefaultGddNacosConfigDataid) 79 dataIds := strings.Split(dataIdStr, ",") 80 listener.SkippedFirstEvent = true 81 for _, dataId := range dataIds { 82 configmgr.NacosClient.AddChangeListener(configmgr.NacosConfigListenerParam{ 83 DataId: "__" + dataId + "__" + "ddhttp", 84 OnChange: CallbackOnChange(listener), 85 }) 86 } 87 case config.ApolloConfigType: 88 configmgr.ApolloClient.AddChangeListener(listener) 89 default: 90 logger.Warn().Msgf("[go-doudou] from ddhttp pkg: unknown config type: %s\n", configType) 91 } 92 } 93 94 func init() { 95 InitialiseRemoteConfigListener() 96 } 97 98 // metrics logs some metrics for http request 99 func metrics(inner http.Handler) http.Handler { 100 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 101 m := httpsnoop.CaptureMetrics(inner, w, r) 102 logger.Info(). 103 Msgf("%s\t%s\t%s\t%d\t%d\t%s", r.RemoteAddr, 104 r.Method, 105 r.URL, 106 m.Code, 107 m.Written, 108 m.Duration.String()) 109 110 }) 111 } 112 113 // log logs http request body and response body for debugging 114 func log(inner http.Handler) http.Handler { 115 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 var ( 117 reqBodyCopy io.ReadCloser 118 err error 119 traceId string 120 ) 121 if reqBodyCopy, r.Body, err = model.CopyReqBody(r.Body); err != nil { 122 logger.Error().Err(err).Msg("call copyReqBody(r.Body) error") 123 } 124 125 rec := httptest.NewRecorder() 126 start := time.Now() 127 inner.ServeHTTP(rec, r) 128 elapsed := time.Since(start) 129 reqBody := model.GetReqBody(reqBodyCopy, r) 130 rid, _ := requestid.FromContext(r.Context()) 131 span := opentracing.SpanFromContext(r.Context()) 132 if jspan, ok := span.(*jaeger.Span); ok { 133 traceId = jspan.SpanContext().TraceID().String() 134 } 135 respBody := model.GetRespBody(rec) 136 reqQuery := r.URL.RawQuery 137 if unescape, err := url.QueryUnescape(reqQuery); err == nil { 138 reqQuery = unescape 139 } 140 fields := map[string]interface{}{ 141 "remoteAddr": r.RemoteAddr, 142 "httpMethod": r.Method, 143 "requestUrl": r.URL.String(), 144 "proto": r.Proto, 145 "host": r.Host, 146 "reqContentLength": r.ContentLength, 147 "reqHeader": r.Header, 148 "requestId": rid, 149 "reqQuery": reqQuery, 150 "reqBody": reqBody, 151 "respBody": respBody, 152 "statusCode": rec.Result().StatusCode, 153 "respHeader": rec.Result().Header, 154 "respContentLength": rec.Body.Len(), 155 "elapsedTime": elapsed.String(), 156 "elapsed": elapsed.Milliseconds(), 157 "span": span, 158 "traceId": traceId, 159 } 160 var reqLog string 161 if reqLog, err = model.JsonMarshalIndent(fields, "", " ", true); err != nil { 162 reqLog = fmt.Sprintf("call jsonMarshalIndent(fields, \"\", \" \", true) error: %s", err) 163 } 164 logger.Info().Fields(fields).Msg(reqLog) 165 header := rec.Result().Header 166 for k, v := range header { 167 w.Header()[k] = v 168 } 169 w.WriteHeader(rec.Result().StatusCode) 170 rec.Body.WriteTo(w) 171 }) 172 } 173 174 // rest set Content-Type to application/json 175 func rest(inner http.Handler) http.Handler { 176 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 177 if stringutils.IsEmpty(w.Header().Get("Content-Type")) { 178 w.Header().Set("Content-Type", "application/json; charset=UTF-8") 179 } 180 inner.ServeHTTP(w, r) 181 }) 182 } 183 184 // fallbackContentType set fallback response Content-Type to contentType 185 func fallbackContentType(contentType string) func(inner http.Handler) http.Handler { 186 return func(inner http.Handler) http.Handler { 187 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 188 if stringutils.IsEmpty(w.Header().Get("Content-Type")) { 189 w.Header().Set("Content-Type", contentType) 190 } 191 inner.ServeHTTP(w, r) 192 }) 193 } 194 } 195 196 // basicAuth adds http basic auth validation 197 func basicAuth() func(inner http.Handler) http.Handler { 198 username := config.DefaultGddManageUser 199 password := config.DefaultGddManagePass 200 return func(inner http.Handler) http.Handler { 201 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 202 if stringutils.IsNotEmpty(config.GddManageUser.Load()) { 203 username = config.GddManageUser.Load() 204 } 205 if stringutils.IsNotEmpty(config.GddManagePass.Load()) { 206 password = config.GddManagePass.Load() 207 } 208 user, pass, ok := r.BasicAuth() 209 if !ok || subtle.ConstantTimeCompare([]byte(user), []byte(username)) != 1 || subtle.ConstantTimeCompare([]byte(pass), []byte(password)) != 1 { 210 w.Header().Set("WWW-Authenticate", `Basic realm="Provide user name and password"`) 211 w.WriteHeader(401) 212 w.Write([]byte("Unauthorised.\n")) 213 return 214 } 215 inner.ServeHTTP(w, r) 216 }) 217 } 218 } 219 220 // recovery handles panic from processing incoming http request 221 func recovery(inner http.Handler) http.Handler { 222 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 223 defer func() { 224 if e := recover(); e != nil { 225 statusCode := http.StatusInternalServerError 226 respErr := fmt.Sprintf("%v", e) 227 if err, ok := e.(error); ok { 228 if errors.Is(err, context.Canceled) { 229 statusCode = http.StatusBadRequest 230 } else { 231 var bizErr model.BizError 232 if errors.As(err, &bizErr) { 233 statusCode = bizErr.StatusCode 234 if bizErr.Cause != nil { 235 e = bizErr.Cause 236 } 237 respErr = bizErr.Error() 238 } 239 } 240 } 241 logger.Error().Msgf("panic: %+v\n\nstacktrace from panic: %s\n", e, string(debug.Stack())) 242 http.Error(w, respErr, statusCode) 243 } 244 }() 245 inner.ServeHTTP(w, r) 246 }) 247 } 248 249 // tracing add jaeger tracing middleware 250 func tracing(inner http.Handler) http.Handler { 251 return nethttp.Middleware( 252 opentracing.GlobalTracer(), 253 inner, 254 nethttp.OperationNameFunc(func(r *http.Request) string { 255 return fmt.Sprintf("HTTP %s: %s", r.Method, r.URL.Path) 256 })) 257 } 258 259 var RunnerChain = goresilience.RunnerChain 260 261 // BulkHead add bulk head pattern middleware based on https://github.com/slok/goresilience 262 // workers is the number of workers in the execution pool. 263 // maxWaitTime is the max time an incoming request will wait to execute before being dropped its execution and return 429 response. 264 func BulkHead(workers int, maxWaitTime time.Duration) func(inner http.Handler) http.Handler { 265 runner := RunnerChain( 266 bulkhead.NewMiddleware(bulkhead.Config{ 267 Workers: workers, 268 MaxWaitTime: maxWaitTime, 269 }), 270 ) 271 return func(inner http.Handler) http.Handler { 272 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 273 err := runner.Run(r.Context(), func(_ context.Context) error { 274 inner.ServeHTTP(w, r) 275 return nil 276 }) 277 if err != nil { 278 http.Error(w, "too many requests", http.StatusTooManyRequests) 279 } 280 }) 281 } 282 } 283 284 func BodyMaxBytes(n int64) func(inner http.Handler) http.Handler { 285 return func(inner http.Handler) http.Handler { 286 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 287 r2 := *r 288 r2.Body = http.MaxBytesReader(w, r.Body, n) 289 inner.ServeHTTP(w, &r2) 290 }) 291 } 292 }