github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/database/database.go (about) 1 package database 2 3 import ( 4 "database/sql" 5 "log" 6 "strings" 7 "time" 8 9 "github.com/isyscore/isc-gobase/isc" 10 ) 11 12 type DatabaseType int 13 14 const ( 15 MySQL DatabaseType = iota // import _ "github.com/go-sql-driver/mysql" 16 Oracle // import _ "github.com/mattn/go-oci8" 17 SqlServer // import _ "github.com/denisenkom/go-mssqldb" 18 PostgreSql // import _ "github.com/lib/pq" 19 Sqlite3 // import _ "github.com/mattn/go-sqlite3" 20 ) 21 22 const ( 23 //CONNECTION_STRING user:password@tcp(host:port)/databaseName 24 CONNECTION_STRING = "%s:%s@tcp(%s:%d)/%s" 25 ) 26 27 func Connect(dbType DatabaseType, connStr string) *sql.DB { 28 return CustomConnect(dbTypeToString(dbType), connStr) 29 } 30 31 func CustomConnect(dbType string, connStr string) *sql.DB { 32 // charset=utf8 33 // parseTime=true 34 innerParam := connStr 35 if strings.Contains(connStr, "?") { 36 // 已有参数 37 if !strings.Contains(connStr, "charset=utf8") { 38 innerParam += "&charset=utf8" 39 } 40 if !strings.Contains(connStr, "parseTime=true") { 41 innerParam += "&parseTime=true" 42 } 43 } else { 44 // 没有参数 45 innerParam += "?charset=utf8&parseTime=true" 46 } 47 48 db, err := sql.Open(dbType, innerParam) 49 if err != nil { 50 log.Printf("初始化数据库失败(%v)\n", err) 51 return nil 52 } 53 return db 54 } 55 56 func dbTypeToString(dbType DatabaseType) string { 57 switch dbType { 58 case MySQL: 59 return "mysql" 60 case Oracle: 61 return "oci8" 62 case SqlServer: 63 return "mssql" 64 case PostgreSql: 65 return "postgres" 66 case Sqlite3: 67 return "sqlite3" 68 default: 69 log.Printf("不支持的数据库类型\n") 70 return "" 71 } 72 } 73 74 func Insert(db *sql.DB, sql string, args ...any) (int64, error) { 75 var id int64 76 var err error 77 if strings.Contains(sql, " RETURNING ") { 78 row := db.QueryRow(sql, args...) 79 err = row.Scan(&id) 80 } else { 81 result, err1 := db.Exec(sql, args...) 82 err = err1 83 if err1 == nil { 84 id, _ = result.LastInsertId() 85 } 86 } 87 return id, err 88 } 89 90 func Update(db *sql.DB, sql string, args ...any) (int64, error) { 91 var n int64 92 var err error 93 result, err := db.Exec(sql, args...) 94 if err == nil { 95 n, _ = result.RowsAffected() 96 } 97 return n, err 98 } 99 100 func Delete(db *sql.DB, sql string, args ...any) (int64, error) { 101 return Update(db, sql, args...) 102 } 103 104 func Query(db *sql.DB, sql string, args ...any) ([]map[string]string, error) { 105 rows, err := db.Query(sql, args...) 106 if err != nil { 107 return nil, err 108 } 109 return fetchRows(rows, err) 110 } 111 112 func QueryRow(db *sql.DB, sql string, args ...any) (map[string]string, error) { 113 rows, err := Query(db, sql, args...) 114 if rows != nil && err == nil && len(rows) > 0 { 115 return rows[0], err 116 } 117 return nil, err 118 } 119 120 func QueryScalar(db *sql.DB, sql string, key string, args ...any) (string, error) { 121 rows, err := Query(db, sql, args...) 122 if rows != nil && err == nil && len(rows) > 0 { 123 row := rows[0] 124 if value, ok := row[key]; ok { 125 return value, err 126 } 127 } 128 return "", err 129 } 130 131 // stmt 缓存 132 var stmtList = make(map[string]*sql.Stmt) 133 134 func PrepareSql(db *sql.DB, name, sql string) (*sql.Stmt, error) { 135 stmt, bl := stmtList[name] 136 if !bl { 137 var err error 138 stmt, err = db.Prepare(sql) 139 if err != nil { 140 return nil, err 141 } 142 stmtList[name] = stmt 143 } 144 return stmt, nil 145 } 146 147 func PrepareQuery(db *sql.DB, name, sql string, args ...any) ([]map[string]string, error) { 148 stmt, err := PrepareSql(db, name, sql) 149 if err != nil { 150 return nil, err 151 } 152 rows, err1 := stmt.Query(args...) 153 return fetchRows(rows, err1) 154 } 155 156 func PrepareQueryRow(db *sql.DB, name, sql string, args ...any) (map[string]string, error) { 157 rows, err := PrepareQuery(db, name, sql, args...) 158 if rows != nil && err == nil && len(rows) > 0 { 159 return rows[0], err 160 } 161 return nil, err 162 } 163 164 func PrepareQueryScalar(db *sql.DB, name, sql string, args ...any) (string, error) { 165 stmt, err := PrepareSql(db, name, sql) 166 if err != nil { 167 return "", err 168 } 169 var value string 170 rows, err1 := stmt.Query(args...) 171 if err1 != nil { 172 return "", err1 173 } 174 if rows.Next() { 175 _ = rows.Scan(&value) 176 } 177 _ = rows.Close() 178 return value, err 179 } 180 181 func PrepareExec(db *sql.DB, name, sql string, args ...any) (int64, error) { 182 var n int64 183 stmt, err := PrepareSql(db, name, sql) 184 if err != nil { 185 return 0, err 186 } 187 if strings.Contains(sql, " RETURNING ") { 188 row, err1 := stmt.Query(args...) 189 if err1 != nil { 190 return n, err1 191 } 192 row.Next() 193 err = row.Scan(&n) 194 _ = row.Close() 195 } else { 196 result, err1 := stmt.Exec(args...) 197 if err1 != nil { 198 return n, err1 199 } 200 if "INSERT" == strings.ToUpper(sql[0:6]) { 201 // XXX: postgres不能用这个方法,处何处理待考虑 202 n, err = result.LastInsertId() 203 } else { 204 n, err = result.RowsAffected() 205 } 206 } 207 return n, err 208 } 209 210 type Rows struct { 211 *sql.Rows 212 } 213 214 type DBValue struct { 215 Value any 216 } 217 218 func (r *Rows) GetByName(fieldName string) *DBValue { 219 cs, _ := r.Columns() 220 index := isc.IndexOf(cs, fieldName) 221 if index == -1 { 222 return nil 223 } 224 count := len(cs) 225 vals := make([]any, count) 226 scans := make([]any, count) 227 for i := range scans { 228 scans[i] = &vals[i] 229 } 230 _ = r.Scan(scans...) 231 if *(scans[index].(*any)) == nil { 232 return nil 233 } else { 234 return &DBValue{Value: scans[index]} 235 } 236 } 237 238 func (r *Rows) GetByNameDef(fieldName string, def any) *DBValue { 239 v := r.GetByName(fieldName) 240 if v == nil { 241 i := def 242 return &DBValue{ 243 Value: &i, 244 } 245 } else { 246 return v 247 } 248 } 249 250 func (r *Rows) GetByIndex(index int) *DBValue { 251 cs, _ := r.Columns() 252 count := len(cs) 253 if index < 0 || index > count-1 { 254 return nil 255 } 256 vals := make([]any, count) 257 scans := make([]any, count) 258 for i := range scans { 259 scans[i] = &vals[i] 260 } 261 _ = r.Scan(scans...) 262 if *(scans[index].(*any)) == nil { 263 return nil 264 } else { 265 return &DBValue{Value: scans[index]} 266 } 267 } 268 269 func (r *Rows) GetByIndexDef(index int, def any) *DBValue { 270 v := r.GetByIndex(index) 271 if v == nil { 272 i := def 273 return &DBValue{ 274 Value: &i, 275 } 276 } else { 277 return v 278 } 279 } 280 281 func (v *DBValue) ToString() string { 282 return string((*(v.Value.(*any))).([]uint8)) 283 } 284 285 func (v *DBValue) ToInt() int { 286 return int((*(v.Value.(*any))).(int64)) 287 } 288 289 func (v *DBValue) ToInt64() int64 { 290 return (*(v.Value.(*any))).(int64) 291 } 292 293 func (v *DBValue) ToFloat() float32 { 294 return (*(v.Value.(*any))).(float32) 295 } 296 297 func (v *DBValue) ToDouble() float64 { 298 return float64((*(v.Value.(*any))).(float32)) 299 } 300 301 func (v *DBValue) ToBoolean() bool { 302 return (*(v.Value.(*any))).([]uint8)[0] == 1 303 } 304 305 func (v *DBValue) ToBytes() []byte { 306 return (*(v.Value.(*any))).([]uint8) 307 } 308 309 func (v *DBValue) ToTime() time.Time { 310 return (*(v.Value.(*any))).(time.Time) 311 } 312 313 func DBBoolean(b bool) []uint8 { 314 if b { 315 return []uint8{1} 316 } else { 317 return []uint8{0} 318 } 319 } 320 321 func fetchRows(rows *sql.Rows, err error) ([]map[string]string, error) { 322 if rows == nil || err != nil { 323 return nil, err 324 } 325 326 fields, _ := rows.Columns() 327 for k, v := range fields { 328 fields[k] = camelCase(v) 329 } 330 columnsLength := len(fields) 331 332 values := make([]string, columnsLength) 333 args := make([]any, columnsLength) 334 for i := 0; i < columnsLength; i++ { 335 args[i] = &values[i] 336 } 337 338 index := 0 339 listLength := 100 340 lists := make([]map[string]string, listLength, listLength) 341 for rows.Next() { 342 if e := rows.Scan(args...); e == nil { 343 row := make(map[string]string, columnsLength) 344 for i, field := range fields { 345 row[field] = values[i] 346 } 347 348 if index < listLength { 349 lists[index] = row 350 } else { 351 lists = append(lists, row) 352 } 353 index++ 354 } 355 } 356 357 _ = rows.Close() 358 359 return lists[0:index], nil 360 } 361 362 func camelCase(str string) string { 363 if strings.Contains(str, "_") { 364 items := strings.Split(str, "_") 365 arr := make([]string, len(items)) 366 for k, v := range items { 367 if 0 == k { 368 arr[k] = v 369 } else { 370 arr[k] = strings.ToTitle(v) 371 } 372 } 373 str = strings.Join(arr, "") 374 } 375 return str 376 }