github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/func.go (about) 1 package sqlite3 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/ncruces/go-sqlite3/internal/util" 8 "github.com/tetratelabs/wazero/api" 9 ) 10 11 // CollationNeeded registers a callback to be invoked 12 // whenever an unknown collation sequence is required. 13 // 14 // https://sqlite.org/c3ref/collation_needed.html 15 func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { 16 var enable uint64 17 if cb != nil { 18 enable = 1 19 } 20 r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable) 21 if err := c.error(r); err != nil { 22 return err 23 } 24 c.collation = cb 25 return nil 26 } 27 28 // AnyCollationNeeded uses [Conn.CollationNeeded] to register 29 // a fake collating function for any unknown collating sequence. 30 // The fake collating function works like BINARY. 31 // 32 // This can be used to load schemas that contain 33 // one or more unknown collating sequences. 34 func (c *Conn) AnyCollationNeeded() { 35 c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) 36 } 37 38 // CreateCollation defines a new collating sequence. 39 // 40 // https://sqlite.org/c3ref/create_collation.html 41 func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { 42 defer c.arena.mark()() 43 namePtr := c.arena.string(name) 44 funcPtr := util.AddHandle(c.ctx, fn) 45 r := c.call("sqlite3_create_collation_go", 46 uint64(c.handle), uint64(namePtr), uint64(funcPtr)) 47 return c.error(r) 48 } 49 50 // CreateFunction defines a new scalar SQL function. 51 // 52 // https://sqlite.org/c3ref/create_function.html 53 func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { 54 defer c.arena.mark()() 55 namePtr := c.arena.string(name) 56 funcPtr := util.AddHandle(c.ctx, fn) 57 r := c.call("sqlite3_create_function_go", 58 uint64(c.handle), uint64(namePtr), uint64(nArg), 59 uint64(flag), uint64(funcPtr)) 60 return c.error(r) 61 } 62 63 // ScalarFunction is the type of a scalar SQL function. 64 // Implementations must not retain arg. 65 type ScalarFunction func(ctx Context, arg ...Value) 66 67 // CreateWindowFunction defines a new aggregate or aggregate window SQL function. 68 // If fn returns a [WindowFunction], then an aggregate window function is created. 69 // If fn returns an [io.Closer], it will be called to free resources. 70 // 71 // https://sqlite.org/c3ref/create_function.html 72 func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { 73 defer c.arena.mark()() 74 call := "sqlite3_create_aggregate_function_go" 75 namePtr := c.arena.string(name) 76 funcPtr := util.AddHandle(c.ctx, fn) 77 if _, ok := fn().(WindowFunction); ok { 78 call = "sqlite3_create_window_function_go" 79 } 80 r := c.call(call, 81 uint64(c.handle), uint64(namePtr), uint64(nArg), 82 uint64(flag), uint64(funcPtr)) 83 return c.error(r) 84 } 85 86 // AggregateFunction is the interface an aggregate function should implement. 87 // 88 // https://sqlite.org/appfunc.html 89 type AggregateFunction interface { 90 // Step is invoked to add a row to the current window. 91 // The function arguments, if any, corresponding to the row being added, are passed to Step. 92 // Implementations must not retain arg. 93 Step(ctx Context, arg ...Value) 94 95 // Value is invoked to return the current (or final) value of the aggregate. 96 Value(ctx Context) 97 } 98 99 // WindowFunction is the interface an aggregate window function should implement. 100 // 101 // https://sqlite.org/windowfunctions.html 102 type WindowFunction interface { 103 AggregateFunction 104 105 // Inverse is invoked to remove the oldest presently aggregated result of Step from the current window. 106 // The function arguments, if any, are those passed to Step for the row being removed. 107 // Implementations must not retain arg. 108 Inverse(ctx Context, arg ...Value) 109 } 110 111 // OverloadFunction overloads a function for a virtual table. 112 // 113 // https://sqlite.org/c3ref/overload_function.html 114 func (c *Conn) OverloadFunction(name string, nArg int) error { 115 defer c.arena.mark()() 116 namePtr := c.arena.string(name) 117 r := c.call("sqlite3_overload_function", 118 uint64(c.handle), uint64(namePtr), uint64(nArg)) 119 return c.error(r) 120 } 121 122 func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) { 123 util.DelHandle(ctx, pApp) 124 } 125 126 func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) { 127 if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil { 128 name := util.ReadString(mod, zName, _MAX_NAME) 129 c.collation(c, name) 130 } 131 } 132 133 func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 { 134 fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) 135 return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2)))) 136 } 137 138 func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) { 139 args := getFuncArgs() 140 defer putFuncArgs(args) 141 db := ctx.Value(connKey{}).(*Conn) 142 fn := util.GetHandle(db.ctx, pApp).(ScalarFunction) 143 callbackArgs(db, args[:nArg], pArg) 144 fn(Context{db, pCtx}, args[:nArg]...) 145 } 146 147 func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) { 148 args := getFuncArgs() 149 defer putFuncArgs(args) 150 db := ctx.Value(connKey{}).(*Conn) 151 callbackArgs(db, args[:nArg], pArg) 152 fn, _ := callbackAggregate(db, pAgg, pApp) 153 fn.Step(Context{db, pCtx}, args[:nArg]...) 154 } 155 156 func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) { 157 db := ctx.Value(connKey{}).(*Conn) 158 fn, handle := callbackAggregate(db, pAgg, pApp) 159 fn.Value(Context{db, pCtx}) 160 util.DelHandle(ctx, handle) 161 } 162 163 func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) { 164 db := ctx.Value(connKey{}).(*Conn) 165 fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction) 166 fn.Value(Context{db, pCtx}) 167 } 168 169 func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) { 170 args := getFuncArgs() 171 defer putFuncArgs(args) 172 db := ctx.Value(connKey{}).(*Conn) 173 callbackArgs(db, args[:nArg], pArg) 174 fn := util.GetHandle(db.ctx, pAgg).(WindowFunction) 175 fn.Inverse(Context{db, pCtx}, args[:nArg]...) 176 } 177 178 func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) { 179 if pApp == 0 { 180 handle := util.ReadUint32(db.mod, pAgg) 181 return util.GetHandle(db.ctx, handle).(AggregateFunction), handle 182 } 183 184 // We need to create the aggregate. 185 fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() 186 handle := util.AddHandle(db.ctx, fn) 187 if pAgg != 0 { 188 util.WriteUint32(db.mod, pAgg, handle) 189 } 190 return fn, handle 191 } 192 193 func callbackArgs(db *Conn, arg []Value, pArg uint32) { 194 for i := range arg { 195 arg[i] = Value{ 196 c: db, 197 handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)), 198 } 199 } 200 } 201 202 var funcArgsPool sync.Pool 203 204 func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { 205 funcArgsPool.Put(p) 206 } 207 208 func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { 209 if p := funcArgsPool.Get(); p == nil { 210 return new([_MAX_FUNCTION_ARG]Value) 211 } else { 212 return p.(*[_MAX_FUNCTION_ARG]Value) 213 } 214 }