github.com/kotovmak/go-admin@v1.1.1/modules/db/statement.go (about) 1 // Copyright 2019 GoAdmin Core Team. All rights reserved. 2 // Use of this source code is governed by a Apache-2.0 style 3 // license that can be found in the LICENSE file. 4 5 package db 6 7 import ( 8 dbsql "database/sql" 9 "errors" 10 "regexp" 11 "strconv" 12 "strings" 13 "sync" 14 15 "github.com/kotovmak/go-admin/modules/db/dialect" 16 "github.com/kotovmak/go-admin/modules/logger" 17 ) 18 19 // SQL wraps the Connection and driver dialect methods. 20 type SQL struct { 21 dialect.SQLComponent 22 diver Connection 23 dialect dialect.Dialect 24 conn string 25 tx *dbsql.Tx 26 } 27 28 // SQLPool is a object pool of SQL. 29 var SQLPool = sync.Pool{ 30 New: func() interface{} { 31 return &SQL{ 32 SQLComponent: dialect.SQLComponent{ 33 Fields: make([]string, 0), 34 TableName: "", 35 Args: make([]interface{}, 0), 36 Wheres: make([]dialect.Where, 0), 37 Leftjoins: make([]dialect.Join, 0), 38 UpdateRaws: make([]dialect.RawUpdate, 0), 39 WhereRaws: "", 40 Order: "", 41 Group: "", 42 Limit: "", 43 }, 44 diver: nil, 45 dialect: nil, 46 } 47 }, 48 } 49 50 // H is a shorthand of map. 51 type H map[string]interface{} 52 53 // newSQL get a new SQL from SQLPool. 54 func newSQL() *SQL { 55 return SQLPool.Get().(*SQL) 56 } 57 58 // ******************************* 59 // process method 60 // ******************************* 61 62 // TableName return a SQL with given table and default connection. 63 func Table(table string) *SQL { 64 sql := newSQL() 65 sql.TableName = table 66 sql.conn = "default" 67 return sql 68 } 69 70 // WithDriver return a SQL with given driver. 71 func WithDriver(conn Connection) *SQL { 72 sql := newSQL() 73 sql.diver = conn 74 sql.dialect = dialect.GetDialectByDriver(conn.Name()) 75 sql.conn = "default" 76 return sql 77 } 78 79 // WithDriverAndConnection return a SQL with given driver and connection name. 80 func WithDriverAndConnection(connName string, conn Connection) *SQL { 81 sql := newSQL() 82 sql.diver = conn 83 sql.dialect = dialect.GetDialectByDriver(conn.Name()) 84 sql.conn = connName 85 return sql 86 } 87 88 // WithDriver return a SQL with given driver. 89 func (sql *SQL) WithDriver(conn Connection) *SQL { 90 sql.diver = conn 91 sql.dialect = dialect.GetDialectByDriver(conn.Name()) 92 return sql 93 } 94 95 // WithConnection set the connection name of SQL. 96 func (sql *SQL) WithConnection(conn string) *SQL { 97 sql.conn = conn 98 return sql 99 } 100 101 // WithTx set the database transaction object of SQL. 102 func (sql *SQL) WithTx(tx *dbsql.Tx) *SQL { 103 sql.tx = tx 104 return sql 105 } 106 107 // TableName set table of SQL. 108 func (sql *SQL) Table(table string) *SQL { 109 sql.clean() 110 sql.TableName = table 111 return sql 112 } 113 114 // Select set select fields. 115 func (sql *SQL) Select(fields ...string) *SQL { 116 sql.Fields = fields 117 sql.Functions = make([]string, len(fields)) 118 reg, _ := regexp.Compile(`(.*?)\((.*?)\)`) 119 for k, field := range fields { 120 res := reg.FindAllStringSubmatch(field, -1) 121 if len(res) > 0 && len(res[0]) > 2 { 122 sql.Functions[k] = res[0][1] 123 sql.Fields[k] = res[0][2] 124 } 125 } 126 return sql 127 } 128 129 // OrderBy set order fields. 130 func (sql *SQL) OrderBy(fields ...string) *SQL { 131 if len(fields) == 0 { 132 panic("wrong order field") 133 } 134 for i := 0; i < len(fields); i++ { 135 if i == len(fields)-2 { 136 sql.Order += " " + sql.wrap(fields[i]) + " " + fields[i+1] 137 return sql 138 } 139 sql.Order += " " + sql.wrap(fields[i]) + " and " 140 } 141 return sql 142 } 143 144 // OrderByRaw set order by. 145 func (sql *SQL) OrderByRaw(order string) *SQL { 146 if order != "" { 147 sql.Order += " " + order 148 } 149 return sql 150 } 151 152 func (sql *SQL) GroupBy(fields ...string) *SQL { 153 if len(fields) == 0 { 154 panic("wrong group by field") 155 } 156 for i := 0; i < len(fields); i++ { 157 if i == len(fields)-1 { 158 sql.Group += " " + sql.wrap(fields[i]) 159 } else { 160 sql.Group += " " + sql.wrap(fields[i]) + "," 161 } 162 } 163 return sql 164 } 165 166 // GroupByRaw set group by. 167 func (sql *SQL) GroupByRaw(group string) *SQL { 168 if group != "" { 169 sql.Group += " " + group 170 } 171 return sql 172 } 173 174 // Skip set offset value. 175 func (sql *SQL) Skip(offset int) *SQL { 176 sql.Offset = strconv.Itoa(offset) 177 return sql 178 } 179 180 // Take set limit value. 181 func (sql *SQL) Take(take int) *SQL { 182 sql.Limit = strconv.Itoa(take) 183 return sql 184 } 185 186 // Where add the where operation and argument value. 187 func (sql *SQL) Where(field string, operation string, arg interface{}) *SQL { 188 sql.Wheres = append(sql.Wheres, dialect.Where{ 189 Field: field, 190 Operation: operation, 191 Qmark: "?", 192 }) 193 sql.Args = append(sql.Args, arg) 194 return sql 195 } 196 197 // WhereIn add the where operation of "in" and argument values. 198 func (sql *SQL) WhereIn(field string, arg []interface{}) *SQL { 199 if len(arg) == 0 { 200 panic("wrong parameter") 201 } 202 sql.Wheres = append(sql.Wheres, dialect.Where{ 203 Field: field, 204 Operation: "in", 205 Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", 206 }) 207 sql.Args = append(sql.Args, arg...) 208 return sql 209 } 210 211 // WhereNotIn add the where operation of "not in" and argument values. 212 func (sql *SQL) WhereNotIn(field string, arg []interface{}) *SQL { 213 if len(arg) == 0 { 214 panic("wrong parameter") 215 } 216 sql.Wheres = append(sql.Wheres, dialect.Where{ 217 Field: field, 218 Operation: "not in", 219 Qmark: "(" + strings.Repeat("?,", len(arg)-1) + "?)", 220 }) 221 sql.Args = append(sql.Args, arg...) 222 return sql 223 } 224 225 // Find query the sql result with given id assuming that primary key name is "id". 226 func (sql *SQL) Find(arg interface{}) (map[string]interface{}, error) { 227 return sql.Where("id", "=", arg).First() 228 } 229 230 // Count query the count of query results. 231 func (sql *SQL) Count() (int64, error) { 232 var ( 233 res map[string]interface{} 234 err error 235 driver = sql.diver.Name() 236 ) 237 238 if res, err = sql.Select("count(*)").First(); err != nil { 239 return 0, err 240 } 241 242 if driver == DriverPostgresql { 243 return res["count"].(int64), nil 244 } else if driver == DriverMssql { 245 return res[""].(int64), nil 246 } 247 248 return res["count(*)"].(int64), nil 249 } 250 251 // Sum sum the value of given field. 252 func (sql *SQL) Sum(field string) (float64, error) { 253 var ( 254 res map[string]interface{} 255 err error 256 key = "sum(" + sql.wrap(field) + ")" 257 ) 258 if res, err = sql.Select("sum(" + field + ")").First(); err != nil { 259 return 0, err 260 } 261 262 if res == nil { 263 return 0, nil 264 } 265 266 if r, ok := res[key].(float64); ok { 267 return r, nil 268 } else if r, ok := res[key].([]uint8); ok { 269 return strconv.ParseFloat(string(r), 64) 270 } else { 271 return 0, nil 272 } 273 } 274 275 // Max find the maximal value of given field. 276 func (sql *SQL) Max(field string) (interface{}, error) { 277 var ( 278 res map[string]interface{} 279 err error 280 key = "max(" + sql.wrap(field) + ")" 281 ) 282 if res, err = sql.Select("max(" + field + ")").First(); err != nil { 283 return 0, err 284 } 285 286 if res == nil { 287 return 0, nil 288 } 289 290 return res[key], nil 291 } 292 293 // Min find the minimal value of given field. 294 func (sql *SQL) Min(field string) (interface{}, error) { 295 var ( 296 res map[string]interface{} 297 err error 298 key = "min(" + sql.wrap(field) + ")" 299 ) 300 if res, err = sql.Select("min(" + field + ")").First(); err != nil { 301 return 0, err 302 } 303 304 if res == nil { 305 return 0, nil 306 } 307 308 return res[key], nil 309 } 310 311 // Avg find the average value of given field. 312 func (sql *SQL) Avg(field string) (interface{}, error) { 313 var ( 314 res map[string]interface{} 315 err error 316 key = "avg(" + sql.wrap(field) + ")" 317 ) 318 if res, err = sql.Select("avg(" + field + ")").First(); err != nil { 319 return 0, err 320 } 321 322 if res == nil { 323 return 0, nil 324 } 325 326 return res[key], nil 327 } 328 329 // WhereRaw set WhereRaws and arguments. 330 func (sql *SQL) WhereRaw(raw string, args ...interface{}) *SQL { 331 sql.WhereRaws = raw 332 sql.Args = append(sql.Args, args...) 333 return sql 334 } 335 336 // UpdateRaw set UpdateRaw. 337 func (sql *SQL) UpdateRaw(raw string, args ...interface{}) *SQL { 338 sql.UpdateRaws = append(sql.UpdateRaws, dialect.RawUpdate{ 339 Expression: raw, 340 Args: args, 341 }) 342 return sql 343 } 344 345 // LeftJoin add a left join info. 346 func (sql *SQL) LeftJoin(table string, fieldA string, operation string, fieldB string) *SQL { 347 sql.Leftjoins = append(sql.Leftjoins, dialect.Join{ 348 FieldA: fieldA, 349 FieldB: fieldB, 350 Table: table, 351 Operation: operation, 352 }) 353 return sql 354 } 355 356 // ******************************* 357 // Transaction method 358 // ******************************* 359 360 // TxFn is the transaction callback function. 361 type TxFn func(tx *dbsql.Tx) (error, map[string]interface{}) 362 363 // WithTransaction call the callback function within the transaction and 364 // catch the error. 365 func (sql *SQL) WithTransaction(fn TxFn) (res map[string]interface{}, err error) { 366 367 tx := sql.diver.BeginTxAndConnection(sql.conn) 368 369 defer func() { 370 if p := recover(); p != nil { 371 // a panic occurred, rollback and repanic 372 _ = tx.Rollback() 373 panic(p) 374 } else if err != nil { 375 // something went wrong, rollback 376 _ = tx.Rollback() 377 } else { 378 // all good, commit 379 err = tx.Commit() 380 } 381 }() 382 383 err, res = fn(tx) 384 return 385 } 386 387 // WithTransactionByLevel call the callback function within the transaction 388 // of given transaction level and catch the error. 389 func (sql *SQL) WithTransactionByLevel(level dbsql.IsolationLevel, fn TxFn) (res map[string]interface{}, err error) { 390 391 tx := sql.diver.BeginTxWithLevelAndConnection(sql.conn, level) 392 393 defer func() { 394 if p := recover(); p != nil { 395 // a panic occurred, rollback and repanic 396 _ = tx.Rollback() 397 panic(p) 398 } else if err != nil { 399 // something went wrong, rollback 400 _ = tx.Rollback() 401 } else { 402 // all good, commit 403 err = tx.Commit() 404 } 405 }() 406 407 err, res = fn(tx) 408 return 409 } 410 411 // ******************************* 412 // terminal method 413 // ------------------------------- 414 // sql args order: 415 // update ... => where ... 416 // ******************************* 417 418 // First query the result and return the first row. 419 func (sql *SQL) First() (map[string]interface{}, error) { 420 defer RecycleSQL(sql) 421 422 sql.dialect.Select(&sql.SQLComponent) 423 424 res, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 425 426 if err != nil { 427 return nil, err 428 } 429 430 if len(res) < 1 { 431 return nil, errors.New("out of index") 432 } 433 return res[0], nil 434 } 435 436 // All query all the result and return. 437 func (sql *SQL) All() ([]map[string]interface{}, error) { 438 defer RecycleSQL(sql) 439 440 sql.dialect.Select(&sql.SQLComponent) 441 442 return sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 443 } 444 445 // ShowColumns show columns info. 446 func (sql *SQL) ShowColumns() ([]map[string]interface{}, error) { 447 defer RecycleSQL(sql) 448 449 return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumns(sql.TableName)) 450 } 451 452 // ShowTables show table info. 453 func (sql *SQL) ShowTables() ([]string, error) { 454 defer RecycleSQL(sql) 455 456 models, err := sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowTables()) 457 458 if err != nil { 459 return []string{}, err 460 } 461 462 tables := make([]string, 0) 463 if len(models) == 0 { 464 return tables, nil 465 } 466 467 key := "Tables_in_" + sql.TableName 468 if sql.diver.Name() == DriverPostgresql || sql.diver.Name() == DriverSqlite { 469 key = "tablename" 470 } else if sql.diver.Name() == DriverMssql { 471 key = "TABLE_NAME" 472 } else if _, ok := models[0][key].(string); !ok { 473 key = "Tables_in_" + strings.ToLower(sql.TableName) 474 } 475 476 for i := 0; i < len(models); i++ { 477 // skip sqlite system tables 478 if sql.diver.Name() == DriverSqlite && models[i][key].(string) == "sqlite_sequence" { 479 continue 480 } 481 482 tables = append(tables, models[i][key].(string)) 483 } 484 485 return tables, nil 486 } 487 488 // Update exec the update method of given key/value pairs. 489 func (sql *SQL) Update(values dialect.H) (int64, error) { 490 defer RecycleSQL(sql) 491 492 sql.Values = values 493 494 sql.dialect.Update(&sql.SQLComponent) 495 496 res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 497 498 if err != nil { 499 return 0, err 500 } 501 502 if affectRow, _ := res.RowsAffected(); affectRow < 1 { 503 return 0, errors.New("no affect row") 504 } 505 506 return res.LastInsertId() 507 } 508 509 // Delete exec the delete method. 510 func (sql *SQL) Delete() error { 511 defer RecycleSQL(sql) 512 513 sql.dialect.Delete(&sql.SQLComponent) 514 515 res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 516 517 if err != nil { 518 return err 519 } 520 521 if affectRow, _ := res.RowsAffected(); affectRow < 1 { 522 return errors.New("no affect row") 523 } 524 525 return nil 526 } 527 528 // Exec exec the exec method. 529 func (sql *SQL) Exec() (int64, error) { 530 defer RecycleSQL(sql) 531 532 sql.dialect.Update(&sql.SQLComponent) 533 534 res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 535 536 if err != nil { 537 return 0, err 538 } 539 540 if affectRow, _ := res.RowsAffected(); affectRow < 1 { 541 return 0, errors.New("no affect row") 542 } 543 544 return res.LastInsertId() 545 } 546 547 const postgresInsertCheckTableName = "goadmin_menu|goadmin_permissions|goadmin_roles|goadmin_users" 548 549 // Insert exec the insert method of given key/value pairs. 550 func (sql *SQL) Insert(values dialect.H) (int64, error) { 551 defer RecycleSQL(sql) 552 553 sql.Values = values 554 555 sql.dialect.Insert(&sql.SQLComponent) 556 557 if sql.diver.Name() == DriverPostgresql && (strings.Contains(postgresInsertCheckTableName, sql.TableName)) { 558 559 resMap, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement+" RETURNING id", sql.Args...) 560 561 if err != nil { 562 563 // Fixed java h2 database postgresql mode 564 _, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 565 566 if err != nil { 567 return 0, err 568 } 569 570 res, err := sql.diver.QueryWithConnection(sql.conn, `SELECT max("id") as "id" FROM "`+sql.TableName+`"`) 571 572 if err != nil { 573 return 0, err 574 } 575 576 if len(res) != 0 { 577 return res[0]["id"].(int64), nil 578 } 579 580 return 0, err 581 } 582 583 if len(resMap) == 0 { 584 return 0, errors.New("no affect row") 585 } 586 587 return resMap[0]["id"].(int64), nil 588 } 589 590 res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...) 591 592 if err != nil { 593 return 0, err 594 } 595 596 if affectRow, _ := res.RowsAffected(); affectRow < 1 { 597 return 0, errors.New("no affect row") 598 } 599 600 return res.LastInsertId() 601 } 602 603 func (sql *SQL) wrap(field string) string { 604 return sql.diver.GetDelimiter() + field + sql.diver.GetDelimiter2() 605 } 606 607 func (sql *SQL) clean() { 608 sql.Functions = make([]string, 0) 609 sql.Group = "" 610 sql.Values = make(map[string]interface{}) 611 sql.Fields = make([]string, 0) 612 sql.TableName = "" 613 sql.Wheres = make([]dialect.Where, 0) 614 sql.Leftjoins = make([]dialect.Join, 0) 615 sql.Args = make([]interface{}, 0) 616 sql.Order = "" 617 sql.Offset = "" 618 sql.Limit = "" 619 sql.WhereRaws = "" 620 sql.UpdateRaws = make([]dialect.RawUpdate, 0) 621 sql.Statement = "" 622 } 623 624 // RecycleSQL clear the SQL and put into the pool. 625 func RecycleSQL(sql *SQL) { 626 627 logger.LogSQL(sql.Statement, sql.Args) 628 629 sql.clean() 630 631 sql.conn = "" 632 sql.diver = nil 633 sql.tx = nil 634 sql.dialect = nil 635 636 SQLPool.Put(sql) 637 }