github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlite.go (about)

     1  // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sqlite implements a database/sql driver for SQLite3.
     6  //
     7  // This driver requires a file: URI always be used to open a database.
     8  // For details see https://sqlite.org/c3ref/open.html#urifilenames.
     9  //
    10  // # Initializing connections or tracing
    11  //
    12  // If you want to do initial configuration of a connection, or enable
    13  // tracing, use the Connector function:
    14  //
    15  //	connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error {
    16  //		return sqlite.ExecScript(conn.(sqlite.SQLConn), "PRAGMA journal_mode=WAL;")
    17  //	}
    18  //	db, err = sql.OpenDB(sqlite.Connector(sqliteURI, connInitFunc, nil))
    19  //
    20  // # Memory Mode
    21  //
    22  // In-memory databases are popular for tests.
    23  // Use the "memdb" VFS (*not* the legacy in-memory modes) to be compatible
    24  // with the database/sql connection pool:
    25  //
    26  //	file:/dbname?vfs=memdb
    27  //
    28  // Use a different dbname for each memory database opened.
    29  //
    30  // # Binding Types
    31  //
    32  // SQLite is flexible about type conversions, and so is this driver.
    33  // Almost all "basic" Go types (int, float64, string) are accepted and
    34  // directly mapped into SQLite, even if they are named Go types.
    35  // The time.Time type is also accepted (described below).
    36  // Values that implement encoding.TextMarshaler or json.Marshaler are
    37  // stored in SQLite in their marshaled form.
    38  //
    39  // # Binding Time
    40  //
    41  // While SQLite3 has no strict time datatype, it does have a series of built-in
    42  // functions that operate on timestamps that expect columns to be in one of many
    43  // formats: https://sqlite.org/lang_datefunc.html
    44  //
    45  // When encoding a time.Time into one of SQLite's preferred formats, we use the
    46  // shortest timestamp format that can accurately represent the time.Time.
    47  // The supported formats are:
    48  //
    49  //  2. YYYY-MM-DD HH:MM
    50  //  3. YYYY-MM-DD HH:MM:SS
    51  //  4. YYYY-MM-DD HH:MM:SS.SSS
    52  //
    53  // If the time.Time is not UTC (strongly consider storing times in UTC!),
    54  // we follow SQLite's norm of appending "[+-]HH:MM" to the above formats.
    55  //
    56  // It is common in SQLite to store "Unix time", seconds-since-epoch in an
    57  // INTEGER column. This is understood by the date and time functions documented
    58  // in the link above. If you want to do that, pass the result of time.Time.Unix
    59  // to the driver.
    60  //
    61  // # Reading Time
    62  //
    63  // In general, time is hard to extract from SQLite as a time.Time.
    64  // If a column is defined as DATE or DATETIME, then text data is parsed
    65  // as TimeFormat and returned as a time.Time. Integer data is parsed as
    66  // seconds since epoch and returned as a time.Time.
    67  package sqlite
    68  
    69  import (
    70  	"context"
    71  	"database/sql"
    72  	"database/sql/driver"
    73  	"encoding"
    74  	"errors"
    75  	"expvar"
    76  	"fmt"
    77  	"io"
    78  	"reflect"
    79  	"strings"
    80  	"sync/atomic"
    81  	"time"
    82  
    83  	"github.com/tailscale/sqlite/sqliteh"
    84  )
    85  
    86  var Open sqliteh.OpenFunc = func(string, sqliteh.OpenFlags, string) (sqliteh.DB, error) {
    87  	return nil, fmt.Errorf("cgosqlite.Open is missing")
    88  }
    89  
    90  // ConnInitFunc is a function called by the driver on new connections.
    91  //
    92  // The conn can be used to execute queries, and implements SQLConn.
    93  // Any error return closes the conn and passes the error to database/sql.
    94  type ConnInitFunc func(ctx context.Context, conn driver.ConnPrepareContext) error
    95  
    96  // TimeFormat is the string format this driver uses to store
    97  // microsecond-precision time in SQLite in text format.
    98  const TimeFormat = "2006-01-02 15:04:05.000-0700"
    99  
   100  func init() {
   101  	sql.Register("sqlite3", drv{})
   102  }
   103  
   104  var maxConnID atomic.Int32
   105  
   106  // UsesAfterClose is a metric that is incremented every time an operation is
   107  // attempted on a connection after Close has already been called. The keys are
   108  // internal identifiers for the code path that incremented a counter.
   109  var UsesAfterClose expvar.Map
   110  
   111  // ErrClosed is returned when an operation is attempted on a connection after
   112  // Close has already been called.
   113  var ErrClosed = errors.New("sqlite3: already closed")
   114  
   115  type drv struct{}
   116  
   117  func (drv) Open(name string) (driver.Conn, error) { panic("deprecated, unused") }
   118  func (drv) OpenConnector(name string) (driver.Connector, error) {
   119  	return &connector{name: name}, nil
   120  }
   121  
   122  func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer) driver.Connector {
   123  	return &connector{
   124  		name:         sqliteURI,
   125  		tracer:       tracer,
   126  		connInitFunc: connInitFunc,
   127  	}
   128  }
   129  
   130  type connector struct {
   131  	name         string
   132  	tracer       sqliteh.Tracer
   133  	connInitFunc ConnInitFunc
   134  }
   135  
   136  func (p *connector) Driver() driver.Driver { return drv{} }
   137  func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
   138  	db, err := Open(p.name, sqliteh.OpenFlagsDefault, "")
   139  	if err != nil {
   140  		if ec, ok := err.(sqliteh.ErrCode); ok {
   141  			e := &Error{
   142  				Code: sqliteh.Code(ec),
   143  				Loc:  "Open",
   144  			}
   145  			if db != nil {
   146  				e.Msg = db.ErrMsg()
   147  			}
   148  			err = e
   149  		}
   150  		if db != nil {
   151  			db.Close()
   152  		}
   153  		return nil, err
   154  	}
   155  
   156  	c := &conn{
   157  		db:     db,
   158  		tracer: p.tracer,
   159  		id:     sqliteh.TraceConnID(maxConnID.Add(1)),
   160  	}
   161  	if p.connInitFunc != nil {
   162  		if err := p.connInitFunc(ctx, c); err != nil {
   163  			db.Close()
   164  			return nil, fmt.Errorf("sqlite.ConnInitFunc: %w", err)
   165  		}
   166  	}
   167  	return c, nil
   168  }
   169  
   170  type txState int
   171  
   172  const (
   173  	txStateNone  = txState(0) // connection is not connected to a Tx
   174  	txStateInit  = txState(1) // BeginTx called, but "BEGIN;" not yet executed
   175  	txStateBegun = txState(2) // "BEGIN;" has been executed
   176  )
   177  
   178  type conn struct {
   179  	db       sqliteh.DB
   180  	id       sqliteh.TraceConnID
   181  	tracer   sqliteh.Tracer
   182  	stmts    map[string]*stmt // persisted statements
   183  	txState  txState
   184  	readOnly bool
   185  	closed   atomic.Bool
   186  }
   187  
   188  func (c *conn) Prepare(query string) (driver.Stmt, error) { panic("deprecated, unused") }
   189  func (c *conn) Begin() (driver.Tx, error)                 { panic("deprecated, unused") }
   190  func (c *conn) Close() error {
   191  	// Don't double-close
   192  	if !c.closed.CompareAndSwap(false, true) {
   193  		UsesAfterClose.Add("Close", 1)
   194  		return nil
   195  	}
   196  
   197  	for q, s := range c.stmts {
   198  		s.stmt.Finalize()
   199  		s.closed.Store(true)
   200  		delete(c.stmts, q)
   201  	}
   202  	err := reserr(c.db, "Conn.Close", "", c.db.Close())
   203  	return err
   204  }
   205  func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   206  	persist := ctx.Value(persistQuery{}) != nil
   207  	return c.prepare(ctx, query, persist)
   208  }
   209  
   210  func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt, err error) {
   211  	if c.closed.Load() {
   212  		UsesAfterClose.Add("prepare", 1)
   213  		return nil, ErrClosed
   214  	}
   215  
   216  	query = strings.TrimSpace(query)
   217  	if s := c.stmts[query]; s != nil {
   218  		// don't hand the same statement out twice; this is re-added on s.Close
   219  		delete(c.stmts, query)
   220  
   221  		s.prepCtx = ctx
   222  		if !s.closed.CompareAndSwap(true, false) {
   223  			// We'd previously set this to 'false', indicating that
   224  			// this stmt is in-use. Return an error instead of
   225  			// reusing the stmt.
   226  			return nil, ErrClosed
   227  		}
   228  
   229  		return s, nil
   230  	}
   231  	if c.tracer != nil {
   232  		// Not a hot path. Any high-load environment should use
   233  		// WithPersist so this is rare.
   234  		start := time.Now()
   235  		defer func() {
   236  			if err != nil {
   237  				c.tracer.Query(ctx, c.id, query, time.Since(start), err)
   238  			}
   239  		}()
   240  	}
   241  	var flags sqliteh.PrepareFlags
   242  	if persist {
   243  		flags = sqliteh.SQLITE_PREPARE_PERSISTENT
   244  	}
   245  	cstmt, rem, err := c.db.Prepare(query, flags)
   246  	if err != nil {
   247  		return nil, reserr(c.db, "Prepare", query, err)
   248  	}
   249  	if rem != "" {
   250  		cstmt.Finalize()
   251  		return nil, &Error{
   252  			Code:  sqliteh.SQLITE_MISUSE,
   253  			Loc:   "Prepare",
   254  			Query: query,
   255  			Msg:   fmt.Sprintf("query has trailing text: %q", rem),
   256  		}
   257  	}
   258  	s = &stmt{
   259  		conn:     c,
   260  		stmt:     cstmt,
   261  		query:    query,
   262  		persist:  persist,
   263  		numInput: -1,
   264  		prepCtx:  ctx,
   265  	}
   266  
   267  	if !persist {
   268  		return s, nil
   269  	}
   270  
   271  	// NOTE: don't add the statement to c.stmts here, since we could return
   272  	// it to another caller before Close is called; it's added to the
   273  	// c.stmts map on Close.
   274  	if c.stmts == nil {
   275  		c.stmts = make(map[string]*stmt)
   276  	}
   277  	return s, nil
   278  }
   279  
   280  func (c *conn) execInternal(ctx context.Context, query string) error {
   281  	s, err := c.prepare(ctx, query, true)
   282  	if err != nil {
   283  		if e, _ := err.(*Error); e != nil {
   284  			e.Loc = "internal:" + e.Loc
   285  		}
   286  		return err
   287  	}
   288  	if _, err := s.ExecContext(ctx, nil); err != nil {
   289  		return err
   290  	}
   291  	s.Close()
   292  	return nil
   293  }
   294  
   295  func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   296  	if c.closed.Load() {
   297  		UsesAfterClose.Add("BeginTx", 1)
   298  		return nil, ErrClosed
   299  	}
   300  
   301  	const LevelSerializable = 6 // matches the sql package constant
   302  	if opts.Isolation != 0 && opts.Isolation != LevelSerializable {
   303  		return nil, errors.New("github.com/tailscale/sqlite driver only supports serializable isolation level")
   304  	}
   305  	c.readOnly = opts.ReadOnly
   306  	c.txState = txStateInit
   307  	if c.tracer != nil {
   308  		c.tracer.BeginTx(ctx, c.id, "", c.readOnly, nil)
   309  	}
   310  	if err := c.txInit(ctx); err != nil {
   311  		return nil, err
   312  	}
   313  	return &connTx{conn: c}, nil
   314  }
   315  
   316  // Raw is so ConnInitFunc can cast to SQLConn.
   317  func (c *conn) Raw(fn func(any) error) error { return fn(c) }
   318  
   319  type readOnlyKey struct{}
   320  
   321  // ReadOnly applies the query_only pragma to the connection.
   322  func ReadOnly(ctx context.Context) context.Context {
   323  	return context.WithValue(ctx, readOnlyKey{}, true)
   324  }
   325  
   326  // IsReadOnly reports whether the context has the ReadOnly key.
   327  func IsReadOnly(ctx context.Context) bool {
   328  	return ctx.Value(readOnlyKey{}) != nil
   329  }
   330  
   331  func (c *conn) txInit(ctx context.Context) error {
   332  	if c.txState != txStateInit {
   333  		return nil
   334  	}
   335  	c.txState = txStateBegun
   336  	if c.readOnly || IsReadOnly(ctx) {
   337  		if err := c.execInternal(ctx, "BEGIN"); err != nil {
   338  			return err
   339  		}
   340  		if err := c.execInternal(ctx, "PRAGMA query_only=true"); err != nil {
   341  			return err
   342  		}
   343  	} else {
   344  		// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
   345  		// semantics via a context annotation function.
   346  		if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
   347  			return err
   348  		}
   349  	}
   350  	return nil
   351  }
   352  
   353  func (c *conn) txEnd(ctx context.Context, endStmt string) error {
   354  	state, readOnly := c.txState, c.readOnly
   355  	c.txState = txStateNone
   356  	c.readOnly = false
   357  	if state != txStateBegun {
   358  		return nil
   359  	}
   360  
   361  	err := c.execInternal(context.Background(), endStmt)
   362  	if readOnly {
   363  		if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil {
   364  			err = err2
   365  		}
   366  	}
   367  	return err
   368  }
   369  
   370  type connTx struct {
   371  	conn *conn
   372  }
   373  
   374  func (tx *connTx) Commit() error {
   375  	if tx.conn.closed.Load() {
   376  		UsesAfterClose.Add("tx.Commit", 1)
   377  		return ErrClosed
   378  	}
   379  
   380  	err := tx.conn.txEnd(context.Background(), "COMMIT")
   381  	if tx.conn.tracer != nil {
   382  		tx.conn.tracer.Commit(tx.conn.id, err)
   383  	}
   384  	return err
   385  }
   386  
   387  func (tx *connTx) Rollback() error {
   388  	if tx.conn.closed.Load() {
   389  		UsesAfterClose.Add("tx.Rollback", 1)
   390  		return ErrClosed
   391  	}
   392  
   393  	err := tx.conn.txEnd(context.Background(), "ROLLBACK")
   394  	if tx.conn.tracer != nil {
   395  		tx.conn.tracer.Rollback(tx.conn.id, err)
   396  	}
   397  	return err
   398  }
   399  
   400  func reserr(db sqliteh.DB, loc, query string, err error) error {
   401  	if err == nil {
   402  		return nil
   403  	}
   404  	e := &Error{
   405  		Code:  sqliteh.Code(err.(sqliteh.ErrCode)),
   406  		Loc:   loc,
   407  		Query: query,
   408  	}
   409  	// TODO(crawshaw): consider an API to expose this. sqlite.DebugErrMsg(db)?
   410  	if true {
   411  		e.Msg = db.ErrMsg()
   412  	}
   413  	return e
   414  }
   415  
   416  type stmt struct {
   417  	conn    *conn
   418  	stmt    sqliteh.Stmt
   419  	query   string
   420  	persist bool        // true if stmt is cached and lives beyond Close
   421  	bound   bool        // true if stmt has parameters bound
   422  	closed  atomic.Bool // true after Close if persist==false
   423  
   424  	numInput int // filled on first NumInput only if persist==true
   425  
   426  	prepCtx context.Context // the context provided to prepare, for tracing
   427  
   428  	// filled on first step only if persist==true
   429  	colDeclTypes []colDeclType
   430  	colNames     []string
   431  }
   432  
   433  func (s *stmt) reserr(loc string, err error) error { return reserr(s.conn.db, loc, s.query, err) }
   434  
   435  func (s *stmt) NumInput() int {
   436  	if s.closed.Load() {
   437  		UsesAfterClose.Add("stmt.NumInput", 1)
   438  		return 0
   439  	}
   440  	if s.persist {
   441  		if s.numInput == -1 {
   442  			s.numInput = s.stmt.BindParameterCount()
   443  		}
   444  		return s.numInput
   445  	}
   446  	return s.stmt.BindParameterCount()
   447  }
   448  
   449  func (s *stmt) Close() error {
   450  	// Always set the 'closed' boolean, even for a persisted query; this is
   451  	// set from false -> true in prepare(), above.
   452  	if s.conn.closed.Load() {
   453  		UsesAfterClose.Add("Stmt.Close_conn", 1)
   454  		return nil
   455  	}
   456  	if !s.closed.CompareAndSwap(false, true) {
   457  		UsesAfterClose.Add("Stmt.Close", 1)
   458  		return nil
   459  	}
   460  
   461  	// We return this statement to the conn only if it's persistent, and
   462  	// only if there's not already a statement with the same query already
   463  	// cached there.
   464  	shouldPersist := s.persist
   465  	if shouldPersist {
   466  		if _, alreadyPersisted := s.conn.stmts[s.query]; alreadyPersisted {
   467  			shouldPersist = false
   468  		}
   469  	}
   470  	if shouldPersist {
   471  		err := s.reserr("Stmt.Close", s.resetAndClear())
   472  		if err == nil {
   473  			s.conn.stmts[s.query] = s
   474  		}
   475  		return err
   476  	}
   477  	return s.reserr("Stmt.Close", s.stmt.Finalize())
   478  }
   479  func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { panic("deprecated, unused") }
   480  func (s *stmt) Query(args []driver.Value) (driver.Rows, error)  { panic("deprecated, unused") }
   481  
   482  func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   483  	if s.closed.Load() {
   484  		UsesAfterClose.Add("stmt.ExecContext", 1)
   485  		return nil, ErrClosed
   486  	}
   487  	if err := s.resetAndClear(); err != nil {
   488  		return nil, s.reserr("Stmt.Exec(Reset)", err)
   489  	}
   490  	if err := s.bindAll(args); err != nil {
   491  		return nil, s.reserr("Stmt.Exec(Bind)", err)
   492  	}
   493  	if ctx.Value(queryCancelKey{}) != nil {
   494  		var cancel context.CancelFunc
   495  		ctx, cancel = context.WithCancel(ctx)
   496  		defer cancel()
   497  
   498  		db := s.stmt.DBHandle()
   499  		go func() { <-ctx.Done(); db.Interrupt() }()
   500  	}
   501  	row, lastInsertRowID, changes, duration, err := s.stmt.StepResult()
   502  	s.bound = false // StepResult resets the query
   503  	err = s.reserr("Stmt.Exec", err)
   504  	if s.conn.tracer != nil {
   505  		s.conn.tracer.Query(s.prepCtx, s.conn.id, s.query, duration, err)
   506  	}
   507  	if err != nil {
   508  		return nil, err
   509  	}
   510  	_ = row // TODO: return error if exec on query which returns rows?
   511  	return getStmtResult(lastInsertRowID, changes), nil
   512  }
   513  
   514  var (
   515  	stmtResultZeroRows = &stmtResult{}
   516  	stmtResultOneRow   = &stmtResult{rowsAffected: 1}
   517  )
   518  
   519  func getStmtResult(lastInsertID int64, rowsAffected int64) *stmtResult {
   520  	// Some common cases to avoid allocs:
   521  	if lastInsertID == 0 {
   522  		switch rowsAffected {
   523  		case 0:
   524  			return stmtResultZeroRows
   525  		case 1:
   526  			return stmtResultOneRow
   527  		}
   528  	}
   529  	return &stmtResult{lastInsertID: lastInsertID, rowsAffected: rowsAffected}
   530  }
   531  
   532  type stmtResult struct {
   533  	lastInsertID int64
   534  	rowsAffected int64
   535  }
   536  
   537  func (res *stmtResult) LastInsertId() (int64, error) { return res.lastInsertID, nil }
   538  func (res *stmtResult) RowsAffected() (int64, error) { return res.rowsAffected, nil }
   539  
   540  func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   541  	if s.closed.Load() {
   542  		UsesAfterClose.Add("stmt.QueryContext", 1)
   543  		return nil, ErrClosed
   544  	}
   545  	if err := s.resetAndClear(); err != nil {
   546  		return nil, s.reserr("Stmt.Query(Reset)", err)
   547  	}
   548  	if err := s.bindAll(args); err != nil {
   549  		return nil, err
   550  	}
   551  	cancel := func() {}
   552  	if ctx.Value(queryCancelKey{}) != nil {
   553  		ctx, cancel = context.WithCancel(ctx)
   554  		db := s.stmt.DBHandle()
   555  		go func() { <-ctx.Done(); db.Interrupt() }()
   556  	}
   557  	return &rows{stmt: s, cancel: cancel}, nil
   558  }
   559  
   560  func (s *stmt) resetAndClear() error {
   561  	if !s.bound {
   562  		return nil
   563  	}
   564  	s.bound = false
   565  	duration, err := s.stmt.ResetAndClear()
   566  	if s.conn.tracer != nil {
   567  		s.conn.tracer.Query(s.prepCtx, s.conn.id, s.query, duration, err)
   568  	}
   569  	return err
   570  }
   571  
   572  func (s *stmt) bindAll(args []driver.NamedValue) error {
   573  	if s.bound {
   574  		panic("sqlite: impossible state, query already running: " + s.query)
   575  	}
   576  	s.bound = true
   577  	if s.conn.tracer != nil {
   578  		s.stmt.StartTimer()
   579  	}
   580  	for _, arg := range args {
   581  		if err := s.bind(arg); err != nil {
   582  			return err
   583  		}
   584  	}
   585  	return nil
   586  }
   587  
   588  func (s *stmt) bind(arg driver.NamedValue) error {
   589  	// TODO(crawshaw): could use a union-ish data type for debugName
   590  	// to avoid the allocation.
   591  	var debugName any
   592  	if arg.Name == "" {
   593  		debugName = arg.Ordinal
   594  	} else {
   595  		debugName = arg.Name
   596  		index := s.stmt.BindParameterIndexSearch(arg.Name)
   597  		if index == 0 {
   598  			return &Error{
   599  				Code:  sqliteh.SQLITE_MISUSE,
   600  				Loc:   "Bind",
   601  				Query: s.query,
   602  				Msg:   fmt.Sprintf("unknown parameter name %q", arg.Name),
   603  			}
   604  		}
   605  		arg.Ordinal = index
   606  	}
   607  
   608  	// Start with obvious types, including time.Time before TextMarshaler.
   609  	found, err := s.bindBasic(debugName, arg.Ordinal, arg.Value)
   610  	if err != nil {
   611  		return err
   612  	} else if found {
   613  		return nil
   614  	}
   615  
   616  	if m, _ := arg.Value.(encoding.TextMarshaler); m != nil {
   617  		b, err := m.MarshalText()
   618  		if err != nil {
   619  			// TODO: modify Error to carry an error so we can %w?
   620  			return &Error{
   621  				Code:  sqliteh.SQLITE_MISUSE,
   622  				Loc:   "Bind",
   623  				Query: s.query,
   624  				Msg:   fmt.Sprintf("Bind:%v: cannot marshal %T: %v", debugName, arg.Value, err),
   625  			}
   626  		}
   627  		_, err = s.bindBasic(debugName, arg.Ordinal, b)
   628  		return err
   629  	}
   630  
   631  	// Look for named basic types or other convertible types.
   632  	val := reflect.ValueOf(arg.Value)
   633  	typ := reflect.TypeOf(arg.Value)
   634  	switch typ.Kind() {
   635  	case reflect.Bool:
   636  		b := int64(0)
   637  		if val.Bool() {
   638  			b = 1
   639  		}
   640  		_, err := s.bindBasic(debugName, arg.Ordinal, b)
   641  		return err
   642  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   643  		_, err := s.bindBasic(debugName, arg.Ordinal, val.Int())
   644  		return err
   645  	case reflect.Uint, reflect.Uint64:
   646  		return &Error{
   647  			Code:  sqliteh.SQLITE_MISUSE,
   648  			Loc:   "Bind",
   649  			Query: s.query,
   650  			Msg:   fmt.Sprintf("Bind:%v: sqlite does not support uint64 (try a string or TextMarshaler)", debugName),
   651  		}
   652  	case reflect.Uint8, reflect.Uint16, reflect.Uint32:
   653  		_, err := s.bindBasic(debugName, arg.Ordinal, int64(val.Uint()))
   654  		return err
   655  	case reflect.Float32, reflect.Float64:
   656  		_, err := s.bindBasic(debugName, arg.Ordinal, val.Float())
   657  		return err
   658  	case reflect.String:
   659  		// TODO(crawshaw): decompose bindBasic somehow.
   660  		// But first: more tests that the errors make sense for each type.
   661  		_, err := s.bindBasic(debugName, arg.Ordinal, val.String())
   662  		return err
   663  	}
   664  
   665  	return &Error{
   666  		Code:  sqliteh.SQLITE_MISUSE,
   667  		Loc:   "Bind",
   668  		Query: s.query,
   669  		Msg:   fmt.Sprintf("Bind:%v: unknown value type %T (try a string or TextMarshaler)", debugName, arg.Value),
   670  	}
   671  }
   672  
   673  func (s *stmt) bindBasic(debugName any, ordinal int, v any) (found bool, err error) {
   674  	defer func() {
   675  		if err != nil {
   676  			err = s.reserr(fmt.Sprintf("Bind:%v:%T", debugName, v), err)
   677  		}
   678  	}()
   679  	switch v := v.(type) {
   680  	case nil:
   681  		return true, s.stmt.BindNull(ordinal)
   682  	case string:
   683  		return true, s.stmt.BindText64(ordinal, v)
   684  	case int:
   685  		return true, s.stmt.BindInt64(ordinal, int64(v))
   686  	case int64:
   687  		return true, s.stmt.BindInt64(ordinal, v)
   688  	case float64:
   689  		return true, s.stmt.BindDouble(ordinal, v)
   690  	case []byte:
   691  		if len(v) == 0 {
   692  			return true, s.stmt.BindZeroBlob64(ordinal, 0)
   693  		} else {
   694  			return true, s.stmt.BindBlob64(ordinal, v)
   695  		}
   696  	case time.Time:
   697  		// Shortest of:
   698  		//	YYYY-MM-DD HH:MM
   699  		// 	YYYY-MM-DD HH:MM:SS
   700  		//	YYYY-MM-DD HH:MM:SS.SSS
   701  		str := v.Format(TimeFormat)
   702  		str = strings.TrimSuffix(str, "-0000")
   703  		str = strings.TrimSuffix(str, ".000")
   704  		str = strings.TrimSuffix(str, ":00")
   705  		return true, s.stmt.BindText64(ordinal, str)
   706  	default:
   707  		return false, nil
   708  	}
   709  }
   710  
   711  // colDeclType is whether and how the declared SQLite column type should
   712  // map to any special handling (as a date, or as a boolean, etc).
   713  type colDeclType byte
   714  
   715  const (
   716  	declTypeUnknown colDeclType = iota
   717  	declTypeDateOrTime
   718  	declTypeBoolean
   719  )
   720  
   721  func colDeclTypeFromString(s string) colDeclType {
   722  	if strings.EqualFold(s, "DATETIME") || strings.EqualFold(s, "DATE") {
   723  		return declTypeDateOrTime
   724  	}
   725  	if strings.EqualFold(s, "BOOLEAN") {
   726  		return declTypeBoolean
   727  	}
   728  	return declTypeUnknown
   729  }
   730  
   731  type rows struct {
   732  	stmt   *stmt
   733  	closed bool
   734  	cancel context.CancelFunc // call when query ends
   735  
   736  	// colType is the column types for Step to fill on each row. We only use 23
   737  	// as it packs well with the closed bool byte above (24 bytes total, same as
   738  	// a slice) and it's uncommon for queries to select so many columns. But if
   739  	// they do, we still work: we just query the column type via cgo on each
   740  	// row. So a bit slower, but fine.
   741  	colType [23]sqliteh.ColumnType
   742  
   743  	colNames []string // filled on call to Columns
   744  
   745  	// Filled on first call to Next.
   746  	colDeclTypes []colDeclType
   747  }
   748  
   749  func (r *rows) Columns() []string {
   750  	if r.closed {
   751  		panic("Columns called after Rows was closed")
   752  	}
   753  	if r.stmt.closed.Load() {
   754  		UsesAfterClose.Add("rows.Columns", 1)
   755  		return nil
   756  	}
   757  	if r.colNames == nil {
   758  		if r.stmt.colNames != nil {
   759  			r.colNames = r.stmt.colNames
   760  		} else {
   761  			r.colNames = make([]string, r.stmt.stmt.ColumnCount())
   762  			for i := range r.colNames {
   763  				r.colNames[i] = r.stmt.stmt.ColumnName(i)
   764  			}
   765  			if r.stmt.persist {
   766  				r.stmt.colNames = r.colNames
   767  			}
   768  		}
   769  	}
   770  	return append([]string{}, r.colNames...)
   771  }
   772  
   773  func (r *rows) Close() error {
   774  	if r.closed {
   775  		return errors.New("sqlite rows result already closed")
   776  	}
   777  	if r.stmt.closed.Load() {
   778  		UsesAfterClose.Add("rows.Close", 1)
   779  		return ErrClosed
   780  	}
   781  	r.closed = true
   782  	defer r.cancel()
   783  	if err := r.stmt.resetAndClear(); err != nil {
   784  		return r.stmt.reserr("Rows.Close(Reset)", err)
   785  	}
   786  	return nil
   787  }
   788  
   789  func (r *rows) Next(dest []driver.Value) error {
   790  	if r.closed {
   791  		return errors.New("sqlite rows result already closed")
   792  	}
   793  	if r.stmt.closed.Load() {
   794  		UsesAfterClose.Add("rows.Next", 1)
   795  		return ErrClosed
   796  	}
   797  	hasRow, err := r.stmt.stmt.Step(r.colType[:])
   798  	if err != nil {
   799  		return r.stmt.reserr("Rows.Next", err)
   800  	}
   801  	if !hasRow {
   802  		return io.EOF
   803  	}
   804  
   805  	if r.colDeclTypes == nil {
   806  		r.colDeclTypes = r.stmt.colDeclTypes
   807  	}
   808  	if r.colDeclTypes == nil {
   809  		colCount := r.stmt.stmt.ColumnCount()
   810  		r.colDeclTypes = make([]colDeclType, colCount)
   811  		for i := range r.colDeclTypes {
   812  			r.colDeclTypes[i] = colDeclTypeFromString(r.stmt.stmt.ColumnDeclType(i))
   813  		}
   814  		if r.stmt.persist {
   815  			r.stmt.colDeclTypes = r.colDeclTypes
   816  		}
   817  	}
   818  
   819  	for i := range dest {
   820  		var colType sqliteh.ColumnType
   821  		if i < len(r.colType) {
   822  			// Common case, for the first couple dozen columns.
   823  			colType = r.colType[i]
   824  		} else {
   825  			// If it's a really wide query, then call into
   826  			// cgo for columns past the length of
   827  			// r.colType.
   828  			colType = r.stmt.stmt.ColumnType(i)
   829  		}
   830  
   831  		if r.colDeclTypes[i] == declTypeDateOrTime {
   832  			switch colType {
   833  			case sqliteh.SQLITE_INTEGER:
   834  				v := r.stmt.stmt.ColumnInt64(i)
   835  				dest[i] = time.Unix(v, 0)
   836  			case sqliteh.SQLITE_FLOAT:
   837  				dest[i] = r.stmt.stmt.ColumnDouble(i)
   838  				// TODO: treat as time?
   839  			case sqliteh.SQLITE_TEXT:
   840  				v := r.stmt.stmt.ColumnText(i)
   841  				format := TimeFormat
   842  				if len(format) > len(v) {
   843  					format = strings.TrimSuffix(format, "-0700")
   844  				}
   845  				if len(format) > len(v) {
   846  					format = strings.TrimSuffix(format, ".000")
   847  				}
   848  				if len(format) > len(v) {
   849  					format = strings.TrimSuffix(format, ":05")
   850  				}
   851  				t, err := time.Parse(format, v)
   852  				if err != nil {
   853  					return fmt.Errorf("cannot parse time from column %d: %v", i, err)
   854  				}
   855  				dest[i] = t
   856  			}
   857  			continue
   858  		}
   859  		switch colType {
   860  		case sqliteh.SQLITE_INTEGER:
   861  			val := r.stmt.stmt.ColumnInt64(i)
   862  			if r.colDeclTypes[i] == declTypeBoolean {
   863  				dest[i] = val > 0
   864  			} else {
   865  				dest[i] = val
   866  			}
   867  		case sqliteh.SQLITE_FLOAT:
   868  			dest[i] = r.stmt.stmt.ColumnDouble(i)
   869  		case sqliteh.SQLITE_BLOB, sqliteh.SQLITE_TEXT:
   870  			dest[i] = r.stmt.stmt.ColumnBlob(i)
   871  		case sqliteh.SQLITE_NULL:
   872  			dest[i] = nil
   873  		}
   874  	}
   875  	return nil
   876  }
   877  
   878  // Error is an error produced by SQLite.
   879  type Error struct {
   880  	Code  sqliteh.Code // SQLite extended error code (SQLITE_OK is an invalid value)
   881  	Loc   string       // method name that generated the error
   882  	Query string       // original SQL query text
   883  	Msg   string       // value of sqlite3_errmsg, set sqlite.ErrMsg = true
   884  }
   885  
   886  func (err Error) Error() string {
   887  	b := new(strings.Builder)
   888  	b.WriteString("sqlite")
   889  	if err.Loc != "" {
   890  		b.WriteByte('.')
   891  		b.WriteString(err.Loc)
   892  	}
   893  	b.WriteString(": ")
   894  	b.WriteString(err.Code.String())
   895  	if err.Msg != "" {
   896  		b.WriteString(": ")
   897  		b.WriteString(err.Msg)
   898  	}
   899  	if err.Query != "" {
   900  		b.WriteString(" (")
   901  		b.WriteString(err.Query)
   902  		b.WriteByte(')')
   903  	}
   904  	return b.String()
   905  }
   906  
   907  // SQLConn is a database/sql.Conn.
   908  // (We cannot create a circular package dependency here.)
   909  type SQLConn interface {
   910  	Raw(func(driverConn any) error) error
   911  }
   912  
   913  // ExecScript executes a set of SQL queries on an sql.Conn.
   914  // It stops on the first error.
   915  // It is recommended you wrap your script in a BEGIN; ... COMMIT; block.
   916  //
   917  // Usage:
   918  //
   919  //	c, err := db.Conn(ctx)
   920  //	if err != nil {
   921  //		// handle err
   922  //	}
   923  //	if err := sqlite.ExecScript(c, queries); err != nil {
   924  //		// handle err
   925  //	}
   926  //	c.Close() // return sql.Conn to pool
   927  func ExecScript(sqlconn SQLConn, queries string) error {
   928  	return sqlconn.Raw(func(driverConn any) error {
   929  		c, ok := driverConn.(*conn)
   930  		if !ok {
   931  			return fmt.Errorf("sqlite.ExecScript: sql.Conn is not the sqlite driver: %T", driverConn)
   932  		}
   933  
   934  		for {
   935  			queries = strings.TrimSpace(queries)
   936  			if queries == "" {
   937  				return nil
   938  			}
   939  			cstmt, rem, err := c.db.Prepare(queries, 0)
   940  			if err != nil {
   941  				return reserr(c.db, "ExecScript", queries, err)
   942  			}
   943  			queries = rem
   944  			_, err = cstmt.Step(nil)
   945  			cstmt.Finalize()
   946  			if err != nil {
   947  				// TODO(crawshaw): consider checking sqlite3_txn_state
   948  				// here and issuing a rollback, incase this script was:
   949  				//	BEGIN; BAD-SQL; COMMIT;
   950  				// So we don't leave the connection open.
   951  				return reserr(c.db, "ExecScript", queries, err)
   952  			}
   953  		}
   954  	})
   955  }
   956  
   957  // BusyTimeout calls sqlite3_busy_timeout on the underlying connection.
   958  func BusyTimeout(sqlconn SQLConn, d time.Duration) error {
   959  	return sqlconn.Raw(func(driverConn any) error {
   960  		c, ok := driverConn.(*conn)
   961  		if !ok {
   962  			return fmt.Errorf("sqlite.BusyTimeout: sql.Conn is not the sqlite driver: %T", driverConn)
   963  		}
   964  		c.db.BusyTimeout(d)
   965  		return nil
   966  	})
   967  }
   968  
   969  // SetWALHook calls sqlite3_wal_hook.
   970  //
   971  // If hook is nil, the hook is removed.
   972  func SetWALHook(sqlconn SQLConn, hook func(dbName string, pages int)) error {
   973  	return sqlconn.Raw(func(driverConn any) error {
   974  		c, ok := driverConn.(*conn)
   975  		if !ok {
   976  			return fmt.Errorf("sqlite.TxnState: sql.Conn is not the sqlite driver: %T", driverConn)
   977  		}
   978  		c.db.SetWALHook(hook)
   979  		return nil
   980  	})
   981  }
   982  
   983  // TxnState calls sqlite3_txn_state on the underlying connection.
   984  func TxnState(sqlconn SQLConn, schema string) (state sqliteh.TxnState, err error) {
   985  	return state, sqlconn.Raw(func(driverConn any) error {
   986  		c, ok := driverConn.(*conn)
   987  		if !ok {
   988  			return fmt.Errorf("sqlite.TxnState: sql.Conn is not the sqlite driver: %T", driverConn)
   989  		}
   990  		state = c.db.TxnState(schema)
   991  		return nil
   992  	})
   993  }
   994  
   995  // Checkpoint calls sqlite3_wal_checkpoint_v2 on the underlying connection.
   996  func Checkpoint(sqlconn SQLConn, dbName string, mode sqliteh.Checkpoint) (numFrames, numFramesCheckpointed int, err error) {
   997  	err = sqlconn.Raw(func(driverConn any) error {
   998  		c, ok := driverConn.(*conn)
   999  		if !ok {
  1000  			return fmt.Errorf("sqlite.Checkpoint: sql.Conn is not the sqlite driver: %T", driverConn)
  1001  		}
  1002  		numFrames, numFramesCheckpointed, err = c.db.Checkpoint(dbName, mode)
  1003  		return reserr(c.db, "Checkpoint", dbName, err)
  1004  	})
  1005  	return numFrames, numFramesCheckpointed, err
  1006  }
  1007  
  1008  // WithPersist makes a ctx instruct the sqlite driver to persist a prepared query.
  1009  //
  1010  // This should be used with recurring queries to avoid constant parsing and
  1011  // planning of the query by SQLite.
  1012  func WithPersist(ctx context.Context) context.Context {
  1013  	return context.WithValue(ctx, persistQuery{}, persistQuery{})
  1014  }
  1015  
  1016  // persistQuery is used as a context value.
  1017  type persistQuery struct{}
  1018  
  1019  // WithQueryCancel makes a ctx that instructs the sqlite driver to explicitly
  1020  // interrupt a running query if its argument context ends.  By default, without
  1021  // this option, queries will only check the context between steps.
  1022  func WithQueryCancel(ctx context.Context) context.Context {
  1023  	return context.WithValue(ctx, queryCancelKey{}, queryCancelKey{})
  1024  }
  1025  
  1026  // queryCancelKey is a context key for query context enforcement.
  1027  type queryCancelKey struct{}