github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/conn.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"net/url"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/ncruces/go-sqlite3/internal/util"
    12  	"github.com/ncruces/go-sqlite3/vfs"
    13  	"github.com/tetratelabs/wazero/api"
    14  )
    15  
    16  // Conn is a database connection handle.
    17  // A Conn is not safe for concurrent use by multiple goroutines.
    18  //
    19  // https://sqlite.org/c3ref/sqlite3.html
    20  type Conn struct {
    21  	*sqlite
    22  
    23  	interrupt  context.Context
    24  	pending    *Stmt
    25  	busy       func(int) bool
    26  	log        func(xErrorCode, string)
    27  	collation  func(*Conn, string)
    28  	authorizer func(AuthorizerActionCode, string, string, string, string) AuthorizerReturnCode
    29  	update     func(AuthorizerActionCode, string, string, int64)
    30  	commit     func() bool
    31  	rollback   func()
    32  	wal        func(*Conn, string, int) error
    33  	arena      arena
    34  
    35  	handle uint32
    36  }
    37  
    38  // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
    39  func Open(filename string) (*Conn, error) {
    40  	return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
    41  }
    42  
    43  // OpenFlags opens an SQLite database file as specified by the filename argument.
    44  //
    45  // If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
    46  // If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
    47  //
    48  //	sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
    49  //
    50  // https://sqlite.org/c3ref/open.html
    51  func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
    52  	if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
    53  		flags |= OPEN_READWRITE | OPEN_CREATE
    54  	}
    55  	return newConn(filename, flags)
    56  }
    57  
    58  type connKey struct{}
    59  
    60  func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
    61  	sqlite, err := instantiateSQLite()
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	defer func() {
    66  		if conn == nil {
    67  			sqlite.close()
    68  		}
    69  	}()
    70  
    71  	c := &Conn{sqlite: sqlite}
    72  	c.arena = c.newArena(1024)
    73  	c.ctx = context.WithValue(c.ctx, connKey{}, c)
    74  	c.handle, err = c.openDB(filename, flags)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	return c, nil
    79  }
    80  
    81  func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
    82  	defer c.arena.mark()()
    83  	connPtr := c.arena.new(ptrlen)
    84  	namePtr := c.arena.string(filename)
    85  
    86  	flags |= OPEN_EXRESCODE
    87  	r := c.call("sqlite3_open_v2", uint64(namePtr), uint64(connPtr), uint64(flags), 0)
    88  
    89  	handle := util.ReadUint32(c.mod, connPtr)
    90  	if err := c.sqlite.error(r, handle); err != nil {
    91  		c.closeDB(handle)
    92  		return 0, err
    93  	}
    94  
    95  	if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
    96  		var pragmas strings.Builder
    97  		if _, after, ok := strings.Cut(filename, "?"); ok {
    98  			query, _ := url.ParseQuery(after)
    99  			for _, p := range query["_pragma"] {
   100  				pragmas.WriteString(`PRAGMA `)
   101  				pragmas.WriteString(p)
   102  				pragmas.WriteString(`;`)
   103  			}
   104  		}
   105  		if pragmas.Len() != 0 {
   106  			pragmaPtr := c.arena.string(pragmas.String())
   107  			r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
   108  			if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
   109  				err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
   110  				c.closeDB(handle)
   111  				return 0, err
   112  			}
   113  		}
   114  	}
   115  	c.call("sqlite3_progress_handler_go", uint64(handle), 100)
   116  	return handle, nil
   117  }
   118  
   119  func (c *Conn) closeDB(handle uint32) {
   120  	r := c.call("sqlite3_close_v2", uint64(handle))
   121  	if err := c.sqlite.error(r, handle); err != nil {
   122  		panic(err)
   123  	}
   124  }
   125  
   126  // Close closes the database connection.
   127  //
   128  // If the database connection is associated with unfinalized prepared statements,
   129  // open blob handles, and/or unfinished backup objects,
   130  // Close will leave the database connection open and return [BUSY].
   131  //
   132  // It is safe to close a nil, zero or closed Conn.
   133  //
   134  // https://sqlite.org/c3ref/close.html
   135  func (c *Conn) Close() error {
   136  	if c == nil || c.handle == 0 {
   137  		return nil
   138  	}
   139  
   140  	c.pending.Close()
   141  	c.pending = nil
   142  
   143  	r := c.call("sqlite3_close", uint64(c.handle))
   144  	if err := c.error(r); err != nil {
   145  		return err
   146  	}
   147  
   148  	c.handle = 0
   149  	return c.close()
   150  }
   151  
   152  // Exec is a convenience function that allows an application to run
   153  // multiple statements of SQL without having to use a lot of code.
   154  //
   155  // https://sqlite.org/c3ref/exec.html
   156  func (c *Conn) Exec(sql string) error {
   157  	c.checkInterrupt()
   158  	defer c.arena.mark()()
   159  	sqlPtr := c.arena.string(sql)
   160  
   161  	r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
   162  	return c.error(r, sql)
   163  }
   164  
   165  // Prepare calls [Conn.PrepareFlags] with no flags.
   166  func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
   167  	return c.PrepareFlags(sql, 0)
   168  }
   169  
   170  // PrepareFlags compiles the first SQL statement in sql;
   171  // tail is left pointing to what remains uncompiled.
   172  // If the input text contains no SQL (if the input is an empty string or a comment),
   173  // both stmt and err will be nil.
   174  //
   175  // https://sqlite.org/c3ref/prepare.html
   176  func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
   177  	if len(sql) > _MAX_SQL_LENGTH {
   178  		return nil, "", TOOBIG
   179  	}
   180  
   181  	defer c.arena.mark()()
   182  	stmtPtr := c.arena.new(ptrlen)
   183  	tailPtr := c.arena.new(ptrlen)
   184  	sqlPtr := c.arena.string(sql)
   185  
   186  	r := c.call("sqlite3_prepare_v3", uint64(c.handle),
   187  		uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
   188  		uint64(stmtPtr), uint64(tailPtr))
   189  
   190  	stmt = &Stmt{c: c}
   191  	stmt.handle = util.ReadUint32(c.mod, stmtPtr)
   192  	if sql := sql[util.ReadUint32(c.mod, tailPtr)-sqlPtr:]; sql != "" {
   193  		tail = sql
   194  	}
   195  
   196  	if err := c.error(r, sql); err != nil {
   197  		return nil, "", err
   198  	}
   199  	if stmt.handle == 0 {
   200  		return nil, "", nil
   201  	}
   202  	return stmt, tail, nil
   203  }
   204  
   205  // DBName returns the schema name for n-th database on the database connection.
   206  //
   207  // https://sqlite.org/c3ref/db_name.html
   208  func (c *Conn) DBName(n int) string {
   209  	r := c.call("sqlite3_db_name", uint64(c.handle), uint64(n))
   210  
   211  	ptr := uint32(r)
   212  	if ptr == 0 {
   213  		return ""
   214  	}
   215  	return util.ReadString(c.mod, ptr, _MAX_NAME)
   216  }
   217  
   218  // Filename returns the filename for a database.
   219  //
   220  // https://sqlite.org/c3ref/db_filename.html
   221  func (c *Conn) Filename(schema string) *vfs.Filename {
   222  	var ptr uint32
   223  	if schema != "" {
   224  		defer c.arena.mark()()
   225  		ptr = c.arena.string(schema)
   226  	}
   227  
   228  	r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))
   229  	return vfs.OpenFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB)
   230  }
   231  
   232  // ReadOnly determines if a database is read-only.
   233  //
   234  // https://sqlite.org/c3ref/db_readonly.html
   235  func (c *Conn) ReadOnly(schema string) (ro bool, ok bool) {
   236  	var ptr uint32
   237  	if schema != "" {
   238  		defer c.arena.mark()()
   239  		ptr = c.arena.string(schema)
   240  	}
   241  	r := c.call("sqlite3_db_readonly", uint64(c.handle), uint64(ptr))
   242  	return int32(r) > 0, int32(r) < 0
   243  }
   244  
   245  // GetAutocommit tests the connection for auto-commit mode.
   246  //
   247  // https://sqlite.org/c3ref/get_autocommit.html
   248  func (c *Conn) GetAutocommit() bool {
   249  	r := c.call("sqlite3_get_autocommit", uint64(c.handle))
   250  	return r != 0
   251  }
   252  
   253  // LastInsertRowID returns the rowid of the most recent successful INSERT
   254  // on the database connection.
   255  //
   256  // https://sqlite.org/c3ref/last_insert_rowid.html
   257  func (c *Conn) LastInsertRowID() int64 {
   258  	r := c.call("sqlite3_last_insert_rowid", uint64(c.handle))
   259  	return int64(r)
   260  }
   261  
   262  // SetLastInsertRowID allows the application to set the value returned by
   263  // [Conn.LastInsertRowID].
   264  //
   265  // https://sqlite.org/c3ref/set_last_insert_rowid.html
   266  func (c *Conn) SetLastInsertRowID(id int64) {
   267  	c.call("sqlite3_set_last_insert_rowid", uint64(c.handle), uint64(id))
   268  }
   269  
   270  // Changes returns the number of rows modified, inserted or deleted
   271  // by the most recently completed INSERT, UPDATE or DELETE statement
   272  // on the database connection.
   273  //
   274  // https://sqlite.org/c3ref/changes.html
   275  func (c *Conn) Changes() int64 {
   276  	r := c.call("sqlite3_changes64", uint64(c.handle))
   277  	return int64(r)
   278  }
   279  
   280  // TotalChanges returns the number of rows modified, inserted or deleted
   281  // by all INSERT, UPDATE or DELETE statements completed
   282  // since the database connection was opened.
   283  //
   284  // https://sqlite.org/c3ref/total_changes.html
   285  func (c *Conn) TotalChanges() int64 {
   286  	r := c.call("sqlite3_total_changes64", uint64(c.handle))
   287  	return int64(r)
   288  }
   289  
   290  // ReleaseMemory frees memory used by a database connection.
   291  //
   292  // https://sqlite.org/c3ref/db_release_memory.html
   293  func (c *Conn) ReleaseMemory() error {
   294  	r := c.call("sqlite3_db_release_memory", uint64(c.handle))
   295  	return c.error(r)
   296  }
   297  
   298  // GetInterrupt gets the context set with [Conn.SetInterrupt],
   299  // or nil if none was set.
   300  func (c *Conn) GetInterrupt() context.Context {
   301  	return c.interrupt
   302  }
   303  
   304  // SetInterrupt interrupts a long-running query when a context is done.
   305  //
   306  // Subsequent uses of the connection will return [INTERRUPT]
   307  // until the context is reset by another call to SetInterrupt.
   308  //
   309  // To associate a timeout with a connection:
   310  //
   311  //	ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
   312  //	conn.SetInterrupt(ctx)
   313  //	defer cancel()
   314  //
   315  // SetInterrupt returns the old context assigned to the connection.
   316  //
   317  // https://sqlite.org/c3ref/interrupt.html
   318  func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
   319  	// Is it the same context?
   320  	if ctx == c.interrupt {
   321  		return ctx
   322  	}
   323  
   324  	// A busy SQL statement prevents SQLite from ignoring an interrupt
   325  	// that comes before any other statements are started.
   326  	if c.pending == nil {
   327  		c.pending, _, _ = c.Prepare(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
   328  	}
   329  
   330  	old = c.interrupt
   331  	c.interrupt = ctx
   332  
   333  	if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) {
   334  		c.pending.Reset()
   335  	}
   336  	if ctx != nil && ctx.Done() != nil {
   337  		c.pending.Step()
   338  	}
   339  	return old
   340  }
   341  
   342  func (c *Conn) checkInterrupt() {
   343  	if c.interrupt != nil && c.interrupt.Err() != nil {
   344  		c.call("sqlite3_interrupt", uint64(c.handle))
   345  	}
   346  }
   347  
   348  func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) {
   349  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB &&
   350  		c.interrupt != nil && c.interrupt.Err() != nil {
   351  		interrupt = 1
   352  	}
   353  	return interrupt
   354  }
   355  
   356  // BusyTimeout sets a busy timeout.
   357  //
   358  // https://sqlite.org/c3ref/busy_timeout.html
   359  func (c *Conn) BusyTimeout(timeout time.Duration) error {
   360  	ms := min((timeout+time.Millisecond-1)/time.Millisecond, math.MaxInt32)
   361  	r := c.call("sqlite3_busy_timeout", uint64(c.handle), uint64(ms))
   362  	return c.error(r)
   363  }
   364  
   365  func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) {
   366  	if c, ok := ctx.Value(connKey{}).(*Conn); ok &&
   367  		(c.interrupt == nil || c.interrupt.Err() == nil) {
   368  		const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64"
   369  		const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4"
   370  		const ndelay = int32(len(delays) - 1)
   371  
   372  		var delay, prior int32
   373  		if count <= ndelay {
   374  			delay = int32(delays[count])
   375  			prior = int32(totals[count])
   376  		} else {
   377  			delay = int32(delays[ndelay])
   378  			prior = int32(totals[ndelay]) + delay*(count-ndelay)
   379  		}
   380  
   381  		if delay = min(delay, tmout-prior); delay > 0 {
   382  			time.Sleep(time.Duration(delay) * time.Millisecond)
   383  			retry = 1
   384  		}
   385  	}
   386  	return retry
   387  }
   388  
   389  // BusyHandler registers a callback to handle [BUSY] errors.
   390  //
   391  // https://sqlite.org/c3ref/busy_handler.html
   392  func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
   393  	var enable uint64
   394  	if cb != nil {
   395  		enable = 1
   396  	}
   397  	r := c.call("sqlite3_busy_handler_go", uint64(c.handle), enable)
   398  	if err := c.error(r); err != nil {
   399  		return err
   400  	}
   401  	c.busy = cb
   402  	return nil
   403  }
   404  
   405  func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
   406  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil &&
   407  		(c.interrupt == nil || c.interrupt.Err() == nil) {
   408  		if c.busy(int(count)) {
   409  			retry = 1
   410  		}
   411  	}
   412  	return retry
   413  }
   414  
   415  func (c *Conn) error(rc uint64, sql ...string) error {
   416  	return c.sqlite.error(rc, c.handle, sql...)
   417  }
   418  
   419  // DriverConn is implemented by the SQLite [database/sql] driver connection.
   420  //
   421  // It can be used to access SQLite features like [online backup].
   422  //
   423  // [online backup]: https://sqlite.org/backup.html
   424  type DriverConn interface {
   425  	Raw() *Conn
   426  }