github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/execute.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"reflect"
     7  
     8  	"github.com/go-sql-driver/mysql"
     9  	"github.com/sirupsen/logrus"
    10  
    11  	"github.com/johnnyeven/libtools/sqlx/builder"
    12  )
    13  
    14  func Do(db *DB, stmt builder.Statement) (result *Result) {
    15  	result = &Result{}
    16  
    17  	e := stmt.Expr()
    18  	if e == nil {
    19  		result.err = NewSqlError(sqlErrTypeInvalidSql, "")
    20  		return
    21  	}
    22  	if e.Err != nil {
    23  		result.err = NewSqlError(sqlErrTypeInvalidSql, e.Err.Error())
    24  		logrus.Errorf("%s", result.err)
    25  		return
    26  	}
    27  
    28  	result.stmtType = stmt.Type()
    29  
    30  	switch result.stmtType {
    31  	case builder.STMT_SELECT:
    32  		rows, queryErr := db.Query(e.Query, e.Args...)
    33  		if queryErr != nil {
    34  			result.err = queryErr
    35  			return
    36  		}
    37  		result.Rows = rows
    38  	case builder.STMT_INSERT, builder.STMT_UPDATE:
    39  		sqlResult, execErr := db.Exec(e.Query, e.Args...)
    40  		if execErr != nil {
    41  			if mysqlErr, ok := execErr.(*mysql.MySQLError); ok && mysqlErr.Number == DuplicateEntryErrNumber {
    42  				result.err = NewSqlError(sqlErrTypeConflict, mysqlErr.Error())
    43  			} else {
    44  				result.err = execErr
    45  			}
    46  			return
    47  		}
    48  		result.Result = sqlResult
    49  	case builder.STMT_DELETE, builder.STMT_RAW:
    50  		sqlResult, execErr := db.Exec(e.Query, e.Args...)
    51  		if execErr != nil {
    52  			result.err = execErr
    53  			return
    54  		}
    55  		result.Result = sqlResult
    56  	}
    57  	return
    58  }
    59  
    60  type Result struct {
    61  	stmtType builder.StmtType
    62  	err      error
    63  	*sql.Rows
    64  	sql.Result
    65  }
    66  
    67  func (r *Result) Err() error {
    68  	return r.err
    69  }
    70  
    71  func (r *Result) Scan(v interface{}) *Result {
    72  	if r.err != nil {
    73  		return r
    74  	}
    75  
    76  	if r.Rows != nil {
    77  		defer r.Rows.Close()
    78  
    79  		if scanner, ok := v.(sql.Scanner); ok {
    80  			for r.Rows.Next() {
    81  				if scanErr := r.Rows.Scan(scanner); scanErr != nil {
    82  					r.err = scanErr
    83  					return r
    84  				}
    85  			}
    86  		} else {
    87  
    88  			modelType := reflect.TypeOf(v)
    89  			if modelType.Kind() != reflect.Ptr {
    90  				r.err = NewSqlError(sqlErrTypeInvalidScanTarget, "can not scan to a none pointer variable")
    91  				return r
    92  			}
    93  
    94  			modelType = modelType.Elem()
    95  
    96  			isSlice := false
    97  			if modelType.Kind() == reflect.Slice {
    98  				modelType = modelType.Elem()
    99  				isSlice = true
   100  			}
   101  
   102  			if modelType.Kind() == reflect.Struct || isSlice {
   103  				columns, getErr := r.Rows.Columns()
   104  				if getErr != nil {
   105  					r.err = getErr
   106  					return r
   107  				}
   108  
   109  				rv := reflect.Indirect(reflect.ValueOf(v))
   110  
   111  				rowLength := 0
   112  
   113  				for r.Rows.Next() {
   114  					if !isSlice && rowLength > 1 {
   115  						r.err = NewSqlError(sqlErrTypeSelectShouldOne, "more than one records found, but only one")
   116  						return r
   117  					}
   118  
   119  					rowLength++
   120  					length := len(columns)
   121  					dest := make([]interface{}, length)
   122  					itemRv := rv
   123  
   124  					if isSlice {
   125  						itemRv = reflect.New(modelType).Elem()
   126  					}
   127  
   128  					destIndexes := make(map[int]bool, length)
   129  
   130  					ForEachStructFieldValue(itemRv, func(structFieldValue reflect.Value, structField reflect.StructField, columnName string) {
   131  						idx := stringIndexOf(columns, columnName)
   132  						if idx >= 0 {
   133  							dest[idx] = structFieldValue.Addr().Interface()
   134  							destIndexes[idx] = true
   135  						}
   136  					})
   137  
   138  					for index := range dest {
   139  						if !destIndexes[index] {
   140  							placeholder := emptyScanner(0)
   141  							dest[index] = &placeholder
   142  						} else {
   143  							// todo null ignore
   144  							dest[index] = newNullableScanner(dest[index])
   145  						}
   146  					}
   147  
   148  					if scanErr := r.Rows.Scan(dest...); scanErr != nil {
   149  						r.err = scanErr
   150  						return r
   151  					}
   152  
   153  					if isSlice {
   154  						rv.Set(reflect.Append(rv, itemRv))
   155  					}
   156  				}
   157  
   158  				if !isSlice && rowLength == 0 {
   159  					r.err = NewSqlError(sqlErrTypeNotFound, "record is not found")
   160  					return r
   161  				}
   162  			} else {
   163  				for r.Rows.Next() {
   164  					if scanErr := r.Rows.Scan(v); scanErr != nil {
   165  						r.err = scanErr
   166  						return r
   167  					}
   168  				}
   169  			}
   170  		}
   171  		if err := r.Rows.Err(); err != nil {
   172  			r.err = err
   173  			return r
   174  		}
   175  
   176  		// Make sure the query can be processed to completion with no errors.
   177  		if err := r.Rows.Close(); err != nil {
   178  			r.err = err
   179  			return r
   180  		}
   181  	}
   182  
   183  	return r
   184  }
   185  
   186  type emptyScanner int
   187  
   188  var _ interface {
   189  	sql.Scanner
   190  	driver.Valuer
   191  } = (*emptyScanner)(nil)
   192  
   193  func (e *emptyScanner) Scan(value interface{}) error {
   194  	return nil
   195  }
   196  
   197  func (e emptyScanner) Value() (driver.Value, error) {
   198  	return 0, nil
   199  }
   200  
   201  func newNullableScanner(dest interface{}) *nullableScanner {
   202  	return &nullableScanner{
   203  		dest: dest,
   204  	}
   205  }
   206  
   207  type nullableScanner struct {
   208  	dest interface{}
   209  }
   210  
   211  var _ interface {
   212  	sql.Scanner
   213  } = (*nullableScanner)(nil)
   214  
   215  func (scanner *nullableScanner) Scan(src interface{}) error {
   216  	if scanner, ok := scanner.dest.(sql.Scanner); ok {
   217  		return scanner.Scan(src)
   218  	}
   219  	if src == nil {
   220  		if zeroSetter, ok := scanner.dest.(ZeroSetter); ok {
   221  			zeroSetter.SetToZero()
   222  			return nil
   223  		}
   224  		return nil
   225  	}
   226  	return convertAssign(scanner.dest, src)
   227  }