github.com/erda-project/erda-infra@v1.0.9/providers/httpserver/data_bind.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httpserver
    16  
    17  import (
    18  	"bytes"
    19  	"encoding"
    20  	"encoding/json"
    21  	"encoding/xml"
    22  	"errors"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"net/http"
    26  	"reflect"
    27  	"strconv"
    28  	"strings"
    29  
    30  	"github.com/labstack/echo"
    31  )
    32  
    33  type dataBinder struct{}
    34  
    35  // Bind implements the `Binder#Bind` function.
    36  func (b *dataBinder) Bind(i interface{}, c echo.Context) (err error) {
    37  	req := c.Request()
    38  	if req.ContentLength > 0 {
    39  		ctype := req.Header.Get(echo.HeaderContentType)
    40  		if len(ctype) <= 0 {
    41  			ctype = echo.MIMEApplicationJSON
    42  		}
    43  		body, err := ioutil.ReadAll(req.Body)
    44  		if err != nil {
    45  			return fmt.Errorf("fail to read body: %s", err)
    46  		}
    47  		req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
    48  		switch {
    49  		case strings.HasPrefix(ctype, echo.MIMEApplicationJSON):
    50  			if err = json.Unmarshal(body, i); err != nil {
    51  				if ute, ok := err.(*json.UnmarshalTypeError); ok {
    52  					return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
    53  				} else if se, ok := err.(*json.SyntaxError); ok {
    54  					return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
    55  				}
    56  				return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
    57  			}
    58  		case strings.HasPrefix(ctype, echo.MIMEApplicationXML), strings.HasPrefix(ctype, echo.MIMETextXML):
    59  			if err = xml.Unmarshal(body, i); err != nil {
    60  				if ute, ok := err.(*xml.UnsupportedTypeError); ok {
    61  					return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
    62  				} else if se, ok := err.(*xml.SyntaxError); ok {
    63  					return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
    64  				}
    65  				return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
    66  			}
    67  		case strings.HasPrefix(ctype, echo.MIMEApplicationForm), strings.HasPrefix(ctype, echo.MIMEMultipartForm):
    68  			params, err := c.FormParams()
    69  			if err != nil {
    70  				return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
    71  			}
    72  			if err = b.bindData(i, params, "form"); err != nil {
    73  				return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
    74  			}
    75  		default:
    76  			return echo.ErrUnsupportedMediaType
    77  		}
    78  	}
    79  	typ := reflect.TypeOf(i)
    80  	for typ.Kind() == reflect.Ptr {
    81  		typ = typ.Elem()
    82  	}
    83  	if typ.Kind() == reflect.Struct {
    84  		if err = b.bindData(i, c.QueryParams(), "query"); err != nil {
    85  			return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
    86  		}
    87  		params := map[string][]string{}
    88  		if ctx, ok := c.(*context); ok && len(ctx.vars) > 0 {
    89  			for k, v := range ctx.vars {
    90  				params[k] = []string{v}
    91  			}
    92  		} else {
    93  			names := c.ParamNames()
    94  			values := c.ParamValues()
    95  			for i, name := range names {
    96  				params[name] = []string{values[i]}
    97  			}
    98  		}
    99  		if err := b.bindData(i, params, "param"); err != nil {
   100  			return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
   101  		}
   102  	}
   103  	return
   104  }
   105  
   106  func (b *dataBinder) bindData(ptr interface{}, data map[string][]string, tag string) error {
   107  	if ptr == nil || len(data) == 0 {
   108  		return nil
   109  	}
   110  	typ := reflect.TypeOf(ptr).Elem()
   111  	val := reflect.ValueOf(ptr).Elem()
   112  
   113  	// Map
   114  	if typ.Kind() == reflect.Map {
   115  		for k, v := range data {
   116  			val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
   117  		}
   118  		return nil
   119  	}
   120  
   121  	// !struct
   122  	if typ.Kind() != reflect.Struct {
   123  		return errors.New("binding element must be a struct")
   124  	}
   125  
   126  	for i := 0; i < typ.NumField(); i++ {
   127  		typeField := typ.Field(i)
   128  		structField := val.Field(i)
   129  		if !structField.CanSet() {
   130  			continue
   131  		}
   132  		structFieldKind := structField.Kind()
   133  		inputFieldName := typeField.Tag.Get(tag)
   134  
   135  		if inputFieldName == "" {
   136  			inputFieldName = typeField.Name
   137  			// If tag is nil, we inspect if the field is a struct.
   138  			if _, ok := structField.Addr().Interface().(echo.BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
   139  				if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
   140  					return err
   141  				}
   142  				continue
   143  			}
   144  		}
   145  
   146  		inputValue, exists := data[inputFieldName]
   147  		if !exists {
   148  			// Go json.Unmarshal supports case insensitive binding.  However the
   149  			// url params are bound case sensitive which is inconsistent.  To
   150  			// fix this we must check all of the map values in a
   151  			// case-insensitive search.
   152  			for k, v := range data {
   153  				if strings.EqualFold(k, inputFieldName) {
   154  					inputValue = v
   155  					exists = true
   156  					break
   157  				}
   158  			}
   159  		}
   160  
   161  		if !exists {
   162  			continue
   163  		}
   164  
   165  		// Call this first, in case we're dealing with an alias to an array type
   166  		if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
   167  			if err != nil {
   168  				return err
   169  			}
   170  			continue
   171  		}
   172  
   173  		numElems := len(inputValue)
   174  		if structFieldKind == reflect.Slice && numElems > 0 {
   175  			sliceOf := structField.Type().Elem().Kind()
   176  			slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
   177  			for j := 0; j < numElems; j++ {
   178  				if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
   179  					return err
   180  				}
   181  			}
   182  			val.Field(i).Set(slice)
   183  		} else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
   184  			return err
   185  
   186  		}
   187  	}
   188  	return nil
   189  }
   190  
   191  func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
   192  	// But also call it here, in case we're dealing with an array of BindUnmarshalers
   193  	if ok, err := unmarshalField(valueKind, val, structField); ok {
   194  		return err
   195  	}
   196  
   197  	switch valueKind {
   198  	case reflect.Ptr:
   199  		return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
   200  	case reflect.Int:
   201  		return setIntField(val, 0, structField)
   202  	case reflect.Int8:
   203  		return setIntField(val, 8, structField)
   204  	case reflect.Int16:
   205  		return setIntField(val, 16, structField)
   206  	case reflect.Int32:
   207  		return setIntField(val, 32, structField)
   208  	case reflect.Int64:
   209  		return setIntField(val, 64, structField)
   210  	case reflect.Uint:
   211  		return setUintField(val, 0, structField)
   212  	case reflect.Uint8:
   213  		return setUintField(val, 8, structField)
   214  	case reflect.Uint16:
   215  		return setUintField(val, 16, structField)
   216  	case reflect.Uint32:
   217  		return setUintField(val, 32, structField)
   218  	case reflect.Uint64:
   219  		return setUintField(val, 64, structField)
   220  	case reflect.Bool:
   221  		return setBoolField(val, structField)
   222  	case reflect.Float32:
   223  		return setFloatField(val, 32, structField)
   224  	case reflect.Float64:
   225  		return setFloatField(val, 64, structField)
   226  	case reflect.String:
   227  		structField.SetString(val)
   228  	default:
   229  		return errors.New("unknown type")
   230  	}
   231  	return nil
   232  }
   233  
   234  func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
   235  	switch valueKind {
   236  	case reflect.Ptr:
   237  		return unmarshalFieldPtr(val, field)
   238  	default:
   239  		return unmarshalFieldNonPtr(val, field)
   240  	}
   241  }
   242  
   243  func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
   244  	fieldIValue := field.Addr().Interface()
   245  	if unmarshaler, ok := fieldIValue.(echo.BindUnmarshaler); ok {
   246  		return true, unmarshaler.UnmarshalParam(value)
   247  	}
   248  	if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok {
   249  		return true, unmarshaler.UnmarshalText([]byte(value))
   250  	}
   251  
   252  	return false, nil
   253  }
   254  
   255  func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
   256  	if field.IsNil() {
   257  		// Initialize the pointer to a nil value
   258  		field.Set(reflect.New(field.Type().Elem()))
   259  	}
   260  	return unmarshalFieldNonPtr(value, field.Elem())
   261  }
   262  
   263  func setIntField(value string, bitSize int, field reflect.Value) error {
   264  	if value == "" {
   265  		value = "0"
   266  	}
   267  	intVal, err := strconv.ParseInt(value, 10, bitSize)
   268  	if err == nil {
   269  		field.SetInt(intVal)
   270  	}
   271  	return err
   272  }
   273  
   274  func setUintField(value string, bitSize int, field reflect.Value) error {
   275  	if value == "" {
   276  		value = "0"
   277  	}
   278  	uintVal, err := strconv.ParseUint(value, 10, bitSize)
   279  	if err == nil {
   280  		field.SetUint(uintVal)
   281  	}
   282  	return err
   283  }
   284  
   285  func setBoolField(value string, field reflect.Value) error {
   286  	if value == "" {
   287  		value = "false"
   288  	}
   289  	boolVal, err := strconv.ParseBool(value)
   290  	if err == nil {
   291  		field.SetBool(boolVal)
   292  	}
   293  	return err
   294  }
   295  
   296  func setFloatField(value string, bitSize int, field reflect.Value) error {
   297  	if value == "" {
   298  		value = "0.0"
   299  	}
   300  	floatVal, err := strconv.ParseFloat(value, bitSize)
   301  	if err == nil {
   302  		field.SetFloat(floatVal)
   303  	}
   304  	return err
   305  }