github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/driver/driver.go (about)

     1  // Package driver provides a database/sql driver for SQLite.
     2  //
     3  // Importing package driver registers a [database/sql] driver named "sqlite3".
     4  // You may also need to import package embed.
     5  //
     6  //	import _ "github.com/ncruces/go-sqlite3/driver"
     7  //	import _ "github.com/ncruces/go-sqlite3/embed"
     8  //
     9  // The data source name for "sqlite3" databases can be a filename or a "file:" [URI].
    10  //
    11  // The [TRANSACTION] mode can be specified using "_txlock":
    12  //
    13  //	sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
    14  //
    15  // Possible values are: "deferred", "immediate", "exclusive".
    16  // A [read-only] transaction is always "deferred", regardless of "_txlock".
    17  //
    18  // The time encoding/decoding format can be specified using "_timefmt":
    19  //
    20  //	sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
    21  //
    22  // Possible values are: "auto" (the default), "sqlite", "rfc3339";
    23  // "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
    24  // "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
    25  // "rfc3339" encodes and decodes RFC 3339 only.
    26  //
    27  // [PRAGMA] statements can be specified using "_pragma":
    28  //
    29  //	sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
    30  //
    31  // If no PRAGMAs are specified, a busy timeout of 1 minute is set.
    32  //
    33  // Order matters:
    34  // busy timeout and locking mode should be the first PRAGMAs set, in that order.
    35  //
    36  // [URI]: https://sqlite.org/uri.html
    37  // [PRAGMA]: https://sqlite.org/pragma.html
    38  // [format]: https://sqlite.org/lang_datefunc.html#time_values
    39  // [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
    40  // [read-only]: https://pkg.go.dev/database/sql#TxOptions
    41  package driver
    42  
    43  import (
    44  	"context"
    45  	"database/sql"
    46  	"database/sql/driver"
    47  	"errors"
    48  	"fmt"
    49  	"io"
    50  	"net/url"
    51  	"strings"
    52  	"time"
    53  	"unsafe"
    54  
    55  	"github.com/ncruces/go-sqlite3"
    56  	"github.com/ncruces/go-sqlite3/internal/util"
    57  )
    58  
    59  // This variable can be replaced with -ldflags:
    60  //
    61  //	go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite"
    62  var driverName = "sqlite3"
    63  
    64  func init() {
    65  	if driverName != "" {
    66  		sql.Register(driverName, &SQLite{})
    67  	}
    68  }
    69  
    70  // Open opens the SQLite database specified by dataSourceName as a [database/sql.DB].
    71  //
    72  // The init function is called by the driver on new connections.
    73  // The [sqlite3.Conn] can be used to execute queries, register functions, etc.
    74  // Any error returned closes the connection and is returned to [database/sql].
    75  func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) {
    76  	c, err := (&SQLite{Init: init}).OpenConnector(dataSourceName)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	return sql.OpenDB(c), nil
    81  }
    82  
    83  // SQLite implements [database/sql/driver.Driver].
    84  type SQLite struct {
    85  	// Init function is called by the driver on new connections.
    86  	// The [sqlite3.Conn] can be used to execute queries, register functions, etc.
    87  	// Any error returned closes the connection and is returned to [database/sql].
    88  	Init func(*sqlite3.Conn) error
    89  }
    90  
    91  // Open implements [database/sql/driver.Driver].
    92  func (d *SQLite) Open(name string) (driver.Conn, error) {
    93  	c, err := d.newConnector(name)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return c.Connect(context.Background())
    98  }
    99  
   100  // OpenConnector implements [database/sql/driver.DriverContext].
   101  func (d *SQLite) OpenConnector(name string) (driver.Connector, error) {
   102  	return d.newConnector(name)
   103  }
   104  
   105  func (d *SQLite) newConnector(name string) (*connector, error) {
   106  	c := connector{driver: d, name: name}
   107  
   108  	var txlock, timefmt string
   109  	if strings.HasPrefix(name, "file:") {
   110  		if _, after, ok := strings.Cut(name, "?"); ok {
   111  			query, err := url.ParseQuery(after)
   112  			if err != nil {
   113  				return nil, err
   114  			}
   115  			txlock = query.Get("_txlock")
   116  			timefmt = query.Get("_timefmt")
   117  			c.pragmas = query.Has("_pragma")
   118  		}
   119  	}
   120  
   121  	switch txlock {
   122  	case "":
   123  		c.txBegin = "BEGIN"
   124  	case "deferred", "immediate", "exclusive":
   125  		c.txBegin = "BEGIN " + txlock
   126  	default:
   127  		return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock)
   128  	}
   129  
   130  	switch timefmt {
   131  	case "":
   132  		c.tmRead = sqlite3.TimeFormatAuto
   133  		c.tmWrite = sqlite3.TimeFormatDefault
   134  	case "sqlite":
   135  		c.tmRead = sqlite3.TimeFormatAuto
   136  		c.tmWrite = sqlite3.TimeFormat3
   137  	case "rfc3339":
   138  		c.tmRead = sqlite3.TimeFormatDefault
   139  		c.tmWrite = sqlite3.TimeFormatDefault
   140  	default:
   141  		c.tmRead = sqlite3.TimeFormat(timefmt)
   142  		c.tmWrite = sqlite3.TimeFormat(timefmt)
   143  	}
   144  	return &c, nil
   145  }
   146  
   147  type connector struct {
   148  	driver  *SQLite
   149  	name    string
   150  	txBegin string
   151  	tmRead  sqlite3.TimeFormat
   152  	tmWrite sqlite3.TimeFormat
   153  	pragmas bool
   154  }
   155  
   156  func (n *connector) Driver() driver.Driver {
   157  	return n.driver
   158  }
   159  
   160  func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
   161  	c := &conn{
   162  		txBegin: n.txBegin,
   163  		tmRead:  n.tmRead,
   164  		tmWrite: n.tmWrite,
   165  	}
   166  
   167  	c.Conn, err = sqlite3.Open(n.name)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	defer func() {
   172  		if err != nil {
   173  			c.Close()
   174  		}
   175  	}()
   176  
   177  	old := c.Conn.SetInterrupt(ctx)
   178  	defer c.Conn.SetInterrupt(old)
   179  
   180  	if !n.pragmas {
   181  		err = c.Conn.BusyTimeout(60 * time.Second)
   182  		if err != nil {
   183  			return nil, err
   184  		}
   185  	}
   186  	if n.driver.Init != nil {
   187  		err = n.driver.Init(c.Conn)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  	}
   192  	if n.pragmas || n.driver.Init != nil {
   193  		s, _, err := c.Conn.Prepare(`PRAGMA query_only`)
   194  		if err != nil {
   195  			return nil, err
   196  		}
   197  		if s.Step() && s.ColumnBool(0) {
   198  			c.readOnly = '1'
   199  		} else {
   200  			c.readOnly = '0'
   201  		}
   202  		err = s.Close()
   203  		if err != nil {
   204  			return nil, err
   205  		}
   206  	}
   207  	return c, nil
   208  }
   209  
   210  type conn struct {
   211  	*sqlite3.Conn
   212  	txBegin    string
   213  	txCommit   string
   214  	txRollback string
   215  	tmRead     sqlite3.TimeFormat
   216  	tmWrite    sqlite3.TimeFormat
   217  	readOnly   byte
   218  }
   219  
   220  var (
   221  	// Ensure these interfaces are implemented:
   222  	_ driver.ConnPrepareContext = &conn{}
   223  	_ driver.ExecerContext      = &conn{}
   224  	_ driver.ConnBeginTx        = &conn{}
   225  	_ sqlite3.DriverConn        = &conn{}
   226  )
   227  
   228  func (c *conn) Raw() *sqlite3.Conn {
   229  	return c.Conn
   230  }
   231  
   232  func (c *conn) Begin() (driver.Tx, error) {
   233  	return c.BeginTx(context.Background(), driver.TxOptions{})
   234  }
   235  
   236  func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   237  	txBegin := c.txBegin
   238  	c.txCommit = `COMMIT`
   239  	c.txRollback = `ROLLBACK`
   240  
   241  	if opts.ReadOnly {
   242  		txBegin = `
   243  			BEGIN deferred;
   244  			PRAGMA query_only=on`
   245  		c.txRollback = `
   246  			ROLLBACK;
   247  			PRAGMA query_only=` + string(c.readOnly)
   248  		c.txCommit = c.txRollback
   249  	}
   250  
   251  	switch opts.Isolation {
   252  	default:
   253  		return nil, util.IsolationErr
   254  	case
   255  		driver.IsolationLevel(sql.LevelDefault),
   256  		driver.IsolationLevel(sql.LevelSerializable):
   257  		break
   258  	}
   259  
   260  	old := c.Conn.SetInterrupt(ctx)
   261  	defer c.Conn.SetInterrupt(old)
   262  
   263  	err := c.Conn.Exec(txBegin)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	return c, nil
   268  }
   269  
   270  func (c *conn) Commit() error {
   271  	err := c.Conn.Exec(c.txCommit)
   272  	if err != nil && !c.Conn.GetAutocommit() {
   273  		c.Rollback()
   274  	}
   275  	return err
   276  }
   277  
   278  func (c *conn) Rollback() error {
   279  	err := c.Conn.Exec(c.txRollback)
   280  	if errors.Is(err, sqlite3.INTERRUPT) {
   281  		old := c.Conn.SetInterrupt(context.Background())
   282  		defer c.Conn.SetInterrupt(old)
   283  		err = c.Conn.Exec(c.txRollback)
   284  	}
   285  	return err
   286  }
   287  
   288  func (c *conn) Prepare(query string) (driver.Stmt, error) {
   289  	return c.PrepareContext(context.Background(), query)
   290  }
   291  
   292  func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   293  	old := c.Conn.SetInterrupt(ctx)
   294  	defer c.Conn.SetInterrupt(old)
   295  
   296  	s, tail, err := c.Conn.Prepare(query)
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  	if tail != "" {
   301  		s.Close()
   302  		return nil, util.TailErr
   303  	}
   304  	return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite}, nil
   305  }
   306  
   307  func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   308  	if len(args) != 0 {
   309  		// Slow path.
   310  		return nil, driver.ErrSkip
   311  	}
   312  
   313  	if savept, ok := ctx.(*saveptCtx); ok {
   314  		// Called from driver.Savepoint.
   315  		savept.Savepoint = c.Conn.Savepoint()
   316  		return resultRowsAffected(0), nil
   317  	}
   318  
   319  	old := c.Conn.SetInterrupt(ctx)
   320  	defer c.Conn.SetInterrupt(old)
   321  
   322  	err := c.Conn.Exec(query)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  
   327  	return newResult(c.Conn), nil
   328  }
   329  
   330  func (c *conn) CheckNamedValue(arg *driver.NamedValue) error {
   331  	return nil
   332  }
   333  
   334  type stmt struct {
   335  	*sqlite3.Stmt
   336  	tmWrite sqlite3.TimeFormat
   337  	tmRead  sqlite3.TimeFormat
   338  }
   339  
   340  var (
   341  	// Ensure these interfaces are implemented:
   342  	_ driver.StmtExecContext   = &stmt{}
   343  	_ driver.StmtQueryContext  = &stmt{}
   344  	_ driver.NamedValueChecker = &stmt{}
   345  )
   346  
   347  func (s *stmt) NumInput() int {
   348  	n := s.Stmt.BindCount()
   349  	for i := 1; i <= n; i++ {
   350  		if s.Stmt.BindName(i) != "" {
   351  			return -1
   352  		}
   353  	}
   354  	return n
   355  }
   356  
   357  // Deprecated: use ExecContext instead.
   358  func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
   359  	return s.ExecContext(context.Background(), namedValues(args))
   360  }
   361  
   362  // Deprecated: use QueryContext instead.
   363  func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
   364  	return s.QueryContext(context.Background(), namedValues(args))
   365  }
   366  
   367  func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   368  	err := s.setupBindings(args)
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	old := s.Stmt.Conn().SetInterrupt(ctx)
   374  	defer s.Stmt.Conn().SetInterrupt(old)
   375  
   376  	err = s.Stmt.Exec()
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	return newResult(s.Stmt.Conn()), nil
   382  }
   383  
   384  func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   385  	err := s.setupBindings(args)
   386  	if err != nil {
   387  		return nil, err
   388  	}
   389  	return &rows{ctx: ctx, stmt: s}, nil
   390  }
   391  
   392  func (s *stmt) setupBindings(args []driver.NamedValue) error {
   393  	err := s.Stmt.ClearBindings()
   394  	if err != nil {
   395  		return err
   396  	}
   397  
   398  	var ids [3]int
   399  	for _, arg := range args {
   400  		ids := ids[:0]
   401  		if arg.Name == "" {
   402  			ids = append(ids, arg.Ordinal)
   403  		} else {
   404  			for _, prefix := range []string{":", "@", "$"} {
   405  				if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
   406  					ids = append(ids, id)
   407  				}
   408  			}
   409  		}
   410  
   411  		for _, id := range ids {
   412  			switch a := arg.Value.(type) {
   413  			case bool:
   414  				err = s.Stmt.BindBool(id, a)
   415  			case int:
   416  				err = s.Stmt.BindInt(id, a)
   417  			case int64:
   418  				err = s.Stmt.BindInt64(id, a)
   419  			case float64:
   420  				err = s.Stmt.BindFloat(id, a)
   421  			case string:
   422  				err = s.Stmt.BindText(id, a)
   423  			case []byte:
   424  				err = s.Stmt.BindBlob(id, a)
   425  			case sqlite3.ZeroBlob:
   426  				err = s.Stmt.BindZeroBlob(id, int64(a))
   427  			case time.Time:
   428  				err = s.Stmt.BindTime(id, a, s.tmWrite)
   429  			case util.JSON:
   430  				err = s.Stmt.BindJSON(id, a.Value)
   431  			case util.PointerUnwrap:
   432  				err = s.Stmt.BindPointer(id, util.UnwrapPointer(a))
   433  			case nil:
   434  				err = s.Stmt.BindNull(id)
   435  			default:
   436  				panic(util.AssertErr())
   437  			}
   438  		}
   439  		if err != nil {
   440  			return err
   441  		}
   442  	}
   443  	return nil
   444  }
   445  
   446  func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
   447  	switch arg.Value.(type) {
   448  	case bool, int, int64, float64, string, []byte,
   449  		time.Time, sqlite3.ZeroBlob,
   450  		util.JSON, util.PointerUnwrap,
   451  		nil:
   452  		return nil
   453  	default:
   454  		return driver.ErrSkip
   455  	}
   456  }
   457  
   458  func newResult(c *sqlite3.Conn) driver.Result {
   459  	rows := c.Changes()
   460  	if rows != 0 {
   461  		id := c.LastInsertRowID()
   462  		if id != 0 {
   463  			return result{id, rows}
   464  		}
   465  	}
   466  	return resultRowsAffected(rows)
   467  }
   468  
   469  type result struct{ lastInsertId, rowsAffected int64 }
   470  
   471  func (r result) LastInsertId() (int64, error) {
   472  	return r.lastInsertId, nil
   473  }
   474  
   475  func (r result) RowsAffected() (int64, error) {
   476  	return r.rowsAffected, nil
   477  }
   478  
   479  type resultRowsAffected int64
   480  
   481  func (r resultRowsAffected) LastInsertId() (int64, error) {
   482  	return 0, nil
   483  }
   484  
   485  func (r resultRowsAffected) RowsAffected() (int64, error) {
   486  	return int64(r), nil
   487  }
   488  
   489  type rows struct {
   490  	ctx context.Context
   491  	*stmt
   492  	names []string
   493  	types []string
   494  }
   495  
   496  func (r *rows) Close() error {
   497  	r.Stmt.ClearBindings()
   498  	return r.Stmt.Reset()
   499  }
   500  
   501  func (r *rows) Columns() []string {
   502  	if r.names == nil {
   503  		count := r.Stmt.ColumnCount()
   504  		r.names = make([]string, count)
   505  		for i := range r.names {
   506  			r.names[i] = r.Stmt.ColumnName(i)
   507  		}
   508  	}
   509  	return r.names
   510  }
   511  
   512  func (r *rows) declType(index int) string {
   513  	if r.types == nil {
   514  		count := r.Stmt.ColumnCount()
   515  		r.types = make([]string, count)
   516  		for i := range r.types {
   517  			r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
   518  		}
   519  	}
   520  	return r.types[index]
   521  }
   522  
   523  func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
   524  	decltype := r.declType(index)
   525  	if len := len(decltype); len > 0 && decltype[len-1] == ')' {
   526  		if i := strings.LastIndexByte(decltype, '('); i >= 0 {
   527  			decltype = decltype[:i]
   528  		}
   529  	}
   530  	return strings.TrimSpace(decltype)
   531  }
   532  
   533  func (r *rows) Next(dest []driver.Value) error {
   534  	old := r.Stmt.Conn().SetInterrupt(r.ctx)
   535  	defer r.Stmt.Conn().SetInterrupt(old)
   536  
   537  	if !r.Stmt.Step() {
   538  		if err := r.Stmt.Err(); err != nil {
   539  			return err
   540  		}
   541  		return io.EOF
   542  	}
   543  
   544  	data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
   545  	err := r.Stmt.Columns(data)
   546  	for i := range dest {
   547  		if t, ok := r.decodeTime(i, dest[i]); ok {
   548  			dest[i] = t
   549  			continue
   550  		}
   551  		if s, ok := dest[i].(string); ok {
   552  			t, ok := maybeTime(s)
   553  			if ok {
   554  				dest[i] = t
   555  			}
   556  		}
   557  	}
   558  	return err
   559  }
   560  
   561  func (r *rows) decodeTime(i int, v any) (_ time.Time, _ bool) {
   562  	if r.tmRead == sqlite3.TimeFormatDefault {
   563  		return
   564  	}
   565  	switch r.declType(i) {
   566  	case "DATE", "TIME", "DATETIME", "TIMESTAMP":
   567  		// maybe
   568  	default:
   569  		return
   570  	}
   571  	switch v.(type) {
   572  	case int64, float64, string:
   573  		// maybe
   574  	default:
   575  		return
   576  	}
   577  	t, err := r.tmRead.Decode(v)
   578  	return t, err == nil
   579  }