github.com/team-ide/go-dialect@v1.9.20/worker/exec.go (about)

     1  package worker
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"github.com/team-ide/go-dialect/dialect"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  )
    15  
    16  func DoExec(db *sql.DB, sqlInfo string, args []interface{}) (result sql.Result, err error) {
    17  	if len(sqlInfo) == 0 {
    18  		return
    19  	}
    20  	resultList, _, _, err := DoExecs(db, []string{sqlInfo}, [][]interface{}{args})
    21  	if err != nil {
    22  		return
    23  	}
    24  	if len(resultList) > 0 {
    25  		result = resultList[0]
    26  	}
    27  	return
    28  }
    29  
    30  type prepareFunc func(ctx context.Context, query string) (*sql.Stmt, error)
    31  
    32  func ExecByPrepare(prepare prepareFunc, ctx context.Context, sqlInfo string, sqlArgs ...interface{}) (result sql.Result, err error) {
    33  	stmt, err := prepare(ctx, sqlInfo)
    34  	if err != nil {
    35  		return
    36  	}
    37  	defer func() { _ = stmt.Close() }()
    38  	result, err = stmt.Exec(sqlArgs...)
    39  	return
    40  }
    41  
    42  func DoOwnerExecs(dia dialect.Dialect, db *sql.DB, ownerName string, sqlList []string, argsList [][]interface{}) (resultList []sql.Result, errSql string, errArgs []interface{}, err error) {
    43  	sqlListSize := len(sqlList)
    44  	if sqlListSize == 0 {
    45  		return
    46  	}
    47  	if len(argsList) == 0 {
    48  		argsList = make([][]interface{}, sqlListSize)
    49  	}
    50  	argsListSize := len(argsList)
    51  	if sqlListSize != argsListSize {
    52  		err = errors.New(fmt.Sprintf("sqlList size is [%d] but argsList size is [%d]", sqlListSize, argsListSize))
    53  		return
    54  	}
    55  	ctx := context.Background()
    56  
    57  	tx, err := db.BeginTx(ctx, nil)
    58  	if err != nil {
    59  		return
    60  	}
    61  	defer func() {
    62  		if err != nil {
    63  			_ = tx.Rollback()
    64  		} else {
    65  			err = tx.Commit()
    66  			if err != nil && strings.Contains(err.Error(), "Not in transaction") {
    67  				err = nil
    68  			}
    69  		}
    70  	}()
    71  
    72  	if ownerName != "" {
    73  		switch dia.DialectType() {
    74  		case dialect.TypeMysql:
    75  			_, _ = ExecByPrepare(tx.PrepareContext, ctx, " USE "+ownerName)
    76  			break
    77  		case dialect.TypeOracle:
    78  			_, _ = ExecByPrepare(tx.PrepareContext, ctx, "ALTER SESSION SET CURRENT_SCHEMA="+ownerName)
    79  			break
    80  			//case dialect.TypeGBase:  // GBase 在 linux使用 database语句将会导致程序奔溃  属于 GBase驱动 so 库 问题
    81  			//	_, _ = tx.Exec("database " + ownerName)
    82  			//	break
    83  		}
    84  	}
    85  	var result sql.Result
    86  	for i := 0; i < sqlListSize; i++ {
    87  		sqlInfo := sqlList[i]
    88  		args := argsList[i]
    89  		if strings.TrimSpace(sqlInfo) == "" {
    90  			continue
    91  		}
    92  		result, err = ExecByPrepare(tx.PrepareContext, ctx, sqlInfo, args...)
    93  		if err != nil {
    94  			errSql = sqlInfo
    95  			errArgs = args
    96  			return
    97  		}
    98  		resultList = append(resultList, result)
    99  	}
   100  
   101  	return
   102  }
   103  
   104  func DoExecs(db *sql.DB, sqlList []string, argsList [][]interface{}) (resultList []sql.Result, errSql string, errArgs []interface{}, err error) {
   105  	sqlListSize := len(sqlList)
   106  	if sqlListSize == 0 {
   107  		return
   108  	}
   109  	if len(argsList) == 0 {
   110  		argsList = make([][]interface{}, sqlListSize)
   111  	}
   112  	argsListSize := len(argsList)
   113  	if sqlListSize != argsListSize {
   114  		err = errors.New(fmt.Sprintf("sqlList size is [%d] but argsList size is [%d]", sqlListSize, argsListSize))
   115  		return
   116  	}
   117  	ctx := context.Background()
   118  
   119  	tx, err := db.BeginTx(ctx, nil)
   120  	if err != nil {
   121  		return
   122  	}
   123  	defer func() {
   124  		if err != nil {
   125  			_ = tx.Rollback()
   126  		} else {
   127  			err = tx.Commit()
   128  			if err != nil && strings.Contains(err.Error(), "Not in transaction") {
   129  				err = nil
   130  			}
   131  		}
   132  	}()
   133  	var result sql.Result
   134  	for i := 0; i < sqlListSize; i++ {
   135  		sqlInfo := sqlList[i]
   136  		args := argsList[i]
   137  		if strings.TrimSpace(sqlInfo) == "" {
   138  			continue
   139  		}
   140  		result, err = ExecByPrepare(tx.PrepareContext, ctx, sqlInfo, args...)
   141  		if err != nil {
   142  			errSql = sqlInfo
   143  			errArgs = args
   144  			return
   145  		}
   146  		resultList = append(resultList, result)
   147  	}
   148  
   149  	return
   150  }
   151  
   152  func DoQuery(db *sql.DB, sqlInfo string, args []interface{}) (list []map[string]interface{}, err error) {
   153  	_, _, list, err = DoQueryWithColumnTypes(db, sqlInfo, args)
   154  	if err != nil {
   155  		return
   156  	}
   157  	return
   158  }
   159  
   160  func DoQueryOne(db *sql.DB, sqlInfo string, args []interface{}) (data map[string]interface{}, err error) {
   161  	_, _, list, err := DoQueryWithColumnTypes(db, sqlInfo, args)
   162  	if err != nil {
   163  		return
   164  	}
   165  	if len(list) > 0 {
   166  		data = list[0]
   167  		if len(list) > 1 {
   168  			err = errors.New("has more rows by query one")
   169  			return
   170  		}
   171  	}
   172  	return
   173  }
   174  
   175  func DoQueryStructs(db *sql.DB, sqlInfo string, args []interface{}, list interface{}) (err error) {
   176  	ctx := context.Background()
   177  
   178  	stmt, err := db.PrepareContext(ctx, sqlInfo)
   179  	if err != nil {
   180  		return
   181  	}
   182  	defer func() { _ = stmt.Close() }()
   183  
   184  	rows, err := stmt.Query(args...)
   185  	if err != nil {
   186  		return
   187  	}
   188  	defer func() { _ = rows.Close() }()
   189  	columnTypes, err := rows.ColumnTypes()
   190  	if err != nil {
   191  		return
   192  	}
   193  	listVOf := reflect.ValueOf(list).Elem()
   194  	listStrType := GetListStructType(list)
   195  	for rows.Next() {
   196  		var values []interface{}
   197  		for range columnTypes {
   198  			values = append(values, new(interface{}))
   199  		}
   200  		err = rows.Scan(values...)
   201  		if err != nil {
   202  			return
   203  		}
   204  
   205  		item := make(map[string]interface{})
   206  		for index, data := range values {
   207  			item[columnTypes[index].Name()] = GetSqlValue(columnTypes[index], data)
   208  		}
   209  		listStrValue := reflect.New(listStrType)
   210  		SetStructColumnValues(item, listStrValue.Elem())
   211  		listVOf = reflect.Append(listVOf, listStrValue)
   212  	}
   213  	reflect.ValueOf(list).Elem().Set(listVOf)
   214  	return
   215  }
   216  
   217  func DoQueryStruct(db *sql.DB, sqlInfo string, args []interface{}, str interface{}) (find bool, err error) {
   218  	ctx := context.Background()
   219  	stmt, err := db.PrepareContext(ctx, sqlInfo)
   220  	if err != nil {
   221  		return
   222  	}
   223  	defer func() { _ = stmt.Close() }()
   224  
   225  	rows, err := stmt.Query(args...)
   226  	if err != nil {
   227  		return
   228  	}
   229  	defer func() { _ = rows.Close() }()
   230  
   231  	columnTypes, err := rows.ColumnTypes()
   232  	if err != nil {
   233  		return
   234  	}
   235  	strVOf := reflect.ValueOf(str)
   236  
   237  	var isBase bool
   238  	switch str.(type) {
   239  	case *int, *int8, *int16, *int32, *int64, *float32, *float64:
   240  		isBase = true
   241  		break
   242  	}
   243  	for rows.Next() {
   244  		if find {
   245  			err = errors.New("has more rows by query one")
   246  			return
   247  		}
   248  		find = true
   249  		var values []interface{}
   250  		if isBase {
   251  			values = []interface{}{str}
   252  		} else {
   253  			for range columnTypes {
   254  				values = append(values, new(interface{}))
   255  			}
   256  		}
   257  		err = rows.Scan(values...)
   258  		if err != nil {
   259  			return
   260  		}
   261  		if isBase {
   262  			continue
   263  		}
   264  		item := make(map[string]interface{})
   265  		for index, data := range values {
   266  			item[columnTypes[index].Name()] = GetSqlValue(columnTypes[index], data)
   267  		}
   268  		SetStructColumnValues(item, strVOf.Elem())
   269  	}
   270  	return
   271  }
   272  func DoQueryWithColumnTypes(db *sql.DB, sqlInfo string, args []interface{}) (columns []string, columnTypes []*sql.ColumnType, list []map[string]interface{}, err error) {
   273  
   274  	ctx := context.Background()
   275  	stmt, err := db.PrepareContext(ctx, sqlInfo)
   276  	if err != nil {
   277  		return
   278  	}
   279  	defer func() { _ = stmt.Close() }()
   280  
   281  	rows, err := stmt.Query(args...)
   282  	if err != nil {
   283  		return
   284  	}
   285  	defer func() { _ = rows.Close() }()
   286  
   287  	columns, err = rows.Columns()
   288  	if err != nil {
   289  		return
   290  	}
   291  	columnTypes, err = rows.ColumnTypes()
   292  	if err != nil {
   293  		return
   294  	}
   295  	for rows.Next() {
   296  		var values []interface{}
   297  		for range columnTypes {
   298  			values = append(values, new(interface{}))
   299  		}
   300  		err = rows.Scan(values...)
   301  		if err != nil {
   302  			return
   303  		}
   304  		item := make(map[string]interface{})
   305  		for index, data := range values {
   306  			item[columns[index]] = GetSqlValue(columnTypes[index], data)
   307  		}
   308  		list = append(list, item)
   309  	}
   310  
   311  	return
   312  }
   313  
   314  var (
   315  	structFieldMapCache  = map[reflect.Type]map[string]reflect.StructField{}
   316  	structColumnMapCache = map[reflect.Type]map[string]reflect.StructField{}
   317  	structMapLock        sync.Mutex
   318  )
   319  
   320  func getStructColumn(tOf reflect.Type) (structFieldMap map[string]reflect.StructField, structColumnMap map[string]reflect.StructField) {
   321  	structMapLock.Lock()
   322  	defer structMapLock.Unlock()
   323  	structFieldMap, ok := structFieldMapCache[tOf]
   324  	structColumnMap = structColumnMapCache[tOf]
   325  	if ok {
   326  		//fmt.Println("find from cache")
   327  		return
   328  	}
   329  	structFieldMap = map[string]reflect.StructField{}
   330  	structColumnMap = map[string]reflect.StructField{}
   331  	for i := 0; i < tOf.NumField(); i++ {
   332  		field := tOf.Field(i)
   333  		structFieldMap[field.Name] = field
   334  		str := field.Tag.Get("column")
   335  		if str != "" && str != "-" {
   336  			ss := strings.Split(str, ",")
   337  			structColumnMap[ss[0]] = field
   338  		} else {
   339  			str = field.Tag.Get("json")
   340  			if str != "" && str != "-" {
   341  				ss := strings.Split(str, ",")
   342  				structColumnMap[ss[0]] = field
   343  			}
   344  		}
   345  	}
   346  	structFieldMapCache[tOf] = structFieldMap
   347  	structColumnMapCache[tOf] = structColumnMap
   348  	return
   349  }
   350  func SetStructColumnValues(columnValueMap map[string]interface{}, strValue reflect.Value) {
   351  	if len(columnValueMap) == 0 {
   352  		return
   353  	}
   354  	tOf := strValue.Type()
   355  
   356  	_, structColumnMap := getStructColumn(tOf)
   357  
   358  	for columnName, columnValue := range columnValueMap {
   359  		field, find := structColumnMap[columnName]
   360  		if !find {
   361  			field, find = structColumnMap[columnName]
   362  		}
   363  		if !find {
   364  			continue
   365  		}
   366  		valueTypeOf := reflect.TypeOf(columnValue)
   367  		columnValueType := ""
   368  		fieldType := field.Type.String()
   369  		if valueTypeOf != nil {
   370  			columnValueType = valueTypeOf.String()
   371  		}
   372  		if columnValueType != fieldType {
   373  			switch fieldType {
   374  			case "string":
   375  				columnValue = dialect.GetStringValue(columnValue)
   376  				break
   377  			case "int8", "int16", "int32", "int64", "int":
   378  				str := dialect.GetStringValue(columnValue)
   379  				var num int64
   380  				if str != "" {
   381  					num, _ = dialect.StringToInt64(str)
   382  				}
   383  				if fieldType == "int8" {
   384  					columnValue = int8(num)
   385  				} else if fieldType == "int16" {
   386  					columnValue = int16(num)
   387  				} else if fieldType == "int32" {
   388  					columnValue = int32(num)
   389  				} else if fieldType == "int64" {
   390  					columnValue = num
   391  				} else if fieldType == "int" {
   392  					columnValue = int(num)
   393  				}
   394  				break
   395  			case "uint8", "uint16", "uint32", "uint64", "uint":
   396  				str := dialect.GetStringValue(columnValue)
   397  				var num uint64
   398  				if str != "" {
   399  					num, _ = dialect.StringToUint64(str)
   400  				}
   401  				if fieldType == "uint8" {
   402  					columnValue = uint8(num)
   403  				} else if fieldType == "uint16" {
   404  					columnValue = uint16(num)
   405  				} else if fieldType == "uint32" {
   406  					columnValue = uint32(num)
   407  				} else if fieldType == "uint64" {
   408  					columnValue = num
   409  				} else if fieldType == "uint" {
   410  					columnValue = uint(num)
   411  				}
   412  				break
   413  			case "float32", "float64":
   414  				str := dialect.GetStringValue(columnValue)
   415  				var num float64
   416  				if str != "" {
   417  					num, _ = strconv.ParseFloat(str, 64)
   418  				}
   419  				if fieldType == "float32" {
   420  					columnValue = float32(num)
   421  				} else if fieldType == "float64" {
   422  					columnValue = num
   423  				}
   424  				break
   425  			case "time.Time":
   426  				if columnValue == nil || columnValue == 0 {
   427  					columnValue = time.Time{}
   428  					break
   429  				}
   430  				valueOf := reflect.ValueOf(columnValue)
   431  				if valueOf.IsNil() || valueOf.IsZero() {
   432  					columnValue = time.Time{}
   433  				}
   434  				break
   435  			}
   436  		}
   437  
   438  		valueOf := reflect.ValueOf(columnValue)
   439  		strValue.FieldByName(field.Name).Set(valueOf)
   440  	}
   441  	return
   442  }
   443  
   444  func GetListStructType(list interface{}) reflect.Type {
   445  	vOf := reflect.ValueOf(list)
   446  	if vOf.Kind() == reflect.Ptr {
   447  		return GetListStructType(vOf.Elem().Interface())
   448  	}
   449  	tOf := reflect.TypeOf(list).Elem()
   450  	if tOf.Kind() == reflect.Ptr { //指针类型获取真正type需要调用Elem
   451  		tOf = tOf.Elem()
   452  	}
   453  	return tOf
   454  }
   455  
   456  func DoQueryCount(db *sql.DB, sqlInfo string, args []interface{}) (count int, err error) {
   457  	ctx := context.Background()
   458  
   459  	stmt, err := db.PrepareContext(ctx, sqlInfo)
   460  	if err != nil {
   461  		return
   462  	}
   463  	defer func() { _ = stmt.Close() }()
   464  
   465  	rows, err := stmt.Query(args...)
   466  	if err != nil {
   467  		return
   468  	}
   469  	defer func() { _ = rows.Close() }()
   470  	for rows.Next() {
   471  		err = rows.Scan(&count)
   472  		if err != nil {
   473  			return
   474  		}
   475  	}
   476  
   477  	return
   478  }
   479  
   480  func DoQueryPage(db *sql.DB, dia dialect.Dialect, sqlInfo string, args []interface{}, page *Page) (list []map[string]interface{}, err error) {
   481  	if page.PageSize < 1 {
   482  		page.PageSize = 1
   483  	}
   484  	if page.PageNo < 1 {
   485  		page.PageNo = 1
   486  	}
   487  	pageSize := page.PageSize
   488  	pageNo := page.PageNo
   489  
   490  	countSql, err := dialect.FormatCountSql(sqlInfo)
   491  	if err != nil {
   492  		return
   493  	}
   494  	page.TotalCount, err = DoQueryCount(db, countSql, args)
   495  	if err != nil {
   496  		return
   497  	}
   498  	page.TotalPage = (page.TotalCount + page.PageSize - 1) / page.PageSize
   499  	// 如果查询的页码 大于 总页码 则不查询
   500  	if pageNo > page.TotalPage {
   501  		return
   502  	}
   503  	pageSql := dia.PackPageSql(sqlInfo, pageSize, pageNo)
   504  
   505  	list, err = DoQuery(db, pageSql, args)
   506  	if err != nil {
   507  		return
   508  	}
   509  
   510  	return
   511  }
   512  
   513  func DoQueryPageStructs(db *sql.DB, dia dialect.Dialect, sqlInfo string, args []interface{}, page *Page, list interface{}) (err error) {
   514  	if page.PageSize < 1 {
   515  		page.PageSize = 1
   516  	}
   517  	if page.PageNo < 1 {
   518  		page.PageNo = 1
   519  	}
   520  	pageSize := page.PageSize
   521  	pageNo := page.PageNo
   522  
   523  	countSql, err := dialect.FormatCountSql(sqlInfo)
   524  	if err != nil {
   525  		return
   526  	}
   527  	page.TotalCount, err = DoQueryCount(db, countSql, args)
   528  	if err != nil {
   529  		return
   530  	}
   531  	page.TotalPage = (page.TotalCount + page.PageSize - 1) / page.PageSize
   532  	// 如果查询的页码 大于 总页码 则不查询
   533  	if pageNo > page.TotalPage {
   534  		return
   535  	}
   536  	pageSql := dia.PackPageSql(sqlInfo, pageSize, pageNo)
   537  
   538  	err = DoQueryStructs(db, pageSql, args, list)
   539  	if err != nil {
   540  		return
   541  	}
   542  
   543  	return
   544  }
   545  
   546  type Page struct {
   547  	PageSize   int `json:"pageSize"`
   548  	PageNo     int `json:"pageNo"`
   549  	TotalCount int `json:"totalCount"`
   550  	TotalPage  int `json:"totalPage"`
   551  }
   552  
   553  func NewPage() *Page {
   554  	return &Page{
   555  		PageSize: 1,
   556  		PageNo:   1,
   557  	}
   558  }