vitess.io/vitess@v0.16.2/go/vt/external/golib/sqlutils/sqlutils.go (about)

     1  /*
     2     Copyright 2014 Outbrain Inc.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  /*
    18  	This file has been copied over from VTOrc package
    19  */
    20  
    21  package sqlutils
    22  
    23  import (
    24  	"database/sql"
    25  	"encoding/json"
    26  	"fmt"
    27  	"strconv"
    28  	"strings"
    29  	"sync"
    30  	"time"
    31  
    32  	"vitess.io/vitess/go/vt/log"
    33  )
    34  
    35  const DateTimeFormat = "2006-01-02 15:04:05.999999"
    36  
    37  // RowMap represents one row in a result set. Its objective is to allow
    38  // for easy, typed getters by column name.
    39  type RowMap map[string]CellData
    40  
    41  // CellData is the result of a single (atomic) column in a single row
    42  type CellData sql.NullString
    43  
    44  func (this *CellData) MarshalJSON() ([]byte, error) {
    45  	if this.Valid {
    46  		return json.Marshal(this.String)
    47  	} else {
    48  		return json.Marshal(nil)
    49  	}
    50  }
    51  
    52  // UnmarshalJSON reds this object from JSON
    53  func (this *CellData) UnmarshalJSON(b []byte) error {
    54  	var s string
    55  	if err := json.Unmarshal(b, &s); err != nil {
    56  		return err
    57  	}
    58  	(*this).String = s
    59  	(*this).Valid = true
    60  
    61  	return nil
    62  }
    63  
    64  func (this *CellData) NullString() *sql.NullString {
    65  	return (*sql.NullString)(this)
    66  }
    67  
    68  // RowData is the result of a single row, in positioned array format
    69  type RowData []CellData
    70  
    71  // MarshalJSON will marshal this map as JSON
    72  func (this *RowData) MarshalJSON() ([]byte, error) {
    73  	cells := make([](*CellData), len(*this))
    74  	for i, val := range *this {
    75  		d := CellData(val)
    76  		cells[i] = &d
    77  	}
    78  	return json.Marshal(cells)
    79  }
    80  
    81  func (this *RowData) Args() []any {
    82  	result := make([]any, len(*this))
    83  	for i := range *this {
    84  		result[i] = (*(*this)[i].NullString())
    85  	}
    86  	return result
    87  }
    88  
    89  // ResultData is an ordered row set of RowData
    90  type ResultData []RowData
    91  type NamedResultData struct {
    92  	Columns []string
    93  	Data    ResultData
    94  }
    95  
    96  var EmptyResultData = ResultData{}
    97  
    98  func (this *RowMap) GetString(key string) string {
    99  	return (*this)[key].String
   100  }
   101  
   102  // GetStringD returns a string from the map, or a default value if the key does not exist
   103  func (this *RowMap) GetStringD(key string, def string) string {
   104  	if cell, ok := (*this)[key]; ok {
   105  		return cell.String
   106  	}
   107  	return def
   108  }
   109  
   110  func (this *RowMap) GetInt64(key string) int64 {
   111  	res, _ := strconv.ParseInt(this.GetString(key), 10, 0)
   112  	return res
   113  }
   114  
   115  func (this *RowMap) GetNullInt64(key string) sql.NullInt64 {
   116  	i, err := strconv.ParseInt(this.GetString(key), 10, 0)
   117  	if err == nil {
   118  		return sql.NullInt64{Int64: i, Valid: true}
   119  	} else {
   120  		return sql.NullInt64{Valid: false}
   121  	}
   122  }
   123  
   124  func (this *RowMap) GetInt(key string) int {
   125  	res, _ := strconv.Atoi(this.GetString(key))
   126  	return res
   127  }
   128  
   129  func (this *RowMap) GetIntD(key string, def int) int {
   130  	res, err := strconv.Atoi(this.GetString(key))
   131  	if err != nil {
   132  		return def
   133  	}
   134  	return res
   135  }
   136  
   137  func (this *RowMap) GetUint(key string) uint {
   138  	res, _ := strconv.ParseUint(this.GetString(key), 10, 0)
   139  	return uint(res)
   140  }
   141  
   142  func (this *RowMap) GetUintD(key string, def uint) uint {
   143  	res, err := strconv.Atoi(this.GetString(key))
   144  	if err != nil {
   145  		return def
   146  	}
   147  	return uint(res)
   148  }
   149  
   150  func (this *RowMap) GetUint64(key string) uint64 {
   151  	res, _ := strconv.ParseUint(this.GetString(key), 10, 0)
   152  	return res
   153  }
   154  
   155  func (this *RowMap) GetUint64D(key string, def uint64) uint64 {
   156  	res, err := strconv.ParseUint(this.GetString(key), 10, 0)
   157  	if err != nil {
   158  		return def
   159  	}
   160  	return uint64(res)
   161  }
   162  
   163  func (this *RowMap) GetBool(key string) bool {
   164  	return this.GetInt(key) != 0
   165  }
   166  
   167  func (this *RowMap) GetTime(key string) time.Time {
   168  	if t, err := time.Parse(DateTimeFormat, this.GetString(key)); err == nil {
   169  		return t
   170  	}
   171  	return time.Time{}
   172  }
   173  
   174  // knownDBs is a DB cache by uri
   175  var knownDBs map[string]*sql.DB = make(map[string]*sql.DB)
   176  var knownDBsMutex = &sync.Mutex{}
   177  
   178  // GetDB returns a DB instance based on uri.
   179  // bool result indicates whether the DB was returned from cache; err
   180  func GetGenericDB(driverName, dataSourceName string) (*sql.DB, bool, error) {
   181  	knownDBsMutex.Lock()
   182  	defer func() {
   183  		knownDBsMutex.Unlock()
   184  	}()
   185  
   186  	var exists bool
   187  	if _, exists = knownDBs[dataSourceName]; !exists {
   188  		if db, err := sql.Open(driverName, dataSourceName); err == nil {
   189  			knownDBs[dataSourceName] = db
   190  		} else {
   191  			return db, exists, err
   192  		}
   193  	}
   194  	return knownDBs[dataSourceName], exists, nil
   195  }
   196  
   197  // GetDB returns a MySQL DB instance based on uri.
   198  // bool result indicates whether the DB was returned from cache; err
   199  func GetDB(mysql_uri string) (*sql.DB, bool, error) {
   200  	return GetGenericDB("mysql", mysql_uri)
   201  }
   202  
   203  // GetSQLiteDB returns a SQLite DB instance based on DB file name.
   204  // bool result indicates whether the DB was returned from cache; err
   205  func GetSQLiteDB(dbFile string) (*sql.DB, bool, error) {
   206  	return GetGenericDB("sqlite", dbFile)
   207  }
   208  
   209  // RowToArray is a convenience function, typically not called directly, which maps a
   210  // single read database row into a NullString
   211  func RowToArray(rows *sql.Rows, columns []string) ([]CellData, error) {
   212  	buff := make([]any, len(columns))
   213  	data := make([]CellData, len(columns))
   214  	for i := range buff {
   215  		buff[i] = data[i].NullString()
   216  	}
   217  	err := rows.Scan(buff...)
   218  	return data, err
   219  }
   220  
   221  // ScanRowsToArrays is a convenience function, typically not called directly, which maps rows
   222  // already read from the databse into arrays of NullString
   223  func ScanRowsToArrays(rows *sql.Rows, on_row func([]CellData) error) error {
   224  	columns, _ := rows.Columns()
   225  	for rows.Next() {
   226  		arr, err := RowToArray(rows, columns)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		err = on_row(arr)
   231  		if err != nil {
   232  			return err
   233  		}
   234  	}
   235  	return nil
   236  }
   237  
   238  func rowToMap(row []CellData, columns []string) map[string]CellData {
   239  	m := make(map[string]CellData)
   240  	for k, data_col := range row {
   241  		m[columns[k]] = data_col
   242  	}
   243  	return m
   244  }
   245  
   246  // ScanRowsToMaps is a convenience function, typically not called directly, which maps rows
   247  // already read from the databse into RowMap entries.
   248  func ScanRowsToMaps(rows *sql.Rows, on_row func(RowMap) error) error {
   249  	columns, _ := rows.Columns()
   250  	err := ScanRowsToArrays(rows, func(arr []CellData) error {
   251  		m := rowToMap(arr, columns)
   252  		err := on_row(m)
   253  		if err != nil {
   254  			return err
   255  		}
   256  		return nil
   257  	})
   258  	return err
   259  }
   260  
   261  // QueryRowsMap is a convenience function allowing querying a result set while poviding a callback
   262  // function activated per read row.
   263  func QueryRowsMap(db *sql.DB, query string, on_row func(RowMap) error, args ...any) (err error) {
   264  	defer func() {
   265  		if derr := recover(); derr != nil {
   266  			err = fmt.Errorf("QueryRowsMap unexpected error: %+v", derr)
   267  		}
   268  	}()
   269  
   270  	var rows *sql.Rows
   271  	rows, err = db.Query(query, args...)
   272  	if rows != nil {
   273  		defer rows.Close()
   274  	}
   275  	if err != nil && err != sql.ErrNoRows {
   276  		log.Error(err)
   277  		return err
   278  	}
   279  	err = ScanRowsToMaps(rows, on_row)
   280  	return
   281  }
   282  
   283  // queryResultData returns a raw array of rows for a given query, optionally reading and returning column names
   284  func queryResultData(db *sql.DB, query string, retrieveColumns bool, args ...any) (resultData ResultData, columns []string, err error) {
   285  	defer func() {
   286  		if derr := recover(); derr != nil {
   287  			err = fmt.Errorf("QueryRowsMap unexpected error: %+v", derr)
   288  		}
   289  	}()
   290  
   291  	var rows *sql.Rows
   292  	rows, err = db.Query(query, args...)
   293  	if err != nil && err != sql.ErrNoRows {
   294  		log.Error(err)
   295  		return EmptyResultData, columns, err
   296  	}
   297  	defer rows.Close()
   298  
   299  	if retrieveColumns {
   300  		// Don't pay if you don't want to
   301  		columns, _ = rows.Columns()
   302  	}
   303  	resultData = ResultData{}
   304  	err = ScanRowsToArrays(rows, func(rowData []CellData) error {
   305  		resultData = append(resultData, rowData)
   306  		return nil
   307  	})
   308  	return resultData, columns, err
   309  }
   310  
   311  // QueryResultData returns a raw array of rows
   312  func QueryResultData(db *sql.DB, query string, args ...any) (ResultData, error) {
   313  	resultData, _, err := queryResultData(db, query, false, args...)
   314  	return resultData, err
   315  }
   316  
   317  // QueryResultDataNamed returns a raw array of rows, with column names
   318  func QueryNamedResultData(db *sql.DB, query string, args ...any) (NamedResultData, error) {
   319  	resultData, columns, err := queryResultData(db, query, true, args...)
   320  	return NamedResultData{Columns: columns, Data: resultData}, err
   321  }
   322  
   323  // QueryRowsMapBuffered reads data from the database into a buffer, and only then applies the given function per row.
   324  // This allows the application to take its time with processing the data, albeit consuming as much memory as required by
   325  // the result set.
   326  func QueryRowsMapBuffered(db *sql.DB, query string, on_row func(RowMap) error, args ...any) error {
   327  	resultData, columns, err := queryResultData(db, query, true, args...)
   328  	if err != nil {
   329  		// Already logged
   330  		return err
   331  	}
   332  	for _, row := range resultData {
   333  		err = on_row(rowToMap(row, columns))
   334  		if err != nil {
   335  			return err
   336  		}
   337  	}
   338  	return nil
   339  }
   340  
   341  // ExecNoPrepare executes given query using given args on given DB, without using prepared statements.
   342  func ExecNoPrepare(db *sql.DB, query string, args ...any) (res sql.Result, err error) {
   343  	defer func() {
   344  		if derr := recover(); derr != nil {
   345  			err = fmt.Errorf("ExecNoPrepare unexpected error: %+v", derr)
   346  		}
   347  	}()
   348  
   349  	res, err = db.Exec(query, args...)
   350  	if err != nil {
   351  		log.Error(err)
   352  	}
   353  	return res, err
   354  }
   355  
   356  // ExecQuery executes given query using given args on given DB. It will safele prepare, execute and close
   357  // the statement.
   358  func execInternal(silent bool, db *sql.DB, query string, args ...any) (res sql.Result, err error) {
   359  	defer func() {
   360  		if derr := recover(); derr != nil {
   361  			err = fmt.Errorf("execInternal unexpected error: %+v", derr)
   362  		}
   363  	}()
   364  	var stmt *sql.Stmt
   365  	stmt, err = db.Prepare(query)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  	defer stmt.Close()
   370  	res, err = stmt.Exec(args...)
   371  	if err != nil && !silent {
   372  		log.Error(err)
   373  	}
   374  	return res, err
   375  }
   376  
   377  // Exec executes given query using given args on given DB. It will safele prepare, execute and close
   378  // the statement.
   379  func Exec(db *sql.DB, query string, args ...any) (sql.Result, error) {
   380  	return execInternal(false, db, query, args...)
   381  }
   382  
   383  // ExecSilently acts like Exec but does not report any error
   384  func ExecSilently(db *sql.DB, query string, args ...any) (sql.Result, error) {
   385  	return execInternal(true, db, query, args...)
   386  }
   387  
   388  func InClauseStringValues(terms []string) string {
   389  	quoted := []string{}
   390  	for _, s := range terms {
   391  		quoted = append(quoted, fmt.Sprintf("'%s'", strings.Replace(s, ",", "''", -1)))
   392  	}
   393  	return strings.Join(quoted, ", ")
   394  }
   395  
   396  // Convert variable length arguments into arguments array
   397  func Args(args ...any) []any {
   398  	return args
   399  }
   400  
   401  func NilIfZero(i int64) any {
   402  	if i == 0 {
   403  		return nil
   404  	}
   405  	return i
   406  }
   407  
   408  func ScanTable(db *sql.DB, tableName string) (NamedResultData, error) {
   409  	query := fmt.Sprintf("select * from %s", tableName)
   410  	return QueryNamedResultData(db, query)
   411  }
   412  
   413  func WriteTable(db *sql.DB, tableName string, data NamedResultData) (err error) {
   414  	if len(data.Data) == 0 {
   415  		return nil
   416  	}
   417  	if len(data.Columns) == 0 {
   418  		return nil
   419  	}
   420  	placeholders := make([]string, len(data.Columns))
   421  	for i := range placeholders {
   422  		placeholders[i] = "?"
   423  	}
   424  	query := fmt.Sprintf(
   425  		`replace into %s (%s) values (%s)`,
   426  		tableName,
   427  		strings.Join(data.Columns, ","),
   428  		strings.Join(placeholders, ","),
   429  	)
   430  	for _, rowData := range data.Data {
   431  		if _, execErr := db.Exec(query, rowData.Args()...); execErr != nil {
   432  			err = execErr
   433  		}
   434  	}
   435  	return err
   436  }