github.com/systematiccaos/gorm@v1.22.6/scan.go (about) 1 package gorm 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "reflect" 7 "strings" 8 "time" 9 10 "github.com/systematiccaos/gorm/schema" 11 ) 12 13 func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { 14 if db.Statement.Schema != nil { 15 for idx, name := range columns { 16 if field := db.Statement.Schema.LookUpField(name); field != nil { 17 values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() 18 continue 19 } 20 values[idx] = new(interface{}) 21 } 22 } else if len(columnTypes) > 0 { 23 for idx, columnType := range columnTypes { 24 if columnType.ScanType() != nil { 25 values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() 26 } else { 27 values[idx] = new(interface{}) 28 } 29 } 30 } else { 31 for idx := range columns { 32 values[idx] = new(interface{}) 33 } 34 } 35 } 36 37 func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { 38 for idx, column := range columns { 39 if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { 40 mapValue[column] = reflectValue.Interface() 41 if valuer, ok := mapValue[column].(driver.Valuer); ok { 42 mapValue[column], _ = valuer.Value() 43 } else if b, ok := mapValue[column].(sql.RawBytes); ok { 44 mapValue[column] = string(b) 45 } 46 } else { 47 mapValue[column] = nil 48 } 49 } 50 } 51 52 func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) { 53 for idx, column := range columns { 54 if sch == nil { 55 values[idx] = reflectValue.Interface() 56 } else if field := sch.LookUpField(column); field != nil && field.Readable { 57 values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() 58 } else if names := strings.Split(column, "__"); len(names) > 1 { 59 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 60 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 61 values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() 62 continue 63 } 64 } 65 values[idx] = &sql.RawBytes{} 66 } else if len(columns) == 1 { 67 sch = nil 68 values[idx] = reflectValue.Interface() 69 } else { 70 values[idx] = &sql.RawBytes{} 71 } 72 } 73 74 db.RowsAffected++ 75 db.AddError(rows.Scan(values...)) 76 77 if sch != nil { 78 for idx, column := range columns { 79 if field := sch.LookUpField(column); field != nil && field.Readable { 80 field.Set(reflectValue, values[idx]) 81 } else if names := strings.Split(column, "__"); len(names) > 1 { 82 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 83 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 84 relValue := rel.Field.ReflectValueOf(reflectValue) 85 value := reflect.ValueOf(values[idx]).Elem() 86 87 if relValue.Kind() == reflect.Ptr && relValue.IsNil() { 88 if value.IsNil() { 89 continue 90 } 91 relValue.Set(reflect.New(relValue.Type().Elem())) 92 } 93 94 field.Set(relValue, values[idx]) 95 } 96 } 97 } 98 } 99 } 100 } 101 102 type ScanMode uint8 103 104 const ( 105 ScanInitialized ScanMode = 1 << 0 // 1 106 ScanUpdate ScanMode = 1 << 1 // 2 107 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 108 ) 109 110 func Scan(rows *sql.Rows, db *DB, mode ScanMode) { 111 var ( 112 columns, _ = rows.Columns() 113 values = make([]interface{}, len(columns)) 114 initialized = mode&ScanInitialized != 0 115 update = mode&ScanUpdate != 0 116 onConflictDonothing = mode&ScanOnConflictDoNothing != 0 117 ) 118 119 db.RowsAffected = 0 120 121 switch dest := db.Statement.Dest.(type) { 122 case map[string]interface{}, *map[string]interface{}: 123 if initialized || rows.Next() { 124 columnTypes, _ := rows.ColumnTypes() 125 prepareValues(values, db, columnTypes, columns) 126 127 db.RowsAffected++ 128 db.AddError(rows.Scan(values...)) 129 130 mapValue, ok := dest.(map[string]interface{}) 131 if !ok { 132 if v, ok := dest.(*map[string]interface{}); ok { 133 if *v == nil { 134 *v = map[string]interface{}{} 135 } 136 mapValue = *v 137 } 138 } 139 scanIntoMap(mapValue, values, columns) 140 } 141 case *[]map[string]interface{}, []map[string]interface{}: 142 columnTypes, _ := rows.ColumnTypes() 143 for initialized || rows.Next() { 144 prepareValues(values, db, columnTypes, columns) 145 146 initialized = false 147 db.RowsAffected++ 148 db.AddError(rows.Scan(values...)) 149 150 mapValue := map[string]interface{}{} 151 scanIntoMap(mapValue, values, columns) 152 if values, ok := dest.([]map[string]interface{}); ok { 153 values = append(values, mapValue) 154 } else if values, ok := dest.(*[]map[string]interface{}); ok { 155 *values = append(*values, mapValue) 156 } 157 } 158 case *int, *int8, *int16, *int32, *int64, 159 *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, 160 *float32, *float64, 161 *bool, *string, *time.Time, 162 *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, 163 *sql.NullBool, *sql.NullString, *sql.NullTime: 164 for initialized || rows.Next() { 165 initialized = false 166 db.RowsAffected++ 167 db.AddError(rows.Scan(dest)) 168 } 169 default: 170 var ( 171 fields = make([]*schema.Field, len(columns)) 172 joinFields [][2]*schema.Field 173 sch = db.Statement.Schema 174 reflectValue = db.Statement.ReflectValue 175 ) 176 177 if reflectValue.Kind() == reflect.Interface { 178 reflectValue = reflectValue.Elem() 179 } 180 181 reflectValueType := reflectValue.Type() 182 switch reflectValueType.Kind() { 183 case reflect.Array, reflect.Slice: 184 reflectValueType = reflectValueType.Elem() 185 } 186 isPtr := reflectValueType.Kind() == reflect.Ptr 187 if isPtr { 188 reflectValueType = reflectValueType.Elem() 189 } 190 191 if sch != nil { 192 if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { 193 sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) 194 } 195 196 for idx, column := range columns { 197 if field := sch.LookUpField(column); field != nil && field.Readable { 198 fields[idx] = field 199 } else if names := strings.Split(column, "__"); len(names) > 1 { 200 if rel, ok := sch.Relationships.Relations[names[0]]; ok { 201 if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { 202 fields[idx] = field 203 204 if len(joinFields) == 0 { 205 joinFields = make([][2]*schema.Field, len(columns)) 206 } 207 joinFields[idx] = [2]*schema.Field{rel.Field, field} 208 continue 209 } 210 } 211 values[idx] = &sql.RawBytes{} 212 } else { 213 values[idx] = &sql.RawBytes{} 214 } 215 } 216 217 if len(columns) == 1 { 218 // isPluck 219 if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner 220 reflectValueType.Kind() != reflect.Struct || // is not struct 221 sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time 222 sch = nil 223 } 224 } 225 } 226 227 switch reflectValue.Kind() { 228 case reflect.Slice, reflect.Array: 229 var elem reflect.Value 230 231 if !update || reflectValue.Len() == 0 { 232 update = false 233 db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) 234 } 235 236 for initialized || rows.Next() { 237 BEGIN: 238 initialized = false 239 240 if update { 241 if int(db.RowsAffected) >= reflectValue.Len() { 242 return 243 } 244 elem = reflectValue.Index(int(db.RowsAffected)) 245 if onConflictDonothing { 246 for _, field := range fields { 247 if _, ok := field.ValueOf(elem); !ok { 248 db.RowsAffected++ 249 goto BEGIN 250 } 251 } 252 } 253 } else { 254 elem = reflect.New(reflectValueType) 255 } 256 257 db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields) 258 259 if !update { 260 if isPtr { 261 reflectValue = reflect.Append(reflectValue, elem) 262 } else { 263 reflectValue = reflect.Append(reflectValue, elem.Elem()) 264 } 265 } 266 } 267 268 if !update { 269 db.Statement.ReflectValue.Set(reflectValue) 270 } 271 case reflect.Struct, reflect.Ptr: 272 if initialized || rows.Next() { 273 db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields) 274 } 275 default: 276 db.AddError(rows.Scan(dest)) 277 } 278 } 279 280 if err := rows.Err(); err != nil && err != db.Error { 281 db.AddError(err) 282 } 283 284 if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { 285 db.AddError(ErrRecordNotFound) 286 } 287 }