github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/sqlite.go (about) 1 // Package sqlite3 wraps the C SQLite API. 2 package sqlite3 3 4 import ( 5 "context" 6 "math" 7 "math/bits" 8 "os" 9 "sync" 10 "unsafe" 11 12 "github.com/ncruces/go-sqlite3/internal/util" 13 "github.com/ncruces/go-sqlite3/vfs" 14 "github.com/tetratelabs/wazero" 15 "github.com/tetratelabs/wazero/api" 16 ) 17 18 // Configure SQLite Wasm. 19 // 20 // Importing package embed initializes [Binary] 21 // with an appropriate build of SQLite: 22 // 23 // import _ "github.com/ncruces/go-sqlite3/embed" 24 var ( 25 Binary []byte // Wasm binary to load. 26 Path string // Path to load the binary from. 27 28 RuntimeConfig wazero.RuntimeConfig 29 ) 30 31 // Initialize decodes and compiles the SQLite Wasm binary. 32 // This is called implicitly when the first connection is openned, 33 // but is potentially slow, so you may want to call it at a more convenient time. 34 func Initialize() error { 35 instance.once.Do(compileSQLite) 36 return instance.err 37 } 38 39 var instance struct { 40 runtime wazero.Runtime 41 compiled wazero.CompiledModule 42 err error 43 once sync.Once 44 } 45 46 func compileSQLite() { 47 if RuntimeConfig == nil { 48 RuntimeConfig = wazero.NewRuntimeConfig() 49 } 50 51 ctx := context.Background() 52 instance.runtime = wazero.NewRuntimeWithConfig(ctx, RuntimeConfig) 53 54 env := instance.runtime.NewHostModuleBuilder("env") 55 env = vfs.ExportHostFunctions(env) 56 env = exportCallbacks(env) 57 _, instance.err = env.Instantiate(ctx) 58 if instance.err != nil { 59 return 60 } 61 62 bin := Binary 63 if bin == nil && Path != "" { 64 bin, instance.err = os.ReadFile(Path) 65 if instance.err != nil { 66 return 67 } 68 } 69 if bin == nil { 70 instance.err = util.NoBinaryErr 71 return 72 } 73 74 instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin) 75 } 76 77 type sqlite struct { 78 ctx context.Context 79 mod api.Module 80 funcs struct { 81 fn [32]api.Function 82 id [32]*byte 83 mask uint32 84 } 85 stack [8]uint64 86 freer uint32 87 } 88 89 func instantiateSQLite() (sqlt *sqlite, err error) { 90 if err := Initialize(); err != nil { 91 return nil, err 92 } 93 94 sqlt = new(sqlite) 95 sqlt.ctx = util.NewContext(context.Background()) 96 97 sqlt.mod, err = instance.runtime.InstantiateModule(sqlt.ctx, 98 instance.compiled, wazero.NewModuleConfig().WithName("")) 99 if err != nil { 100 return nil, err 101 } 102 103 global := sqlt.mod.ExportedGlobal("malloc_destructor") 104 if global == nil { 105 return nil, util.BadBinaryErr 106 } 107 108 sqlt.freer = util.ReadUint32(sqlt.mod, uint32(global.Get())) 109 if sqlt.freer == 0 { 110 return nil, util.BadBinaryErr 111 } 112 return sqlt, nil 113 } 114 115 func (sqlt *sqlite) close() error { 116 return sqlt.mod.Close(sqlt.ctx) 117 } 118 119 func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { 120 if rc == _OK { 121 return nil 122 } 123 124 err := Error{code: rc} 125 126 if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { 127 panic(util.OOMErr) 128 } 129 130 if r := sqlt.call("sqlite3_errstr", rc); r != 0 { 131 err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_NAME) 132 } 133 134 if handle != 0 { 135 if r := sqlt.call("sqlite3_errmsg", uint64(handle)); r != 0 { 136 err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH) 137 } 138 139 if sql != nil { 140 if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { 141 err.sql = sql[0][r:] 142 } 143 } 144 } 145 146 switch err.msg { 147 case err.str, "not an error": 148 err.msg = "" 149 } 150 return &err 151 } 152 153 func (sqlt *sqlite) getfn(name string) api.Function { 154 c := &sqlt.funcs 155 p := unsafe.StringData(name) 156 for i := range c.id { 157 if c.id[i] == p { 158 c.id[i] = nil 159 c.mask &^= uint32(1) << i 160 return c.fn[i] 161 } 162 } 163 return sqlt.mod.ExportedFunction(name) 164 } 165 166 func (sqlt *sqlite) putfn(name string, fn api.Function) { 167 c := &sqlt.funcs 168 p := unsafe.StringData(name) 169 i := bits.TrailingZeros32(^c.mask) 170 if i < 32 { 171 c.id[i] = p 172 c.fn[i] = fn 173 c.mask |= uint32(1) << i 174 } else { 175 c.id[0] = p 176 c.fn[0] = fn 177 c.mask = uint32(1) 178 } 179 } 180 181 func (sqlt *sqlite) call(name string, params ...uint64) uint64 { 182 copy(sqlt.stack[:], params) 183 fn := sqlt.getfn(name) 184 err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:]) 185 if err != nil { 186 panic(err) 187 } 188 sqlt.putfn(name, fn) 189 return sqlt.stack[0] 190 } 191 192 func (sqlt *sqlite) free(ptr uint32) { 193 if ptr == 0 { 194 return 195 } 196 sqlt.call("free", uint64(ptr)) 197 } 198 199 func (sqlt *sqlite) new(size uint64) uint32 { 200 if size > _MAX_ALLOCATION_SIZE { 201 panic(util.OOMErr) 202 } 203 ptr := uint32(sqlt.call("malloc", size)) 204 if ptr == 0 && size != 0 { 205 panic(util.OOMErr) 206 } 207 return ptr 208 } 209 210 func (sqlt *sqlite) newBytes(b []byte) uint32 { 211 if (*[0]byte)(b) == nil { 212 return 0 213 } 214 ptr := sqlt.new(uint64(len(b))) 215 util.WriteBytes(sqlt.mod, ptr, b) 216 return ptr 217 } 218 219 func (sqlt *sqlite) newString(s string) uint32 { 220 ptr := sqlt.new(uint64(len(s) + 1)) 221 util.WriteString(sqlt.mod, ptr, s) 222 return ptr 223 } 224 225 func (sqlt *sqlite) newArena(size uint64) arena { 226 // Ensure the arena's size is a multiple of 8. 227 size = (size + 7) &^ 7 228 return arena{ 229 sqlt: sqlt, 230 size: uint32(size), 231 base: sqlt.new(size), 232 } 233 } 234 235 type arena struct { 236 sqlt *sqlite 237 ptrs []uint32 238 base uint32 239 next uint32 240 size uint32 241 } 242 243 func (a *arena) free() { 244 if a.sqlt == nil { 245 return 246 } 247 for _, ptr := range a.ptrs { 248 a.sqlt.free(ptr) 249 } 250 a.sqlt.free(a.base) 251 a.sqlt = nil 252 } 253 254 func (a *arena) mark() (reset func()) { 255 ptrs := len(a.ptrs) 256 next := a.next 257 return func() { 258 for _, ptr := range a.ptrs[ptrs:] { 259 a.sqlt.free(ptr) 260 } 261 a.ptrs = a.ptrs[:ptrs] 262 a.next = next 263 } 264 } 265 266 func (a *arena) new(size uint64) uint32 { 267 // Align the next address, to 4 or 8 bytes. 268 if size&7 != 0 { 269 a.next = (a.next + 3) &^ 3 270 } else { 271 a.next = (a.next + 7) &^ 7 272 } 273 if size <= uint64(a.size-a.next) { 274 ptr := a.base + a.next 275 a.next += uint32(size) 276 return ptr 277 } 278 ptr := a.sqlt.new(size) 279 a.ptrs = append(a.ptrs, ptr) 280 return ptr 281 } 282 283 func (a *arena) bytes(b []byte) uint32 { 284 if (*[0]byte)(b) == nil { 285 return 0 286 } 287 ptr := a.new(uint64(len(b))) 288 util.WriteBytes(a.sqlt.mod, ptr, b) 289 return ptr 290 } 291 292 func (a *arena) string(s string) uint32 { 293 ptr := a.new(uint64(len(s) + 1)) 294 util.WriteString(a.sqlt.mod, ptr, s) 295 return ptr 296 } 297 298 func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { 299 util.ExportFuncII(env, "go_progress_handler", progressCallback) 300 util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback) 301 util.ExportFuncIII(env, "go_busy_handler", busyCallback) 302 util.ExportFuncII(env, "go_commit_hook", commitCallback) 303 util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) 304 util.ExportFuncVIIIIJ(env, "go_update_hook", updateCallback) 305 util.ExportFuncIIIII(env, "go_wal_hook", walCallback) 306 util.ExportFuncIIIIII(env, "go_autovacuum_pages", autoVacuumCallback) 307 util.ExportFuncIIIIIII(env, "go_authorizer", authorizerCallback) 308 util.ExportFuncVIII(env, "go_log", logCallback) 309 util.ExportFuncVI(env, "go_destroy", destroyCallback) 310 util.ExportFuncVIIII(env, "go_func", funcCallback) 311 util.ExportFuncVIIIII(env, "go_step", stepCallback) 312 util.ExportFuncVIII(env, "go_final", finalCallback) 313 util.ExportFuncVII(env, "go_value", valueCallback) 314 util.ExportFuncVIIII(env, "go_inverse", inverseCallback) 315 util.ExportFuncVIIII(env, "go_collation_needed", collationCallback) 316 util.ExportFuncIIIIII(env, "go_compare", compareCallback) 317 util.ExportFuncIIIIII(env, "go_vtab_create", vtabModuleCallback(xCreate)) 318 util.ExportFuncIIIIII(env, "go_vtab_connect", vtabModuleCallback(xConnect)) 319 util.ExportFuncII(env, "go_vtab_disconnect", vtabDisconnectCallback) 320 util.ExportFuncII(env, "go_vtab_destroy", vtabDestroyCallback) 321 util.ExportFuncIII(env, "go_vtab_best_index", vtabBestIndexCallback) 322 util.ExportFuncIIIII(env, "go_vtab_update", vtabUpdateCallback) 323 util.ExportFuncIII(env, "go_vtab_rename", vtabRenameCallback) 324 util.ExportFuncIIIII(env, "go_vtab_find_function", vtabFindFuncCallback) 325 util.ExportFuncII(env, "go_vtab_begin", vtabBeginCallback) 326 util.ExportFuncII(env, "go_vtab_sync", vtabSyncCallback) 327 util.ExportFuncII(env, "go_vtab_commit", vtabCommitCallback) 328 util.ExportFuncII(env, "go_vtab_rollback", vtabRollbackCallback) 329 util.ExportFuncIII(env, "go_vtab_savepoint", vtabSavepointCallback) 330 util.ExportFuncIII(env, "go_vtab_release", vtabReleaseCallback) 331 util.ExportFuncIII(env, "go_vtab_rollback_to", vtabRollbackToCallback) 332 util.ExportFuncIIIIII(env, "go_vtab_integrity", vtabIntegrityCallback) 333 util.ExportFuncIII(env, "go_cur_open", cursorOpenCallback) 334 util.ExportFuncII(env, "go_cur_close", cursorCloseCallback) 335 util.ExportFuncIIIIII(env, "go_cur_filter", cursorFilterCallback) 336 util.ExportFuncII(env, "go_cur_next", cursorNextCallback) 337 util.ExportFuncII(env, "go_cur_eof", cursorEOFCallback) 338 util.ExportFuncIIII(env, "go_cur_column", cursorColumnCallback) 339 util.ExportFuncIII(env, "go_cur_rowid", cursorRowIDCallback) 340 return env 341 }