github.com/erda-project/erda-infra@v1.0.9/providers/httpserver/handler.go (about) 1 // Copyright (c) 2021 Terminus, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package httpserver 16 17 import ( 18 "bytes" 19 "encoding/json" 20 "fmt" 21 "io" 22 "io/ioutil" 23 "net/http" 24 "reflect" 25 "strconv" 26 27 "github.com/erda-project/erda-infra/providers/httpserver/server" 28 "github.com/go-playground/validator" 29 "github.com/labstack/echo" 30 "github.com/recallsong/go-utils/errorx" 31 "github.com/recallsong/go-utils/reflectx" 32 ) 33 34 type ( 35 // Response . 36 Response interface { 37 Status(Context) int 38 ReadCloser(Context) io.ReadCloser 39 Error(Context) error 40 } 41 // ResponseGetter . 42 ResponseGetter interface { 43 Response(ctx Context) Response 44 } 45 // Interceptor . 46 Interceptor func(handler func(ctx Context) error) func(ctx Context) error 47 ) 48 49 func getInterceptors(options []interface{}) []server.MiddlewareFunc { 50 var list []server.MiddlewareFunc 51 for _, opt := range options { 52 var inter Interceptor 53 switch val := opt.(type) { 54 case Interceptor: 55 inter = val 56 case func(handler func(ctx Context) error) func(ctx Context) error: 57 inter = Interceptor(val) 58 case server.MiddlewareFunc: 59 list = append(list, val) 60 case func(server.HandlerFunc) server.HandlerFunc: 61 list = append(list, val) 62 default: 63 continue 64 } 65 if inter != nil { 66 list = append(list, func(fn server.HandlerFunc) server.HandlerFunc { 67 handler := inter(func(ctx Context) error { 68 return fn(ctx.(*context)) 69 }) 70 return func(ctx server.Context) error { 71 return handler(ctx.(*context)) 72 } 73 }) 74 } 75 } 76 return list 77 } 78 79 func (r *router) add(method, path string, handler interface{}, inters []server.MiddlewareFunc, outer server.MiddlewareFunc) server.HandlerFunc { 80 var echoHandler server.HandlerFunc 81 switch fn := handler.(type) { 82 case server.HandlerFunc: 83 echoHandler = fn 84 case func(server.Context) error: 85 echoHandler = server.HandlerFunc(fn) 86 case func(server.Context): 87 echoHandler = server.HandlerFunc(func(ctx server.Context) error { 88 fn(ctx) 89 return nil 90 }) 91 case http.HandlerFunc: 92 echoHandler = server.HandlerFunc(func(ctx server.Context) error { 93 fn(ctx.Response(), ctx.Request()) 94 return nil 95 }) 96 case func(http.ResponseWriter, *http.Request): 97 echoHandler = server.HandlerFunc(func(ctx server.Context) error { 98 fn(ctx.Response(), ctx.Request()) 99 return nil 100 }) 101 case func(*http.Request, http.ResponseWriter): 102 echoHandler = server.HandlerFunc(func(ctx server.Context) error { 103 fn(ctx.Request(), ctx.Response()) 104 return nil 105 }) 106 case http.Handler: 107 echoHandler = server.HandlerFunc(func(ctx server.Context) error { 108 fn.ServeHTTP(ctx.Response(), ctx.Request()) 109 return nil 110 }) 111 default: 112 echoHandler = r.handlerWrap(handler) 113 if echoHandler == nil { 114 panic(fmt.Errorf("%s %s: not support http server handler type: %v", method, path, handler)) 115 } 116 } 117 if outer != nil { 118 list := make([]server.MiddlewareFunc, 1+len(r.interceptors)+len(inters)) 119 list[0] = outer 120 copy(list[1:], r.interceptors) 121 copy(list[1+len(r.interceptors):], inters) 122 inters = list 123 } else { 124 inters = append(r.interceptors[0:len(r.interceptors):len(r.interceptors)], inters...) 125 } 126 if len(inters) > 0 { 127 handler := echoHandler 128 for i := len(inters) - 1; i >= 0; i-- { 129 handler = inters[i](handler) 130 } 131 echoHandler = handler 132 } 133 r.tx.Add(method, path, echoHandler) 134 return echoHandler 135 } 136 137 var ( 138 readerType = reflect.TypeOf((*io.Reader)(nil)).Elem() 139 readCloserType = reflect.TypeOf((*io.ReadCloser)(nil)).Elem() 140 errorType = reflect.TypeOf((*error)(nil)).Elem() 141 requestType = reflect.TypeOf((*http.Request)(nil)) 142 responseType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() 143 echoContextType = reflect.TypeOf((*server.Context)(nil)).Elem() 144 contextType = reflect.TypeOf((*Context)(nil)).Elem() 145 interfaceType = reflect.TypeOf((*interface{})(nil)).Elem() 146 ) 147 148 func (r *router) handlerWrap(handler interface{}) server.HandlerFunc { 149 typ := reflect.TypeOf(handler) 150 if typ.Kind() == reflect.Func { 151 val := reflect.ValueOf(handler) 152 var argGets []func(ctx server.Context) (interface{}, error) 153 argNum := typ.NumIn() 154 for i := 0; i < argNum; i++ { 155 argTyp := typ.In(i) 156 getter := argGetter(argTyp) 157 if getter == nil { 158 return nil 159 } 160 argGets = append(argGets, getter) 161 } 162 retNum := typ.NumOut() 163 if retNum > 3 { 164 return nil 165 } 166 var retGet func(values []reflect.Value) (*int, io.ReadCloser, io.Reader, interface{}, error) 167 var retIndex [5]*int 168 var hasRet bool 169 for i := 0; i < retNum; i++ { 170 retTyp := typ.Out(i) 171 index := i 172 if retTyp.Kind() == reflect.Int { 173 if retIndex[0] == nil { 174 retIndex[0] = &index 175 hasRet = true 176 continue 177 } 178 } else if retTyp.AssignableTo(readCloserType) { 179 if retIndex[1] == nil { 180 retIndex[1] = &index 181 hasRet = true 182 continue 183 } 184 } else if retTyp.AssignableTo(readerType) { 185 if retIndex[2] == nil { 186 retIndex[2] = &index 187 hasRet = true 188 continue 189 } 190 } else if retTyp == errorType { 191 if retIndex[3] == nil { 192 retIndex[3] = &index 193 hasRet = true 194 continue 195 } 196 } else if retTyp == interfaceType { 197 if retIndex[4] == nil { 198 retIndex[4] = &index 199 hasRet = true 200 continue 201 } 202 } 203 return nil 204 } 205 if hasRet { 206 retGet = func(values []reflect.Value) (status *int, readerCloser io.ReadCloser, reader io.Reader, data interface{}, err error) { 207 if retIndex[0] != nil { 208 val := int(values[*retIndex[0]].Int()) 209 status = &val 210 } 211 if retIndex[1] != nil { 212 val := values[*retIndex[1]].Interface() 213 readerCloser = val.(io.ReadCloser) 214 } 215 if retIndex[2] != nil { 216 val := values[*retIndex[2]].Interface() 217 reader = val.(io.Reader) 218 } 219 if retIndex[3] != nil { 220 val := values[*retIndex[3]].Interface() 221 if val != nil { 222 err = val.(error) 223 } 224 } 225 if retIndex[4] != nil { 226 data = values[*retIndex[4]].Interface() 227 } 228 return 229 } 230 } 231 return server.HandlerFunc(func(ctx server.Context) error { 232 var values []reflect.Value 233 for _, getter := range argGets { 234 val, err := getter(ctx) 235 if err != nil { 236 if _, ok := err.(validator.ValidationErrors); ok { 237 //TODO: custom error encode 238 return ctx.JSON(400, map[string]interface{}{ 239 "success": false, 240 "err": map[string]interface{}{ 241 "code": "400", 242 "msg": err.Error(), 243 }, 244 }) 245 } 246 if herr, ok := err.(*echo.HTTPError); ok { 247 if http.StatusBadRequest <= herr.Code && herr.Code < http.StatusInternalServerError { 248 //TODO: custom error encode 249 ctx.JSON(400, map[string]interface{}{ 250 "success": false, 251 "err": map[string]interface{}{ 252 "code": strconv.Itoa(herr.Code), 253 "msg": herr.Message, 254 }, 255 }) 256 } 257 } 258 return err 259 } 260 value := reflect.ValueOf(val) 261 values = append(values, value) 262 } 263 returns := val.Call(values) 264 if retGet == nil { 265 return nil 266 } 267 status, readCloser, reader, data, err := retGet(returns) 268 if data != nil { 269 var resp Response 270 context := ctx.(Context) 271 switch val := data.(type) { 272 case ResponseGetter: 273 resp = val.Response(context) 274 case Response: 275 resp = val 276 } 277 if resp != nil { 278 rc := resp.ReadCloser(context) 279 if rc != nil { 280 readCloser = rc 281 } 282 statusCode := resp.Status(context) 283 if statusCode > 0 { 284 status = &statusCode 285 } 286 e := resp.Error(context) 287 if e != nil { 288 err = e 289 } 290 } 291 } 292 if status != nil { 293 ctx.Response().WriteHeader(*status) 294 } 295 var errs errorx.Errors 296 if err != nil { 297 errs = append(errs, err) 298 } 299 if readCloser != nil { 300 defer readCloser.Close() 301 _, err = io.Copy(ctx.Response(), readCloser) 302 if err != nil { 303 errs = append(errs, err) 304 } 305 } else if reader != nil { 306 _, err = io.Copy(ctx.Response(), reader) 307 if err != nil { 308 errs = append(errs, err) 309 } 310 } else if data != nil { 311 switch val := data.(type) { 312 case string: 313 _, err = ctx.Response().Write(reflectx.StringToBytes(val)) 314 case []byte: 315 _, err = ctx.Response().Write(val) 316 default: 317 err = json.NewEncoder(ctx.Response()).Encode(data) 318 } 319 if err != nil { 320 errs = append(errs, err) 321 } 322 } 323 return errs.MaybeUnwrap() 324 }) 325 } 326 return nil 327 } 328 329 func argGetter(argTyp reflect.Type) func(ctx server.Context) (interface{}, error) { 330 if argTyp == requestType { 331 return requestGetter 332 } else if argTyp == responseType { 333 return responseGetter 334 } else if argTyp == contextType || argTyp == echoContextType { 335 return contextGetter 336 } else { 337 kind := argTyp.Kind() 338 if kind == reflect.String { 339 return requestBodyStirngGetter 340 } else if kind == reflect.Slice && argTyp.Elem().Kind() == reflect.Uint8 { 341 return requestBodyBytesGetter 342 } 343 typ := argTyp 344 for kind == reflect.Ptr { 345 typ = typ.Elem() 346 kind = typ.Kind() 347 } 348 switch kind { 349 case reflect.Struct: 350 var validate bool 351 for i, num := 0, typ.NumField(); i < num; i++ { 352 if len(typ.Field(i).Tag.Get("validate")) > 0 { 353 validate = true 354 break 355 } 356 } 357 return requestDataBind(argTyp, validate) 358 case reflect.Map, reflect.Interface: 359 return requestDataBind(argTyp, false) 360 case reflect.String: 361 return requestBodyStirngGetter 362 case reflect.Bool, 363 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 364 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 365 reflect.Float32, reflect.Float64, 366 reflect.Array, reflect.Slice: 367 return requestValuesGetter(argTyp) 368 default: 369 return nil 370 } 371 } 372 } 373 374 func requestGetter(ctx server.Context) (interface{}, error) { return ctx.Request(), nil } 375 func responseGetter(ctx server.Context) (interface{}, error) { return ctx.Response(), nil } 376 func contextGetter(ctx server.Context) (interface{}, error) { return ctx, nil } 377 func requestDataBind(typ reflect.Type, validate bool) func(server.Context) (interface{}, error) { 378 return func(ctx server.Context) (data interface{}, err error) { 379 outVal := reflect.New(typ) 380 if typ.Kind() != reflect.Ptr { 381 data = outVal.Interface() 382 err = ctx.Bind(data) 383 } else { 384 eval := outVal.Elem() 385 etype := typ.Elem() 386 for etype.Kind() == reflect.Ptr { 387 v := reflect.New(etype) 388 eval.Set(v) 389 eval = v.Elem() 390 etype = etype.Elem() 391 } 392 switch etype.Kind() { 393 case reflect.Map: 394 v := reflect.New(etype) 395 v.Elem().Set(reflect.MakeMap(etype)) 396 eval.Set(v) 397 case reflect.Slice: 398 v := reflect.New(etype) 399 v.Elem().Set(reflect.MakeSlice(etype, 0, 0)) 400 eval.Set(v) 401 default: 402 eval.Set(reflect.New(etype)) 403 } 404 data = eval.Interface() 405 err = ctx.Bind(data) 406 } 407 if err != nil { 408 return nil, err 409 } 410 if validate { 411 err = ctx.Validate(data) 412 if err != nil { 413 return nil, err 414 } 415 } 416 return outVal.Elem().Interface(), nil 417 } 418 } 419 func requestValuesGetter(typ reflect.Type) func(ctx server.Context) (interface{}, error) { 420 return func(ctx server.Context) (interface{}, error) { 421 out := reflect.New(typ) 422 byts, err := ioutil.ReadAll(ctx.Request().Body) 423 if err != nil { 424 return nil, fmt.Errorf("fail to read body: %s", err) 425 } 426 ctx.Request().Body = ioutil.NopCloser(bytes.NewBuffer(byts)) 427 err = json.Unmarshal(byts, out.Interface()) 428 if err != nil { 429 return nil, fmt.Errorf("fail to Unmarshal body: %s", err) 430 } 431 return out.Elem().Interface(), nil 432 } 433 } 434 func requestBodyBytesGetter(ctx server.Context) (interface{}, error) { 435 byts, err := ioutil.ReadAll(ctx.Request().Body) 436 if err != nil { 437 return nil, fmt.Errorf("fail to read body: %s", err) 438 } 439 ctx.Request().Body = ioutil.NopCloser(bytes.NewBuffer(byts)) 440 return byts, nil 441 } 442 443 func requestBodyStirngGetter(ctx server.Context) (interface{}, error) { 444 byts, err := ioutil.ReadAll(ctx.Request().Body) 445 if err != nil { 446 return "", fmt.Errorf("fail to read body: %s", err) 447 } 448 return reflectx.BytesToString(byts), nil 449 } 450 451 type structValidator struct { 452 validator *validator.Validate 453 } 454 455 // Validate . 456 func (v *structValidator) Validate(i interface{}) error { 457 return v.validator.Struct(i) 458 }