github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/database/database.go (about)

     1  package database
     2  
     3  import (
     4  	"database/sql"
     5  	"log"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/isyscore/isc-gobase/isc"
    10  )
    11  
    12  type DatabaseType int
    13  
    14  const (
    15  	MySQL      DatabaseType = iota // import _ "github.com/go-sql-driver/mysql"
    16  	Oracle                         // import _ "github.com/mattn/go-oci8"
    17  	SqlServer                      // import _ "github.com/denisenkom/go-mssqldb"
    18  	PostgreSql                     // import _ "github.com/lib/pq"
    19  	Sqlite3                        // import _ "github.com/mattn/go-sqlite3"
    20  )
    21  
    22  const (
    23  	//CONNECTION_STRING user:password@tcp(host:port)/databaseName
    24  	CONNECTION_STRING = "%s:%s@tcp(%s:%d)/%s"
    25  )
    26  
    27  func Connect(dbType DatabaseType, connStr string) *sql.DB {
    28  	return CustomConnect(dbTypeToString(dbType), connStr)
    29  }
    30  
    31  func CustomConnect(dbType string, connStr string) *sql.DB {
    32  	// charset=utf8
    33  	// parseTime=true
    34  	innerParam := connStr
    35  	if strings.Contains(connStr, "?") {
    36  		// 已有参数
    37  		if !strings.Contains(connStr, "charset=utf8") {
    38  			innerParam += "&charset=utf8"
    39  		}
    40  		if !strings.Contains(connStr, "parseTime=true") {
    41  			innerParam += "&parseTime=true"
    42  		}
    43  	} else {
    44  		// 没有参数
    45  		innerParam += "?charset=utf8&parseTime=true"
    46  	}
    47  
    48  	db, err := sql.Open(dbType, innerParam)
    49  	if err != nil {
    50  		log.Printf("初始化数据库失败(%v)\n", err)
    51  		return nil
    52  	}
    53  	return db
    54  }
    55  
    56  func dbTypeToString(dbType DatabaseType) string {
    57  	switch dbType {
    58  	case MySQL:
    59  		return "mysql"
    60  	case Oracle:
    61  		return "oci8"
    62  	case SqlServer:
    63  		return "mssql"
    64  	case PostgreSql:
    65  		return "postgres"
    66  	case Sqlite3:
    67  		return "sqlite3"
    68  	default:
    69  		log.Printf("不支持的数据库类型\n")
    70  		return ""
    71  	}
    72  }
    73  
    74  func Insert(db *sql.DB, sql string, args ...any) (int64, error) {
    75  	var id int64
    76  	var err error
    77  	if strings.Contains(sql, " RETURNING ") {
    78  		row := db.QueryRow(sql, args...)
    79  		err = row.Scan(&id)
    80  	} else {
    81  		result, err1 := db.Exec(sql, args...)
    82  		err = err1
    83  		if err1 == nil {
    84  			id, _ = result.LastInsertId()
    85  		}
    86  	}
    87  	return id, err
    88  }
    89  
    90  func Update(db *sql.DB, sql string, args ...any) (int64, error) {
    91  	var n int64
    92  	var err error
    93  	result, err := db.Exec(sql, args...)
    94  	if err == nil {
    95  		n, _ = result.RowsAffected()
    96  	}
    97  	return n, err
    98  }
    99  
   100  func Delete(db *sql.DB, sql string, args ...any) (int64, error) {
   101  	return Update(db, sql, args...)
   102  }
   103  
   104  func Query(db *sql.DB, sql string, args ...any) ([]map[string]string, error) {
   105  	rows, err := db.Query(sql, args...)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	return fetchRows(rows, err)
   110  }
   111  
   112  func QueryRow(db *sql.DB, sql string, args ...any) (map[string]string, error) {
   113  	rows, err := Query(db, sql, args...)
   114  	if rows != nil && err == nil && len(rows) > 0 {
   115  		return rows[0], err
   116  	}
   117  	return nil, err
   118  }
   119  
   120  func QueryScalar(db *sql.DB, sql string, key string, args ...any) (string, error) {
   121  	rows, err := Query(db, sql, args...)
   122  	if rows != nil && err == nil && len(rows) > 0 {
   123  		row := rows[0]
   124  		if value, ok := row[key]; ok {
   125  			return value, err
   126  		}
   127  	}
   128  	return "", err
   129  }
   130  
   131  // stmt 缓存
   132  var stmtList = make(map[string]*sql.Stmt)
   133  
   134  func PrepareSql(db *sql.DB, name, sql string) (*sql.Stmt, error) {
   135  	stmt, bl := stmtList[name]
   136  	if !bl {
   137  		var err error
   138  		stmt, err = db.Prepare(sql)
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		stmtList[name] = stmt
   143  	}
   144  	return stmt, nil
   145  }
   146  
   147  func PrepareQuery(db *sql.DB, name, sql string, args ...any) ([]map[string]string, error) {
   148  	stmt, err := PrepareSql(db, name, sql)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	rows, err1 := stmt.Query(args...)
   153  	return fetchRows(rows, err1)
   154  }
   155  
   156  func PrepareQueryRow(db *sql.DB, name, sql string, args ...any) (map[string]string, error) {
   157  	rows, err := PrepareQuery(db, name, sql, args...)
   158  	if rows != nil && err == nil && len(rows) > 0 {
   159  		return rows[0], err
   160  	}
   161  	return nil, err
   162  }
   163  
   164  func PrepareQueryScalar(db *sql.DB, name, sql string, args ...any) (string, error) {
   165  	stmt, err := PrepareSql(db, name, sql)
   166  	if err != nil {
   167  		return "", err
   168  	}
   169  	var value string
   170  	rows, err1 := stmt.Query(args...)
   171  	if err1 != nil {
   172  		return "", err1
   173  	}
   174  	if rows.Next() {
   175  		_ = rows.Scan(&value)
   176  	}
   177  	_ = rows.Close()
   178  	return value, err
   179  }
   180  
   181  func PrepareExec(db *sql.DB, name, sql string, args ...any) (int64, error) {
   182  	var n int64
   183  	stmt, err := PrepareSql(db, name, sql)
   184  	if err != nil {
   185  		return 0, err
   186  	}
   187  	if strings.Contains(sql, " RETURNING ") {
   188  		row, err1 := stmt.Query(args...)
   189  		if err1 != nil {
   190  			return n, err1
   191  		}
   192  		row.Next()
   193  		err = row.Scan(&n)
   194  		_ = row.Close()
   195  	} else {
   196  		result, err1 := stmt.Exec(args...)
   197  		if err1 != nil {
   198  			return n, err1
   199  		}
   200  		if "INSERT" == strings.ToUpper(sql[0:6]) {
   201  			// XXX: postgres不能用这个方法,处何处理待考虑
   202  			n, err = result.LastInsertId()
   203  		} else {
   204  			n, err = result.RowsAffected()
   205  		}
   206  	}
   207  	return n, err
   208  }
   209  
   210  type Rows struct {
   211  	*sql.Rows
   212  }
   213  
   214  type DBValue struct {
   215  	Value any
   216  }
   217  
   218  func (r *Rows) GetByName(fieldName string) *DBValue {
   219  	cs, _ := r.Columns()
   220  	index := isc.IndexOf(cs, fieldName)
   221  	if index == -1 {
   222  		return nil
   223  	}
   224  	count := len(cs)
   225  	vals := make([]any, count)
   226  	scans := make([]any, count)
   227  	for i := range scans {
   228  		scans[i] = &vals[i]
   229  	}
   230  	_ = r.Scan(scans...)
   231  	if *(scans[index].(*any)) == nil {
   232  		return nil
   233  	} else {
   234  		return &DBValue{Value: scans[index]}
   235  	}
   236  }
   237  
   238  func (r *Rows) GetByNameDef(fieldName string, def any) *DBValue {
   239  	v := r.GetByName(fieldName)
   240  	if v == nil {
   241  		i := def
   242  		return &DBValue{
   243  			Value: &i,
   244  		}
   245  	} else {
   246  		return v
   247  	}
   248  }
   249  
   250  func (r *Rows) GetByIndex(index int) *DBValue {
   251  	cs, _ := r.Columns()
   252  	count := len(cs)
   253  	if index < 0 || index > count-1 {
   254  		return nil
   255  	}
   256  	vals := make([]any, count)
   257  	scans := make([]any, count)
   258  	for i := range scans {
   259  		scans[i] = &vals[i]
   260  	}
   261  	_ = r.Scan(scans...)
   262  	if *(scans[index].(*any)) == nil {
   263  		return nil
   264  	} else {
   265  		return &DBValue{Value: scans[index]}
   266  	}
   267  }
   268  
   269  func (r *Rows) GetByIndexDef(index int, def any) *DBValue {
   270  	v := r.GetByIndex(index)
   271  	if v == nil {
   272  		i := def
   273  		return &DBValue{
   274  			Value: &i,
   275  		}
   276  	} else {
   277  		return v
   278  	}
   279  }
   280  
   281  func (v *DBValue) ToString() string {
   282  	return string((*(v.Value.(*any))).([]uint8))
   283  }
   284  
   285  func (v *DBValue) ToInt() int {
   286  	return int((*(v.Value.(*any))).(int64))
   287  }
   288  
   289  func (v *DBValue) ToInt64() int64 {
   290  	return (*(v.Value.(*any))).(int64)
   291  }
   292  
   293  func (v *DBValue) ToFloat() float32 {
   294  	return (*(v.Value.(*any))).(float32)
   295  }
   296  
   297  func (v *DBValue) ToDouble() float64 {
   298  	return float64((*(v.Value.(*any))).(float32))
   299  }
   300  
   301  func (v *DBValue) ToBoolean() bool {
   302  	return (*(v.Value.(*any))).([]uint8)[0] == 1
   303  }
   304  
   305  func (v *DBValue) ToBytes() []byte {
   306  	return (*(v.Value.(*any))).([]uint8)
   307  }
   308  
   309  func (v *DBValue) ToTime() time.Time {
   310  	return (*(v.Value.(*any))).(time.Time)
   311  }
   312  
   313  func DBBoolean(b bool) []uint8 {
   314  	if b {
   315  		return []uint8{1}
   316  	} else {
   317  		return []uint8{0}
   318  	}
   319  }
   320  
   321  func fetchRows(rows *sql.Rows, err error) ([]map[string]string, error) {
   322  	if rows == nil || err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	fields, _ := rows.Columns()
   327  	for k, v := range fields {
   328  		fields[k] = camelCase(v)
   329  	}
   330  	columnsLength := len(fields)
   331  
   332  	values := make([]string, columnsLength)
   333  	args := make([]any, columnsLength)
   334  	for i := 0; i < columnsLength; i++ {
   335  		args[i] = &values[i]
   336  	}
   337  
   338  	index := 0
   339  	listLength := 100
   340  	lists := make([]map[string]string, listLength, listLength)
   341  	for rows.Next() {
   342  		if e := rows.Scan(args...); e == nil {
   343  			row := make(map[string]string, columnsLength)
   344  			for i, field := range fields {
   345  				row[field] = values[i]
   346  			}
   347  
   348  			if index < listLength {
   349  				lists[index] = row
   350  			} else {
   351  				lists = append(lists, row)
   352  			}
   353  			index++
   354  		}
   355  	}
   356  
   357  	_ = rows.Close()
   358  
   359  	return lists[0:index], nil
   360  }
   361  
   362  func camelCase(str string) string {
   363  	if strings.Contains(str, "_") {
   364  		items := strings.Split(str, "_")
   365  		arr := make([]string, len(items))
   366  		for k, v := range items {
   367  			if 0 == k {
   368  				arr[k] = v
   369  			} else {
   370  				arr[k] = strings.ToTitle(v)
   371  			}
   372  		}
   373  		str = strings.Join(arr, "")
   374  	}
   375  	return str
   376  }