github.com/boomhut/fiber/v2@v2.0.0-20230603160335-b65c856e57d3/internal/schema/decoder.go (about)

     1  // Copyright 2012 The Gorilla Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package schema
     6  
     7  import (
     8  	"encoding"
     9  	"errors"
    10  	"fmt"
    11  	"reflect"
    12  	"strings"
    13  )
    14  
    15  // NewDecoder returns a new Decoder.
    16  func NewDecoder() *Decoder {
    17  	return &Decoder{cache: newCache()}
    18  }
    19  
    20  // Decoder decodes values from a map[string][]string to a struct.
    21  type Decoder struct {
    22  	cache             *cache
    23  	zeroEmpty         bool
    24  	ignoreUnknownKeys bool
    25  }
    26  
    27  // SetAliasTag changes the tag used to locate custom field aliases.
    28  // The default tag is "schema".
    29  func (d *Decoder) SetAliasTag(tag string) {
    30  	d.cache.tag = tag
    31  }
    32  
    33  // ZeroEmpty controls the behaviour when the decoder encounters empty values
    34  // in a map.
    35  // If z is true and a key in the map has the empty string as a value
    36  // then the corresponding struct field is set to the zero value.
    37  // If z is false then empty strings are ignored.
    38  //
    39  // The default value is false, that is empty values do not change
    40  // the value of the struct field.
    41  func (d *Decoder) ZeroEmpty(z bool) {
    42  	d.zeroEmpty = z
    43  }
    44  
    45  // IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
    46  // keys in the map.
    47  // If i is true and an unknown field is encountered, it is ignored. This is
    48  // similar to how unknown keys are handled by encoding/json.
    49  // If i is false then Decode will return an error. Note that any valid keys
    50  // will still be decoded in to the target struct.
    51  //
    52  // To preserve backwards compatibility, the default value is false.
    53  func (d *Decoder) IgnoreUnknownKeys(i bool) {
    54  	d.ignoreUnknownKeys = i
    55  }
    56  
    57  // RegisterConverter registers a converter function for a custom type.
    58  func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
    59  	d.cache.registerConverter(value, converterFunc)
    60  }
    61  
    62  // Decode decodes a map[string][]string to a struct.
    63  //
    64  // The first parameter must be a pointer to a struct.
    65  //
    66  // The second parameter is a map, typically url.Values from an HTTP request.
    67  // Keys are "paths" in dotted notation to the struct fields and nested structs.
    68  //
    69  // See the package documentation for a full explanation of the mechanics.
    70  func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
    71  	v := reflect.ValueOf(dst)
    72  	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
    73  		return errors.New("schema: interface must be a pointer to struct")
    74  	}
    75  	v = v.Elem()
    76  	t := v.Type()
    77  	multiError := MultiError{}
    78  	for path, values := range src {
    79  		if parts, err := d.cache.parsePath(path, t); err == nil {
    80  			if err = d.decode(v, path, parts, values); err != nil {
    81  				multiError[path] = err
    82  			}
    83  		} else if !d.ignoreUnknownKeys {
    84  			multiError[path] = UnknownKeyError{Key: path}
    85  		}
    86  	}
    87  	multiError.merge(d.checkRequired(t, src))
    88  	if len(multiError) > 0 {
    89  		return multiError
    90  	}
    91  	return nil
    92  }
    93  
    94  // checkRequired checks whether required fields are empty
    95  //
    96  // check type t recursively if t has struct fields.
    97  //
    98  // src is the source map for decoding, we use it here to see if those required fields are included in src
    99  func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError {
   100  	m, errs := d.findRequiredFields(t, "", "")
   101  	for key, fields := range m {
   102  		if isEmptyFields(fields, src) {
   103  			errs[key] = EmptyFieldError{Key: key}
   104  		}
   105  	}
   106  	return errs
   107  }
   108  
   109  // findRequiredFields recursively searches the struct type t for required fields.
   110  //
   111  // canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation
   112  // for nested struct fields. canonicalPrefix is a complete path which never omits
   113  // any embedded struct fields. searchPrefix is a user-friendly path which may omit
   114  // some embedded struct fields to point promoted fields.
   115  func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) {
   116  	struc := d.cache.get(t)
   117  	if struc == nil {
   118  		// unexpect, cache.get never return nil
   119  		return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")}
   120  	}
   121  
   122  	m := map[string][]fieldWithPrefix{}
   123  	errs := MultiError{}
   124  	for _, f := range struc.fields {
   125  		if f.typ.Kind() == reflect.Struct {
   126  			fcprefix := canonicalPrefix + f.canonicalAlias + "."
   127  			for _, fspath := range f.paths(searchPrefix) {
   128  				fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".")
   129  				for key, fields := range fm {
   130  					m[key] = append(m[key], fields...)
   131  				}
   132  				errs.merge(ferrs)
   133  			}
   134  		}
   135  		if f.isRequired {
   136  			key := canonicalPrefix + f.canonicalAlias
   137  			m[key] = append(m[key], fieldWithPrefix{
   138  				fieldInfo: f,
   139  				prefix:    searchPrefix,
   140  			})
   141  		}
   142  	}
   143  	return m, errs
   144  }
   145  
   146  type fieldWithPrefix struct {
   147  	*fieldInfo
   148  	prefix string
   149  }
   150  
   151  // isEmptyFields returns true if all of specified fields are empty.
   152  func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool {
   153  	for _, f := range fields {
   154  		for _, path := range f.paths(f.prefix) {
   155  			v, ok := src[path]
   156  			if ok && !isEmpty(f.typ, v) {
   157  				return false
   158  			}
   159  			for key := range src {
   160  				// issue references:
   161  				// https://github.com/gofiber/fiber/issues/1414
   162  				// https://github.com/gorilla/schema/issues/176
   163  				nested := strings.IndexByte(key, '.') != -1
   164  
   165  				// for non required nested structs
   166  				c1 := strings.HasSuffix(f.prefix, ".") && key == path
   167  
   168  				// for required nested structs
   169  				c2 := f.prefix == "" && nested && strings.HasPrefix(key, path)
   170  
   171  				// for non nested fields
   172  				c3 := f.prefix == "" && !nested && key == path
   173  				if !isEmpty(f.typ, src[key]) && (c1 || c2 || c3) {
   174  					return false
   175  				}
   176  			}
   177  		}
   178  	}
   179  	return true
   180  }
   181  
   182  // isEmpty returns true if value is empty for specific type
   183  func isEmpty(t reflect.Type, value []string) bool {
   184  	if len(value) == 0 {
   185  		return true
   186  	}
   187  	switch t.Kind() {
   188  	case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type:
   189  		return len(value[0]) == 0
   190  	}
   191  	return false
   192  }
   193  
   194  // decode fills a struct field using a parsed path.
   195  func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error {
   196  	// Get the field walking the struct fields by index.
   197  	for _, name := range parts[0].path {
   198  		if v.Type().Kind() == reflect.Ptr {
   199  			if v.IsNil() {
   200  				v.Set(reflect.New(v.Type().Elem()))
   201  			}
   202  			v = v.Elem()
   203  		}
   204  
   205  		// alloc embedded structs
   206  		if v.Type().Kind() == reflect.Struct {
   207  			for i := 0; i < v.NumField(); i++ {
   208  				field := v.Field(i)
   209  				if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
   210  					field.Set(reflect.New(field.Type().Elem()))
   211  				}
   212  			}
   213  		}
   214  
   215  		v = v.FieldByName(name)
   216  	}
   217  	// Don't even bother for unexported fields.
   218  	if !v.CanSet() {
   219  		return nil
   220  	}
   221  
   222  	// Dereference if needed.
   223  	t := v.Type()
   224  	if t.Kind() == reflect.Ptr {
   225  		t = t.Elem()
   226  		if v.IsNil() {
   227  			v.Set(reflect.New(t))
   228  		}
   229  		v = v.Elem()
   230  	}
   231  
   232  	// Slice of structs. Let's go recursive.
   233  	if len(parts) > 1 {
   234  		idx := parts[0].index
   235  		if v.IsNil() || v.Len() < idx+1 {
   236  			value := reflect.MakeSlice(t, idx+1, idx+1)
   237  			if v.Len() < idx+1 {
   238  				// Resize it.
   239  				reflect.Copy(value, v)
   240  			}
   241  			v.Set(value)
   242  		}
   243  		return d.decode(v.Index(idx), path, parts[1:], values)
   244  	}
   245  
   246  	// Get the converter early in case there is one for a slice type.
   247  	conv := d.cache.converter(t)
   248  	m := isTextUnmarshaler(v)
   249  	if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
   250  		var items []reflect.Value
   251  		elemT := t.Elem()
   252  		isPtrElem := elemT.Kind() == reflect.Ptr
   253  		if isPtrElem {
   254  			elemT = elemT.Elem()
   255  		}
   256  
   257  		// Try to get a converter for the element type.
   258  		conv := d.cache.converter(elemT)
   259  		if conv == nil {
   260  			conv = builtinConverters[elemT.Kind()]
   261  			if conv == nil {
   262  				// As we are not dealing with slice of structs here, we don't need to check if the type
   263  				// implements TextUnmarshaler interface
   264  				return fmt.Errorf("schema: converter not found for %v", elemT)
   265  			}
   266  		}
   267  
   268  		for key, value := range values {
   269  			if value == "" {
   270  				if d.zeroEmpty {
   271  					items = append(items, reflect.Zero(elemT))
   272  				}
   273  			} else if m.IsValid {
   274  				u := reflect.New(elemT)
   275  				if m.IsSliceElementPtr {
   276  					u = reflect.New(reflect.PtrTo(elemT).Elem())
   277  				}
   278  				if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
   279  					return ConversionError{
   280  						Key:   path,
   281  						Type:  t,
   282  						Index: key,
   283  						Err:   err,
   284  					}
   285  				}
   286  				if m.IsSliceElementPtr {
   287  					items = append(items, u.Elem().Addr())
   288  				} else if u.Kind() == reflect.Ptr {
   289  					items = append(items, u.Elem())
   290  				} else {
   291  					items = append(items, u)
   292  				}
   293  			} else if item := conv(value); item.IsValid() {
   294  				if isPtrElem {
   295  					ptr := reflect.New(elemT)
   296  					ptr.Elem().Set(item)
   297  					item = ptr
   298  				}
   299  				if item.Type() != elemT && !isPtrElem {
   300  					item = item.Convert(elemT)
   301  				}
   302  				items = append(items, item)
   303  			} else {
   304  				if strings.Contains(value, ",") {
   305  					values := strings.Split(value, ",")
   306  					for _, value := range values {
   307  						if value == "" {
   308  							if d.zeroEmpty {
   309  								items = append(items, reflect.Zero(elemT))
   310  							}
   311  						} else if item := conv(value); item.IsValid() {
   312  							if isPtrElem {
   313  								ptr := reflect.New(elemT)
   314  								ptr.Elem().Set(item)
   315  								item = ptr
   316  							}
   317  							if item.Type() != elemT && !isPtrElem {
   318  								item = item.Convert(elemT)
   319  							}
   320  							items = append(items, item)
   321  						} else {
   322  							return ConversionError{
   323  								Key:   path,
   324  								Type:  elemT,
   325  								Index: key,
   326  							}
   327  						}
   328  					}
   329  				} else {
   330  					return ConversionError{
   331  						Key:   path,
   332  						Type:  elemT,
   333  						Index: key,
   334  					}
   335  				}
   336  			}
   337  		}
   338  		value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
   339  		v.Set(value)
   340  	} else {
   341  		val := ""
   342  		// Use the last value provided if any values were provided
   343  		if len(values) > 0 {
   344  			val = values[len(values)-1]
   345  		}
   346  
   347  		if conv != nil {
   348  			if value := conv(val); value.IsValid() {
   349  				v.Set(value.Convert(t))
   350  			} else {
   351  				return ConversionError{
   352  					Key:   path,
   353  					Type:  t,
   354  					Index: -1,
   355  				}
   356  			}
   357  		} else if m.IsValid {
   358  			if m.IsPtr {
   359  				u := reflect.New(v.Type())
   360  				if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
   361  					return ConversionError{
   362  						Key:   path,
   363  						Type:  t,
   364  						Index: -1,
   365  						Err:   err,
   366  					}
   367  				}
   368  				v.Set(reflect.Indirect(u))
   369  			} else {
   370  				// If the value implements the encoding.TextUnmarshaler interface
   371  				// apply UnmarshalText as the converter
   372  				if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
   373  					return ConversionError{
   374  						Key:   path,
   375  						Type:  t,
   376  						Index: -1,
   377  						Err:   err,
   378  					}
   379  				}
   380  			}
   381  		} else if val == "" {
   382  			if d.zeroEmpty {
   383  				v.Set(reflect.Zero(t))
   384  			}
   385  		} else if conv := builtinConverters[t.Kind()]; conv != nil {
   386  			if value := conv(val); value.IsValid() {
   387  				v.Set(value.Convert(t))
   388  			} else {
   389  				return ConversionError{
   390  					Key:   path,
   391  					Type:  t,
   392  					Index: -1,
   393  				}
   394  			}
   395  		} else {
   396  			return fmt.Errorf("schema: converter not found for %v", t)
   397  		}
   398  	}
   399  	return nil
   400  }
   401  
   402  func isTextUnmarshaler(v reflect.Value) unmarshaler {
   403  	// Create a new unmarshaller instance
   404  	m := unmarshaler{}
   405  	if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
   406  		return m
   407  	}
   408  	// As the UnmarshalText function should be applied to the pointer of the
   409  	// type, we check that type to see if it implements the necessary
   410  	// method.
   411  	if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
   412  		m.IsPtr = true
   413  		return m
   414  	}
   415  
   416  	// if v is []T or *[]T create new T
   417  	t := v.Type()
   418  	if t.Kind() == reflect.Ptr {
   419  		t = t.Elem()
   420  	}
   421  	if t.Kind() == reflect.Slice {
   422  		// Check if the slice implements encoding.TextUnmarshaller
   423  		if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
   424  			return m
   425  		}
   426  		// If t is a pointer slice, check if its elements implement
   427  		// encoding.TextUnmarshaler
   428  		m.IsSliceElement = true
   429  		if t = t.Elem(); t.Kind() == reflect.Ptr {
   430  			t = reflect.PtrTo(t.Elem())
   431  			v = reflect.Zero(t)
   432  			m.IsSliceElementPtr = true
   433  			m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
   434  			return m
   435  		}
   436  	}
   437  
   438  	v = reflect.New(t)
   439  	m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
   440  	return m
   441  }
   442  
   443  // TextUnmarshaler helpers ----------------------------------------------------
   444  // unmarshaller contains information about a TextUnmarshaler type
   445  type unmarshaler struct {
   446  	Unmarshaler encoding.TextUnmarshaler
   447  	// IsValid indicates whether the resolved type indicated by the other
   448  	// flags implements the encoding.TextUnmarshaler interface.
   449  	IsValid bool
   450  	// IsPtr indicates that the resolved type is the pointer of the original
   451  	// type.
   452  	IsPtr bool
   453  	// IsSliceElement indicates that the resolved type is a slice element of
   454  	// the original type.
   455  	IsSliceElement bool
   456  	// IsSliceElementPtr indicates that the resolved type is a pointer to a
   457  	// slice element of the original type.
   458  	IsSliceElementPtr bool
   459  }
   460  
   461  // Errors ---------------------------------------------------------------------
   462  
   463  // ConversionError stores information about a failed conversion.
   464  type ConversionError struct {
   465  	Key   string       // key from the source map.
   466  	Type  reflect.Type // expected type of elem
   467  	Index int          // index for multi-value fields; -1 for single-value fields.
   468  	Err   error        // low-level error (when it exists)
   469  }
   470  
   471  func (e ConversionError) Error() string {
   472  	var output string
   473  
   474  	if e.Index < 0 {
   475  		output = fmt.Sprintf("schema: error converting value for %q", e.Key)
   476  	} else {
   477  		output = fmt.Sprintf("schema: error converting value for index %d of %q",
   478  			e.Index, e.Key)
   479  	}
   480  
   481  	if e.Err != nil {
   482  		output = fmt.Sprintf("%s. Details: %s", output, e.Err)
   483  	}
   484  
   485  	return output
   486  }
   487  
   488  // UnknownKeyError stores information about an unknown key in the source map.
   489  type UnknownKeyError struct {
   490  	Key string // key from the source map.
   491  }
   492  
   493  func (e UnknownKeyError) Error() string {
   494  	return fmt.Sprintf("schema: invalid path %q", e.Key)
   495  }
   496  
   497  // EmptyFieldError stores information about an empty required field.
   498  type EmptyFieldError struct {
   499  	Key string // required key in the source map.
   500  }
   501  
   502  func (e EmptyFieldError) Error() string {
   503  	return fmt.Sprintf("%v is empty", e.Key)
   504  }
   505  
   506  // MultiError stores multiple decoding errors.
   507  //
   508  // Borrowed from the App Engine SDK.
   509  type MultiError map[string]error
   510  
   511  func (e MultiError) Error() string {
   512  	s := ""
   513  	for _, err := range e {
   514  		s = err.Error()
   515  		break
   516  	}
   517  	switch len(e) {
   518  	case 0:
   519  		return "(0 errors)"
   520  	case 1:
   521  		return s
   522  	case 2:
   523  		return s + " (and 1 other error)"
   524  	}
   525  	return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
   526  }
   527  
   528  func (e MultiError) merge(errors MultiError) {
   529  	for key, err := range errors {
   530  		if e[key] == nil {
   531  			e[key] = err
   532  		}
   533  	}
   534  }