git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/schema/decoder.go (about)

     1  // Copyright 2012 The Gorilla Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package schema
     6  
     7  import (
     8  	"encoding"
     9  	"errors"
    10  	"fmt"
    11  	"reflect"
    12  	"strings"
    13  )
    14  
    15  // NewDecoder returns a new Decoder.
    16  func NewDecoder() *Decoder {
    17  	return &Decoder{cache: newCache()}
    18  }
    19  
    20  // Decoder decodes values from a map[string][]string to a struct.
    21  type Decoder struct {
    22  	cache             *cache
    23  	zeroEmpty         bool
    24  	ignoreUnknownKeys bool
    25  }
    26  
    27  // SetAliasTag changes the tag used to locate custom field aliases.
    28  // The default tag is "schema".
    29  func (d *Decoder) SetAliasTag(tag string) {
    30  	d.cache.tag = tag
    31  }
    32  
    33  // ZeroEmpty controls the behaviour when the decoder encounters empty values
    34  // in a map.
    35  // If z is true and a key in the map has the empty string as a value
    36  // then the corresponding struct field is set to the zero value.
    37  // If z is false then empty strings are ignored.
    38  //
    39  // The default value is false, that is empty values do not change
    40  // the value of the struct field.
    41  func (d *Decoder) ZeroEmpty(z bool) {
    42  	d.zeroEmpty = z
    43  }
    44  
    45  // IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown
    46  // keys in the map.
    47  // If i is true and an unknown field is encountered, it is ignored. This is
    48  // similar to how unknown keys are handled by encoding/json.
    49  // If i is false then Decode will return an error. Note that any valid keys
    50  // will still be decoded in to the target struct.
    51  //
    52  // To preserve backwards compatibility, the default value is false.
    53  func (d *Decoder) IgnoreUnknownKeys(i bool) {
    54  	d.ignoreUnknownKeys = i
    55  }
    56  
    57  // RegisterConverter registers a converter function for a custom type.
    58  func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) {
    59  	d.cache.registerConverter(value, converterFunc)
    60  }
    61  
    62  // Decode decodes a map[string][]string to a struct.
    63  //
    64  // The first parameter must be a pointer to a struct.
    65  //
    66  // The second parameter is a map, typically url.Values from an HTTP request.
    67  // Keys are "paths" in dotted notation to the struct fields and nested structs.
    68  //
    69  // See the package documentation for a full explanation of the mechanics.
    70  func (d *Decoder) Decode(dst interface{}, src map[string][]string) error {
    71  	v := reflect.ValueOf(dst)
    72  	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
    73  		return errors.New("schema: interface must be a pointer to struct")
    74  	}
    75  	v = v.Elem()
    76  	t := v.Type()
    77  	errors := MultiError{}
    78  	for path, values := range src {
    79  		if parts, err := d.cache.parsePath(path, t); err == nil {
    80  			if err = d.decode(v, path, parts, values); err != nil {
    81  				errors[path] = err
    82  			}
    83  		} else if !d.ignoreUnknownKeys {
    84  			errors[path] = UnknownKeyError{Key: path}
    85  		}
    86  	}
    87  	errors.merge(d.checkRequired(t, src))
    88  	if len(errors) > 0 {
    89  		return errors
    90  	}
    91  	return nil
    92  }
    93  
    94  // checkRequired checks whether required fields are empty
    95  //
    96  // check type t recursively if t has struct fields.
    97  //
    98  // src is the source map for decoding, we use it here to see if those required fields are included in src
    99  func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError {
   100  	m, errs := d.findRequiredFields(t, "", "")
   101  	for key, fields := range m {
   102  		if isEmptyFields(fields, src) {
   103  			errs[key] = EmptyFieldError{Key: key}
   104  		}
   105  	}
   106  	return errs
   107  }
   108  
   109  // findRequiredFields recursively searches the struct type t for required fields.
   110  //
   111  // canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation
   112  // for nested struct fields. canonicalPrefix is a complete path which never omits
   113  // any embedded struct fields. searchPrefix is a user-friendly path which may omit
   114  // some embedded struct fields to point promoted fields.
   115  func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) {
   116  	struc := d.cache.get(t)
   117  	if struc == nil {
   118  		// unexpect, cache.get never return nil
   119  		return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")}
   120  	}
   121  
   122  	m := map[string][]fieldWithPrefix{}
   123  	errs := MultiError{}
   124  	for _, f := range struc.fields {
   125  		if f.typ.Kind() == reflect.Struct {
   126  			fcprefix := canonicalPrefix + f.canonicalAlias + "."
   127  			for _, fspath := range f.paths(searchPrefix) {
   128  				fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".")
   129  				for key, fields := range fm {
   130  					m[key] = append(m[key], fields...)
   131  				}
   132  				errs.merge(ferrs)
   133  			}
   134  		}
   135  		if f.isRequired {
   136  			key := canonicalPrefix + f.canonicalAlias
   137  			m[key] = append(m[key], fieldWithPrefix{
   138  				fieldInfo: f,
   139  				prefix:    searchPrefix,
   140  			})
   141  		}
   142  	}
   143  	return m, errs
   144  }
   145  
   146  type fieldWithPrefix struct {
   147  	*fieldInfo
   148  	prefix string
   149  }
   150  
   151  // isEmptyFields returns true if all of specified fields are empty.
   152  func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool {
   153  	for _, f := range fields {
   154  		for _, path := range f.paths(f.prefix) {
   155  			v, ok := src[path]
   156  			if ok && !isEmpty(f.typ, v) {
   157  				return false
   158  			}
   159  			for key := range src {
   160  				if !isEmpty(f.typ, src[key]) && strings.HasPrefix(key, path) {
   161  					return false
   162  				}
   163  			}
   164  		}
   165  	}
   166  	return true
   167  }
   168  
   169  // isEmpty returns true if value is empty for specific type
   170  func isEmpty(t reflect.Type, value []string) bool {
   171  	if len(value) == 0 {
   172  		return true
   173  	}
   174  	switch t.Kind() {
   175  	case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type:
   176  		return len(value[0]) == 0
   177  	}
   178  	return false
   179  }
   180  
   181  // decode fills a struct field using a parsed path.
   182  func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error {
   183  	// Get the field walking the struct fields by index.
   184  	for _, name := range parts[0].path {
   185  		if v.Type().Kind() == reflect.Ptr {
   186  			if v.IsNil() {
   187  				v.Set(reflect.New(v.Type().Elem()))
   188  			}
   189  			v = v.Elem()
   190  		}
   191  
   192  		// alloc embedded structs
   193  		if v.Type().Kind() == reflect.Struct {
   194  			for i := 0; i < v.NumField(); i++ {
   195  				field := v.Field(i)
   196  				if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous {
   197  					field.Set(reflect.New(field.Type().Elem()))
   198  				}
   199  			}
   200  		}
   201  
   202  		v = v.FieldByName(name)
   203  	}
   204  	// Don't even bother for unexported fields.
   205  	if !v.CanSet() {
   206  		return nil
   207  	}
   208  
   209  	// Dereference if needed.
   210  	t := v.Type()
   211  	if t.Kind() == reflect.Ptr {
   212  		t = t.Elem()
   213  		if v.IsNil() {
   214  			v.Set(reflect.New(t))
   215  		}
   216  		v = v.Elem()
   217  	}
   218  
   219  	// Slice of structs. Let's go recursive.
   220  	if len(parts) > 1 {
   221  		idx := parts[0].index
   222  		if v.IsNil() || v.Len() < idx+1 {
   223  			value := reflect.MakeSlice(t, idx+1, idx+1)
   224  			if v.Len() < idx+1 {
   225  				// Resize it.
   226  				reflect.Copy(value, v)
   227  			}
   228  			v.Set(value)
   229  		}
   230  		return d.decode(v.Index(idx), path, parts[1:], values)
   231  	}
   232  
   233  	// Get the converter early in case there is one for a slice type.
   234  	conv := d.cache.converter(t)
   235  	m := isTextUnmarshaler(v)
   236  	if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement {
   237  		var items []reflect.Value
   238  		elemT := t.Elem()
   239  		isPtrElem := elemT.Kind() == reflect.Ptr
   240  		if isPtrElem {
   241  			elemT = elemT.Elem()
   242  		}
   243  
   244  		// Try to get a converter for the element type.
   245  		conv := d.cache.converter(elemT)
   246  		if conv == nil {
   247  			conv = builtinConverters[elemT.Kind()]
   248  			if conv == nil {
   249  				// As we are not dealing with slice of structs here, we don't need to check if the type
   250  				// implements TextUnmarshaler interface
   251  				return fmt.Errorf("schema: converter not found for %v", elemT)
   252  			}
   253  		}
   254  
   255  		for key, value := range values {
   256  			if value == "" {
   257  				if d.zeroEmpty {
   258  					items = append(items, reflect.Zero(elemT))
   259  				}
   260  			} else if m.IsValid {
   261  				u := reflect.New(elemT)
   262  				if m.IsSliceElementPtr {
   263  					u = reflect.New(reflect.PtrTo(elemT).Elem())
   264  				}
   265  				if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil {
   266  					return ConversionError{
   267  						Key:   path,
   268  						Type:  t,
   269  						Index: key,
   270  						Err:   err,
   271  					}
   272  				}
   273  				if m.IsSliceElementPtr {
   274  					items = append(items, u.Elem().Addr())
   275  				} else if u.Kind() == reflect.Ptr {
   276  					items = append(items, u.Elem())
   277  				} else {
   278  					items = append(items, u)
   279  				}
   280  			} else if item := conv(value); item.IsValid() {
   281  				if isPtrElem {
   282  					ptr := reflect.New(elemT)
   283  					ptr.Elem().Set(item)
   284  					item = ptr
   285  				}
   286  				if item.Type() != elemT && !isPtrElem {
   287  					item = item.Convert(elemT)
   288  				}
   289  				items = append(items, item)
   290  			} else {
   291  				if strings.Contains(value, ",") {
   292  					values := strings.Split(value, ",")
   293  					for _, value := range values {
   294  						if value == "" {
   295  							if d.zeroEmpty {
   296  								items = append(items, reflect.Zero(elemT))
   297  							}
   298  						} else if item := conv(value); item.IsValid() {
   299  							if isPtrElem {
   300  								ptr := reflect.New(elemT)
   301  								ptr.Elem().Set(item)
   302  								item = ptr
   303  							}
   304  							if item.Type() != elemT && !isPtrElem {
   305  								item = item.Convert(elemT)
   306  							}
   307  							items = append(items, item)
   308  						} else {
   309  							return ConversionError{
   310  								Key:   path,
   311  								Type:  elemT,
   312  								Index: key,
   313  							}
   314  						}
   315  					}
   316  				} else {
   317  					return ConversionError{
   318  						Key:   path,
   319  						Type:  elemT,
   320  						Index: key,
   321  					}
   322  				}
   323  			}
   324  		}
   325  		value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...)
   326  		v.Set(value)
   327  	} else {
   328  		val := ""
   329  		// Use the last value provided if any values were provided
   330  		if len(values) > 0 {
   331  			val = values[len(values)-1]
   332  		}
   333  
   334  		if conv != nil {
   335  			if value := conv(val); value.IsValid() {
   336  				v.Set(value.Convert(t))
   337  			} else {
   338  				return ConversionError{
   339  					Key:   path,
   340  					Type:  t,
   341  					Index: -1,
   342  				}
   343  			}
   344  		} else if m.IsValid {
   345  			if m.IsPtr {
   346  				u := reflect.New(v.Type())
   347  				if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil {
   348  					return ConversionError{
   349  						Key:   path,
   350  						Type:  t,
   351  						Index: -1,
   352  						Err:   err,
   353  					}
   354  				}
   355  				v.Set(reflect.Indirect(u))
   356  			} else {
   357  				// If the value implements the encoding.TextUnmarshaler interface
   358  				// apply UnmarshalText as the converter
   359  				if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil {
   360  					return ConversionError{
   361  						Key:   path,
   362  						Type:  t,
   363  						Index: -1,
   364  						Err:   err,
   365  					}
   366  				}
   367  			}
   368  		} else if val == "" {
   369  			if d.zeroEmpty {
   370  				v.Set(reflect.Zero(t))
   371  			}
   372  		} else if conv := builtinConverters[t.Kind()]; conv != nil {
   373  			if value := conv(val); value.IsValid() {
   374  				v.Set(value.Convert(t))
   375  			} else {
   376  				return ConversionError{
   377  					Key:   path,
   378  					Type:  t,
   379  					Index: -1,
   380  				}
   381  			}
   382  		} else {
   383  			return fmt.Errorf("schema: converter not found for %v", t)
   384  		}
   385  	}
   386  	return nil
   387  }
   388  
   389  func isTextUnmarshaler(v reflect.Value) unmarshaler {
   390  	// Create a new unmarshaller instance
   391  	m := unmarshaler{}
   392  	if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
   393  		return m
   394  	}
   395  	// As the UnmarshalText function should be applied to the pointer of the
   396  	// type, we check that type to see if it implements the necessary
   397  	// method.
   398  	if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid {
   399  		m.IsPtr = true
   400  		return m
   401  	}
   402  
   403  	// if v is []T or *[]T create new T
   404  	t := v.Type()
   405  	if t.Kind() == reflect.Ptr {
   406  		t = t.Elem()
   407  	}
   408  	if t.Kind() == reflect.Slice {
   409  		// Check if the slice implements encoding.TextUnmarshaller
   410  		if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid {
   411  			return m
   412  		}
   413  		// If t is a pointer slice, check if its elements implement
   414  		// encoding.TextUnmarshaler
   415  		m.IsSliceElement = true
   416  		if t = t.Elem(); t.Kind() == reflect.Ptr {
   417  			t = reflect.PtrTo(t.Elem())
   418  			v = reflect.Zero(t)
   419  			m.IsSliceElementPtr = true
   420  			m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
   421  			return m
   422  		}
   423  	}
   424  
   425  	v = reflect.New(t)
   426  	m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler)
   427  	return m
   428  }
   429  
   430  // TextUnmarshaler helpers ----------------------------------------------------
   431  // unmarshaller contains information about a TextUnmarshaler type
   432  type unmarshaler struct {
   433  	Unmarshaler encoding.TextUnmarshaler
   434  	// IsValid indicates whether the resolved type indicated by the other
   435  	// flags implements the encoding.TextUnmarshaler interface.
   436  	IsValid bool
   437  	// IsPtr indicates that the resolved type is the pointer of the original
   438  	// type.
   439  	IsPtr bool
   440  	// IsSliceElement indicates that the resolved type is a slice element of
   441  	// the original type.
   442  	IsSliceElement bool
   443  	// IsSliceElementPtr indicates that the resolved type is a pointer to a
   444  	// slice element of the original type.
   445  	IsSliceElementPtr bool
   446  }
   447  
   448  // Errors ---------------------------------------------------------------------
   449  
   450  // ConversionError stores information about a failed conversion.
   451  type ConversionError struct {
   452  	Key   string       // key from the source map.
   453  	Type  reflect.Type // expected type of elem
   454  	Index int          // index for multi-value fields; -1 for single-value fields.
   455  	Err   error        // low-level error (when it exists)
   456  }
   457  
   458  func (e ConversionError) Error() string {
   459  	var output string
   460  
   461  	if e.Index < 0 {
   462  		output = fmt.Sprintf("schema: error converting value for %q", e.Key)
   463  	} else {
   464  		output = fmt.Sprintf("schema: error converting value for index %d of %q",
   465  			e.Index, e.Key)
   466  	}
   467  
   468  	if e.Err != nil {
   469  		output = fmt.Sprintf("%s. Details: %s", output, e.Err)
   470  	}
   471  
   472  	return output
   473  }
   474  
   475  // UnknownKeyError stores information about an unknown key in the source map.
   476  type UnknownKeyError struct {
   477  	Key string // key from the source map.
   478  }
   479  
   480  func (e UnknownKeyError) Error() string {
   481  	return fmt.Sprintf("schema: invalid path %q", e.Key)
   482  }
   483  
   484  // EmptyFieldError stores information about an empty required field.
   485  type EmptyFieldError struct {
   486  	Key string // required key in the source map.
   487  }
   488  
   489  func (e EmptyFieldError) Error() string {
   490  	return fmt.Sprintf("%v is empty", e.Key)
   491  }
   492  
   493  // MultiError stores multiple decoding errors.
   494  //
   495  // Borrowed from the App Engine SDK.
   496  type MultiError map[string]error
   497  
   498  func (e MultiError) Error() string {
   499  	s := ""
   500  	for _, err := range e {
   501  		s = err.Error()
   502  		break
   503  	}
   504  	switch len(e) {
   505  	case 0:
   506  		return "(0 errors)"
   507  	case 1:
   508  		return s
   509  	case 2:
   510  		return s + " (and 1 other error)"
   511  	}
   512  	return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1)
   513  }
   514  
   515  func (e MultiError) merge(errors MultiError) {
   516  	for key, err := range errors {
   517  		if e[key] == nil {
   518  			e[key] = err
   519  		}
   520  	}
   521  }