github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/queryglue.go (about)

     1  package sqlitepool
     2  
     3  // This file contains bridging functions designed to let users of
     4  // database/sql move to sqlitepool without changing the semantics
     5  // of their code.
     6  //
     7  // Eventually users should piece-wise migrate to another interface.
     8  // (Or we should invest in this interface? Seems suboptimal.)
     9  
    10  import (
    11  	sqlpkg "database/sql"
    12  	"database/sql/driver"
    13  	"encoding"
    14  	"fmt"
    15  	"reflect"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/tailscale/sqlite/sqliteh"
    20  )
    21  
    22  // Exec is like database/sql.Tx.Exec.
    23  // Only use this for one-off/rare queries.
    24  // For normal queries, see the Exec method on Tx.
    25  func Exec(db sqliteh.DB, sql string, args ...any) error {
    26  	stmt, _, err := db.Prepare(sql, 0)
    27  	if err != nil {
    28  		return err
    29  	}
    30  	if err := bindAll(db, stmt, args...); err != nil {
    31  		return fmt.Errorf("Exec: %w", err)
    32  	}
    33  	_, _, _, _, err = stmt.StepResult()
    34  	if err != nil {
    35  		err = fmt.Errorf("%w: %v", err, db.ErrMsg())
    36  	}
    37  	stmt.Finalize()
    38  	return err
    39  }
    40  
    41  // QueryRow is like database/sql.Tx.QueryRow.
    42  // Only use this for one-off/rare queries.
    43  // For normal queries, see the methods on Rx.
    44  func QueryRow(db sqliteh.DB, sql string, args ...any) *Row {
    45  	stmt, _, err := db.Prepare(sql, 0)
    46  	if err != nil {
    47  		return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, db.ErrMsg())}
    48  	}
    49  	if err := bindAll(db, stmt, args...); err != nil {
    50  		return &Row{err: fmt.Errorf("QueryRow: %w", err)}
    51  	}
    52  	row, err := stmt.Step(nil)
    53  	if err != nil {
    54  		msg := db.ErrMsg()
    55  		stmt.Finalize()
    56  		return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)}
    57  	}
    58  	if !row {
    59  		stmt.Finalize()
    60  		return &Row{err: sqlpkg.ErrNoRows}
    61  	}
    62  	return &Row{stmt: stmt, oneOff: true}
    63  }
    64  
    65  // Query is like database/sql.Tx.Query.
    66  // Only use this for one-off/rare queries.
    67  // For normal queries, see the methods on Rx.
    68  func Query(db sqliteh.DB, sql string, args ...any) (*Rows, error) {
    69  	stmt, _, err := db.Prepare(sql, 0)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("Query: %w: %v", err, db.ErrMsg())
    72  	}
    73  	if err := bindAll(db, stmt, args...); err != nil {
    74  		return nil, err
    75  	}
    76  	return &Rows{stmt: stmt, oneOff: true}, nil
    77  }
    78  
    79  // Exec is like database/sql.Tx.Exec.
    80  func (tx *Tx) Exec(sql string, args ...any) error {
    81  	stmt := tx.Prepare(sql)
    82  	if err := bindAll(tx.conn.db, stmt, args...); err != nil {
    83  		return err
    84  	}
    85  	_, _, _, _, err := stmt.StepResult()
    86  	if err != nil {
    87  		return fmt.Errorf("%w: %v", err, tx.conn.db.ErrMsg())
    88  	}
    89  	return nil
    90  }
    91  
    92  func (tx *Tx) ExecRes(sql string, args ...any) (rowsAffected int64, err error) {
    93  	stmt := tx.Prepare(sql)
    94  	if err := bindAll(tx.conn.db, stmt, args...); err != nil {
    95  		return 0, err
    96  	}
    97  	_, _, rowsAffected, _, err = stmt.StepResult()
    98  	return rowsAffected, err
    99  }
   100  
   101  // QueryRow is like database/sql.Tx.QueryRow.
   102  func (rx *Rx) QueryRow(sql string, args ...any) *Row {
   103  	stmt := rx.Prepare(sql)
   104  	if err := bindAll(rx.conn.db, stmt, args...); err != nil {
   105  		return &Row{err: fmt.Errorf("QueryRow: %w", err)}
   106  	}
   107  	row, err := stmt.Step(nil)
   108  	if err != nil {
   109  		msg := rx.DB().ErrMsg()
   110  		stmt.ResetAndClear()
   111  		return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)}
   112  	}
   113  	if !row {
   114  		stmt.ResetAndClear()
   115  		return &Row{err: sqlpkg.ErrNoRows}
   116  	}
   117  	return &Row{stmt: stmt}
   118  }
   119  
   120  // Query is like database/sql.Tx.Query.
   121  func (rx *Rx) Query(sql string, args ...any) (*Rows, error) {
   122  	stmt := rx.Prepare(sql)
   123  	if err := bindAll(rx.conn.db, stmt, args...); err != nil {
   124  		return nil, fmt.Errorf("Query: %w", err)
   125  	}
   126  	return &Rows{stmt: stmt}, nil
   127  }
   128  
   129  // Rows is like database/sql.Tx.Rows.
   130  type Rows struct {
   131  	stmt   sqliteh.Stmt
   132  	err    error
   133  	oneOff bool
   134  }
   135  
   136  func (rs *Rows) Next() bool {
   137  	if rs.err != nil {
   138  		return false
   139  	}
   140  	row, err := rs.stmt.Step(nil)
   141  	if err != nil {
   142  		rs.err = fmt.Errorf("QueryRow.Next: %w: %v", err, rs.stmt.DBHandle().ErrMsg())
   143  		return false
   144  	}
   145  	if !row {
   146  		rs.stmt.ResetAndClear()
   147  	}
   148  	return row
   149  }
   150  
   151  func (rs *Rows) Err() error {
   152  	return rs.err
   153  }
   154  
   155  func (rs *Rows) Scan(dest ...any) error {
   156  	if rs.err != nil {
   157  		return rs.err
   158  	}
   159  	return scanAll(rs.stmt, dest...)
   160  }
   161  
   162  func (rs *Rows) Close() error {
   163  	if rs.stmt == nil {
   164  		return nil
   165  	}
   166  	_, err := rs.stmt.ResetAndClear()
   167  	msg := rs.stmt.DBHandle().ErrMsg()
   168  	var err2 error
   169  	if rs.oneOff {
   170  		err2 = rs.stmt.Finalize()
   171  	}
   172  	rs.stmt = nil
   173  	if err != nil {
   174  		return fmt.Errorf("Rows.ResetAndClear: %w: %v", err, msg)
   175  	}
   176  	if err2 != nil {
   177  		return fmt.Errorf("Rows.ResetAndClear: %w: %v", err2, rs.stmt.DBHandle().ErrMsg())
   178  	}
   179  	return nil
   180  }
   181  
   182  // Row is like database/sql.Tx.Row.
   183  type Row struct {
   184  	stmt   sqliteh.Stmt
   185  	err    error
   186  	oneOff bool
   187  }
   188  
   189  func (r *Row) Err() error {
   190  	return r.err
   191  }
   192  
   193  func (r *Row) Scan(dest ...any) error {
   194  	if r.err != nil {
   195  		return r.err
   196  	}
   197  	err := scanAll(r.stmt, dest...)
   198  	r.stmt.ResetAndClear()
   199  	if r.oneOff {
   200  		r.stmt.Finalize()
   201  	}
   202  	return err
   203  }
   204  
   205  type scanner interface {
   206  	Scan(value any) error
   207  }
   208  
   209  // scanAll mimics (some of) the sqlite driver's scanning logic, which is
   210  // split across the driver and the database/sql package.
   211  func scanAll(stmt sqliteh.Stmt, dest ...any) error {
   212  	for i := 0; i < len(dest); i++ {
   213  		if s, ok := dest[i].(scanner); ok {
   214  			// We have a handful of *sql.NullInt64 objects in
   215  			// our tree, so we implement minimal support for
   216  			// them here. TODO: remove some time.
   217  			var v any
   218  			switch stmt.ColumnType(i) {
   219  			case sqliteh.SQLITE_INTEGER:
   220  				v = stmt.ColumnInt64(i)
   221  			case sqliteh.SQLITE_FLOAT:
   222  				v = stmt.ColumnDouble(i)
   223  			case sqliteh.SQLITE_TEXT:
   224  				v = stmt.ColumnText(i)
   225  			case sqliteh.SQLITE_BLOB:
   226  				v = stmt.ColumnText(i)
   227  			case sqliteh.SQLITE_NULL:
   228  				v = nil
   229  			}
   230  			if err := s.Scan(v); err != nil {
   231  				return err
   232  			}
   233  			continue
   234  		}
   235  		v := reflect.ValueOf(dest[i])
   236  		if v.Elem().Kind() == reflect.Slice && v.Elem().Type().Elem().Kind() == reflect.Uint8 {
   237  			b := append([]byte(nil), stmt.ColumnBlob(i)...)
   238  			v.Elem().SetBytes(b)
   239  			continue
   240  		}
   241  		switch v.Elem().Kind() {
   242  		case reflect.Bool:
   243  			v.Elem().SetBool(stmt.ColumnInt64(i) != 0)
   244  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   245  			v.Elem().SetInt(stmt.ColumnInt64(i))
   246  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   247  			v.Elem().SetUint(uint64(stmt.ColumnInt64(i)))
   248  		case reflect.Float32, reflect.Float64:
   249  			v.Elem().SetFloat(stmt.ColumnDouble(i))
   250  		case reflect.String:
   251  			v.Elem().SetString(stmt.ColumnText(i))
   252  		default:
   253  			return fmt.Errorf("sqlitepool.scan:%d: cannot handle destination kind %v (%T)", i, v.Kind(), dest[i])
   254  		}
   255  	}
   256  	return nil
   257  }
   258  
   259  func bindAll(db sqliteh.DB, stmt sqliteh.Stmt, args ...any) error {
   260  	for i, arg := range args {
   261  		if err := bind(db, stmt, i+1, arg); err != nil {
   262  			stmt.ResetAndClear()
   263  			return fmt.Errorf("bind: %d, %q: %w", i, arg, err)
   264  		}
   265  	}
   266  	return nil
   267  }
   268  
   269  type driverValue interface {
   270  	Value() (driver.Value, error)
   271  }
   272  
   273  // bind, from the driver in sqlite.go.
   274  func bind(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) error {
   275  	// Start with obvious types, including time.Time before TextMarshaler.
   276  	found, err := bindBasic(db, s, ordinal, v)
   277  	if err != nil {
   278  		return err
   279  	} else if found {
   280  		return nil
   281  	}
   282  
   283  	if m, _ := v.(driverValue); m != nil {
   284  		// We have a few NullInt64s we need to handle.
   285  		// TODO: remove or rethink in the future.
   286  		var err error
   287  		v, err = m.Value()
   288  		if err != nil {
   289  			return fmt.Errorf("sqlitepool.bind:%d: bad driver.Value: %w", ordinal, err)
   290  		}
   291  		if v == nil {
   292  			_, err := bindBasic(db, s, ordinal, nil)
   293  			return err
   294  		}
   295  	}
   296  
   297  	if m, _ := v.(encoding.TextMarshaler); m != nil {
   298  		b, err := m.MarshalText()
   299  		if err != nil {
   300  			return fmt.Errorf("sqlitepool.bind:%d: cannot marshal %T: %w", ordinal, v, err)
   301  		}
   302  		_, err = bindBasic(db, s, ordinal, b)
   303  		return err
   304  	}
   305  
   306  	// Look for named basic types or other convertible types.
   307  	val := reflect.ValueOf(v)
   308  	if val.Kind() == reflect.Pointer {
   309  		if val.IsNil() {
   310  			_, err := bindBasic(db, s, ordinal, nil)
   311  			return err
   312  		}
   313  		val = val.Elem()
   314  	}
   315  	typ := reflect.TypeOf(v)
   316  	if typ.Kind() == reflect.Pointer {
   317  		typ = typ.Elem()
   318  	}
   319  	switch typ.Kind() {
   320  	case reflect.Bool:
   321  		b := int64(0)
   322  		if val.Bool() {
   323  			b = 1
   324  		}
   325  		_, err := bindBasic(db, s, ordinal, b)
   326  		return err
   327  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   328  		var i int64
   329  		if !val.IsZero() {
   330  			i = val.Int()
   331  		}
   332  		_, err := bindBasic(db, s, ordinal, i)
   333  		return err
   334  	case reflect.Uint, reflect.Uint64:
   335  		return fmt.Errorf("sqlitepool.bind:%d: sqlite does not support uint64 (try a string or TextMarshaler)", ordinal)
   336  	case reflect.Uint8, reflect.Uint16, reflect.Uint32:
   337  		_, err := bindBasic(db, s, ordinal, int64(val.Uint()))
   338  		return err
   339  	case reflect.Float32, reflect.Float64:
   340  		_, err := bindBasic(db, s, ordinal, val.Float())
   341  		return err
   342  	case reflect.String:
   343  		_, err := bindBasic(db, s, ordinal, val.String())
   344  		return err
   345  	}
   346  
   347  	return fmt.Errorf("sqlitepool.bind:%d: unknown value type %T (try a string or TextMarshaler)", ordinal, v)
   348  }
   349  
   350  // bindBasic, from the driver in sqlite.go.
   351  func bindBasic(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) (found bool, err error) {
   352  	defer func() {
   353  		if err != nil {
   354  			err = fmt.Errorf("sqlitepool.bind:%d:%T: %w: %v", ordinal, v, err, db.ErrMsg())
   355  		}
   356  	}()
   357  	switch v := v.(type) {
   358  	case nil:
   359  		return true, s.BindNull(ordinal)
   360  	case string:
   361  		return true, s.BindText64(ordinal, v)
   362  	case int:
   363  		return true, s.BindInt64(ordinal, int64(v))
   364  	case int64:
   365  		return true, s.BindInt64(ordinal, v)
   366  	case float64:
   367  		return true, s.BindDouble(ordinal, v)
   368  	case []byte:
   369  		if len(v) == 0 {
   370  			return true, s.BindZeroBlob64(ordinal, 0)
   371  		} else {
   372  			return true, s.BindBlob64(ordinal, v)
   373  		}
   374  	case time.Time:
   375  		// Shortest of:
   376  		//	YYYY-MM-DD HH:MM
   377  		// 	YYYY-MM-DD HH:MM:SS
   378  		//	YYYY-MM-DD HH:MM:SS.SSS
   379  		str := v.Format(timeFormat)
   380  		str = strings.TrimSuffix(str, "-0000")
   381  		str = strings.TrimSuffix(str, ".000")
   382  		str = strings.TrimSuffix(str, ":00")
   383  		return true, s.BindText64(ordinal, str)
   384  	default:
   385  		return false, nil
   386  	}
   387  }
   388  
   389  // timeFormat from the driver in sqlite.go.
   390  const timeFormat = "2006-01-02 15:04:05.000-0700"