github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/sqlite.go (about)

     1  // Package sqlite3 wraps the C SQLite API.
     2  package sqlite3
     3  
     4  import (
     5  	"context"
     6  	"math"
     7  	"math/bits"
     8  	"os"
     9  	"sync"
    10  	"unsafe"
    11  
    12  	"github.com/ncruces/go-sqlite3/internal/util"
    13  	"github.com/ncruces/go-sqlite3/vfs"
    14  	"github.com/tetratelabs/wazero"
    15  	"github.com/tetratelabs/wazero/api"
    16  )
    17  
    18  // Configure SQLite Wasm.
    19  //
    20  // Importing package embed initializes [Binary]
    21  // with an appropriate build of SQLite:
    22  //
    23  //	import _ "github.com/ncruces/go-sqlite3/embed"
    24  var (
    25  	Binary []byte // Wasm binary to load.
    26  	Path   string // Path to load the binary from.
    27  
    28  	RuntimeConfig wazero.RuntimeConfig
    29  )
    30  
    31  // Initialize decodes and compiles the SQLite Wasm binary.
    32  // This is called implicitly when the first connection is openned,
    33  // but is potentially slow, so you may want to call it at a more convenient time.
    34  func Initialize() error {
    35  	instance.once.Do(compileSQLite)
    36  	return instance.err
    37  }
    38  
    39  var instance struct {
    40  	runtime  wazero.Runtime
    41  	compiled wazero.CompiledModule
    42  	err      error
    43  	once     sync.Once
    44  }
    45  
    46  func compileSQLite() {
    47  	if RuntimeConfig == nil {
    48  		RuntimeConfig = wazero.NewRuntimeConfig()
    49  	}
    50  
    51  	ctx := context.Background()
    52  	instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig)
    53  
    54  	env := instance.runtime.NewHostModuleBuilder("env")
    55  	env = vfs.ExportHostFunctions(env)
    56  	env = exportCallbacks(env)
    57  	_, instance.err = env.Instantiate(ctx)
    58  	if instance.err != nil {
    59  		return
    60  	}
    61  
    62  	bin := Binary
    63  	if bin == nil && Path != "" {
    64  		bin, instance.err = os.ReadFile(Path)
    65  		if instance.err != nil {
    66  			return
    67  		}
    68  	}
    69  	if bin == nil {
    70  		instance.err = util.NoBinaryErr
    71  		return
    72  	}
    73  
    74  	instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
    75  }
    76  
    77  type sqlite struct {
    78  	ctx   context.Context
    79  	mod   api.Module
    80  	funcs struct {
    81  		fn   [32]api.Function
    82  		id   [32]*byte
    83  		mask uint32
    84  	}
    85  	stack [8]uint64
    86  	freer uint32
    87  }
    88  
    89  func instantiateSQLite() (sqlt *sqlite, err error) {
    90  	if err := Initialize(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	sqlt = new(sqlite)
    95  	sqlt.ctx = util.NewContext(context.Background())
    96  
    97  	sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx,
    98  		instance.compiled, wazero.NewModuleConfig().WithName(""))
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	global := sqlt.mod.ExportedGlobal("malloc_destructor")
   104  	if global == nil {
   105  		return nil, util.BadBinaryErr
   106  	}
   107  
   108  	sqlt.freer = util.ReadUint32(sqlt.mod, uint32(global.Get()))
   109  	if sqlt.freer == 0 {
   110  		return nil, util.BadBinaryErr
   111  	}
   112  	return sqlt, nil
   113  }
   114  
   115  func (sqlt *sqlite) close() error {
   116  	return sqlt.mod.Close(sqlt.ctx)
   117  }
   118  
   119  func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
   120  	if rc == _OK {
   121  		return nil
   122  	}
   123  
   124  	err := Error{code: rc}
   125  
   126  	if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
   127  		panic(util.OOMErr)
   128  	}
   129  
   130  	if r := sqlt.call("sqlite3_errstr", rc); r != 0 {
   131  		err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME)
   132  	}
   133  
   134  	if handle != 0 {
   135  		if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 {
   136  			err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
   137  		}
   138  
   139  		if sql != nil {
   140  			if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
   141  				err.sql = sql[0][r:]
   142  			}
   143  		}
   144  	}
   145  
   146  	switch err.msg {
   147  	case err.str, "not an error":
   148  		err.msg = ""
   149  	}
   150  	return &err
   151  }
   152  
   153  func (sqlt *sqlite) getfn(name string) api.Function {
   154  	c := &sqlt.funcs
   155  	p := unsafe.StringData(name)
   156  	for i := range c.id {
   157  		if c.id[i] == p {
   158  			c.id[i] = nil
   159  			c.mask &^= uint32(1) << i
   160  			return c.fn[i]
   161  		}
   162  	}
   163  	return sqlt.mod.ExportedFunction(name)
   164  }
   165  
   166  func (sqlt *sqlite) putfn(name string, fn api.Function) {
   167  	c := &sqlt.funcs
   168  	p := unsafe.StringData(name)
   169  	i := bits.TrailingZeros32(^c.mask)
   170  	if i < 32 {
   171  		c.id[i] = p
   172  		c.fn[i] = fn
   173  		c.mask |= uint32(1) << i
   174  	} else {
   175  		c.id[0] = p
   176  		c.fn[0] = fn
   177  		c.mask = uint32(1)
   178  	}
   179  }
   180  
   181  func (sqlt *sqlite) call(name string, params ...uint64) uint64 {
   182  	copy(sqlt.stack[:], params)
   183  	fn := sqlt.getfn(name)
   184  	err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
   185  	if err != nil {
   186  		panic(err)
   187  	}
   188  	sqlt.putfn(name, fn)
   189  	return sqlt.stack[0]
   190  }
   191  
   192  func (sqlt *sqlite) free(ptr uint32) {
   193  	if ptr == 0 {
   194  		return
   195  	}
   196  	sqlt.call("free", uint64(ptr))
   197  }
   198  
   199  func (sqlt *sqlite) new(size uint64) uint32 {
   200  	if size > _MAX_ALLOCATION_SIZE {
   201  		panic(util.OOMErr)
   202  	}
   203  	ptr := uint32(sqlt.call("malloc", size))
   204  	if ptr == 0 && size != 0 {
   205  		panic(util.OOMErr)
   206  	}
   207  	return ptr
   208  }
   209  
   210  func (sqlt *sqlite) newBytes(b []byte) uint32 {
   211  	if (*[0]byte)(b) == nil {
   212  		return 0
   213  	}
   214  	ptr := sqlt.new(uint64(len(b)))
   215  	util.WriteBytes(sqlt.mod, ptr, b)
   216  	return ptr
   217  }
   218  
   219  func (sqlt *sqlite) newString(s string) uint32 {
   220  	ptr := sqlt.new(uint64(len(s) + 1))
   221  	util.WriteString(sqlt.mod, ptr, s)
   222  	return ptr
   223  }
   224  
   225  func (sqlt *sqlite) newArena(size uint64) arena {
   226  	// Ensure the arena's size is a multiple of 8.
   227  	size = (size + 7) &^ 7
   228  	return arena{
   229  		sqlt: sqlt,
   230  		size: uint32(size),
   231  		base: sqlt.new(size),
   232  	}
   233  }
   234  
   235  type arena struct {
   236  	sqlt *sqlite
   237  	ptrs []uint32
   238  	base uint32
   239  	next uint32
   240  	size uint32
   241  }
   242  
   243  func (a *arena) free() {
   244  	if a.sqlt == nil {
   245  		return
   246  	}
   247  	for _, ptr := range a.ptrs {
   248  		a.sqlt.free(ptr)
   249  	}
   250  	a.sqlt.free(a.base)
   251  	a.sqlt = nil
   252  }
   253  
   254  func (a *arena) mark() (reset func()) {
   255  	ptrs := len(a.ptrs)
   256  	next := a.next
   257  	return func() {
   258  		for _, ptr := range a.ptrs[ptrs:] {
   259  			a.sqlt.free(ptr)
   260  		}
   261  		a.ptrs = a.ptrs[:ptrs]
   262  		a.next = next
   263  	}
   264  }
   265  
   266  func (a *arena) new(size uint64) uint32 {
   267  	// Align the next address, to 4 or 8 bytes.
   268  	if size&7 != 0 {
   269  		a.next = (a.next + 3) &^ 3
   270  	} else {
   271  		a.next = (a.next + 7) &^ 7
   272  	}
   273  	if size <= uint64(a.size-a.next) {
   274  		ptr := a.base + a.next
   275  		a.next += uint32(size)
   276  		return ptr
   277  	}
   278  	ptr := a.sqlt.new(size)
   279  	a.ptrs = append(a.ptrs, ptr)
   280  	return ptr
   281  }
   282  
   283  func (a *arena) bytes(b []byte) uint32 {
   284  	if (*[0]byte)(b) == nil {
   285  		return 0
   286  	}
   287  	ptr := a.new(uint64(len(b)))
   288  	util.WriteBytes(a.sqlt.mod, ptr, b)
   289  	return ptr
   290  }
   291  
   292  func (a *arena) string(s string) uint32 {
   293  	ptr := a.new(uint64(len(s) + 1))
   294  	util.WriteString(a.sqlt.mod, ptr, s)
   295  	return ptr
   296  }
   297  
   298  func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
   299  	util.ExportFuncII(env, "go_progress_handler", progressCallback)
   300  	util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback)
   301  	util.ExportFuncIII(env, "go_busy_handler", busyCallback)
   302  	util.ExportFuncII(env, "go_commit_hook", commitCallback)
   303  	util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)
   304  	util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback)
   305  	util.ExportFuncIIIII(env, "go_wal_hook", walCallback)
   306  	util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback)
   307  	util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback)
   308  	util.ExportFuncVIII(env, "go_log", logCallback)
   309  	util.ExportFuncVI(env, "go_destroy", destroyCallback)
   310  	util.ExportFuncVIIII(env, "go_func", funcCallback)
   311  	util.ExportFuncVIIIII(env, "go_step", stepCallback)
   312  	util.ExportFuncVIII(env, "go_final", finalCallback)
   313  	util.ExportFuncVII(env, "go_value", valueCallback)
   314  	util.ExportFuncVIIII(env, "go_inverse", inverseCallback)
   315  	util.ExportFuncVIIII(env, "go_collation_needed", collationCallback)
   316  	util.ExportFuncIIIIII(env, "go_compare", compareCallback)
   317  	util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate))
   318  	util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect))
   319  	util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback)
   320  	util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback)
   321  	util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback)
   322  	util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback)
   323  	util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback)
   324  	util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback)
   325  	util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback)
   326  	util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback)
   327  	util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback)
   328  	util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback)
   329  	util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback)
   330  	util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback)
   331  	util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback)
   332  	util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback)
   333  	util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback)
   334  	util.ExportFuncII(env, "go_cur_close", cursorCloseCallback)
   335  	util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback)
   336  	util.ExportFuncII(env, "go_cur_next", cursorNextCallback)
   337  	util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback)
   338  	util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback)
   339  	util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback)
   340  	return env
   341  }