github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/func.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/ncruces/go-sqlite3/internal/util"
     8  	"github.com/tetratelabs/wazero/api"
     9  )
    10  
    11  // CollationNeeded registers a callback to be invoked
    12  // whenever an unknown collation sequence is required.
    13  //
    14  // https://sqlite.org/c3ref/collation_needed.html
    15  func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
    16  	var enable uint64
    17  	if cb != nil {
    18  		enable = 1
    19  	}
    20  	r := c.call("sqlite3_collation_needed_go", uint64(c.handle), enable)
    21  	if err := c.error(r); err != nil {
    22  		return err
    23  	}
    24  	c.collation = cb
    25  	return nil
    26  }
    27  
    28  // AnyCollationNeeded uses [Conn.CollationNeeded] to register
    29  // a fake collating function for any unknown collating sequence.
    30  // The fake collating function works like BINARY.
    31  //
    32  // This can be used to load schemas that contain
    33  // one or more unknown collating sequences.
    34  func (c *Conn) AnyCollationNeeded() {
    35  	c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
    36  }
    37  
    38  // CreateCollation defines a new collating sequence.
    39  //
    40  // https://sqlite.org/c3ref/create_collation.html
    41  func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
    42  	defer c.arena.mark()()
    43  	namePtr := c.arena.string(name)
    44  	funcPtr := util.AddHandle(c.ctx, fn)
    45  	r := c.call("sqlite3_create_collation_go",
    46  		uint64(c.handle), uint64(namePtr), uint64(funcPtr))
    47  	return c.error(r)
    48  }
    49  
    50  // CreateFunction defines a new scalar SQL function.
    51  //
    52  // https://sqlite.org/c3ref/create_function.html
    53  func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
    54  	defer c.arena.mark()()
    55  	namePtr := c.arena.string(name)
    56  	funcPtr := util.AddHandle(c.ctx, fn)
    57  	r := c.call("sqlite3_create_function_go",
    58  		uint64(c.handle), uint64(namePtr), uint64(nArg),
    59  		uint64(flag), uint64(funcPtr))
    60  	return c.error(r)
    61  }
    62  
    63  // ScalarFunction is the type of a scalar SQL function.
    64  // Implementations must not retain arg.
    65  type ScalarFunction func(ctx Context, arg ...Value)
    66  
    67  // CreateWindowFunction defines a new aggregate or aggregate window SQL function.
    68  // If fn returns a [WindowFunction], then an aggregate window function is created.
    69  // If fn returns an [io.Closer], it will be called to free resources.
    70  //
    71  // https://sqlite.org/c3ref/create_function.html
    72  func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
    73  	defer c.arena.mark()()
    74  	call := "sqlite3_create_aggregate_function_go"
    75  	namePtr := c.arena.string(name)
    76  	funcPtr := util.AddHandle(c.ctx, fn)
    77  	if _, ok := fn().(WindowFunction); ok {
    78  		call = "sqlite3_create_window_function_go"
    79  	}
    80  	r := c.call(call,
    81  		uint64(c.handle), uint64(namePtr), uint64(nArg),
    82  		uint64(flag), uint64(funcPtr))
    83  	return c.error(r)
    84  }
    85  
    86  // AggregateFunction is the interface an aggregate function should implement.
    87  //
    88  // https://sqlite.org/appfunc.html
    89  type AggregateFunction interface {
    90  	// Step is invoked to add a row to the current window.
    91  	// The function arguments, if any, corresponding to the row being added, are passed to Step.
    92  	// Implementations must not retain arg.
    93  	Step(ctx Context, arg ...Value)
    94  
    95  	// Value is invoked to return the current (or final) value of the aggregate.
    96  	Value(ctx Context)
    97  }
    98  
    99  // WindowFunction is the interface an aggregate window function should implement.
   100  //
   101  // https://sqlite.org/windowfunctions.html
   102  type WindowFunction interface {
   103  	AggregateFunction
   104  
   105  	// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
   106  	// The function arguments, if any, are those passed to Step for the row being removed.
   107  	// Implementations must not retain arg.
   108  	Inverse(ctx Context, arg ...Value)
   109  }
   110  
   111  // OverloadFunction overloads a function for a virtual table.
   112  //
   113  // https://sqlite.org/c3ref/overload_function.html
   114  func (c *Conn) OverloadFunction(name string, nArg int) error {
   115  	defer c.arena.mark()()
   116  	namePtr := c.arena.string(name)
   117  	r := c.call("sqlite3_overload_function",
   118  		uint64(c.handle), uint64(namePtr), uint64(nArg))
   119  	return c.error(r)
   120  }
   121  
   122  func destroyCallback(ctx context.Context, mod api.Module, pApp uint32) {
   123  	util.DelHandle(ctx, pApp)
   124  }
   125  
   126  func collationCallback(ctx context.Context, mod api.Module, pArg, pDB, eTextRep, zName uint32) {
   127  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.collation != nil {
   128  		name := util.ReadString(mod, zName, _MAX_NAME)
   129  		c.collation(c, name)
   130  	}
   131  }
   132  
   133  func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
   134  	fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
   135  	return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
   136  }
   137  
   138  func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp, nArg, pArg uint32) {
   139  	args := getFuncArgs()
   140  	defer putFuncArgs(args)
   141  	db := ctx.Value(connKey{}).(*Conn)
   142  	fn := util.GetHandle(db.ctx, pApp).(ScalarFunction)
   143  	callbackArgs(db, args[:nArg], pArg)
   144  	fn(Context{db, pCtx}, args[:nArg]...)
   145  }
   146  
   147  func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp, nArg, pArg uint32) {
   148  	args := getFuncArgs()
   149  	defer putFuncArgs(args)
   150  	db := ctx.Value(connKey{}).(*Conn)
   151  	callbackArgs(db, args[:nArg], pArg)
   152  	fn, _ := callbackAggregate(db, pAgg, pApp)
   153  	fn.Step(Context{db, pCtx}, args[:nArg]...)
   154  }
   155  
   156  func finalCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp uint32) {
   157  	db := ctx.Value(connKey{}).(*Conn)
   158  	fn, handle := callbackAggregate(db, pAgg, pApp)
   159  	fn.Value(Context{db, pCtx})
   160  	util.DelHandle(ctx, handle)
   161  }
   162  
   163  func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg uint32) {
   164  	db := ctx.Value(connKey{}).(*Conn)
   165  	fn := util.GetHandle(db.ctx, pAgg).(AggregateFunction)
   166  	fn.Value(Context{db, pCtx})
   167  }
   168  
   169  func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg, nArg, pArg uint32) {
   170  	args := getFuncArgs()
   171  	defer putFuncArgs(args)
   172  	db := ctx.Value(connKey{}).(*Conn)
   173  	callbackArgs(db, args[:nArg], pArg)
   174  	fn := util.GetHandle(db.ctx, pAgg).(WindowFunction)
   175  	fn.Inverse(Context{db, pCtx}, args[:nArg]...)
   176  }
   177  
   178  func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32) {
   179  	if pApp == 0 {
   180  		handle := util.ReadUint32(db.mod, pAgg)
   181  		return util.GetHandle(db.ctx, handle).(AggregateFunction), handle
   182  	}
   183  
   184  	// We need to create the aggregate.
   185  	fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
   186  	handle := util.AddHandle(db.ctx, fn)
   187  	if pAgg != 0 {
   188  		util.WriteUint32(db.mod, pAgg, handle)
   189  	}
   190  	return fn, handle
   191  }
   192  
   193  func callbackArgs(db *Conn, arg []Value, pArg uint32) {
   194  	for i := range arg {
   195  		arg[i] = Value{
   196  			c:      db,
   197  			handle: util.ReadUint32(db.mod, pArg+ptrlen*uint32(i)),
   198  		}
   199  	}
   200  }
   201  
   202  var funcArgsPool sync.Pool
   203  
   204  func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
   205  	funcArgsPool.Put(p)
   206  }
   207  
   208  func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
   209  	if p := funcArgsPool.Get(); p == nil {
   210  		return new([_MAX_FUNCTION_ARG]Value)
   211  	} else {
   212  		return p.(*[_MAX_FUNCTION_ARG]Value)
   213  	}
   214  }