github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/scan.go (about) 1 package taorm 2 3 import ( 4 "database/sql" 5 "reflect" 6 "unsafe" 7 ) 8 9 // ScanRows scans result rows into out. 10 // 11 // out can be either *primitive, *Struct, *[]Struct, or *[]*Struct. 12 func ScanRows(out interface{}, tx _SQLCommon, query string, args ...interface{}) (_err error) { 13 defer func() { _err = WrapError(_err) }() 14 15 rows, err := tx.Query(query, args...) 16 if err != nil { 17 return err 18 } 19 20 defer rows.Close() 21 22 columns, err := rows.Columns() 23 if err != nil { 24 return err 25 } 26 27 ty := reflect.TypeOf(out) 28 if ty.Kind() != reflect.Ptr { 29 return ErrInvalidOut 30 } 31 32 ty = ty.Elem() 33 switch ty.Kind() { 34 case reflect.Struct: 35 info, err := getRegistered(out) 36 if err != nil { 37 return err 38 } 39 if rows.Next() { 40 pointers, err := info.ptrsOf(out, columns) 41 if err != nil { 42 return err 43 } 44 return rows.Scan(pointers...) 45 } 46 err = rows.Err() 47 if err == nil { 48 err = sql.ErrNoRows 49 } 50 return err 51 case reflect.Slice: 52 slice := reflect.MakeSlice(ty, 0, 0) 53 ty = ty.Elem() 54 isPtr := ty.Kind() == reflect.Ptr 55 if isPtr { 56 ty = ty.Elem() 57 } 58 if ty.Kind() != reflect.Struct { 59 return ErrInvalidOut 60 } 61 info, err := getRegistered(reflect.NewAt(ty, unsafe.Pointer(nil)).Interface()) 62 if err != nil { 63 return err 64 } 65 if isPtr { 66 for rows.Next() { 67 elem := reflect.New(ty) 68 elemPtr := elem.Interface() 69 pointers, err := info.ptrsOf(elemPtr, columns) 70 if err != nil { 71 return err 72 } 73 if err := rows.Scan(pointers...); err != nil { 74 return err 75 } 76 slice = reflect.Append(slice, elem) 77 } 78 } else { 79 elem := reflect.New(ty) 80 elemPtr := elem.Interface() 81 pointers, err := info.ptrsOf(elemPtr, columns) 82 if err != nil { 83 return err 84 } 85 for rows.Next() { 86 if err := rows.Scan(pointers...); err != nil { 87 return err 88 } 89 slice = reflect.Append(slice, elem.Elem()) 90 } 91 } 92 reflect.ValueOf(out).Elem().Set(slice) 93 return rows.Err() 94 default: 95 if len(columns) != 1 { 96 return ErrInvalidOut 97 } 98 if rows.Next() { 99 return rows.Scan(out) 100 } 101 err = rows.Err() 102 if err == nil { 103 err = sql.ErrNoRows 104 } 105 return err 106 } 107 } 108 109 // MustScanRows ... 110 func MustScanRows(out interface{}, tx _SQLCommon, query string, args ...interface{}) { 111 if err := ScanRows(out, tx, query, args...); err != nil { 112 panic(err) 113 } 114 }