github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/scan.go (about)

     1  package taorm
     2  
     3  import (
     4  	"database/sql"
     5  	"reflect"
     6  	"unsafe"
     7  )
     8  
     9  // ScanRows scans result rows into out.
    10  //
    11  // out can be either *primitive, *Struct, *[]Struct, or *[]*Struct.
    12  func ScanRows(out interface{}, tx _SQLCommon, query string, args ...interface{}) (_err error) {
    13  	defer func() { _err = WrapError(_err) }()
    14  
    15  	rows, err := tx.Query(query, args...)
    16  	if err != nil {
    17  		return err
    18  	}
    19  
    20  	defer rows.Close()
    21  
    22  	columns, err := rows.Columns()
    23  	if err != nil {
    24  		return err
    25  	}
    26  
    27  	ty := reflect.TypeOf(out)
    28  	if ty.Kind() != reflect.Ptr {
    29  		return ErrInvalidOut
    30  	}
    31  
    32  	ty = ty.Elem()
    33  	switch ty.Kind() {
    34  	case reflect.Struct:
    35  		info, err := getRegistered(out)
    36  		if err != nil {
    37  			return err
    38  		}
    39  		if rows.Next() {
    40  			pointers, err := info.ptrsOf(out, columns)
    41  			if err != nil {
    42  				return err
    43  			}
    44  			return rows.Scan(pointers...)
    45  		}
    46  		err = rows.Err()
    47  		if err == nil {
    48  			err = sql.ErrNoRows
    49  		}
    50  		return err
    51  	case reflect.Slice:
    52  		slice := reflect.MakeSlice(ty, 0, 0)
    53  		ty = ty.Elem()
    54  		isPtr := ty.Kind() == reflect.Ptr
    55  		if isPtr {
    56  			ty = ty.Elem()
    57  		}
    58  		if ty.Kind() != reflect.Struct {
    59  			return ErrInvalidOut
    60  		}
    61  		info, err := getRegistered(reflect.NewAt(ty, unsafe.Pointer(nil)).Interface())
    62  		if err != nil {
    63  			return err
    64  		}
    65  		if isPtr {
    66  			for rows.Next() {
    67  				elem := reflect.New(ty)
    68  				elemPtr := elem.Interface()
    69  				pointers, err := info.ptrsOf(elemPtr, columns)
    70  				if err != nil {
    71  					return err
    72  				}
    73  				if err := rows.Scan(pointers...); err != nil {
    74  					return err
    75  				}
    76  				slice = reflect.Append(slice, elem)
    77  			}
    78  		} else {
    79  			elem := reflect.New(ty)
    80  			elemPtr := elem.Interface()
    81  			pointers, err := info.ptrsOf(elemPtr, columns)
    82  			if err != nil {
    83  				return err
    84  			}
    85  			for rows.Next() {
    86  				if err := rows.Scan(pointers...); err != nil {
    87  					return err
    88  				}
    89  				slice = reflect.Append(slice, elem.Elem())
    90  			}
    91  		}
    92  		reflect.ValueOf(out).Elem().Set(slice)
    93  		return rows.Err()
    94  	default:
    95  		if len(columns) != 1 {
    96  			return ErrInvalidOut
    97  		}
    98  		if rows.Next() {
    99  			return rows.Scan(out)
   100  		}
   101  		err = rows.Err()
   102  		if err == nil {
   103  			err = sql.ErrNoRows
   104  		}
   105  		return err
   106  	}
   107  }
   108  
   109  // MustScanRows ...
   110  func MustScanRows(out interface{}, tx _SQLCommon, query string, args ...interface{}) {
   111  	if err := ScanRows(out, tx, query, args...); err != nil {
   112  		panic(err)
   113  	}
   114  }