vitess.io/vitess@v0.16.2/go/vt/external/golib/sqlutils/sqlutils.go (about) 1 /* 2 Copyright 2014 Outbrain Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 /* 18 This file has been copied over from VTOrc package 19 */ 20 21 package sqlutils 22 23 import ( 24 "database/sql" 25 "encoding/json" 26 "fmt" 27 "strconv" 28 "strings" 29 "sync" 30 "time" 31 32 "vitess.io/vitess/go/vt/log" 33 ) 34 35 const DateTimeFormat = "2006-01-02 15:04:05.999999" 36 37 // RowMap represents one row in a result set. Its objective is to allow 38 // for easy, typed getters by column name. 39 type RowMap map[string]CellData 40 41 // CellData is the result of a single (atomic) column in a single row 42 type CellData sql.NullString 43 44 func (this *CellData) MarshalJSON() ([]byte, error) { 45 if this.Valid { 46 return json.Marshal(this.String) 47 } else { 48 return json.Marshal(nil) 49 } 50 } 51 52 // UnmarshalJSON reds this object from JSON 53 func (this *CellData) UnmarshalJSON(b []byte) error { 54 var s string 55 if err := json.Unmarshal(b, &s); err != nil { 56 return err 57 } 58 (*this).String = s 59 (*this).Valid = true 60 61 return nil 62 } 63 64 func (this *CellData) NullString() *sql.NullString { 65 return (*sql.NullString)(this) 66 } 67 68 // RowData is the result of a single row, in positioned array format 69 type RowData []CellData 70 71 // MarshalJSON will marshal this map as JSON 72 func (this *RowData) MarshalJSON() ([]byte, error) { 73 cells := make([](*CellData), len(*this)) 74 for i, val := range *this { 75 d := CellData(val) 76 cells[i] = &d 77 } 78 return json.Marshal(cells) 79 } 80 81 func (this *RowData) Args() []any { 82 result := make([]any, len(*this)) 83 for i := range *this { 84 result[i] = (*(*this)[i].NullString()) 85 } 86 return result 87 } 88 89 // ResultData is an ordered row set of RowData 90 type ResultData []RowData 91 type NamedResultData struct { 92 Columns []string 93 Data ResultData 94 } 95 96 var EmptyResultData = ResultData{} 97 98 func (this *RowMap) GetString(key string) string { 99 return (*this)[key].String 100 } 101 102 // GetStringD returns a string from the map, or a default value if the key does not exist 103 func (this *RowMap) GetStringD(key string, def string) string { 104 if cell, ok := (*this)[key]; ok { 105 return cell.String 106 } 107 return def 108 } 109 110 func (this *RowMap) GetInt64(key string) int64 { 111 res, _ := strconv.ParseInt(this.GetString(key), 10, 0) 112 return res 113 } 114 115 func (this *RowMap) GetNullInt64(key string) sql.NullInt64 { 116 i, err := strconv.ParseInt(this.GetString(key), 10, 0) 117 if err == nil { 118 return sql.NullInt64{Int64: i, Valid: true} 119 } else { 120 return sql.NullInt64{Valid: false} 121 } 122 } 123 124 func (this *RowMap) GetInt(key string) int { 125 res, _ := strconv.Atoi(this.GetString(key)) 126 return res 127 } 128 129 func (this *RowMap) GetIntD(key string, def int) int { 130 res, err := strconv.Atoi(this.GetString(key)) 131 if err != nil { 132 return def 133 } 134 return res 135 } 136 137 func (this *RowMap) GetUint(key string) uint { 138 res, _ := strconv.ParseUint(this.GetString(key), 10, 0) 139 return uint(res) 140 } 141 142 func (this *RowMap) GetUintD(key string, def uint) uint { 143 res, err := strconv.Atoi(this.GetString(key)) 144 if err != nil { 145 return def 146 } 147 return uint(res) 148 } 149 150 func (this *RowMap) GetUint64(key string) uint64 { 151 res, _ := strconv.ParseUint(this.GetString(key), 10, 0) 152 return res 153 } 154 155 func (this *RowMap) GetUint64D(key string, def uint64) uint64 { 156 res, err := strconv.ParseUint(this.GetString(key), 10, 0) 157 if err != nil { 158 return def 159 } 160 return uint64(res) 161 } 162 163 func (this *RowMap) GetBool(key string) bool { 164 return this.GetInt(key) != 0 165 } 166 167 func (this *RowMap) GetTime(key string) time.Time { 168 if t, err := time.Parse(DateTimeFormat, this.GetString(key)); err == nil { 169 return t 170 } 171 return time.Time{} 172 } 173 174 // knownDBs is a DB cache by uri 175 var knownDBs map[string]*sql.DB = make(map[string]*sql.DB) 176 var knownDBsMutex = &sync.Mutex{} 177 178 // GetDB returns a DB instance based on uri. 179 // bool result indicates whether the DB was returned from cache; err 180 func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) { 181 knownDBsMutex.Lock() 182 defer func() { 183 knownDBsMutex.Unlock() 184 }() 185 186 var exists bool 187 if _, exists = knownDBs[dataSourceName]; !exists { 188 if db, err := sql.Open(driverName, dataSourceName); err == nil { 189 knownDBs[dataSourceName] = db 190 } else { 191 return db, exists, err 192 } 193 } 194 return knownDBs[dataSourceName], exists, nil 195 } 196 197 // GetDB returns a MySQL DB instance based on uri. 198 // bool result indicates whether the DB was returned from cache; err 199 func GetDB(mysql_uri string) (*sql.DB, bool, error) { 200 return GetGenericDB("mysql", mysql_uri) 201 } 202 203 // GetSQLiteDB returns a SQLite DB instance based on DB file name. 204 // bool result indicates whether the DB was returned from cache; err 205 func GetSQLiteDB(dbFile string) (*sql.DB, bool, error) { 206 return GetGenericDB("sqlite", dbFile) 207 } 208 209 // RowToArray is a convenience function, typically not called directly, which maps a 210 // single read database row into a NullString 211 func RowToArray(rows *sql.Rows, columns []string) ([]CellData, error) { 212 buff := make([]any, len(columns)) 213 data := make([]CellData, len(columns)) 214 for i := range buff { 215 buff[i] = data[i].NullString() 216 } 217 err := rows.Scan(buff...) 218 return data, err 219 } 220 221 // ScanRowsToArrays is a convenience function, typically not called directly, which maps rows 222 // already read from the databse into arrays of NullString 223 func ScanRowsToArrays(rows *sql.Rows, on_row func([]CellData) error) error { 224 columns, _ := rows.Columns() 225 for rows.Next() { 226 arr, err := RowToArray(rows, columns) 227 if err != nil { 228 return err 229 } 230 err = on_row(arr) 231 if err != nil { 232 return err 233 } 234 } 235 return nil 236 } 237 238 func rowToMap(row []CellData, columns []string) map[string]CellData { 239 m := make(map[string]CellData) 240 for k, data_col := range row { 241 m[columns[k]] = data_col 242 } 243 return m 244 } 245 246 // ScanRowsToMaps is a convenience function, typically not called directly, which maps rows 247 // already read from the databse into RowMap entries. 248 func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error { 249 columns, _ := rows.Columns() 250 err := ScanRowsToArrays(rows, func(arr []CellData) error { 251 m := rowToMap(arr, columns) 252 err := on_row(m) 253 if err != nil { 254 return err 255 } 256 return nil 257 }) 258 return err 259 } 260 261 // QueryRowsMap is a convenience function allowing querying a result set while poviding a callback 262 // function activated per read row. 263 func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...any) (err error) { 264 defer func() { 265 if derr := recover(); derr != nil { 266 err = fmt.Errorf("QueryRowsMap unexpected error: %+v", derr) 267 } 268 }() 269 270 var rows *sql.Rows 271 rows, err = db.Query(query, args...) 272 if rows != nil { 273 defer rows.Close() 274 } 275 if err != nil && err != sql.ErrNoRows { 276 log.Error(err) 277 return err 278 } 279 err = ScanRowsToMaps(rows, on_row) 280 return 281 } 282 283 // queryResultData returns a raw array of rows for a given query, optionally reading and returning column names 284 func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...any) (resultData ResultData, columns []string, err error) { 285 defer func() { 286 if derr := recover(); derr != nil { 287 err = fmt.Errorf("QueryRowsMap unexpected error: %+v", derr) 288 } 289 }() 290 291 var rows *sql.Rows 292 rows, err = db.Query(query, args...) 293 if err != nil && err != sql.ErrNoRows { 294 log.Error(err) 295 return EmptyResultData, columns, err 296 } 297 defer rows.Close() 298 299 if retrieveColumns { 300 // Don't pay if you don't want to 301 columns, _ = rows.Columns() 302 } 303 resultData = ResultData{} 304 err = ScanRowsToArrays(rows, func(rowData []CellData) error { 305 resultData = append(resultData, rowData) 306 return nil 307 }) 308 return resultData, columns, err 309 } 310 311 // QueryResultData returns a raw array of rows 312 func QueryResultData(db *sql.DB, query string, args ...any) (ResultData, error) { 313 resultData, _, err := queryResultData(db, query, false, args...) 314 return resultData, err 315 } 316 317 // QueryResultDataNamed returns a raw array of rows, with column names 318 func QueryNamedResultData(db *sql.DB, query string, args ...any) (NamedResultData, error) { 319 resultData, columns, err := queryResultData(db, query, true, args...) 320 return NamedResultData{Columns: columns, Data: resultData}, err 321 } 322 323 // QueryRowsMapBuffered reads data from the database into a buffer, and only then applies the given function per row. 324 // This allows the application to take its time with processing the data, albeit consuming as much memory as required by 325 // the result set. 326 func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, args ...any) error { 327 resultData, columns, err := queryResultData(db, query, true, args...) 328 if err != nil { 329 // Already logged 330 return err 331 } 332 for _, row := range resultData { 333 err = on_row(rowToMap(row, columns)) 334 if err != nil { 335 return err 336 } 337 } 338 return nil 339 } 340 341 // ExecNoPrepare executes given query using given args on given DB, without using prepared statements. 342 func ExecNoPrepare(db *sql.DB, query string, args ...any) (res sql.Result, err error) { 343 defer func() { 344 if derr := recover(); derr != nil { 345 err = fmt.Errorf("ExecNoPrepare unexpected error: %+v", derr) 346 } 347 }() 348 349 res, err = db.Exec(query, args...) 350 if err != nil { 351 log.Error(err) 352 } 353 return res, err 354 } 355 356 // ExecQuery executes given query using given args on given DB. It will safele prepare, execute and close 357 // the statement. 358 func execInternal(silent bool, db *sql.DB, query string, args ...any) (res sql.Result, err error) { 359 defer func() { 360 if derr := recover(); derr != nil { 361 err = fmt.Errorf("execInternal unexpected error: %+v", derr) 362 } 363 }() 364 var stmt *sql.Stmt 365 stmt, err = db.Prepare(query) 366 if err != nil { 367 return nil, err 368 } 369 defer stmt.Close() 370 res, err = stmt.Exec(args...) 371 if err != nil && !silent { 372 log.Error(err) 373 } 374 return res, err 375 } 376 377 // Exec executes given query using given args on given DB. It will safele prepare, execute and close 378 // the statement. 379 func Exec(db *sql.DB, query string, args ...any) (sql.Result, error) { 380 return execInternal(false, db, query, args...) 381 } 382 383 // ExecSilently acts like Exec but does not report any error 384 func ExecSilently(db *sql.DB, query string, args ...any) (sql.Result, error) { 385 return execInternal(true, db, query, args...) 386 } 387 388 func InClauseStringValues(terms []string) string { 389 quoted := []string{} 390 for _, s := range terms { 391 quoted = append(quoted, fmt.Sprintf("'%s'", strings.Replace(s, ",", "''", -1))) 392 } 393 return strings.Join(quoted, ", ") 394 } 395 396 // Convert variable length arguments into arguments array 397 func Args(args ...any) []any { 398 return args 399 } 400 401 func NilIfZero(i int64) any { 402 if i == 0 { 403 return nil 404 } 405 return i 406 } 407 408 func ScanTable(db *sql.DB, tableName string) (NamedResultData, error) { 409 query := fmt.Sprintf("select * from %s", tableName) 410 return QueryNamedResultData(db, query) 411 } 412 413 func WriteTable(db *sql.DB, tableName string, data NamedResultData) (err error) { 414 if len(data.Data) == 0 { 415 return nil 416 } 417 if len(data.Columns) == 0 { 418 return nil 419 } 420 placeholders := make([]string, len(data.Columns)) 421 for i := range placeholders { 422 placeholders[i] = "?" 423 } 424 query := fmt.Sprintf( 425 `replace into %s (%s) values (%s)`, 426 tableName, 427 strings.Join(data.Columns, ","), 428 strings.Join(placeholders, ","), 429 ) 430 for _, rowData := range data.Data { 431 if _, execErr := db.Exec(query, rowData.Args()...); execErr != nil { 432 err = execErr 433 } 434 } 435 return err 436 }