github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/cgosqlite/cgosqlite.go (about)

     1  package cgosqlite
     2  
     3  // This list of compiler options is heavily influenced by:
     4  //
     5  // https://www.sqlite.org/compile.html#recommended_compile_time_options
     6  //
     7  // One exception is we do not use SQLITE_OMIT_DECLTYPE, as the design
     8  // of the database/sql driver seems to require it.
     9  
    10  // #cgo CFLAGS: -DSQLITE_THREADSAFE=2
    11  // #cgo CFLAGS: -DSQLITE_DQS=0
    12  // #cgo CFLAGS: -DSQLITE_DEFAULT_MEMSTATUS=0
    13  // #cgo CFLAGS: -DSQLITE_DEFAULT_WAL_SYNCHRONOUS=1
    14  // #cgo CFLAGS: -DSQLITE_LIKE_DOESNT_MATCH_BLOBS
    15  // #cgo CFLAGS: -DSQLITE_MAX_EXPR_DEPTH=0
    16  // #cgo CFLAGS: -DSQLITE_OMIT_DEPRECATED
    17  // #cgo CFLAGS: -DSQLITE_OMIT_PROGRESS_CALLBACK
    18  // #cgo CFLAGS: -DSQLITE_OMIT_SHARED_CACHE
    19  // #cgo CFLAGS: -DSQLITE_USE_ALLOCA
    20  // #cgo CFLAGS: -DSQLITE_OMIT_AUTOINIT
    21  // #cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
    22  // #cgo CFLAGS: -DSQLITE_ENABLE_FTS5
    23  // #cgo CFLAGS: -DSQLITE_ENABLE_RTREE
    24  // #cgo CFLAGS: -DSQLITE_ENABLE_JSON1
    25  // #cgo CFLAGS: -DSQLITE_ENABLE_SESSION
    26  // #cgo CFLAGS: -DSQLITE_ENABLE_SNAPSHOT
    27  // #cgo CFLAGS: -DSQLITE_ENABLE_PREUPDATE_HOOK
    28  // #cgo CFLAGS: -DSQLITE_ENABLE_COLUMN_METADATA
    29  // #cgo CFLAGS: -DSQLITE_ENABLE_STAT4
    30  // #cgo CFLAGS: -DSQLITE_ENABLE_DBSTAT_VTAB=1
    31  // #cgo CFLAGS: -DHAVE_USLEEP=1
    32  //
    33  // // Select POSIX 2014 at least for clock_gettime.
    34  // #cgo CFLAGS: -D_XOPEN_SOURCE=600
    35  // #cgo linux CFLAGS: -std=c99
    36  //
    37  // // On Android, unlike Linux, there are no separate libpthread or librt
    38  // // libraries. That functionality is included directly in libc, which does not
    39  // // need to be explicitly linked against. See
    40  // // https://developer.android.com/ndk/guides/stable_apis#c_library.
    41  // #cgo android LDFLAGS: -ldl -lm
    42  // #cgo linux,!android LDFLAGS: -ldl -lm -lrt
    43  //
    44  // #include <stdint.h>
    45  // #include <stdlib.h>
    46  // #include <string.h>
    47  // #include <pthread.h>
    48  // #include <sqlite3.h>
    49  // #include <time.h>
    50  // #include "cgosqlite.h"
    51  import "C"
    52  import (
    53  	"sync"
    54  	"time"
    55  	"unsafe"
    56  
    57  	"github.com/tailscale/sqlite/sqliteh"
    58  )
    59  
    60  func init() {
    61  	C.sqlite3_initialize()
    62  }
    63  
    64  // DB implements sqliteh.DB.
    65  type DB struct {
    66  	db *C.sqlite3
    67  
    68  	declTypes map[string]string
    69  }
    70  
    71  // cStmt is a wrapper around an sqlite3 *sqlite3_stmt. Except rather than
    72  // storing it as a pointer, it's stored as uintptr to avoid allocations due to
    73  // poor interactions between cgo's pointer checker and Go's escape analysis.
    74  //
    75  // The ptr method returns the value as a pointer, for call sites that haven't
    76  // yet been optimized or don't need the optimization. This lets us migrate
    77  // incrementally.
    78  //
    79  // See http://go/corp/9919.
    80  type cStmt struct {
    81  	v C.handle_sqlite3_stmt
    82  }
    83  
    84  // cStmtFromPtr returns a cStmt from a C pointer.
    85  func cStmtFromPtr(p *C.sqlite3_stmt) cStmt {
    86  	return cStmt{v: C.handle_sqlite3_stmt(uintptr(unsafe.Pointer(p)))}
    87  }
    88  
    89  func (h cStmt) int() C.handle_sqlite3_stmt { return h.v }
    90  func (h cStmt) ptr() *C.sqlite3_stmt       { return (*C.sqlite3_stmt)(unsafe.Pointer(uintptr(h.v))) }
    91  
    92  // Stmt implements sqliteh.Stmt.
    93  type Stmt struct {
    94  	db    *DB
    95  	stmt  cStmt
    96  	start C.struct_timespec
    97  
    98  	// used as scratch space when calling into cgo
    99  	rowid, changes C.sqlite3_int64
   100  	duration       C.int64_t
   101  }
   102  
   103  // Open implements sqliteh.OpenFunc.
   104  func Open(filename string, flags sqliteh.OpenFlags, vfs string) (sqliteh.DB, error) {
   105  	cfilename := C.CString(filename)
   106  	defer C.free(unsafe.Pointer(cfilename))
   107  
   108  	cvfs := (*C.char)(nil)
   109  	if vfs != "" {
   110  		cvfs = C.CString(vfs)
   111  		defer C.free(unsafe.Pointer(cvfs))
   112  	}
   113  
   114  	var cdb *C.sqlite3
   115  	res := C.sqlite3_open_v2(cfilename, &cdb, C.int(flags), cvfs)
   116  	var db *DB
   117  	if cdb != nil {
   118  		db = &DB{db: cdb}
   119  	}
   120  	return db, errCode(res)
   121  }
   122  
   123  func (db *DB) Close() error {
   124  	// TODO(crawshaw): consider using sqlite3_close_v2, if we are going to use finalizers for cleanup.
   125  	walHookFunc.Delete(db.db)
   126  	res := C.sqlite3_close(db.db)
   127  	return errCode(res)
   128  }
   129  
   130  func (db *DB) Interrupt() {
   131  	C.sqlite3_interrupt(db.db)
   132  }
   133  
   134  func (db *DB) ErrMsg() string {
   135  	return C.GoString(C.sqlite3_errmsg(db.db))
   136  }
   137  
   138  func (db *DB) Changes() int {
   139  	return int(C.sqlite3_changes(db.db))
   140  }
   141  
   142  func (db *DB) TotalChanges() int {
   143  	return int(C.sqlite3_total_changes(db.db))
   144  }
   145  
   146  func (db *DB) ExtendedErrCode() sqliteh.Code {
   147  	return sqliteh.Code(C.sqlite3_extended_errcode(db.db))
   148  }
   149  
   150  func (db *DB) LastInsertRowid() int64 {
   151  	return int64(C.sqlite3_last_insert_rowid(db.db))
   152  }
   153  
   154  func (db *DB) BusyTimeout(d time.Duration) {
   155  	C.sqlite3_busy_timeout(db.db, C.int(d/1e6))
   156  }
   157  
   158  func (db *DB) Checkpoint(dbName string, mode sqliteh.Checkpoint) (int, int, error) {
   159  	var cDB *C.char
   160  	if dbName != "" {
   161  		// Docs say: "If parameter zDb is NULL or points to a zero length string",
   162  		// so they are equivalent here.
   163  		cDB = C.CString(dbName)
   164  		defer C.free(unsafe.Pointer(cDB))
   165  	}
   166  	var nLog, nCkpt C.int
   167  	res := C.sqlite3_wal_checkpoint_v2(db.db, cDB, C.int(mode), &nLog, &nCkpt)
   168  	return int(nLog), int(nCkpt), errCode(res)
   169  }
   170  
   171  func (db *DB) AutoCheckpoint(n int) error {
   172  	res := C.sqlite3_wal_autocheckpoint(db.db, C.int(n))
   173  	return errCode(res)
   174  }
   175  
   176  func (db *DB) SetWALHook(f func(dbName string, pages int)) {
   177  	if f != nil {
   178  		walHookFunc.Store(db.db, walHookCb(f))
   179  	} else {
   180  		walHookFunc.Delete(db.db)
   181  	}
   182  	C.ts_sqlite3_wal_hook_go(db.db)
   183  }
   184  
   185  func (db *DB) TxnState(schema string) sqliteh.TxnState {
   186  	var cSchema *C.char
   187  	if schema != "" {
   188  		cSchema = C.CString(schema)
   189  		defer C.free(unsafe.Pointer(cSchema))
   190  	}
   191  	return sqliteh.TxnState(C.sqlite3_txn_state(db.db, cSchema))
   192  }
   193  
   194  func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqliteh.Stmt, remainingQuery string, err error) {
   195  	csql := C.CString(query)
   196  	defer C.free(unsafe.Pointer(csql))
   197  
   198  	var cstmt *C.sqlite3_stmt
   199  	var csqlTail *C.char
   200  	res := C.sqlite3_prepare_v3(db.db, csql, C.int(len(query))+1, C.uint(prepFlags), &cstmt, &csqlTail)
   201  	if err := errCode(res); err != nil {
   202  		return nil, "", err
   203  	}
   204  	remainingQuery = query[len(query)-int(C.strlen(csqlTail)):]
   205  	return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil
   206  }
   207  
   208  func (stmt *Stmt) DBHandle() sqliteh.DB {
   209  	cdb := C.sqlite3_db_handle(stmt.stmt.ptr())
   210  	if cdb != nil {
   211  		return &DB{db: cdb}
   212  	}
   213  	return nil
   214  }
   215  
   216  func (stmt *Stmt) SQL() string {
   217  	return C.GoString(C.sqlite3_sql(stmt.stmt.ptr()))
   218  }
   219  
   220  func (stmt *Stmt) ExpandedSQL() string {
   221  	return C.GoString(C.sqlite3_expanded_sql(stmt.stmt.ptr()))
   222  }
   223  
   224  func (stmt *Stmt) Reset() error {
   225  	return errCode(C.sqlite3_reset(stmt.stmt.ptr()))
   226  }
   227  
   228  func (stmt *Stmt) Finalize() error {
   229  	return errCode(C.sqlite3_finalize(stmt.stmt.ptr()))
   230  }
   231  
   232  func (stmt *Stmt) ClearBindings() error {
   233  	return errCode(C.sqlite3_clear_bindings(stmt.stmt.ptr()))
   234  }
   235  
   236  func (stmt *Stmt) ResetAndClear() (time.Duration, error) {
   237  	if stmt.start != (C.struct_timespec{}) {
   238  		stmt.duration = 0
   239  		err := errCode(C.reset_and_clear(stmt.stmt.int(), &stmt.start, &stmt.duration))
   240  		return time.Duration(stmt.duration), err
   241  	}
   242  	if sp := stmt.stmt.int(); sp != 0 {
   243  		return 0, errCode(C.reset_and_clear(stmt.stmt.int(), nil, nil))
   244  	}
   245  	// The statement was never initialized. This can happen if, for example, the
   246  	// parser found only comments (so the statement was not empty, but did not
   247  	// yield any instructions).
   248  	return 0, nil
   249  }
   250  
   251  func (stmt *Stmt) StartTimer() {
   252  	C.monotonic_clock_gettime(&stmt.start)
   253  }
   254  
   255  func (stmt *Stmt) ColumnDatabaseName(col int) string {
   256  	return internStringFromCString(C.sqlite3_column_database_name(stmt.stmt.ptr(), C.int(col)))
   257  }
   258  
   259  func (stmt *Stmt) ColumnTableName(col int) string {
   260  	return internStringFromCString(C.sqlite3_column_table_name(stmt.stmt.ptr(), C.int(col)))
   261  }
   262  
   263  func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) {
   264  	var ptr *C.char
   265  	if len(colType) > 0 {
   266  		ptr = (*C.char)(unsafe.Pointer(&colType[0]))
   267  	}
   268  	res := C.ts_sqlite3_step(stmt.stmt.int(), ptr, C.int(len(colType)))
   269  	switch res {
   270  	case C.SQLITE_ROW:
   271  		return true, nil
   272  	case C.SQLITE_DONE:
   273  		return false, nil
   274  	default:
   275  		return false, errCode(res)
   276  	}
   277  }
   278  
   279  func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time.Duration, err error) {
   280  	stmt.rowid, stmt.changes, stmt.duration = 0, 0, 0
   281  	res := C.step_result(stmt.stmt.int(), &stmt.rowid, &stmt.changes, &stmt.duration)
   282  	lastInsertRowID = int64(stmt.rowid)
   283  	changes = int64(stmt.changes)
   284  	d = time.Duration(stmt.duration)
   285  
   286  	switch res {
   287  	case C.SQLITE_ROW:
   288  		return true, lastInsertRowID, changes, d, nil
   289  	case C.SQLITE_DONE:
   290  		return false, lastInsertRowID, changes, d, nil
   291  	default:
   292  		return false, lastInsertRowID, changes, d, errCode(res)
   293  	}
   294  }
   295  
   296  func (stmt *Stmt) BindDouble(col int, val float64) error {
   297  	return errCode(C.ts_sqlite3_bind_double(stmt.stmt.int(), C.int(col), C.double(val)))
   298  }
   299  
   300  func (stmt *Stmt) BindInt64(col int, val int64) error {
   301  	return errCode(C.ts_sqlite3_bind_int64(stmt.stmt.int(), C.int(col), C.sqlite3_int64(val)))
   302  }
   303  
   304  func (stmt *Stmt) BindNull(col int) error {
   305  	return errCode(C.ts_sqlite3_bind_null(stmt.stmt.int(), C.int(col)))
   306  }
   307  
   308  func (stmt *Stmt) BindText64(col int, val string) error {
   309  	if len(val) == 0 {
   310  		return errCode(C.bind_text64_empty(stmt.stmt.int(), C.int(col)))
   311  	}
   312  	v := C.CString(val) // freed by sqlite
   313  	return errCode(C.bind_text64(stmt.stmt.int(), C.int(col), v, C.sqlite3_uint64(len(val))))
   314  }
   315  
   316  func (stmt *Stmt) BindZeroBlob64(col int, n uint64) error {
   317  	return errCode(C.sqlite3_bind_zeroblob64(stmt.stmt.ptr(), C.int(col), C.sqlite3_uint64(n)))
   318  }
   319  
   320  func (stmt *Stmt) BindBlob64(col int, val []byte) error {
   321  	var str *C.char
   322  	if len(val) > 0 {
   323  		str = (*C.char)(unsafe.Pointer(&val[0]))
   324  	}
   325  	return errCode(C.bind_blob64(stmt.stmt.int(), C.int(col), str, C.sqlite3_uint64(len(val))))
   326  }
   327  
   328  func (stmt *Stmt) BindParameterCount() int {
   329  	return int(C.sqlite3_bind_parameter_count(stmt.stmt.ptr()))
   330  }
   331  
   332  func (stmt *Stmt) BindParameterName(col int) string {
   333  	return internStringFromCString(C.sqlite3_bind_parameter_name(stmt.stmt.ptr(), C.int(col)))
   334  }
   335  
   336  func (stmt *Stmt) BindParameterIndex(name string) int {
   337  	return int(C.bind_parameter_index(stmt.stmt.int(), name))
   338  }
   339  
   340  func (stmt *Stmt) BindParameterIndexSearch(name string) int {
   341  	// TODO: do prepend in C to save allocation
   342  	if i := stmt.BindParameterIndex(":" + name); i > 0 {
   343  		return i
   344  	}
   345  	if i := stmt.BindParameterIndex("@" + name); i > 0 {
   346  		return i
   347  	}
   348  	return stmt.BindParameterIndex("?" + name)
   349  }
   350  
   351  func (stmt *Stmt) ColumnCount() int {
   352  	return int(C.sqlite3_column_count(stmt.stmt.ptr()))
   353  }
   354  
   355  func (stmt *Stmt) ColumnName(col int) string {
   356  	return internStringFromCString(C.sqlite3_column_name(stmt.stmt.ptr(), C.int(col)))
   357  }
   358  
   359  func (stmt *Stmt) ColumnText(col int) string {
   360  	str := (*C.char)(unsafe.Pointer(C.ts_sqlite3_column_text(stmt.stmt.int(), C.int(col))))
   361  	n := C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col))
   362  	if str == nil || n == 0 {
   363  		return ""
   364  	}
   365  	return C.GoStringN(str, n)
   366  }
   367  
   368  func (stmt *Stmt) ColumnBlob(col int) []byte {
   369  	res := C.ts_sqlite3_column_blob(stmt.stmt.int(), C.int(col))
   370  	if res == nil {
   371  		return nil
   372  	}
   373  	n := int(C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col)))
   374  	return unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
   375  }
   376  
   377  func (stmt *Stmt) ColumnDouble(col int) float64 {
   378  	return float64(C.ts_sqlite3_column_double(stmt.stmt.int(), C.int(col)))
   379  }
   380  
   381  func (stmt *Stmt) ColumnInt64(col int) int64 {
   382  	return int64(C.ts_sqlite3_column_int64(stmt.stmt.int(), C.int(col)))
   383  }
   384  
   385  func (stmt *Stmt) ColumnType(col int) sqliteh.ColumnType {
   386  	return sqliteh.ColumnType(C.ts_sqlite3_column_type(stmt.stmt.int(), C.int(col)))
   387  }
   388  
   389  func (stmt *Stmt) ColumnDeclType(col int) string {
   390  	cstr := C.sqlite3_column_decltype(stmt.stmt.ptr(), C.int(col))
   391  	if cstr == nil {
   392  		return ""
   393  	}
   394  	bstr := (*byte)(unsafe.Pointer(cstr))
   395  	clen := findnull(bstr)
   396  	if stmt.db.declTypes == nil {
   397  		stmt.db.declTypes = make(map[string]string)
   398  	}
   399  	if res, found := stmt.db.declTypes[unsafe.String(bstr, clen)]; found {
   400  		return res
   401  	}
   402  	res := C.GoStringN(cstr, C.int(clen))
   403  	stmt.db.declTypes[res] = res
   404  	return res
   405  }
   406  
   407  var emptyCStr = C.CString("")
   408  
   409  func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) }
   410  
   411  // internCache contains interned strings.
   412  var internCache sync.Map // string => string (key == value)
   413  
   414  // internStringFromBytes returns string(b), interned into a map forever. It's meant
   415  // for use on hot, small strings from closed set (like database or table or
   416  // column names) where it doesn't matter if it leaks forever.
   417  func internStringFromBytes(b []byte) string {
   418  	if len(b) == 0 {
   419  		return ""
   420  	}
   421  	return internStringFromPtr((*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
   422  }
   423  
   424  func internStringFromPtr(p *C.char, n C.int) string {
   425  	if n == 0 {
   426  		return ""
   427  	}
   428  	v, _ := internCache.Load(unsafe.String((*byte)(unsafe.Pointer(p)), int(n)))
   429  	if s, ok := v.(string); ok {
   430  		return s
   431  	}
   432  	s := C.GoStringN(p, n)
   433  	internCache.Store(s, s)
   434  	return s
   435  }
   436  
   437  func internStringFromCString(p *C.char) string {
   438  	return internStringFromPtr(p, C.int(findnull((*byte)(unsafe.Pointer(p)))))
   439  }