github.com/urso/go-structform@v0.0.2/gotype/fold_reflect.go (about)

     1  package gotype
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  	"unicode"
     7  	"unicode/utf8"
     8  
     9  	structform "github.com/urso/go-structform"
    10  )
    11  
    12  type typeFoldRegistry struct {
    13  	// mu sync.RWMutex
    14  	m map[typeFoldKey]reFoldFn
    15  }
    16  
    17  type typeFoldKey struct {
    18  	ty     reflect.Type
    19  	inline bool
    20  }
    21  
    22  var _foldRegistry = newTypeFoldRegistry()
    23  
    24  func getReflectFold(c *foldContext, t reflect.Type) (reFoldFn, error) {
    25  	var err error
    26  
    27  	f := c.reg.find(t)
    28  	if f != nil {
    29  		return f, nil
    30  	}
    31  
    32  	f = getReflectFoldPrimitive(t)
    33  	if f != nil {
    34  		c.reg.set(t, f)
    35  		return f, nil
    36  	}
    37  
    38  	if t.Implements(tFolder) {
    39  		f := reFoldFolderIfc
    40  		c.reg.set(t, f)
    41  		return f, nil
    42  	}
    43  
    44  	switch t.Kind() {
    45  	case reflect.Ptr:
    46  		f, err = getFoldPointer(c, t)
    47  	case reflect.Struct:
    48  		f, err = getReflectFoldStruct(c, t, false)
    49  	case reflect.Map:
    50  		f, err = getReflectFoldMap(c, t)
    51  	case reflect.Slice, reflect.Array:
    52  		f, err = getReflectFoldSlice(c, t)
    53  	case reflect.Interface:
    54  		f, err = getReflectFoldElem(c, t)
    55  	default:
    56  		f, err = getReflectFoldPrimitiveKind(t)
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  	}
    61  
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	c.reg.set(t, f)
    66  	return f, nil
    67  }
    68  
    69  func getReflectFoldMap(c *foldContext, t reflect.Type) (reFoldFn, error) {
    70  	iterVisitor, err := getReflectFoldMapKeys(c, t)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	return func(C *foldContext, rv reflect.Value) error {
    76  		if err := C.OnObjectStart(rv.Len(), structform.AnyType); err != nil {
    77  			return err
    78  		}
    79  		if err := iterVisitor(C, rv); err != nil {
    80  			return err
    81  		}
    82  		return C.OnObjectFinished()
    83  	}, nil
    84  }
    85  
    86  func getFoldPointer(c *foldContext, t reflect.Type) (reFoldFn, error) {
    87  	N, bt := baseType(t)
    88  	elemVisitor, err := getReflectFold(c, bt)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	return makePointerFold(N, elemVisitor), nil
    93  }
    94  
    95  func makePointerFold(N int, elemVisitor reFoldFn) reFoldFn {
    96  	if N == 0 {
    97  		return elemVisitor
    98  	}
    99  
   100  	return func(C *foldContext, v reflect.Value) error {
   101  		for i := 0; i < N; i++ {
   102  			if v.IsNil() {
   103  				return C.OnNil()
   104  			}
   105  			v = v.Elem()
   106  		}
   107  		return elemVisitor(C, v)
   108  	}
   109  }
   110  
   111  func getReflectFoldElem(c *foldContext, t reflect.Type) (reFoldFn, error) {
   112  	return foldInterfaceElem, nil
   113  }
   114  
   115  func foldInterfaceElem(C *foldContext, v reflect.Value) error {
   116  	if v.IsNil() {
   117  		return C.visitor.OnNil()
   118  	}
   119  	return foldAnyReflect(C, v.Elem())
   120  }
   121  
   122  func getReflectFoldStruct(c *foldContext, t reflect.Type, inline bool) (reFoldFn, error) {
   123  	fields, err := getStructFieldsFolds(c, t)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	if inline {
   129  		return makeFieldsFold(fields), nil
   130  	}
   131  	return makeStructFold(fields), nil
   132  }
   133  
   134  // TODO: benchmark field accessors based on pointer offsets
   135  func getStructFieldsFolds(c *foldContext, t reflect.Type) ([]reFoldFn, error) {
   136  	count := t.NumField()
   137  	fields := make([]reFoldFn, 0, count)
   138  
   139  	for i := 0; i < count; i++ {
   140  		fv, err := buildFieldFold(c, t, i)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  
   145  		if fv == nil {
   146  			continue
   147  		}
   148  
   149  		fields = append(fields, fv)
   150  	}
   151  
   152  	if len(fields) < cap(fields) {
   153  		tmp := make([]reFoldFn, len(fields))
   154  		copy(tmp, fields)
   155  		fields = tmp
   156  	}
   157  
   158  	return fields, nil
   159  }
   160  
   161  func makeStructFold(fields []reFoldFn) reFoldFn {
   162  	fieldsVisitor := makeFieldsFold(fields)
   163  	return func(C *foldContext, v reflect.Value) error {
   164  		if err := C.OnObjectStart(len(fields), structform.AnyType); err != nil {
   165  			return err
   166  		}
   167  		if err := fieldsVisitor(C, v); err != nil {
   168  			return err
   169  		}
   170  		return C.OnObjectFinished()
   171  	}
   172  }
   173  
   174  func makeFieldsFold(fields []reFoldFn) reFoldFn {
   175  	return func(C *foldContext, v reflect.Value) error {
   176  		for _, fv := range fields {
   177  			if err := fv(C, v); err != nil {
   178  				return err
   179  			}
   180  		}
   181  		return nil
   182  	}
   183  }
   184  
   185  func buildFieldFold(C *foldContext, t reflect.Type, idx int) (reFoldFn, error) {
   186  	st := t.Field(idx)
   187  
   188  	name := st.Name
   189  	rune, _ := utf8.DecodeRuneInString(name)
   190  	if !unicode.IsUpper(rune) {
   191  		// ignore non exported fields
   192  		return nil, nil
   193  	}
   194  
   195  	tagName, tagOpts := parseTags(st.Tag.Get(C.opts.tag))
   196  	if tagOpts.squash && tagOpts.omitEmpty {
   197  		return nil, errInlineAndOmitEmpty
   198  	}
   199  
   200  	if tagOpts.squash {
   201  		return buildFieldFoldInline(C, t, idx, tagOpts.omitEmpty)
   202  	}
   203  
   204  	foldT := st.Type
   205  	if tagOpts.omitEmpty {
   206  		_, foldT = baseType(st.Type)
   207  	}
   208  	valueVisitor, err := getReflectFold(C, foldT)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	if tagName != "" {
   214  		name = tagName
   215  	} else {
   216  		name = strings.ToLower(name)
   217  	}
   218  
   219  	if tagOpts.omitEmpty {
   220  		return makeNonEmptyFieldFold(name, idx, st.Type, valueVisitor)
   221  	}
   222  	return makeFieldFold(name, idx, valueVisitor)
   223  }
   224  
   225  func buildFieldFoldInline(
   226  	C *foldContext,
   227  	t reflect.Type,
   228  	idx int,
   229  	omitEmpty bool,
   230  ) (reFoldFn, error) {
   231  	var (
   232  		st          = t.Field(idx)
   233  		N, bt       = baseType(st.Type)
   234  		baseVisitor reFoldFn
   235  		err         error
   236  	)
   237  
   238  	f := C.reg.findInline(st.Type)
   239  	if f != nil {
   240  		return makeFieldInlineFold(idx, f), nil
   241  	}
   242  
   243  	baseVisitor = C.reg.findInline(bt)
   244  	if baseVisitor == nil {
   245  		baseVisitor, err = fieldFoldGenInline(C, bt)
   246  		if err != nil {
   247  			return nil, err
   248  		}
   249  		C.reg.setInline(bt, baseVisitor)
   250  	}
   251  
   252  	f = makePointerFold(N, baseVisitor)
   253  	C.reg.setInline(st.Type, f)
   254  
   255  	return makeFieldInlineFold(idx, f), nil
   256  }
   257  
   258  func fieldFoldGenInline(C *foldContext, t reflect.Type) (reFoldFn, error) {
   259  	if C.userReg != nil {
   260  		if f := C.userReg[t]; f != nil {
   261  			f = embeddObjReFold(C, f)
   262  		}
   263  	}
   264  
   265  	if t.Implements(tFolder) {
   266  		return embeddObjReFold(C, reFoldFolderIfc), nil
   267  	}
   268  
   269  	switch t.Kind() {
   270  	case reflect.Struct:
   271  		return getReflectFoldStruct(C, t, true)
   272  	case reflect.Map:
   273  		return getReflectFoldMapKeys(C, t)
   274  	case reflect.Interface:
   275  		return getReflectFoldInlineInterface(C, t)
   276  	}
   277  
   278  	return nil, errSquashNeedObject
   279  }
   280  
   281  func makeFieldFold(name string, idx int, fn reFoldFn) (reFoldFn, error) {
   282  	return func(C *foldContext, v reflect.Value) error {
   283  		if err := C.OnKey(name); err != nil {
   284  			return err
   285  		}
   286  		return fn(C, v.Field(idx))
   287  	}, nil
   288  }
   289  
   290  func makeFieldInlineFold(idx int, fn reFoldFn) reFoldFn {
   291  	return func(C *foldContext, v reflect.Value) error {
   292  		return fn(C, v.Field(idx))
   293  	}
   294  }
   295  
   296  func makeNonEmptyFieldFold(name string, idx int, t reflect.Type, fn reFoldFn) (reFoldFn, error) {
   297  	resolver := makeResolveValue(t)
   298  	if resolver == nil {
   299  		return makeFieldFold(name, idx, fn)
   300  	}
   301  
   302  	return func(C *foldContext, v reflect.Value) (err error) {
   303  		field, ok := resolver(v.Field(idx))
   304  		if ok {
   305  			if err = C.OnKey(name); err != nil {
   306  				return
   307  			}
   308  			err = fn(C, field)
   309  		}
   310  		return
   311  	}, nil
   312  }
   313  
   314  func makeResolveValue(st reflect.Type) func(reflect.Value) (reflect.Value, bool) {
   315  	type resolver func(reflect.Value) (reflect.Value, bool)
   316  
   317  	resolveBySize := func(v reflect.Value) (reflect.Value, bool) {
   318  		return v, v.Len() > 0
   319  	}
   320  
   321  	resolveNonNil := func(v reflect.Value) (reflect.Value, bool) {
   322  		return v, !v.IsNil()
   323  	}
   324  
   325  	var resolvers []resolver
   326  	for {
   327  		switch st.Kind() {
   328  		case reflect.Ptr:
   329  			var r resolver
   330  			st, r = makeResolvePointers(st)
   331  			resolvers = append(resolvers, r)
   332  			continue
   333  		case reflect.Interface:
   334  			resolvers = append(resolvers, resolveNonNil)
   335  		case reflect.Map, reflect.String, reflect.Slice, reflect.Array:
   336  			resolvers = append(resolvers, resolveBySize)
   337  		default:
   338  		}
   339  		break
   340  	}
   341  
   342  	if len(resolvers) == 0 {
   343  		return nil
   344  	}
   345  	if len(resolvers) == 1 {
   346  		return resolvers[0]
   347  	}
   348  
   349  	return func(v reflect.Value) (reflect.Value, bool) {
   350  		for _, r := range resolvers {
   351  			var ok bool
   352  			if v, ok = r(v); !ok {
   353  				return v, ok
   354  			}
   355  		}
   356  		return v, true
   357  	}
   358  }
   359  
   360  func makeResolvePointers(st reflect.Type) (reflect.Type, func(reflect.Value) (reflect.Value, bool)) {
   361  	N, bt := baseType(st)
   362  	return bt, func(v reflect.Value) (reflect.Value, bool) {
   363  		for i := 0; i < N; i++ {
   364  			if v.IsNil() {
   365  				return v, false
   366  			}
   367  			v = v.Elem()
   368  		}
   369  		return v, true
   370  	}
   371  }
   372  
   373  func getReflectFoldSlice(c *foldContext, t reflect.Type) (reFoldFn, error) {
   374  	elemVisitor, err := getReflectFold(c, t.Elem())
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  
   379  	return func(C *foldContext, rv reflect.Value) error {
   380  		count := rv.Len()
   381  
   382  		if err := C.OnArrayStart(count, structform.AnyType); err != nil {
   383  			return err
   384  		}
   385  		for i := 0; i < count; i++ {
   386  			if err := elemVisitor(C, rv.Index(i)); err != nil {
   387  				return err
   388  			}
   389  		}
   390  
   391  		return C.OnArrayFinished()
   392  	}, nil
   393  }
   394  
   395  /*
   396  // TODO: create visitors casting the actual values via reflection instead of
   397  //       golang type conversion:
   398  func getReflectFoldPrimitive(t reflect.Type) reFoldFn {
   399  	switch t.Kind() {
   400  	case reflect.Bool:
   401  		return reFoldBool
   402  	case reflect.Int:
   403  		return reFoldInt
   404  	case reflect.Int8:
   405  		return reFoldInt8
   406  	case reflect.Int16:
   407  		return reFoldInt16
   408  	case reflect.Int32:
   409  		return reFoldInt32
   410  	case reflect.Int64:
   411  		return reFoldInt64
   412  	case reflect.Uint:
   413  		return reFoldUint
   414  	case reflect.Uint8:
   415  		return reFoldUint8
   416  	case reflect.Uint16:
   417  		return reFoldUint16
   418  	case reflect.Uint32:
   419  		return reFoldUint32
   420  	case reflect.Uint64:
   421  		return reFoldUint64
   422  	case reflect.Float32:
   423  		return reFoldFloat32
   424  	case reflect.Float64:
   425  		return reFoldFloat64
   426  	case reflect.String:
   427  		return reFoldString
   428  
   429  	case reflect.Slice:
   430  		switch t.Elem().Kind() {
   431  		case reflect.Interface:
   432  			return reFoldArrAny
   433  		case reflect.Bool:
   434  			return reFoldArrBool
   435  		case reflect.Int:
   436  			return reFoldArrInt
   437  		case reflect.Int8:
   438  			return reFoldArrInt8
   439  		case reflect.Int16:
   440  			return reFoldArrInt16
   441  		case reflect.Int32:
   442  			return reFoldArrInt32
   443  		case reflect.Int64:
   444  			return reFoldArrInt64
   445  		case reflect.Uint:
   446  			return reFoldArrUint
   447  		case reflect.Uint8:
   448  			return reFoldArrUint8
   449  		case reflect.Uint16:
   450  			return reFoldArrUint16
   451  		case reflect.Uint32:
   452  			return reFoldArrUint32
   453  		case reflect.Uint64:
   454  			return reFoldArrUint64
   455  		case reflect.Float32:
   456  			return reFoldArrFloat32
   457  		case reflect.Float64:
   458  			return reFoldArrFloat64
   459  		case reflect.String:
   460  			return reFoldArrString
   461  		}
   462  
   463  	case reflect.Map:
   464  		if t.Key().Kind() != reflect.String {
   465  			return nil
   466  		}
   467  
   468  		switch t.Elem().Kind() {
   469  		case reflect.Interface:
   470  			return reflectMapAny
   471  		case reflect.Bool:
   472  			return reFoldMapBool
   473  		case reflect.Int:
   474  			return reFoldMapInt
   475  		case reflect.Int8:
   476  			return reFoldMapInt8
   477  		case reflect.Int16:
   478  			return reFoldMapInt16
   479  		case reflect.Int32:
   480  			return reFoldMapInt32
   481  		case reflect.Int64:
   482  			return reFoldMapInt64
   483  		case reflect.Uint:
   484  			return reFoldMapUint
   485  		case reflect.Uint8:
   486  			return reFoldMapUint8
   487  		case reflect.Uint16:
   488  			return reFoldMapUint16
   489  		case reflect.Uint32:
   490  			return reFoldMapUint32
   491  		case reflect.Uint64:
   492  			return reFoldMapUint64
   493  		case reflect.Float32:
   494  			return reFoldMapFloat32
   495  		case reflect.Float64:
   496  			return reFoldMapFloat64
   497  		case reflect.String:
   498  			return reFoldMapString
   499  		}
   500  	}
   501  
   502  	return nil
   503  }
   504  */
   505  
   506  func foldAnyReflect(C *foldContext, v reflect.Value) error {
   507  	f, err := getReflectFold(C, v.Type())
   508  	if err != nil {
   509  		return err
   510  	}
   511  	return f(C, v)
   512  }
   513  
   514  func newTypeFoldRegistry() *typeFoldRegistry {
   515  	return &typeFoldRegistry{m: map[typeFoldKey]reFoldFn{}}
   516  }
   517  
   518  func (r *typeFoldRegistry) find(t reflect.Type) reFoldFn {
   519  	// r.mu.RLock()
   520  	// defer r.mu.RUnlock()
   521  	return r.m[typeFoldKey{ty: t, inline: false}]
   522  }
   523  
   524  func (r *typeFoldRegistry) findInline(t reflect.Type) reFoldFn {
   525  	// r.mu.RLock()
   526  	// defer r.mu.RUnlock()
   527  	return r.m[typeFoldKey{ty: t, inline: true}]
   528  }
   529  
   530  func (r *typeFoldRegistry) set(t reflect.Type, f reFoldFn) {
   531  	// r.mu.Lock()
   532  	// defer r.mu.Unlock()
   533  	r.m[typeFoldKey{ty: t, inline: false}] = f
   534  }
   535  
   536  func (r *typeFoldRegistry) setInline(t reflect.Type, f reFoldFn) {
   537  	// r.mu.Lock()
   538  	// defer r.mu.Unlock()
   539  	r.m[typeFoldKey{ty: t, inline: true}] = f
   540  }
   541  
   542  func liftFold(sample interface{}, fn foldFn) reFoldFn {
   543  	t := reflect.TypeOf(sample)
   544  	return func(C *foldContext, v reflect.Value) error {
   545  		if v.Type().Name() != "" {
   546  			v = v.Convert(t)
   547  		}
   548  		return fn(C, v.Interface())
   549  	}
   550  }
   551  
   552  func baseType(t reflect.Type) (int, reflect.Type) {
   553  	i := 0
   554  	for t.Kind() == reflect.Ptr {
   555  		t = t.Elem()
   556  		i++
   557  	}
   558  	return i, t
   559  }