github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/orm.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"errors"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"github.com/lingyao2333/mo-zero/core/mapping"
     9  )
    10  
    11  const tagName = "db"
    12  
    13  var (
    14  	// ErrNotMatchDestination is an error that indicates not matching destination to scan.
    15  	ErrNotMatchDestination = errors.New("not matching destination to scan")
    16  	// ErrNotReadableValue is an error that indicates value is not addressable or interfaceable.
    17  	ErrNotReadableValue = errors.New("value not addressable or interfaceable")
    18  	// ErrNotSettable is an error that indicates the passed in variable is not settable.
    19  	ErrNotSettable = errors.New("passed in variable is not settable")
    20  	// ErrUnsupportedValueType is an error that indicates unsupported unmarshal type.
    21  	ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
    22  )
    23  
    24  type rowsScanner interface {
    25  	Columns() ([]string, error)
    26  	Err() error
    27  	Next() bool
    28  	Scan(v ...interface{}) error
    29  }
    30  
    31  func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
    32  	rt := mapping.Deref(v.Type())
    33  	size := rt.NumField()
    34  	result := make(map[string]interface{}, size)
    35  
    36  	for i := 0; i < size; i++ {
    37  		key := parseTagName(rt.Field(i))
    38  		if len(key) == 0 {
    39  			return nil, nil
    40  		}
    41  
    42  		valueField := reflect.Indirect(v).Field(i)
    43  		switch valueField.Kind() {
    44  		case reflect.Ptr:
    45  			if !valueField.CanInterface() {
    46  				return nil, ErrNotReadableValue
    47  			}
    48  			if valueField.IsNil() {
    49  				baseValueType := mapping.Deref(valueField.Type())
    50  				valueField.Set(reflect.New(baseValueType))
    51  			}
    52  			result[key] = valueField.Interface()
    53  		default:
    54  			if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
    55  				return nil, ErrNotReadableValue
    56  			}
    57  			result[key] = valueField.Addr().Interface()
    58  		}
    59  	}
    60  
    61  	return result, nil
    62  }
    63  
    64  func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
    65  	fields := unwrapFields(v)
    66  	if strict && len(columns) < len(fields) {
    67  		return nil, ErrNotMatchDestination
    68  	}
    69  
    70  	taggedMap, err := getTaggedFieldValueMap(v)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	values := make([]interface{}, len(columns))
    76  	if len(taggedMap) == 0 {
    77  		for i := 0; i < len(values); i++ {
    78  			valueField := fields[i]
    79  			switch valueField.Kind() {
    80  			case reflect.Ptr:
    81  				if !valueField.CanInterface() {
    82  					return nil, ErrNotReadableValue
    83  				}
    84  				if valueField.IsNil() {
    85  					baseValueType := mapping.Deref(valueField.Type())
    86  					valueField.Set(reflect.New(baseValueType))
    87  				}
    88  				values[i] = valueField.Interface()
    89  			default:
    90  				if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
    91  					return nil, ErrNotReadableValue
    92  				}
    93  				values[i] = valueField.Addr().Interface()
    94  			}
    95  		}
    96  	} else {
    97  		for i, column := range columns {
    98  			if tagged, ok := taggedMap[column]; ok {
    99  				values[i] = tagged
   100  			} else {
   101  				var anonymous interface{}
   102  				values[i] = &anonymous
   103  			}
   104  		}
   105  	}
   106  
   107  	return values, nil
   108  }
   109  
   110  func parseTagName(field reflect.StructField) string {
   111  	key := field.Tag.Get(tagName)
   112  	if len(key) == 0 {
   113  		return ""
   114  	}
   115  
   116  	options := strings.Split(key, ",")
   117  	return options[0]
   118  }
   119  
   120  func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
   121  	if !scanner.Next() {
   122  		if err := scanner.Err(); err != nil {
   123  			return err
   124  		}
   125  		return ErrNotFound
   126  	}
   127  
   128  	rv := reflect.ValueOf(v)
   129  	if err := mapping.ValidatePtr(&rv); err != nil {
   130  		return err
   131  	}
   132  
   133  	rte := reflect.TypeOf(v).Elem()
   134  	rve := rv.Elem()
   135  	switch rte.Kind() {
   136  	case reflect.Bool,
   137  		reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   138  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   139  		reflect.Float32, reflect.Float64,
   140  		reflect.String:
   141  		if rve.CanSet() {
   142  			return scanner.Scan(v)
   143  		}
   144  
   145  		return ErrNotSettable
   146  	case reflect.Struct:
   147  		columns, err := scanner.Columns()
   148  		if err != nil {
   149  			return err
   150  		}
   151  
   152  		values, err := mapStructFieldsIntoSlice(rve, columns, strict)
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		return scanner.Scan(values...)
   158  	default:
   159  		return ErrUnsupportedValueType
   160  	}
   161  }
   162  
   163  func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
   164  	rv := reflect.ValueOf(v)
   165  	if err := mapping.ValidatePtr(&rv); err != nil {
   166  		return err
   167  	}
   168  
   169  	rt := reflect.TypeOf(v)
   170  	rte := rt.Elem()
   171  	rve := rv.Elem()
   172  	switch rte.Kind() {
   173  	case reflect.Slice:
   174  		if rve.CanSet() {
   175  			ptr := rte.Elem().Kind() == reflect.Ptr
   176  			appendFn := func(item reflect.Value) {
   177  				if ptr {
   178  					rve.Set(reflect.Append(rve, item))
   179  				} else {
   180  					rve.Set(reflect.Append(rve, reflect.Indirect(item)))
   181  				}
   182  			}
   183  			fillFn := func(value interface{}) error {
   184  				if rve.CanSet() {
   185  					if err := scanner.Scan(value); err != nil {
   186  						return err
   187  					}
   188  
   189  					appendFn(reflect.ValueOf(value))
   190  					return nil
   191  				}
   192  				return ErrNotSettable
   193  			}
   194  
   195  			base := mapping.Deref(rte.Elem())
   196  			switch base.Kind() {
   197  			case reflect.Bool,
   198  				reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   199  				reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   200  				reflect.Float32, reflect.Float64,
   201  				reflect.String:
   202  				for scanner.Next() {
   203  					value := reflect.New(base)
   204  					if err := fillFn(value.Interface()); err != nil {
   205  						return err
   206  					}
   207  				}
   208  			case reflect.Struct:
   209  				columns, err := scanner.Columns()
   210  				if err != nil {
   211  					return err
   212  				}
   213  
   214  				for scanner.Next() {
   215  					value := reflect.New(base)
   216  					values, err := mapStructFieldsIntoSlice(value, columns, strict)
   217  					if err != nil {
   218  						return err
   219  					}
   220  
   221  					if err := scanner.Scan(values...); err != nil {
   222  						return err
   223  					}
   224  
   225  					appendFn(value)
   226  				}
   227  			default:
   228  				return ErrUnsupportedValueType
   229  			}
   230  
   231  			return nil
   232  		}
   233  
   234  		return ErrNotSettable
   235  	default:
   236  		return ErrUnsupportedValueType
   237  	}
   238  }
   239  
   240  func unwrapFields(v reflect.Value) []reflect.Value {
   241  	var fields []reflect.Value
   242  	indirect := reflect.Indirect(v)
   243  
   244  	for i := 0; i < indirect.NumField(); i++ {
   245  		child := indirect.Field(i)
   246  		if child.Kind() == reflect.Ptr && child.IsNil() {
   247  			baseValueType := mapping.Deref(child.Type())
   248  			child.Set(reflect.New(baseValueType))
   249  		}
   250  
   251  		child = reflect.Indirect(child)
   252  		childType := indirect.Type().Field(i)
   253  		if child.Kind() == reflect.Struct && childType.Anonymous {
   254  			fields = append(fields, unwrapFields(child)...)
   255  		} else {
   256  			fields = append(fields, child)
   257  		}
   258  	}
   259  
   260  	return fields
   261  }