github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/statement/stmt.go (about)

     1  // Package statement defines table-valued functions using SQL.
     2  //
     3  // It can be used to create "parametrized views":
     4  // pre-packaged queries that can be parametrized at query execution time.
     5  //
     6  // https://github.com/0x09/sqlite-statement-vtab
     7  package statement
     8  
     9  import (
    10  	"encoding/json"
    11  	"fmt"
    12  	"strconv"
    13  	"strings"
    14  	"unsafe"
    15  
    16  	"github.com/ncruces/go-sqlite3"
    17  )
    18  
    19  // Register registers the statement virtual table.
    20  func Register(db *sqlite3.Conn) {
    21  	sqlite3.CreateModule(db, "statement", declare, declare)
    22  }
    23  
    24  type table struct {
    25  	stmt  *sqlite3.Stmt
    26  	sql   string
    27  	inuse bool
    28  }
    29  
    30  func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) {
    31  	if len(arg) != 1 {
    32  		return nil, fmt.Errorf("statement: wrong number of arguments")
    33  	}
    34  
    35  	sql := "SELECT * FROM\n" + arg[0]
    36  
    37  	stmt, _, err := db.Prepare(sql)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	var sep string
    43  	var str strings.Builder
    44  	str.WriteString("CREATE TABLE x(")
    45  	outputs := stmt.ColumnCount()
    46  	for i := 0; i < outputs; i++ {
    47  		name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
    48  		str.WriteString(sep)
    49  		str.WriteString(name)
    50  		str.WriteString(" ")
    51  		str.WriteString(stmt.ColumnDeclType(i))
    52  		sep = ","
    53  	}
    54  	inputs := stmt.BindCount()
    55  	for i := 1; i <= inputs; i++ {
    56  		str.WriteString(sep)
    57  		name := stmt.BindName(i)
    58  		if name == "" {
    59  			str.WriteString("[")
    60  			str.WriteString(strconv.Itoa(i))
    61  			str.WriteString("] HIDDEN")
    62  		} else {
    63  			str.WriteString(sqlite3.QuoteIdentifier(name[1:]))
    64  			str.WriteString(" HIDDEN")
    65  		}
    66  		sep = ","
    67  	}
    68  	str.WriteByte(')')
    69  
    70  	err = db.DeclareVTab(str.String())
    71  	if err != nil {
    72  		stmt.Close()
    73  		return nil, err
    74  	}
    75  
    76  	return &table{sql: sql, stmt: stmt}, nil
    77  }
    78  
    79  func (t *table) Close() error {
    80  	return t.stmt.Close()
    81  }
    82  
    83  func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
    84  	idx.EstimatedCost = 1000
    85  
    86  	var argvIndex = 1
    87  	var needIndex bool
    88  	var listIndex []int
    89  	outputs := t.stmt.ColumnCount()
    90  	for i, cst := range idx.Constraint {
    91  		// Skip if this is a constraint on one of our output columns.
    92  		if cst.Column < outputs {
    93  			continue
    94  		}
    95  
    96  		// A given query plan is only usable if all provided input columns
    97  		// are usable and have equal constraints only.
    98  		if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ {
    99  			return sqlite3.CONSTRAINT
   100  		}
   101  
   102  		// The non-zero argvIdx values must be contiguous.
   103  		// If they're not, build a list and serialize it through IdxStr.
   104  		nextIndex := cst.Column - outputs + 1
   105  		idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
   106  			ArgvIndex: argvIndex,
   107  			Omit:      true,
   108  		}
   109  		if nextIndex != argvIndex {
   110  			needIndex = true
   111  		}
   112  		listIndex = append(listIndex, nextIndex)
   113  		argvIndex++
   114  	}
   115  
   116  	if needIndex {
   117  		buf, err := json.Marshal(listIndex)
   118  		if err != nil {
   119  			return err
   120  		}
   121  		idx.IdxStr = unsafe.String(&buf[0], len(buf))
   122  	}
   123  	return nil
   124  }
   125  
   126  func (t *table) Open() (sqlite3.VTabCursor, error) {
   127  	stmt := t.stmt
   128  	if !t.inuse {
   129  		t.inuse = true
   130  	} else {
   131  		var err error
   132  		stmt, _, err = t.stmt.Conn().Prepare(t.sql)
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  	}
   137  	return &cursor{table: t, stmt: stmt}, nil
   138  }
   139  
   140  func (t *table) Rename(new string) error {
   141  	return nil
   142  }
   143  
   144  type cursor struct {
   145  	table *table
   146  	stmt  *sqlite3.Stmt
   147  	arg   []sqlite3.Value
   148  	rowID int64
   149  }
   150  
   151  func (c *cursor) Close() error {
   152  	if c.stmt == c.table.stmt {
   153  		c.table.inuse = false
   154  		c.stmt.ClearBindings()
   155  		return c.stmt.Reset()
   156  	}
   157  	return c.stmt.Close()
   158  }
   159  
   160  func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
   161  	c.arg = arg
   162  	c.rowID = 0
   163  	c.stmt.ClearBindings()
   164  	if err := c.stmt.Reset(); err != nil {
   165  		return err
   166  	}
   167  
   168  	var list []int
   169  	if idxStr != "" {
   170  		buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
   171  		err := json.Unmarshal(buf, &list)
   172  		if err != nil {
   173  			return err
   174  		}
   175  	}
   176  
   177  	for i, arg := range arg {
   178  		param := i + 1
   179  		if list != nil {
   180  			param = list[i]
   181  		}
   182  		err := c.stmt.BindValue(param, arg)
   183  		if err != nil {
   184  			return err
   185  		}
   186  	}
   187  	return c.Next()
   188  }
   189  
   190  func (c *cursor) Next() error {
   191  	if c.stmt.Step() {
   192  		c.rowID++
   193  	}
   194  	return c.stmt.Err()
   195  }
   196  
   197  func (c *cursor) EOF() bool {
   198  	return !c.stmt.Busy()
   199  }
   200  
   201  func (c *cursor) RowID() (int64, error) {
   202  	return c.rowID, nil
   203  }
   204  
   205  func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
   206  	switch outputs := c.stmt.ColumnCount(); {
   207  	case col < outputs:
   208  		ctx.ResultValue(c.stmt.ColumnValue(col))
   209  	case col-outputs < len(c.arg):
   210  		ctx.ResultValue(c.arg[col-outputs])
   211  	}
   212  	return nil
   213  }