github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/orm.go (about) 1 package sqlx 2 3 import ( 4 "errors" 5 "reflect" 6 "strings" 7 8 "github.com/lingyao2333/mo-zero/core/mapping" 9 ) 10 11 const tagName = "db" 12 13 var ( 14 // ErrNotMatchDestination is an error that indicates not matching destination to scan. 15 ErrNotMatchDestination = errors.New("not matching destination to scan") 16 // ErrNotReadableValue is an error that indicates value is not addressable or interfaceable. 17 ErrNotReadableValue = errors.New("value not addressable or interfaceable") 18 // ErrNotSettable is an error that indicates the passed in variable is not settable. 19 ErrNotSettable = errors.New("passed in variable is not settable") 20 // ErrUnsupportedValueType is an error that indicates unsupported unmarshal type. 21 ErrUnsupportedValueType = errors.New("unsupported unmarshal type") 22 ) 23 24 type rowsScanner interface { 25 Columns() ([]string, error) 26 Err() error 27 Next() bool 28 Scan(v ...interface{}) error 29 } 30 31 func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) { 32 rt := mapping.Deref(v.Type()) 33 size := rt.NumField() 34 result := make(map[string]interface{}, size) 35 36 for i := 0; i < size; i++ { 37 key := parseTagName(rt.Field(i)) 38 if len(key) == 0 { 39 return nil, nil 40 } 41 42 valueField := reflect.Indirect(v).Field(i) 43 switch valueField.Kind() { 44 case reflect.Ptr: 45 if !valueField.CanInterface() { 46 return nil, ErrNotReadableValue 47 } 48 if valueField.IsNil() { 49 baseValueType := mapping.Deref(valueField.Type()) 50 valueField.Set(reflect.New(baseValueType)) 51 } 52 result[key] = valueField.Interface() 53 default: 54 if !valueField.CanAddr() || !valueField.Addr().CanInterface() { 55 return nil, ErrNotReadableValue 56 } 57 result[key] = valueField.Addr().Interface() 58 } 59 } 60 61 return result, nil 62 } 63 64 func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) { 65 fields := unwrapFields(v) 66 if strict && len(columns) < len(fields) { 67 return nil, ErrNotMatchDestination 68 } 69 70 taggedMap, err := getTaggedFieldValueMap(v) 71 if err != nil { 72 return nil, err 73 } 74 75 values := make([]interface{}, len(columns)) 76 if len(taggedMap) == 0 { 77 for i := 0; i < len(values); i++ { 78 valueField := fields[i] 79 switch valueField.Kind() { 80 case reflect.Ptr: 81 if !valueField.CanInterface() { 82 return nil, ErrNotReadableValue 83 } 84 if valueField.IsNil() { 85 baseValueType := mapping.Deref(valueField.Type()) 86 valueField.Set(reflect.New(baseValueType)) 87 } 88 values[i] = valueField.Interface() 89 default: 90 if !valueField.CanAddr() || !valueField.Addr().CanInterface() { 91 return nil, ErrNotReadableValue 92 } 93 values[i] = valueField.Addr().Interface() 94 } 95 } 96 } else { 97 for i, column := range columns { 98 if tagged, ok := taggedMap[column]; ok { 99 values[i] = tagged 100 } else { 101 var anonymous interface{} 102 values[i] = &anonymous 103 } 104 } 105 } 106 107 return values, nil 108 } 109 110 func parseTagName(field reflect.StructField) string { 111 key := field.Tag.Get(tagName) 112 if len(key) == 0 { 113 return "" 114 } 115 116 options := strings.Split(key, ",") 117 return options[0] 118 } 119 120 func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error { 121 if !scanner.Next() { 122 if err := scanner.Err(); err != nil { 123 return err 124 } 125 return ErrNotFound 126 } 127 128 rv := reflect.ValueOf(v) 129 if err := mapping.ValidatePtr(&rv); err != nil { 130 return err 131 } 132 133 rte := reflect.TypeOf(v).Elem() 134 rve := rv.Elem() 135 switch rte.Kind() { 136 case reflect.Bool, 137 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 138 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 139 reflect.Float32, reflect.Float64, 140 reflect.String: 141 if rve.CanSet() { 142 return scanner.Scan(v) 143 } 144 145 return ErrNotSettable 146 case reflect.Struct: 147 columns, err := scanner.Columns() 148 if err != nil { 149 return err 150 } 151 152 values, err := mapStructFieldsIntoSlice(rve, columns, strict) 153 if err != nil { 154 return err 155 } 156 157 return scanner.Scan(values...) 158 default: 159 return ErrUnsupportedValueType 160 } 161 } 162 163 func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error { 164 rv := reflect.ValueOf(v) 165 if err := mapping.ValidatePtr(&rv); err != nil { 166 return err 167 } 168 169 rt := reflect.TypeOf(v) 170 rte := rt.Elem() 171 rve := rv.Elem() 172 switch rte.Kind() { 173 case reflect.Slice: 174 if rve.CanSet() { 175 ptr := rte.Elem().Kind() == reflect.Ptr 176 appendFn := func(item reflect.Value) { 177 if ptr { 178 rve.Set(reflect.Append(rve, item)) 179 } else { 180 rve.Set(reflect.Append(rve, reflect.Indirect(item))) 181 } 182 } 183 fillFn := func(value interface{}) error { 184 if rve.CanSet() { 185 if err := scanner.Scan(value); err != nil { 186 return err 187 } 188 189 appendFn(reflect.ValueOf(value)) 190 return nil 191 } 192 return ErrNotSettable 193 } 194 195 base := mapping.Deref(rte.Elem()) 196 switch base.Kind() { 197 case reflect.Bool, 198 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 199 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 200 reflect.Float32, reflect.Float64, 201 reflect.String: 202 for scanner.Next() { 203 value := reflect.New(base) 204 if err := fillFn(value.Interface()); err != nil { 205 return err 206 } 207 } 208 case reflect.Struct: 209 columns, err := scanner.Columns() 210 if err != nil { 211 return err 212 } 213 214 for scanner.Next() { 215 value := reflect.New(base) 216 values, err := mapStructFieldsIntoSlice(value, columns, strict) 217 if err != nil { 218 return err 219 } 220 221 if err := scanner.Scan(values...); err != nil { 222 return err 223 } 224 225 appendFn(value) 226 } 227 default: 228 return ErrUnsupportedValueType 229 } 230 231 return nil 232 } 233 234 return ErrNotSettable 235 default: 236 return ErrUnsupportedValueType 237 } 238 } 239 240 func unwrapFields(v reflect.Value) []reflect.Value { 241 var fields []reflect.Value 242 indirect := reflect.Indirect(v) 243 244 for i := 0; i < indirect.NumField(); i++ { 245 child := indirect.Field(i) 246 if child.Kind() == reflect.Ptr && child.IsNil() { 247 baseValueType := mapping.Deref(child.Type()) 248 child.Set(reflect.New(baseValueType)) 249 } 250 251 child = reflect.Indirect(child) 252 childType := indirect.Type().Field(i) 253 if child.Kind() == reflect.Struct && childType.Anonymous { 254 fields = append(fields, unwrapFields(child)...) 255 } else { 256 fields = append(fields, child) 257 } 258 } 259 260 return fields 261 }