github.com/team-ide/go-dialect@v1.9.20/worker/exec.go (about) 1 package worker 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "github.com/team-ide/go-dialect/dialect" 9 "reflect" 10 "strconv" 11 "strings" 12 "sync" 13 "time" 14 ) 15 16 func DoExec(db *sql.DB, sqlInfo string, args []interface{}) (result sql.Result, err error) { 17 if len(sqlInfo) == 0 { 18 return 19 } 20 resultList, _, _, err := DoExecs(db, []string{sqlInfo}, [][]interface{}{args}) 21 if err != nil { 22 return 23 } 24 if len(resultList) > 0 { 25 result = resultList[0] 26 } 27 return 28 } 29 30 type prepareFunc func(ctx context.Context, query string) (*sql.Stmt, error) 31 32 func ExecByPrepare(prepare prepareFunc, ctx context.Context, sqlInfo string, sqlArgs ...interface{}) (result sql.Result, err error) { 33 stmt, err := prepare(ctx, sqlInfo) 34 if err != nil { 35 return 36 } 37 defer func() { _ = stmt.Close() }() 38 result, err = stmt.Exec(sqlArgs...) 39 return 40 } 41 42 func DoOwnerExecs(dia dialect.Dialect, db *sql.DB, ownerName string, sqlList []string, argsList [][]interface{}) (resultList []sql.Result, errSql string, errArgs []interface{}, err error) { 43 sqlListSize := len(sqlList) 44 if sqlListSize == 0 { 45 return 46 } 47 if len(argsList) == 0 { 48 argsList = make([][]interface{}, sqlListSize) 49 } 50 argsListSize := len(argsList) 51 if sqlListSize != argsListSize { 52 err = errors.New(fmt.Sprintf("sqlList size is [%d] but argsList size is [%d]", sqlListSize, argsListSize)) 53 return 54 } 55 ctx := context.Background() 56 57 tx, err := db.BeginTx(ctx, nil) 58 if err != nil { 59 return 60 } 61 defer func() { 62 if err != nil { 63 _ = tx.Rollback() 64 } else { 65 err = tx.Commit() 66 if err != nil && strings.Contains(err.Error(), "Not in transaction") { 67 err = nil 68 } 69 } 70 }() 71 72 if ownerName != "" { 73 switch dia.DialectType() { 74 case dialect.TypeMysql: 75 _, _ = ExecByPrepare(tx.PrepareContext, ctx, " USE "+ownerName) 76 break 77 case dialect.TypeOracle: 78 _, _ = ExecByPrepare(tx.PrepareContext, ctx, "ALTER SESSION SET CURRENT_SCHEMA="+ownerName) 79 break 80 //case dialect.TypeGBase: // GBase 在 linux使用 database语句将会导致程序奔溃 属于 GBase驱动 so 库 问题 81 // _, _ = tx.Exec("database " + ownerName) 82 // break 83 } 84 } 85 var result sql.Result 86 for i := 0; i < sqlListSize; i++ { 87 sqlInfo := sqlList[i] 88 args := argsList[i] 89 if strings.TrimSpace(sqlInfo) == "" { 90 continue 91 } 92 result, err = ExecByPrepare(tx.PrepareContext, ctx, sqlInfo, args...) 93 if err != nil { 94 errSql = sqlInfo 95 errArgs = args 96 return 97 } 98 resultList = append(resultList, result) 99 } 100 101 return 102 } 103 104 func DoExecs(db *sql.DB, sqlList []string, argsList [][]interface{}) (resultList []sql.Result, errSql string, errArgs []interface{}, err error) { 105 sqlListSize := len(sqlList) 106 if sqlListSize == 0 { 107 return 108 } 109 if len(argsList) == 0 { 110 argsList = make([][]interface{}, sqlListSize) 111 } 112 argsListSize := len(argsList) 113 if sqlListSize != argsListSize { 114 err = errors.New(fmt.Sprintf("sqlList size is [%d] but argsList size is [%d]", sqlListSize, argsListSize)) 115 return 116 } 117 ctx := context.Background() 118 119 tx, err := db.BeginTx(ctx, nil) 120 if err != nil { 121 return 122 } 123 defer func() { 124 if err != nil { 125 _ = tx.Rollback() 126 } else { 127 err = tx.Commit() 128 if err != nil && strings.Contains(err.Error(), "Not in transaction") { 129 err = nil 130 } 131 } 132 }() 133 var result sql.Result 134 for i := 0; i < sqlListSize; i++ { 135 sqlInfo := sqlList[i] 136 args := argsList[i] 137 if strings.TrimSpace(sqlInfo) == "" { 138 continue 139 } 140 result, err = ExecByPrepare(tx.PrepareContext, ctx, sqlInfo, args...) 141 if err != nil { 142 errSql = sqlInfo 143 errArgs = args 144 return 145 } 146 resultList = append(resultList, result) 147 } 148 149 return 150 } 151 152 func DoQuery(db *sql.DB, sqlInfo string, args []interface{}) (list []map[string]interface{}, err error) { 153 _, _, list, err = DoQueryWithColumnTypes(db, sqlInfo, args) 154 if err != nil { 155 return 156 } 157 return 158 } 159 160 func DoQueryOne(db *sql.DB, sqlInfo string, args []interface{}) (data map[string]interface{}, err error) { 161 _, _, list, err := DoQueryWithColumnTypes(db, sqlInfo, args) 162 if err != nil { 163 return 164 } 165 if len(list) > 0 { 166 data = list[0] 167 if len(list) > 1 { 168 err = errors.New("has more rows by query one") 169 return 170 } 171 } 172 return 173 } 174 175 func DoQueryStructs(db *sql.DB, sqlInfo string, args []interface{}, list interface{}) (err error) { 176 ctx := context.Background() 177 178 stmt, err := db.PrepareContext(ctx, sqlInfo) 179 if err != nil { 180 return 181 } 182 defer func() { _ = stmt.Close() }() 183 184 rows, err := stmt.Query(args...) 185 if err != nil { 186 return 187 } 188 defer func() { _ = rows.Close() }() 189 columnTypes, err := rows.ColumnTypes() 190 if err != nil { 191 return 192 } 193 listVOf := reflect.ValueOf(list).Elem() 194 listStrType := GetListStructType(list) 195 for rows.Next() { 196 var values []interface{} 197 for range columnTypes { 198 values = append(values, new(interface{})) 199 } 200 err = rows.Scan(values...) 201 if err != nil { 202 return 203 } 204 205 item := make(map[string]interface{}) 206 for index, data := range values { 207 item[columnTypes[index].Name()] = GetSqlValue(columnTypes[index], data) 208 } 209 listStrValue := reflect.New(listStrType) 210 SetStructColumnValues(item, listStrValue.Elem()) 211 listVOf = reflect.Append(listVOf, listStrValue) 212 } 213 reflect.ValueOf(list).Elem().Set(listVOf) 214 return 215 } 216 217 func DoQueryStruct(db *sql.DB, sqlInfo string, args []interface{}, str interface{}) (find bool, err error) { 218 ctx := context.Background() 219 stmt, err := db.PrepareContext(ctx, sqlInfo) 220 if err != nil { 221 return 222 } 223 defer func() { _ = stmt.Close() }() 224 225 rows, err := stmt.Query(args...) 226 if err != nil { 227 return 228 } 229 defer func() { _ = rows.Close() }() 230 231 columnTypes, err := rows.ColumnTypes() 232 if err != nil { 233 return 234 } 235 strVOf := reflect.ValueOf(str) 236 237 var isBase bool 238 switch str.(type) { 239 case *int, *int8, *int16, *int32, *int64, *float32, *float64: 240 isBase = true 241 break 242 } 243 for rows.Next() { 244 if find { 245 err = errors.New("has more rows by query one") 246 return 247 } 248 find = true 249 var values []interface{} 250 if isBase { 251 values = []interface{}{str} 252 } else { 253 for range columnTypes { 254 values = append(values, new(interface{})) 255 } 256 } 257 err = rows.Scan(values...) 258 if err != nil { 259 return 260 } 261 if isBase { 262 continue 263 } 264 item := make(map[string]interface{}) 265 for index, data := range values { 266 item[columnTypes[index].Name()] = GetSqlValue(columnTypes[index], data) 267 } 268 SetStructColumnValues(item, strVOf.Elem()) 269 } 270 return 271 } 272 func DoQueryWithColumnTypes(db *sql.DB, sqlInfo string, args []interface{}) (columns []string, columnTypes []*sql.ColumnType, list []map[string]interface{}, err error) { 273 274 ctx := context.Background() 275 stmt, err := db.PrepareContext(ctx, sqlInfo) 276 if err != nil { 277 return 278 } 279 defer func() { _ = stmt.Close() }() 280 281 rows, err := stmt.Query(args...) 282 if err != nil { 283 return 284 } 285 defer func() { _ = rows.Close() }() 286 287 columns, err = rows.Columns() 288 if err != nil { 289 return 290 } 291 columnTypes, err = rows.ColumnTypes() 292 if err != nil { 293 return 294 } 295 for rows.Next() { 296 var values []interface{} 297 for range columnTypes { 298 values = append(values, new(interface{})) 299 } 300 err = rows.Scan(values...) 301 if err != nil { 302 return 303 } 304 item := make(map[string]interface{}) 305 for index, data := range values { 306 item[columns[index]] = GetSqlValue(columnTypes[index], data) 307 } 308 list = append(list, item) 309 } 310 311 return 312 } 313 314 var ( 315 structFieldMapCache = map[reflect.Type]map[string]reflect.StructField{} 316 structColumnMapCache = map[reflect.Type]map[string]reflect.StructField{} 317 structMapLock sync.Mutex 318 ) 319 320 func getStructColumn(tOf reflect.Type) (structFieldMap map[string]reflect.StructField, structColumnMap map[string]reflect.StructField) { 321 structMapLock.Lock() 322 defer structMapLock.Unlock() 323 structFieldMap, ok := structFieldMapCache[tOf] 324 structColumnMap = structColumnMapCache[tOf] 325 if ok { 326 //fmt.Println("find from cache") 327 return 328 } 329 structFieldMap = map[string]reflect.StructField{} 330 structColumnMap = map[string]reflect.StructField{} 331 for i := 0; i < tOf.NumField(); i++ { 332 field := tOf.Field(i) 333 structFieldMap[field.Name] = field 334 str := field.Tag.Get("column") 335 if str != "" && str != "-" { 336 ss := strings.Split(str, ",") 337 structColumnMap[ss[0]] = field 338 } else { 339 str = field.Tag.Get("json") 340 if str != "" && str != "-" { 341 ss := strings.Split(str, ",") 342 structColumnMap[ss[0]] = field 343 } 344 } 345 } 346 structFieldMapCache[tOf] = structFieldMap 347 structColumnMapCache[tOf] = structColumnMap 348 return 349 } 350 func SetStructColumnValues(columnValueMap map[string]interface{}, strValue reflect.Value) { 351 if len(columnValueMap) == 0 { 352 return 353 } 354 tOf := strValue.Type() 355 356 _, structColumnMap := getStructColumn(tOf) 357 358 for columnName, columnValue := range columnValueMap { 359 field, find := structColumnMap[columnName] 360 if !find { 361 field, find = structColumnMap[columnName] 362 } 363 if !find { 364 continue 365 } 366 valueTypeOf := reflect.TypeOf(columnValue) 367 columnValueType := "" 368 fieldType := field.Type.String() 369 if valueTypeOf != nil { 370 columnValueType = valueTypeOf.String() 371 } 372 if columnValueType != fieldType { 373 switch fieldType { 374 case "string": 375 columnValue = dialect.GetStringValue(columnValue) 376 break 377 case "int8", "int16", "int32", "int64", "int": 378 str := dialect.GetStringValue(columnValue) 379 var num int64 380 if str != "" { 381 num, _ = dialect.StringToInt64(str) 382 } 383 if fieldType == "int8" { 384 columnValue = int8(num) 385 } else if fieldType == "int16" { 386 columnValue = int16(num) 387 } else if fieldType == "int32" { 388 columnValue = int32(num) 389 } else if fieldType == "int64" { 390 columnValue = num 391 } else if fieldType == "int" { 392 columnValue = int(num) 393 } 394 break 395 case "uint8", "uint16", "uint32", "uint64", "uint": 396 str := dialect.GetStringValue(columnValue) 397 var num uint64 398 if str != "" { 399 num, _ = dialect.StringToUint64(str) 400 } 401 if fieldType == "uint8" { 402 columnValue = uint8(num) 403 } else if fieldType == "uint16" { 404 columnValue = uint16(num) 405 } else if fieldType == "uint32" { 406 columnValue = uint32(num) 407 } else if fieldType == "uint64" { 408 columnValue = num 409 } else if fieldType == "uint" { 410 columnValue = uint(num) 411 } 412 break 413 case "float32", "float64": 414 str := dialect.GetStringValue(columnValue) 415 var num float64 416 if str != "" { 417 num, _ = strconv.ParseFloat(str, 64) 418 } 419 if fieldType == "float32" { 420 columnValue = float32(num) 421 } else if fieldType == "float64" { 422 columnValue = num 423 } 424 break 425 case "time.Time": 426 if columnValue == nil || columnValue == 0 { 427 columnValue = time.Time{} 428 break 429 } 430 valueOf := reflect.ValueOf(columnValue) 431 if valueOf.IsNil() || valueOf.IsZero() { 432 columnValue = time.Time{} 433 } 434 break 435 } 436 } 437 438 valueOf := reflect.ValueOf(columnValue) 439 strValue.FieldByName(field.Name).Set(valueOf) 440 } 441 return 442 } 443 444 func GetListStructType(list interface{}) reflect.Type { 445 vOf := reflect.ValueOf(list) 446 if vOf.Kind() == reflect.Ptr { 447 return GetListStructType(vOf.Elem().Interface()) 448 } 449 tOf := reflect.TypeOf(list).Elem() 450 if tOf.Kind() == reflect.Ptr { //指针类型获取真正type需要调用Elem 451 tOf = tOf.Elem() 452 } 453 return tOf 454 } 455 456 func DoQueryCount(db *sql.DB, sqlInfo string, args []interface{}) (count int, err error) { 457 ctx := context.Background() 458 459 stmt, err := db.PrepareContext(ctx, sqlInfo) 460 if err != nil { 461 return 462 } 463 defer func() { _ = stmt.Close() }() 464 465 rows, err := stmt.Query(args...) 466 if err != nil { 467 return 468 } 469 defer func() { _ = rows.Close() }() 470 for rows.Next() { 471 err = rows.Scan(&count) 472 if err != nil { 473 return 474 } 475 } 476 477 return 478 } 479 480 func DoQueryPage(db *sql.DB, dia dialect.Dialect, sqlInfo string, args []interface{}, page *Page) (list []map[string]interface{}, err error) { 481 if page.PageSize < 1 { 482 page.PageSize = 1 483 } 484 if page.PageNo < 1 { 485 page.PageNo = 1 486 } 487 pageSize := page.PageSize 488 pageNo := page.PageNo 489 490 countSql, err := dialect.FormatCountSql(sqlInfo) 491 if err != nil { 492 return 493 } 494 page.TotalCount, err = DoQueryCount(db, countSql, args) 495 if err != nil { 496 return 497 } 498 page.TotalPage = (page.TotalCount + page.PageSize - 1) / page.PageSize 499 // 如果查询的页码 大于 总页码 则不查询 500 if pageNo > page.TotalPage { 501 return 502 } 503 pageSql := dia.PackPageSql(sqlInfo, pageSize, pageNo) 504 505 list, err = DoQuery(db, pageSql, args) 506 if err != nil { 507 return 508 } 509 510 return 511 } 512 513 func DoQueryPageStructs(db *sql.DB, dia dialect.Dialect, sqlInfo string, args []interface{}, page *Page, list interface{}) (err error) { 514 if page.PageSize < 1 { 515 page.PageSize = 1 516 } 517 if page.PageNo < 1 { 518 page.PageNo = 1 519 } 520 pageSize := page.PageSize 521 pageNo := page.PageNo 522 523 countSql, err := dialect.FormatCountSql(sqlInfo) 524 if err != nil { 525 return 526 } 527 page.TotalCount, err = DoQueryCount(db, countSql, args) 528 if err != nil { 529 return 530 } 531 page.TotalPage = (page.TotalCount + page.PageSize - 1) / page.PageSize 532 // 如果查询的页码 大于 总页码 则不查询 533 if pageNo > page.TotalPage { 534 return 535 } 536 pageSql := dia.PackPageSql(sqlInfo, pageSize, pageNo) 537 538 err = DoQueryStructs(db, pageSql, args, list) 539 if err != nil { 540 return 541 } 542 543 return 544 } 545 546 type Page struct { 547 PageSize int `json:"pageSize"` 548 PageNo int `json:"pageNo"` 549 TotalCount int `json:"totalCount"` 550 TotalPage int `json:"totalPage"` 551 } 552 553 func NewPage() *Page { 554 return &Page{ 555 PageSize: 1, 556 PageNo: 1, 557 } 558 }