github.com/erda-project/erda-infra@v1.0.9/providers/httpserver/handler.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httpserver
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/json"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"net/http"
    24  	"reflect"
    25  	"strconv"
    26  
    27  	"github.com/erda-project/erda-infra/providers/httpserver/server"
    28  	"github.com/go-playground/validator"
    29  	"github.com/labstack/echo"
    30  	"github.com/recallsong/go-utils/errorx"
    31  	"github.com/recallsong/go-utils/reflectx"
    32  )
    33  
    34  type (
    35  	// Response .
    36  	Response interface {
    37  		Status(Context) int
    38  		ReadCloser(Context) io.ReadCloser
    39  		Error(Context) error
    40  	}
    41  	// ResponseGetter .
    42  	ResponseGetter interface {
    43  		Response(ctx Context) Response
    44  	}
    45  	// Interceptor .
    46  	Interceptor func(handler func(ctx Context) error) func(ctx Context) error
    47  )
    48  
    49  func getInterceptors(options []interface{}) []server.MiddlewareFunc {
    50  	var list []server.MiddlewareFunc
    51  	for _, opt := range options {
    52  		var inter Interceptor
    53  		switch val := opt.(type) {
    54  		case Interceptor:
    55  			inter = val
    56  		case func(handler func(ctx Context) error) func(ctx Context) error:
    57  			inter = Interceptor(val)
    58  		case server.MiddlewareFunc:
    59  			list = append(list, val)
    60  		case func(server.HandlerFunc) server.HandlerFunc:
    61  			list = append(list, val)
    62  		default:
    63  			continue
    64  		}
    65  		if inter != nil {
    66  			list = append(list, func(fn server.HandlerFunc) server.HandlerFunc {
    67  				handler := inter(func(ctx Context) error {
    68  					return fn(ctx.(*context))
    69  				})
    70  				return func(ctx server.Context) error {
    71  					return handler(ctx.(*context))
    72  				}
    73  			})
    74  		}
    75  	}
    76  	return list
    77  }
    78  
    79  func (r *router) add(method, path string, handler interface{}, inters []server.MiddlewareFunc, outer server.MiddlewareFunc) server.HandlerFunc {
    80  	var echoHandler server.HandlerFunc
    81  	switch fn := handler.(type) {
    82  	case server.HandlerFunc:
    83  		echoHandler = fn
    84  	case func(server.Context) error:
    85  		echoHandler = server.HandlerFunc(fn)
    86  	case func(server.Context):
    87  		echoHandler = server.HandlerFunc(func(ctx server.Context) error {
    88  			fn(ctx)
    89  			return nil
    90  		})
    91  	case http.HandlerFunc:
    92  		echoHandler = server.HandlerFunc(func(ctx server.Context) error {
    93  			fn(ctx.Response(), ctx.Request())
    94  			return nil
    95  		})
    96  	case func(http.ResponseWriter, *http.Request):
    97  		echoHandler = server.HandlerFunc(func(ctx server.Context) error {
    98  			fn(ctx.Response(), ctx.Request())
    99  			return nil
   100  		})
   101  	case func(*http.Request, http.ResponseWriter):
   102  		echoHandler = server.HandlerFunc(func(ctx server.Context) error {
   103  			fn(ctx.Request(), ctx.Response())
   104  			return nil
   105  		})
   106  	case http.Handler:
   107  		echoHandler = server.HandlerFunc(func(ctx server.Context) error {
   108  			fn.ServeHTTP(ctx.Response(), ctx.Request())
   109  			return nil
   110  		})
   111  	default:
   112  		echoHandler = r.handlerWrap(handler)
   113  		if echoHandler == nil {
   114  			panic(fmt.Errorf("%s %s: not support http server handler type: %v", method, path, handler))
   115  		}
   116  	}
   117  	if outer != nil {
   118  		list := make([]server.MiddlewareFunc, 1+len(r.interceptors)+len(inters))
   119  		list[0] = outer
   120  		copy(list[1:], r.interceptors)
   121  		copy(list[1+len(r.interceptors):], inters)
   122  		inters = list
   123  	} else {
   124  		inters = append(r.interceptors[0:len(r.interceptors):len(r.interceptors)], inters...)
   125  	}
   126  	if len(inters) > 0 {
   127  		handler := echoHandler
   128  		for i := len(inters) - 1; i >= 0; i-- {
   129  			handler = inters[i](handler)
   130  		}
   131  		echoHandler = handler
   132  	}
   133  	r.tx.Add(method, path, echoHandler)
   134  	return echoHandler
   135  }
   136  
   137  var (
   138  	readerType      = reflect.TypeOf((*io.Reader)(nil)).Elem()
   139  	readCloserType  = reflect.TypeOf((*io.ReadCloser)(nil)).Elem()
   140  	errorType       = reflect.TypeOf((*error)(nil)).Elem()
   141  	requestType     = reflect.TypeOf((*http.Request)(nil))
   142  	responseType    = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem()
   143  	echoContextType = reflect.TypeOf((*server.Context)(nil)).Elem()
   144  	contextType     = reflect.TypeOf((*Context)(nil)).Elem()
   145  	interfaceType   = reflect.TypeOf((*interface{})(nil)).Elem()
   146  )
   147  
   148  func (r *router) handlerWrap(handler interface{}) server.HandlerFunc {
   149  	typ := reflect.TypeOf(handler)
   150  	if typ.Kind() == reflect.Func {
   151  		val := reflect.ValueOf(handler)
   152  		var argGets []func(ctx server.Context) (interface{}, error)
   153  		argNum := typ.NumIn()
   154  		for i := 0; i < argNum; i++ {
   155  			argTyp := typ.In(i)
   156  			getter := argGetter(argTyp)
   157  			if getter == nil {
   158  				return nil
   159  			}
   160  			argGets = append(argGets, getter)
   161  		}
   162  		retNum := typ.NumOut()
   163  		if retNum > 3 {
   164  			return nil
   165  		}
   166  		var retGet func(values []reflect.Value) (*int, io.ReadCloser, io.Reader, interface{}, error)
   167  		var retIndex [5]*int
   168  		var hasRet bool
   169  		for i := 0; i < retNum; i++ {
   170  			retTyp := typ.Out(i)
   171  			index := i
   172  			if retTyp.Kind() == reflect.Int {
   173  				if retIndex[0] == nil {
   174  					retIndex[0] = &index
   175  					hasRet = true
   176  					continue
   177  				}
   178  			} else if retTyp.AssignableTo(readCloserType) {
   179  				if retIndex[1] == nil {
   180  					retIndex[1] = &index
   181  					hasRet = true
   182  					continue
   183  				}
   184  			} else if retTyp.AssignableTo(readerType) {
   185  				if retIndex[2] == nil {
   186  					retIndex[2] = &index
   187  					hasRet = true
   188  					continue
   189  				}
   190  			} else if retTyp == errorType {
   191  				if retIndex[3] == nil {
   192  					retIndex[3] = &index
   193  					hasRet = true
   194  					continue
   195  				}
   196  			} else if retTyp == interfaceType {
   197  				if retIndex[4] == nil {
   198  					retIndex[4] = &index
   199  					hasRet = true
   200  					continue
   201  				}
   202  			}
   203  			return nil
   204  		}
   205  		if hasRet {
   206  			retGet = func(values []reflect.Value) (status *int, readerCloser io.ReadCloser, reader io.Reader, data interface{}, err error) {
   207  				if retIndex[0] != nil {
   208  					val := int(values[*retIndex[0]].Int())
   209  					status = &val
   210  				}
   211  				if retIndex[1] != nil {
   212  					val := values[*retIndex[1]].Interface()
   213  					readerCloser = val.(io.ReadCloser)
   214  				}
   215  				if retIndex[2] != nil {
   216  					val := values[*retIndex[2]].Interface()
   217  					reader = val.(io.Reader)
   218  				}
   219  				if retIndex[3] != nil {
   220  					val := values[*retIndex[3]].Interface()
   221  					if val != nil {
   222  						err = val.(error)
   223  					}
   224  				}
   225  				if retIndex[4] != nil {
   226  					data = values[*retIndex[4]].Interface()
   227  				}
   228  				return
   229  			}
   230  		}
   231  		return server.HandlerFunc(func(ctx server.Context) error {
   232  			var values []reflect.Value
   233  			for _, getter := range argGets {
   234  				val, err := getter(ctx)
   235  				if err != nil {
   236  					if _, ok := err.(validator.ValidationErrors); ok {
   237  						//TODO: custom error encode
   238  						return ctx.JSON(400, map[string]interface{}{
   239  							"success": false,
   240  							"err": map[string]interface{}{
   241  								"code": "400",
   242  								"msg":  err.Error(),
   243  							},
   244  						})
   245  					}
   246  					if herr, ok := err.(*echo.HTTPError); ok {
   247  						if http.StatusBadRequest <= herr.Code && herr.Code < http.StatusInternalServerError {
   248  							//TODO: custom error encode
   249  							ctx.JSON(400, map[string]interface{}{
   250  								"success": false,
   251  								"err": map[string]interface{}{
   252  									"code": strconv.Itoa(herr.Code),
   253  									"msg":  herr.Message,
   254  								},
   255  							})
   256  						}
   257  					}
   258  					return err
   259  				}
   260  				value := reflect.ValueOf(val)
   261  				values = append(values, value)
   262  			}
   263  			returns := val.Call(values)
   264  			if retGet == nil {
   265  				return nil
   266  			}
   267  			status, readCloser, reader, data, err := retGet(returns)
   268  			if data != nil {
   269  				var resp Response
   270  				context := ctx.(Context)
   271  				switch val := data.(type) {
   272  				case ResponseGetter:
   273  					resp = val.Response(context)
   274  				case Response:
   275  					resp = val
   276  				}
   277  				if resp != nil {
   278  					rc := resp.ReadCloser(context)
   279  					if rc != nil {
   280  						readCloser = rc
   281  					}
   282  					statusCode := resp.Status(context)
   283  					if statusCode > 0 {
   284  						status = &statusCode
   285  					}
   286  					e := resp.Error(context)
   287  					if e != nil {
   288  						err = e
   289  					}
   290  				}
   291  			}
   292  			if status != nil {
   293  				ctx.Response().WriteHeader(*status)
   294  			}
   295  			var errs errorx.Errors
   296  			if err != nil {
   297  				errs = append(errs, err)
   298  			}
   299  			if readCloser != nil {
   300  				defer readCloser.Close()
   301  				_, err = io.Copy(ctx.Response(), readCloser)
   302  				if err != nil {
   303  					errs = append(errs, err)
   304  				}
   305  			} else if reader != nil {
   306  				_, err = io.Copy(ctx.Response(), reader)
   307  				if err != nil {
   308  					errs = append(errs, err)
   309  				}
   310  			} else if data != nil {
   311  				switch val := data.(type) {
   312  				case string:
   313  					_, err = ctx.Response().Write(reflectx.StringToBytes(val))
   314  				case []byte:
   315  					_, err = ctx.Response().Write(val)
   316  				default:
   317  					err = json.NewEncoder(ctx.Response()).Encode(data)
   318  				}
   319  				if err != nil {
   320  					errs = append(errs, err)
   321  				}
   322  			}
   323  			return errs.MaybeUnwrap()
   324  		})
   325  	}
   326  	return nil
   327  }
   328  
   329  func argGetter(argTyp reflect.Type) func(ctx server.Context) (interface{}, error) {
   330  	if argTyp == requestType {
   331  		return requestGetter
   332  	} else if argTyp == responseType {
   333  		return responseGetter
   334  	} else if argTyp == contextType || argTyp == echoContextType {
   335  		return contextGetter
   336  	} else {
   337  		kind := argTyp.Kind()
   338  		if kind == reflect.String {
   339  			return requestBodyStirngGetter
   340  		} else if kind == reflect.Slice && argTyp.Elem().Kind() == reflect.Uint8 {
   341  			return requestBodyBytesGetter
   342  		}
   343  		typ := argTyp
   344  		for kind == reflect.Ptr {
   345  			typ = typ.Elem()
   346  			kind = typ.Kind()
   347  		}
   348  		switch kind {
   349  		case reflect.Struct:
   350  			var validate bool
   351  			for i, num := 0, typ.NumField(); i < num; i++ {
   352  				if len(typ.Field(i).Tag.Get("validate")) > 0 {
   353  					validate = true
   354  					break
   355  				}
   356  			}
   357  			return requestDataBind(argTyp, validate)
   358  		case reflect.Map, reflect.Interface:
   359  			return requestDataBind(argTyp, false)
   360  		case reflect.String:
   361  			return requestBodyStirngGetter
   362  		case reflect.Bool,
   363  			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   364  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   365  			reflect.Float32, reflect.Float64,
   366  			reflect.Array, reflect.Slice:
   367  			return requestValuesGetter(argTyp)
   368  		default:
   369  			return nil
   370  		}
   371  	}
   372  }
   373  
   374  func requestGetter(ctx server.Context) (interface{}, error)  { return ctx.Request(), nil }
   375  func responseGetter(ctx server.Context) (interface{}, error) { return ctx.Response(), nil }
   376  func contextGetter(ctx server.Context) (interface{}, error)  { return ctx, nil }
   377  func requestDataBind(typ reflect.Type, validate bool) func(server.Context) (interface{}, error) {
   378  	return func(ctx server.Context) (data interface{}, err error) {
   379  		outVal := reflect.New(typ)
   380  		if typ.Kind() != reflect.Ptr {
   381  			data = outVal.Interface()
   382  			err = ctx.Bind(data)
   383  		} else {
   384  			eval := outVal.Elem()
   385  			etype := typ.Elem()
   386  			for etype.Kind() == reflect.Ptr {
   387  				v := reflect.New(etype)
   388  				eval.Set(v)
   389  				eval = v.Elem()
   390  				etype = etype.Elem()
   391  			}
   392  			switch etype.Kind() {
   393  			case reflect.Map:
   394  				v := reflect.New(etype)
   395  				v.Elem().Set(reflect.MakeMap(etype))
   396  				eval.Set(v)
   397  			case reflect.Slice:
   398  				v := reflect.New(etype)
   399  				v.Elem().Set(reflect.MakeSlice(etype, 0, 0))
   400  				eval.Set(v)
   401  			default:
   402  				eval.Set(reflect.New(etype))
   403  			}
   404  			data = eval.Interface()
   405  			err = ctx.Bind(data)
   406  		}
   407  		if err != nil {
   408  			return nil, err
   409  		}
   410  		if validate {
   411  			err = ctx.Validate(data)
   412  			if err != nil {
   413  				return nil, err
   414  			}
   415  		}
   416  		return outVal.Elem().Interface(), nil
   417  	}
   418  }
   419  func requestValuesGetter(typ reflect.Type) func(ctx server.Context) (interface{}, error) {
   420  	return func(ctx server.Context) (interface{}, error) {
   421  		out := reflect.New(typ)
   422  		byts, err := ioutil.ReadAll(ctx.Request().Body)
   423  		if err != nil {
   424  			return nil, fmt.Errorf("fail to read body: %s", err)
   425  		}
   426  		ctx.Request().Body = ioutil.NopCloser(bytes.NewBuffer(byts))
   427  		err = json.Unmarshal(byts, out.Interface())
   428  		if err != nil {
   429  			return nil, fmt.Errorf("fail to Unmarshal body: %s", err)
   430  		}
   431  		return out.Elem().Interface(), nil
   432  	}
   433  }
   434  func requestBodyBytesGetter(ctx server.Context) (interface{}, error) {
   435  	byts, err := ioutil.ReadAll(ctx.Request().Body)
   436  	if err != nil {
   437  		return nil, fmt.Errorf("fail to read body: %s", err)
   438  	}
   439  	ctx.Request().Body = ioutil.NopCloser(bytes.NewBuffer(byts))
   440  	return byts, nil
   441  }
   442  
   443  func requestBodyStirngGetter(ctx server.Context) (interface{}, error) {
   444  	byts, err := ioutil.ReadAll(ctx.Request().Body)
   445  	if err != nil {
   446  		return "", fmt.Errorf("fail to read body: %s", err)
   447  	}
   448  	return reflectx.BytesToString(byts), nil
   449  }
   450  
   451  type structValidator struct {
   452  	validator *validator.Validate
   453  }
   454  
   455  // Validate .
   456  func (v *structValidator) Validate(i interface{}) error {
   457  	return v.validator.Struct(i)
   458  }