github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/execute.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "reflect" 7 8 "github.com/go-sql-driver/mysql" 9 "github.com/sirupsen/logrus" 10 11 "github.com/johnnyeven/libtools/sqlx/builder" 12 ) 13 14 func Do(db *DB, stmt builder.Statement) (result *Result) { 15 result = &Result{} 16 17 e := stmt.Expr() 18 if e == nil { 19 result.err = NewSqlError(sqlErrTypeInvalidSql, "") 20 return 21 } 22 if e.Err != nil { 23 result.err = NewSqlError(sqlErrTypeInvalidSql, e.Err.Error()) 24 logrus.Errorf("%s", result.err) 25 return 26 } 27 28 result.stmtType = stmt.Type() 29 30 switch result.stmtType { 31 case builder.STMT_SELECT: 32 rows, queryErr := db.Query(e.Query, e.Args...) 33 if queryErr != nil { 34 result.err = queryErr 35 return 36 } 37 result.Rows = rows 38 case builder.STMT_INSERT, builder.STMT_UPDATE: 39 sqlResult, execErr := db.Exec(e.Query, e.Args...) 40 if execErr != nil { 41 if mysqlErr, ok := execErr.(*mysql.MySQLError); ok && mysqlErr.Number == DuplicateEntryErrNumber { 42 result.err = NewSqlError(sqlErrTypeConflict, mysqlErr.Error()) 43 } else { 44 result.err = execErr 45 } 46 return 47 } 48 result.Result = sqlResult 49 case builder.STMT_DELETE, builder.STMT_RAW: 50 sqlResult, execErr := db.Exec(e.Query, e.Args...) 51 if execErr != nil { 52 result.err = execErr 53 return 54 } 55 result.Result = sqlResult 56 } 57 return 58 } 59 60 type Result struct { 61 stmtType builder.StmtType 62 err error 63 *sql.Rows 64 sql.Result 65 } 66 67 func (r *Result) Err() error { 68 return r.err 69 } 70 71 func (r *Result) Scan(v interface{}) *Result { 72 if r.err != nil { 73 return r 74 } 75 76 if r.Rows != nil { 77 defer r.Rows.Close() 78 79 if scanner, ok := v.(sql.Scanner); ok { 80 for r.Rows.Next() { 81 if scanErr := r.Rows.Scan(scanner); scanErr != nil { 82 r.err = scanErr 83 return r 84 } 85 } 86 } else { 87 88 modelType := reflect.TypeOf(v) 89 if modelType.Kind() != reflect.Ptr { 90 r.err = NewSqlError(sqlErrTypeInvalidScanTarget, "can not scan to a none pointer variable") 91 return r 92 } 93 94 modelType = modelType.Elem() 95 96 isSlice := false 97 if modelType.Kind() == reflect.Slice { 98 modelType = modelType.Elem() 99 isSlice = true 100 } 101 102 if modelType.Kind() == reflect.Struct || isSlice { 103 columns, getErr := r.Rows.Columns() 104 if getErr != nil { 105 r.err = getErr 106 return r 107 } 108 109 rv := reflect.Indirect(reflect.ValueOf(v)) 110 111 rowLength := 0 112 113 for r.Rows.Next() { 114 if !isSlice && rowLength > 1 { 115 r.err = NewSqlError(sqlErrTypeSelectShouldOne, "more than one records found, but only one") 116 return r 117 } 118 119 rowLength++ 120 length := len(columns) 121 dest := make([]interface{}, length) 122 itemRv := rv 123 124 if isSlice { 125 itemRv = reflect.New(modelType).Elem() 126 } 127 128 destIndexes := make(map[int]bool, length) 129 130 ForEachStructFieldValue(itemRv, func(structFieldValue reflect.Value, structField reflect.StructField, columnName string) { 131 idx := stringIndexOf(columns, columnName) 132 if idx >= 0 { 133 dest[idx] = structFieldValue.Addr().Interface() 134 destIndexes[idx] = true 135 } 136 }) 137 138 for index := range dest { 139 if !destIndexes[index] { 140 placeholder := emptyScanner(0) 141 dest[index] = &placeholder 142 } else { 143 // todo null ignore 144 dest[index] = newNullableScanner(dest[index]) 145 } 146 } 147 148 if scanErr := r.Rows.Scan(dest...); scanErr != nil { 149 r.err = scanErr 150 return r 151 } 152 153 if isSlice { 154 rv.Set(reflect.Append(rv, itemRv)) 155 } 156 } 157 158 if !isSlice && rowLength == 0 { 159 r.err = NewSqlError(sqlErrTypeNotFound, "record is not found") 160 return r 161 } 162 } else { 163 for r.Rows.Next() { 164 if scanErr := r.Rows.Scan(v); scanErr != nil { 165 r.err = scanErr 166 return r 167 } 168 } 169 } 170 } 171 if err := r.Rows.Err(); err != nil { 172 r.err = err 173 return r 174 } 175 176 // Make sure the query can be processed to completion with no errors. 177 if err := r.Rows.Close(); err != nil { 178 r.err = err 179 return r 180 } 181 } 182 183 return r 184 } 185 186 type emptyScanner int 187 188 var _ interface { 189 sql.Scanner 190 driver.Valuer 191 } = (*emptyScanner)(nil) 192 193 func (e *emptyScanner) Scan(value interface{}) error { 194 return nil 195 } 196 197 func (e emptyScanner) Value() (driver.Value, error) { 198 return 0, nil 199 } 200 201 func newNullableScanner(dest interface{}) *nullableScanner { 202 return &nullableScanner{ 203 dest: dest, 204 } 205 } 206 207 type nullableScanner struct { 208 dest interface{} 209 } 210 211 var _ interface { 212 sql.Scanner 213 } = (*nullableScanner)(nil) 214 215 func (scanner *nullableScanner) Scan(src interface{}) error { 216 if scanner, ok := scanner.dest.(sql.Scanner); ok { 217 return scanner.Scan(src) 218 } 219 if src == nil { 220 if zeroSetter, ok := scanner.dest.(ZeroSetter); ok { 221 zeroSetter.SetToZero() 222 return nil 223 } 224 return nil 225 } 226 return convertAssign(scanner.dest, src) 227 }