github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/statement/stmt.go (about) 1 // Package statement defines table-valued functions using SQL. 2 // 3 // It can be used to create "parametrized views": 4 // pre-packaged queries that can be parametrized at query execution time. 5 // 6 // https://github.com/0x09/sqlite-statement-vtab 7 package statement 8 9 import ( 10 "encoding/json" 11 "fmt" 12 "strconv" 13 "strings" 14 "unsafe" 15 16 "github.com/ncruces/go-sqlite3" 17 ) 18 19 // Register registers the statement virtual table. 20 func Register(db *sqlite3.Conn) { 21 sqlite3.CreateModule(db, "statement", declare, declare) 22 } 23 24 type table struct { 25 stmt *sqlite3.Stmt 26 sql string 27 inuse bool 28 } 29 30 func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (*table, error) { 31 if len(arg) != 1 { 32 return nil, fmt.Errorf("statement: wrong number of arguments") 33 } 34 35 sql := "SELECT * FROM\n" + arg[0] 36 37 stmt, _, err := db.Prepare(sql) 38 if err != nil { 39 return nil, err 40 } 41 42 var sep string 43 var str strings.Builder 44 str.WriteString("CREATE TABLE x(") 45 outputs := stmt.ColumnCount() 46 for i := 0; i < outputs; i++ { 47 name := sqlite3.QuoteIdentifier(stmt.ColumnName(i)) 48 str.WriteString(sep) 49 str.WriteString(name) 50 str.WriteString(" ") 51 str.WriteString(stmt.ColumnDeclType(i)) 52 sep = "," 53 } 54 inputs := stmt.BindCount() 55 for i := 1; i <= inputs; i++ { 56 str.WriteString(sep) 57 name := stmt.BindName(i) 58 if name == "" { 59 str.WriteString("[") 60 str.WriteString(strconv.Itoa(i)) 61 str.WriteString("] HIDDEN") 62 } else { 63 str.WriteString(sqlite3.QuoteIdentifier(name[1:])) 64 str.WriteString(" HIDDEN") 65 } 66 sep = "," 67 } 68 str.WriteByte(')') 69 70 err = db.DeclareVTab(str.String()) 71 if err != nil { 72 stmt.Close() 73 return nil, err 74 } 75 76 return &table{sql: sql, stmt: stmt}, nil 77 } 78 79 func (t *table) Close() error { 80 return t.stmt.Close() 81 } 82 83 func (t *table) BestIndex(idx *sqlite3.IndexInfo) error { 84 idx.EstimatedCost = 1000 85 86 var argvIndex = 1 87 var needIndex bool 88 var listIndex []int 89 outputs := t.stmt.ColumnCount() 90 for i, cst := range idx.Constraint { 91 // Skip if this is a constraint on one of our output columns. 92 if cst.Column < outputs { 93 continue 94 } 95 96 // A given query plan is only usable if all provided input columns 97 // are usable and have equal constraints only. 98 if !cst.Usable || cst.Op != sqlite3.INDEX_CONSTRAINT_EQ { 99 return sqlite3.CONSTRAINT 100 } 101 102 // The non-zero argvIdx values must be contiguous. 103 // If they're not, build a list and serialize it through IdxStr. 104 nextIndex := cst.Column - outputs + 1 105 idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ 106 ArgvIndex: argvIndex, 107 Omit: true, 108 } 109 if nextIndex != argvIndex { 110 needIndex = true 111 } 112 listIndex = append(listIndex, nextIndex) 113 argvIndex++ 114 } 115 116 if needIndex { 117 buf, err := json.Marshal(listIndex) 118 if err != nil { 119 return err 120 } 121 idx.IdxStr = unsafe.String(&buf[0], len(buf)) 122 } 123 return nil 124 } 125 126 func (t *table) Open() (sqlite3.VTabCursor, error) { 127 stmt := t.stmt 128 if !t.inuse { 129 t.inuse = true 130 } else { 131 var err error 132 stmt, _, err = t.stmt.Conn().Prepare(t.sql) 133 if err != nil { 134 return nil, err 135 } 136 } 137 return &cursor{table: t, stmt: stmt}, nil 138 } 139 140 func (t *table) Rename(new string) error { 141 return nil 142 } 143 144 type cursor struct { 145 table *table 146 stmt *sqlite3.Stmt 147 arg []sqlite3.Value 148 rowID int64 149 } 150 151 func (c *cursor) Close() error { 152 if c.stmt == c.table.stmt { 153 c.table.inuse = false 154 c.stmt.ClearBindings() 155 return c.stmt.Reset() 156 } 157 return c.stmt.Close() 158 } 159 160 func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { 161 c.arg = arg 162 c.rowID = 0 163 c.stmt.ClearBindings() 164 if err := c.stmt.Reset(); err != nil { 165 return err 166 } 167 168 var list []int 169 if idxStr != "" { 170 buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr)) 171 err := json.Unmarshal(buf, &list) 172 if err != nil { 173 return err 174 } 175 } 176 177 for i, arg := range arg { 178 param := i + 1 179 if list != nil { 180 param = list[i] 181 } 182 err := c.stmt.BindValue(param, arg) 183 if err != nil { 184 return err 185 } 186 } 187 return c.Next() 188 } 189 190 func (c *cursor) Next() error { 191 if c.stmt.Step() { 192 c.rowID++ 193 } 194 return c.stmt.Err() 195 } 196 197 func (c *cursor) EOF() bool { 198 return !c.stmt.Busy() 199 } 200 201 func (c *cursor) RowID() (int64, error) { 202 return c.rowID, nil 203 } 204 205 func (c *cursor) Column(ctx *sqlite3.Context, col int) error { 206 switch outputs := c.stmt.ColumnCount(); { 207 case col < outputs: 208 ctx.ResultValue(c.stmt.ColumnValue(col)) 209 case col-outputs < len(c.arg): 210 ctx.ResultValue(c.arg[col-outputs]) 211 } 212 return nil 213 }