gitee.com/eden-framework/sqlx@v0.0.3/super_scan.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"gitee.com/eden-framework/sqlx/builder"
     9  	"gitee.com/eden-framework/sqlx/nullable"
    10  )
    11  
    12  type ScanIterator interface {
    13  	// new a ptr value for scan
    14  	New() interface{}
    15  	// for receive scanned value
    16  	Next(v interface{}) error
    17  }
    18  
    19  func Scan(rows *sql.Rows, v interface{}) error {
    20  	if rows == nil {
    21  		return nil
    22  	}
    23  
    24  	defer rows.Close()
    25  
    26  	modelScanner, err := newModelScanner(v)
    27  	if err != nil {
    28  		return err
    29  	}
    30  
    31  	// simple value scan
    32  	for rows.Next() {
    33  		if modelScanner.direct {
    34  			if scanErr := rows.Scan(modelScanner.v); scanErr != nil {
    35  				return scanErr
    36  			}
    37  			continue
    38  		}
    39  
    40  		rv, err := modelScanner.New()
    41  		if err != nil {
    42  			return err
    43  		}
    44  
    45  		if scanErr := scanStruct(rows, rv); scanErr != nil {
    46  			return scanErr
    47  		}
    48  
    49  		if err := modelScanner.Next(rv); err != nil {
    50  			return err
    51  		}
    52  	}
    53  
    54  	if !modelScanner.direct && !modelScanner.isSlice && modelScanner.scanIterator == nil && modelScanner.count == 0 {
    55  		return NewSqlError(sqlErrTypeNotFound, "record is not found")
    56  	}
    57  
    58  	if err := rows.Err(); err != nil {
    59  		return err
    60  	}
    61  
    62  	// Make sure the query can be processed to completion with no errors.
    63  	if err := rows.Close(); err != nil {
    64  		return err
    65  	}
    66  
    67  	return nil
    68  }
    69  
    70  func newModelScanner(v interface{}) (*modelScanner, error) {
    71  	si := &modelScanner{v: v}
    72  
    73  	if _, ok := v.(sql.Scanner); ok {
    74  		si.direct = true
    75  	} else if scanIterator, ok := v.(ScanIterator); ok {
    76  		si.scanIterator = scanIterator
    77  	} else {
    78  		modelType := reflect.TypeOf(v)
    79  
    80  		if modelType.Kind() != reflect.Ptr {
    81  			return nil, NewSqlError(sqlErrTypeInvalidScanTarget, "can not scan to a none pointer variable")
    82  		}
    83  
    84  		si.modelType = modelType.Elem()
    85  
    86  		if si.modelType.Kind() == reflect.Slice {
    87  			si.modelType = si.modelType.Elem()
    88  			si.isSlice = true
    89  		}
    90  
    91  		si.rv = reflect.Indirect(reflect.ValueOf(v))
    92  		si.direct = si.modelType.Kind() != reflect.Struct
    93  	}
    94  
    95  	return si, nil
    96  }
    97  
    98  type modelScanner struct {
    99  	v            interface{}
   100  	rv           reflect.Value
   101  	direct       bool
   102  	isSlice      bool
   103  	count        int
   104  	modelType    reflect.Type
   105  	scanIterator ScanIterator
   106  }
   107  
   108  func (s *modelScanner) New() (reflect.Value, error) {
   109  	if s.scanIterator != nil {
   110  		rv := reflect.ValueOf(s.scanIterator.New())
   111  		if rv.Kind() != reflect.Ptr {
   112  			return reflect.Value{}, NewSqlError(sqlErrTypeInvalidScanTarget, "can not scan to a none pointer variable")
   113  		}
   114  		return rv.Elem(), nil
   115  	}
   116  	if s.isSlice {
   117  		return reflect.New(s.modelType).Elem(), nil
   118  	}
   119  	return s.rv, nil
   120  }
   121  
   122  func (s *modelScanner) Next(rv reflect.Value) error {
   123  	s.count++
   124  
   125  	if s.scanIterator != nil {
   126  		return s.scanIterator.Next(rv.Addr().Interface())
   127  	}
   128  
   129  	if s.isSlice {
   130  		s.rv.Set(reflect.Append(s.rv, rv))
   131  		return nil
   132  	}
   133  
   134  	s.rv.Set(rv)
   135  	return nil
   136  }
   137  
   138  func scanStruct(rows *sql.Rows, rv reflect.Value) error {
   139  	columns, err := rows.Columns()
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	n := len(columns)
   145  	dest := make([]interface{}, n)
   146  
   147  	columnIndexes := map[string]int{}
   148  	p := placeholder()
   149  
   150  	for i, name := range columns {
   151  		columnIndexes[strings.ToLower(name)] = i
   152  		dest[i] = p
   153  	}
   154  
   155  	builder.ForEachStructFieldValue(rv, func(structFieldValue reflect.Value, structField reflect.StructField, columnName string, tagValue string) {
   156  		if i, ok := columnIndexes[columnName]; ok && i > -1 {
   157  			dest[i] = nullable.NewNullIgnoreScanner(structFieldValue.Addr().Interface())
   158  		}
   159  	})
   160  
   161  	return rows.Scan(dest...)
   162  }
   163  
   164  func placeholder() *emptyScanner {
   165  	p := emptyScanner(0)
   166  	return &p
   167  }
   168  
   169  type emptyScanner int
   170  
   171  func (e *emptyScanner) Scan(value interface{}) error {
   172  	return nil
   173  }