github.com/dranikpg/go-dto@v1.0.0/godto.go (about)

     1  // Package dto is an easy-to-use library for data mapping
     2  //
     3  // go-dto maps primitives, structs, slices, maps, pointers
     4  // and supports custom functions and error mapping.
     5  //
     6  // Contrary to other struct mappers it uses only name based field resolution
     7  // and maps its values recursively. This means that go-dto tries to map struct fields
     8  // with the same names.
     9  //
    10  // Conversion functions can be used to overwrite mapping behaviour.
    11  // Inspection functions allow to modify a value after it has been mapped.
    12  //
    13  // See the tests and github page for more exmaples.
    14  package dto
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"reflect"
    20  )
    21  
    22  type structValueMap = map[string]reflect.Value
    23  
    24  // Marker type for functions with no receiver
    25  type nilRecvT struct{}
    26  
    27  var nilRecvRfType = reflect.TypeOf(nilRecvT{})
    28  var errorRfType = reflect.TypeOf((*error)(nil)).Elem()
    29  var mapperPtrRfType = reflect.TypeOf((*Mapper)(nil))
    30  
    31  type convertFuncClosure = func(reflect.Value, *Mapper) (error, reflect.Value)
    32  type inspectFuncClosure = func(reflect.Value, reflect.Value, *Mapper) error
    33  
    34  // ErrNoValidMapping indicates that no valid mapping was found
    35  type ErrNoValidMapping struct {
    36  	ToType   reflect.Type
    37  	FromType reflect.Type
    38  }
    39  
    40  func (nvme ErrNoValidMapping) Error() string {
    41  	return fmt.Sprintf("No valid mapping found for %v from %v", nvme.ToType, nvme.FromType)
    42  }
    43  
    44  // Mapper contains conversion and inspect functions
    45  type Mapper struct {
    46  	// linear search might be faster than nested maps
    47  	convFunc map[reflect.Type]map[reflect.Type]convertFuncClosure
    48  	postFunc map[reflect.Type]map[reflect.Type][]inspectFuncClosure
    49  }
    50  
    51  // ==================================== utils =================================
    52  
    53  // Collect all struct fields (including anonymous) into a structValueMap
    54  func collectStructFields(rfValue reflect.Value, rfType reflect.Type, fields structValueMap) {
    55  	for i := 0; i < rfType.NumField(); i++ {
    56  		fieldValue := rfValue.Field(i)
    57  		fieldType := rfType.Field(i)
    58  		if fieldType.Anonymous {
    59  			collectStructFields(fieldValue, fieldType.Type, fields)
    60  		} else {
    61  			fields[fieldType.Name] = fieldValue
    62  		}
    63  	}
    64  }
    65  
    66  // Return reflect.Value with pointer removed (first layer only)
    67  func reflectValueRemovePtr(v interface{}) reflect.Value {
    68  	rv := reflect.ValueOf(v)
    69  	if rv.Type().Kind() == reflect.Ptr {
    70  		return rv.Elem()
    71  	}
    72  	return rv
    73  }
    74  
    75  // Maps an error from a reflect value
    76  // Panics if the value is non nill and not an error
    77  func errorFromReflectValue(rv reflect.Value) error {
    78  	if rv.IsNil() {
    79  		return nil
    80  	}
    81  	err, ok := rv.Interface().(error)
    82  	if !ok {
    83  		panic("Failed to map error from reflect.Value")
    84  	}
    85  	return err
    86  }
    87  
    88  // ==================================== Conversion and inspection functions ===
    89  
    90  // Run inspect functions for (to-from) pair
    91  func (m *Mapper) runInspectFuncs(toRv, fromRv reflect.Value) error {
    92  	toMap, ok := m.postFunc[toRv.Type()]
    93  	if !ok {
    94  		return nil
    95  	}
    96  	for _, recvType := range []reflect.Type{fromRv.Type(), nilRecvRfType} {
    97  		funcs, ok := toMap[recvType]
    98  		if !ok {
    99  			continue
   100  		}
   101  		for _, fun := range funcs {
   102  			if err := fun(toRv.Addr(), fromRv, m); err != nil {
   103  				return err
   104  			}
   105  		}
   106  	}
   107  	return nil
   108  }
   109  
   110  // Run convert function for (to-from) pair
   111  // Returns (error, true) if a valid function was found, (nil, false) otherwise
   112  func (m *Mapper) runConvFuncs(toRv, fromRv reflect.Value) (bool, error) {
   113  	toMap, ok := m.convFunc[fromRv.Type()]
   114  	if !ok {
   115  		return false, nil
   116  	}
   117  	if convertFunc, ok := toMap[toRv.Type()]; ok {
   118  		err, val := convertFunc(fromRv, m)
   119  		if err != nil {
   120  			return true, err
   121  		}
   122  		toRv.Set(val)
   123  		return true, nil
   124  	}
   125  	return false, nil
   126  }
   127  
   128  // AddConvFunc adds a conversion function to the Mapper
   129  //
   130  // Panics if f is not a valid conversion function
   131  // Overwrites previous functions with the same type pair
   132  func (m *Mapper) AddConvFunc(f interface{}) {
   133  	rt := reflect.TypeOf(f)
   134  
   135  	// check basic argument invariant
   136  	if rt.NumOut() < 1 || rt.NumIn() < 1 {
   137  		panic("Bad conversion function")
   138  	}
   139  
   140  	// check if to inject mapper
   141  	takesMapper := false
   142  	if rt.NumIn() > 1 && rt.In(1) == mapperPtrRfType {
   143  		takesMapper = true
   144  	}
   145  
   146  	// check if returns an error
   147  	returnsError := false
   148  	outType := rt.Out(0)
   149  	if rt.NumOut() > 1 && rt.Out(1).Implements(errorRfType) {
   150  		returnsError = true
   151  	}
   152  
   153  	inType := rt.In(0)
   154  
   155  	// create maps
   156  	if len(m.convFunc) == 0 {
   157  		m.convFunc = make(map[reflect.Type]map[reflect.Type]convertFuncClosure)
   158  	}
   159  	if len(m.convFunc[inType]) == 0 {
   160  		m.convFunc[inType] = make(map[reflect.Type]convertFuncClosure)
   161  	}
   162  
   163  	// register closure
   164  	m.convFunc[inType][outType] = func(from reflect.Value, m *Mapper) (error, reflect.Value) {
   165  		args := []reflect.Value{from}
   166  		if takesMapper {
   167  			args = append(args, reflect.ValueOf(m))
   168  		}
   169  		out := reflect.ValueOf(f).Call(args)
   170  		if returnsError {
   171  			return errorFromReflectValue(out[0]), out[1]
   172  		}
   173  		return nil, out[0]
   174  	}
   175  }
   176  
   177  // AddInspectFunc adds an inspection function to the Mapper
   178  //
   179  // Panics if f is not a valid inspection function
   180  func (m *Mapper) AddInspectFunc(f interface{}) {
   181  	ft := reflect.TypeOf(f)
   182  	inType := ft.In(0).Elem()
   183  
   184  	// check if takes from
   185  	fromType := nilRecvRfType
   186  	if ft.NumIn() > 1 {
   187  		fromType = ft.In(1)
   188  	}
   189  
   190  	// check if takes mapper
   191  	takesMapper := false
   192  	if ft.NumIn() > 2 && ft.In(2) == reflect.TypeOf(m) {
   193  		takesMapper = true
   194  	}
   195  
   196  	// check if returns error
   197  	returnsError := false
   198  	if ft.NumOut() > 0 && ft.Out(0).Implements(errorRfType) {
   199  		returnsError = true
   200  	}
   201  
   202  	// create map path
   203  	if len(m.postFunc) == 0 {
   204  		m.postFunc = make(map[reflect.Type]map[reflect.Type][]inspectFuncClosure)
   205  	}
   206  	if len(m.postFunc[inType]) == 0 {
   207  		m.postFunc[inType] = make(map[reflect.Type][]inspectFuncClosure)
   208  	}
   209  
   210  	// register closure
   211  	m.postFunc[inType][fromType] = append(m.postFunc[inType][fromType],
   212  		func(v1, v2 reflect.Value, m *Mapper) error {
   213  			args := []reflect.Value{v1}
   214  			if fromType != nilRecvRfType {
   215  				args = append(args, v2)
   216  			}
   217  			if takesMapper {
   218  				args = append(args, reflect.ValueOf(m))
   219  			}
   220  
   221  			out := reflect.ValueOf(f).Call(args)
   222  			if returnsError {
   223  				return errorFromReflectValue(out[0])
   224  			}
   225  			return nil
   226  		},
   227  	)
   228  }
   229  
   230  // ==================================== Mapping functions =====================
   231  
   232  // Map slices
   233  // Panics if arguments are not slices
   234  func (m *Mapper) mapSlice(toRv, fromRv reflect.Value) error {
   235  	toRv.Set(reflect.MakeSlice(toRv.Type(), fromRv.Len(), fromRv.Len()))
   236  	for i := 0; i < fromRv.Len(); i++ {
   237  		if err := m.mapValue(toRv.Index(i), fromRv.Index(i)); err != nil {
   238  			return err
   239  		}
   240  	}
   241  	return nil
   242  }
   243  
   244  // Map maps
   245  // Panics if arguments are not maps
   246  func (m *Mapper) mapMap(toRv, fromRv reflect.Value) error {
   247  	toRv.Set(reflect.MakeMapWithSize(toRv.Type(), fromRv.Len()))
   248  	// Map values
   249  	mapIt := fromRv.MapRange()
   250  	for mapIt.Next() {
   251  		toKey := reflect.New(toRv.Type().Key()).Elem()
   252  		toValue := reflect.New(toRv.Type().Elem()).Elem()
   253  		if err := m.mapValue(toKey, mapIt.Key()); err != nil {
   254  			return err
   255  		}
   256  		if err := m.mapValue(toValue, mapIt.Value()); err != nil {
   257  			return err
   258  		}
   259  		toRv.SetMapIndex(toKey, toValue)
   260  	}
   261  	return nil
   262  }
   263  
   264  // Map structs
   265  // Panics if arguments are not structs
   266  func (m *Mapper) mapStructs(toRv, fromRv reflect.Value) error {
   267  	toFields := make(structValueMap)
   268  	collectStructFields(toRv, toRv.Type(), toFields)
   269  
   270  	fromFields := make(structValueMap)
   271  	collectStructFields(fromRv, fromRv.Type(), fromFields)
   272  
   273  	for fieldName, toValue := range toFields {
   274  		fromValue, ok := fromFields[fieldName]
   275  		if !ok {
   276  			continue
   277  		}
   278  		err := m.mapValue(toValue, fromValue)
   279  		if err != nil {
   280  			return err
   281  		}
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  // Map map values to slice
   288  // Panics if arguments are not slice and map accordingly
   289  func (m *Mapper) mapMapToSlice(toRv, fromRv reflect.Value) error {
   290  	toRv.Set(reflect.MakeSlice(toRv.Type(), fromRv.Len(), fromRv.Len()))
   291  	i := 0
   292  	mapIt := fromRv.MapRange()
   293  	for mapIt.Next() {
   294  		if err := m.mapValue(toRv.Index(i), mapIt.Value()); err != nil {
   295  			return err
   296  		}
   297  		i++
   298  	}
   299  	return nil
   300  }
   301  
   302  // Map a map of slices to slice
   303  // Panics of arguments are not a map of slices and a slice accordingly
   304  func (m *Mapper) mapMapSlicesToSlice(toRv, fromRv reflect.Value) error {
   305  	// calculate lenght
   306  	sumLen := 0
   307  	mapIt := fromRv.MapRange()
   308  	for mapIt.Next() {
   309  		sumLen += mapIt.Value().Len()
   310  	}
   311  
   312  	toRv.Set(reflect.MakeSlice(toRv.Type(), sumLen, sumLen))
   313  
   314  	i := 0
   315  	mapIt = fromRv.MapRange()
   316  	for mapIt.Next() {
   317  		mapSlice := mapIt.Value()
   318  		for j := 0; j < mapSlice.Len(); i, j = i+1, j+1 {
   319  			if err := m.mapValue(toRv.Index(i), mapSlice.Index(j)); err != nil {
   320  				return err
   321  			}
   322  		}
   323  	}
   324  
   325  	return nil
   326  }
   327  
   328  // Try to map any value
   329  func (m *Mapper) mapValue(toRv, fromRv reflect.Value) (returnError error) {
   330  	tk, fk := toRv.Type().Kind(), fromRv.Type().Kind()
   331  
   332  	// Defer inspect functions
   333  	defer func() {
   334  		if returnError != nil {
   335  			return
   336  		}
   337  		returnError = m.runInspectFuncs(toRv, fromRv)
   338  	}()
   339  
   340  	// 1. Check conversion functions
   341  	converted, err := m.runConvFuncs(toRv, fromRv)
   342  	if converted {
   343  		return err
   344  	}
   345  
   346  	// 2. Check direct assignment
   347  	if fromRv.Type().AssignableTo(toRv.Type()) {
   348  		toRv.Set(fromRv)
   349  		return
   350  	}
   351  
   352  	// 3. Check conversion
   353  	if fromRv.Type().ConvertibleTo(toRv.Type()) {
   354  		toRv.Set(fromRv.Convert(toRv.Type()))
   355  		return
   356  	}
   357  
   358  	// 4. Handle pointers by dereferencing from
   359  	if fk == reflect.Ptr {
   360  		return m.mapValue(toRv, fromRv.Elem())
   361  	}
   362  
   363  	// 5. Handle sructs
   364  	if tk == reflect.Struct && fk == reflect.Struct {
   365  		return m.mapStructs(toRv, fromRv)
   366  	}
   367  
   368  	// 6. Handle slices
   369  	if tk == reflect.Slice && fk == reflect.Slice {
   370  		return m.mapSlice(toRv, fromRv)
   371  	}
   372  
   373  	// 7. Handle maps
   374  	if tk == reflect.Map && fk == reflect.Map {
   375  		return m.mapMap(toRv, fromRv)
   376  	}
   377  
   378  	// 8. Handle map to slice
   379  	if tk == reflect.Slice && fk == reflect.Map {
   380  		err := m.mapMapToSlice(toRv, fromRv)
   381  
   382  		// 9. Handle map of slices to slice
   383  		mapElemK := fromRv.Type().Elem().Kind()
   384  		if errors.As(err, &ErrNoValidMapping{}) && mapElemK == reflect.Slice {
   385  			// dont propagate errors
   386  			if errFlatten := m.mapMapSlicesToSlice(toRv, fromRv); errFlatten == nil {
   387  				return
   388  			}
   389  		}
   390  
   391  		return err
   392  	}
   393  
   394  	return ErrNoValidMapping{
   395  		ToType:   toRv.Type(),
   396  		FromType: fromRv.Type(),
   397  	}
   398  }
   399  
   400  // ==================================== Public helpers ========================
   401  
   402  // Map transfers values from `from` to `to`
   403  func (m *Mapper) Map(to, from interface{}) error {
   404  	return m.mapValue(reflectValueRemovePtr(to), reflectValueRemovePtr(from))
   405  }
   406  
   407  // Map transfers values from `from` to `to` with a new Mapper
   408  func Map(to, from interface{}) error {
   409  	m := Mapper{}
   410  	return m.Map(to, from)
   411  }