github.com/systematiccaos/gorm@v1.22.6/scan.go (about)

     1  package gorm
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/systematiccaos/gorm/schema"
    11  )
    12  
    13  func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
    14  	if db.Statement.Schema != nil {
    15  		for idx, name := range columns {
    16  			if field := db.Statement.Schema.LookUpField(name); field != nil {
    17  				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
    18  				continue
    19  			}
    20  			values[idx] = new(interface{})
    21  		}
    22  	} else if len(columnTypes) > 0 {
    23  		for idx, columnType := range columnTypes {
    24  			if columnType.ScanType() != nil {
    25  				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
    26  			} else {
    27  				values[idx] = new(interface{})
    28  			}
    29  		}
    30  	} else {
    31  		for idx := range columns {
    32  			values[idx] = new(interface{})
    33  		}
    34  	}
    35  }
    36  
    37  func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
    38  	for idx, column := range columns {
    39  		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
    40  			mapValue[column] = reflectValue.Interface()
    41  			if valuer, ok := mapValue[column].(driver.Valuer); ok {
    42  				mapValue[column], _ = valuer.Value()
    43  			} else if b, ok := mapValue[column].(sql.RawBytes); ok {
    44  				mapValue[column] = string(b)
    45  			}
    46  		} else {
    47  			mapValue[column] = nil
    48  		}
    49  	}
    50  }
    51  
    52  func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
    53  	for idx, column := range columns {
    54  		if sch == nil {
    55  			values[idx] = reflectValue.Interface()
    56  		} else if field := sch.LookUpField(column); field != nil && field.Readable {
    57  			values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
    58  		} else if names := strings.Split(column, "__"); len(names) > 1 {
    59  			if rel, ok := sch.Relationships.Relations[names[0]]; ok {
    60  				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
    61  					values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
    62  					continue
    63  				}
    64  			}
    65  			values[idx] = &sql.RawBytes{}
    66  		} else if len(columns) == 1 {
    67  			sch = nil
    68  			values[idx] = reflectValue.Interface()
    69  		} else {
    70  			values[idx] = &sql.RawBytes{}
    71  		}
    72  	}
    73  
    74  	db.RowsAffected++
    75  	db.AddError(rows.Scan(values...))
    76  
    77  	if sch != nil {
    78  		for idx, column := range columns {
    79  			if field := sch.LookUpField(column); field != nil && field.Readable {
    80  				field.Set(reflectValue, values[idx])
    81  			} else if names := strings.Split(column, "__"); len(names) > 1 {
    82  				if rel, ok := sch.Relationships.Relations[names[0]]; ok {
    83  					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
    84  						relValue := rel.Field.ReflectValueOf(reflectValue)
    85  						value := reflect.ValueOf(values[idx]).Elem()
    86  
    87  						if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
    88  							if value.IsNil() {
    89  								continue
    90  							}
    91  							relValue.Set(reflect.New(relValue.Type().Elem()))
    92  						}
    93  
    94  						field.Set(relValue, values[idx])
    95  					}
    96  				}
    97  			}
    98  		}
    99  	}
   100  }
   101  
   102  type ScanMode uint8
   103  
   104  const (
   105  	ScanInitialized         ScanMode = 1 << 0 // 1
   106  	ScanUpdate              ScanMode = 1 << 1 // 2
   107  	ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
   108  )
   109  
   110  func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
   111  	var (
   112  		columns, _          = rows.Columns()
   113  		values              = make([]interface{}, len(columns))
   114  		initialized         = mode&ScanInitialized != 0
   115  		update              = mode&ScanUpdate != 0
   116  		onConflictDonothing = mode&ScanOnConflictDoNothing != 0
   117  	)
   118  
   119  	db.RowsAffected = 0
   120  
   121  	switch dest := db.Statement.Dest.(type) {
   122  	case map[string]interface{}, *map[string]interface{}:
   123  		if initialized || rows.Next() {
   124  			columnTypes, _ := rows.ColumnTypes()
   125  			prepareValues(values, db, columnTypes, columns)
   126  
   127  			db.RowsAffected++
   128  			db.AddError(rows.Scan(values...))
   129  
   130  			mapValue, ok := dest.(map[string]interface{})
   131  			if !ok {
   132  				if v, ok := dest.(*map[string]interface{}); ok {
   133  					if *v == nil {
   134  						*v = map[string]interface{}{}
   135  					}
   136  					mapValue = *v
   137  				}
   138  			}
   139  			scanIntoMap(mapValue, values, columns)
   140  		}
   141  	case *[]map[string]interface{}, []map[string]interface{}:
   142  		columnTypes, _ := rows.ColumnTypes()
   143  		for initialized || rows.Next() {
   144  			prepareValues(values, db, columnTypes, columns)
   145  
   146  			initialized = false
   147  			db.RowsAffected++
   148  			db.AddError(rows.Scan(values...))
   149  
   150  			mapValue := map[string]interface{}{}
   151  			scanIntoMap(mapValue, values, columns)
   152  			if values, ok := dest.([]map[string]interface{}); ok {
   153  				values = append(values, mapValue)
   154  			} else if values, ok := dest.(*[]map[string]interface{}); ok {
   155  				*values = append(*values, mapValue)
   156  			}
   157  		}
   158  	case *int, *int8, *int16, *int32, *int64,
   159  		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
   160  		*float32, *float64,
   161  		*bool, *string, *time.Time,
   162  		*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
   163  		*sql.NullBool, *sql.NullString, *sql.NullTime:
   164  		for initialized || rows.Next() {
   165  			initialized = false
   166  			db.RowsAffected++
   167  			db.AddError(rows.Scan(dest))
   168  		}
   169  	default:
   170  		var (
   171  			fields       = make([]*schema.Field, len(columns))
   172  			joinFields   [][2]*schema.Field
   173  			sch          = db.Statement.Schema
   174  			reflectValue = db.Statement.ReflectValue
   175  		)
   176  
   177  		if reflectValue.Kind() == reflect.Interface {
   178  			reflectValue = reflectValue.Elem()
   179  		}
   180  
   181  		reflectValueType := reflectValue.Type()
   182  		switch reflectValueType.Kind() {
   183  		case reflect.Array, reflect.Slice:
   184  			reflectValueType = reflectValueType.Elem()
   185  		}
   186  		isPtr := reflectValueType.Kind() == reflect.Ptr
   187  		if isPtr {
   188  			reflectValueType = reflectValueType.Elem()
   189  		}
   190  
   191  		if sch != nil {
   192  			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
   193  				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
   194  			}
   195  
   196  			for idx, column := range columns {
   197  				if field := sch.LookUpField(column); field != nil && field.Readable {
   198  					fields[idx] = field
   199  				} else if names := strings.Split(column, "__"); len(names) > 1 {
   200  					if rel, ok := sch.Relationships.Relations[names[0]]; ok {
   201  						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
   202  							fields[idx] = field
   203  
   204  							if len(joinFields) == 0 {
   205  								joinFields = make([][2]*schema.Field, len(columns))
   206  							}
   207  							joinFields[idx] = [2]*schema.Field{rel.Field, field}
   208  							continue
   209  						}
   210  					}
   211  					values[idx] = &sql.RawBytes{}
   212  				} else {
   213  					values[idx] = &sql.RawBytes{}
   214  				}
   215  			}
   216  
   217  			if len(columns) == 1 {
   218  				// isPluck
   219  				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
   220  					reflectValueType.Kind() != reflect.Struct || // is not struct
   221  					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
   222  					sch = nil
   223  				}
   224  			}
   225  		}
   226  
   227  		switch reflectValue.Kind() {
   228  		case reflect.Slice, reflect.Array:
   229  			var elem reflect.Value
   230  
   231  			if !update || reflectValue.Len() == 0 {
   232  				update = false
   233  				db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
   234  			}
   235  
   236  			for initialized || rows.Next() {
   237  			BEGIN:
   238  				initialized = false
   239  
   240  				if update {
   241  					if int(db.RowsAffected) >= reflectValue.Len() {
   242  						return
   243  					}
   244  					elem = reflectValue.Index(int(db.RowsAffected))
   245  					if onConflictDonothing {
   246  						for _, field := range fields {
   247  							if _, ok := field.ValueOf(elem); !ok {
   248  								db.RowsAffected++
   249  								goto BEGIN
   250  							}
   251  						}
   252  					}
   253  				} else {
   254  					elem = reflect.New(reflectValueType)
   255  				}
   256  
   257  				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
   258  
   259  				if !update {
   260  					if isPtr {
   261  						reflectValue = reflect.Append(reflectValue, elem)
   262  					} else {
   263  						reflectValue = reflect.Append(reflectValue, elem.Elem())
   264  					}
   265  				}
   266  			}
   267  
   268  			if !update {
   269  				db.Statement.ReflectValue.Set(reflectValue)
   270  			}
   271  		case reflect.Struct, reflect.Ptr:
   272  			if initialized || rows.Next() {
   273  				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
   274  			}
   275  		default:
   276  			db.AddError(rows.Scan(dest))
   277  		}
   278  	}
   279  
   280  	if err := rows.Err(); err != nil && err != db.Error {
   281  		db.AddError(err)
   282  	}
   283  
   284  	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
   285  		db.AddError(ErrRecordNotFound)
   286  	}
   287  }