github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/pivot/pivot.go (about)

     1  // Package pivot implements a pivot virtual table.
     2  //
     3  // https://github.com/jakethaw/pivot_vtab
     4  package pivot
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"strings"
    10  
    11  	"github.com/ncruces/go-sqlite3"
    12  )
    13  
    14  // Register registers the pivot virtual table.
    15  func Register(db *sqlite3.Conn) {
    16  	sqlite3.CreateModule(db, "pivot", declare, declare)
    17  }
    18  
    19  type table struct {
    20  	db   *sqlite3.Conn
    21  	scan string
    22  	cell string
    23  	keys []string
    24  	cols []*sqlite3.Value
    25  }
    26  
    27  func declare(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
    28  	if len(arg) != 3 {
    29  		return nil, fmt.Errorf("pivot: wrong number of arguments")
    30  	}
    31  
    32  	table := &table{db: db}
    33  	defer func() {
    34  		if err != nil {
    35  			table.Close()
    36  		}
    37  	}()
    38  
    39  	var sep string
    40  	var create strings.Builder
    41  	create.WriteString("CREATE TABLE x(")
    42  
    43  	// Row key query.
    44  	table.scan = "SELECT * FROM\n" + arg[0]
    45  	stmt, _, err := db.Prepare(table.scan)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	defer stmt.Close()
    50  
    51  	table.keys = make([]string, stmt.ColumnCount())
    52  	for i := range table.keys {
    53  		name := sqlite3.QuoteIdentifier(stmt.ColumnName(i))
    54  		table.keys[i] = name
    55  		create.WriteString(sep)
    56  		create.WriteString(name)
    57  		sep = ","
    58  	}
    59  	stmt.Close()
    60  
    61  	// Column definition query.
    62  	stmt, _, err = db.Prepare("SELECT * FROM\n" + arg[1])
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	if stmt.ColumnCount() != 2 {
    68  		return nil, fmt.Errorf("pivot: column definition query expects 2 result columns")
    69  	}
    70  	for stmt.Step() {
    71  		name := sqlite3.QuoteIdentifier(stmt.ColumnText(1))
    72  		table.cols = append(table.cols, stmt.ColumnValue(0).Dup())
    73  		create.WriteString(",")
    74  		create.WriteString(name)
    75  	}
    76  	stmt.Close()
    77  
    78  	// Pivot cell query.
    79  	table.cell = "SELECT * FROM\n" + arg[2]
    80  	stmt, _, err = db.Prepare(table.cell)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	if stmt.ColumnCount() != 1 {
    86  		return nil, fmt.Errorf("pivot: cell query expects 1 result columns")
    87  	}
    88  	if stmt.BindCount() != len(table.keys)+1 {
    89  		return nil, fmt.Errorf("pivot: cell query expects %d bound parameters", len(table.keys)+1)
    90  	}
    91  
    92  	create.WriteByte(')')
    93  	err = db.DeclareVTab(create.String())
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	return table, nil
    98  }
    99  
   100  func (t *table) Close() error {
   101  	for i := range t.cols {
   102  		t.cols[i].Close()
   103  	}
   104  	return nil
   105  }
   106  
   107  func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
   108  	var idxStr strings.Builder
   109  	idxStr.WriteString(t.scan)
   110  
   111  	argvIndex := 1
   112  	sep := " WHERE "
   113  	for i, cst := range idx.Constraint {
   114  		if !cst.Usable || !(0 <= cst.Column && cst.Column < len(t.keys)) {
   115  			continue
   116  		}
   117  		op := operator(cst.Op)
   118  		if op == "" {
   119  			continue
   120  		}
   121  
   122  		idxStr.WriteString(sep)
   123  		idxStr.WriteString(t.keys[cst.Column])
   124  		idxStr.WriteString(" ")
   125  		idxStr.WriteString(op)
   126  		idxStr.WriteString(" ?")
   127  
   128  		idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
   129  			ArgvIndex: argvIndex,
   130  			Omit:      true,
   131  		}
   132  		sep = " AND "
   133  		argvIndex++
   134  	}
   135  
   136  	sep = " ORDER BY "
   137  	idx.OrderByConsumed = true
   138  	for _, ord := range idx.OrderBy {
   139  		if !(0 <= ord.Column && ord.Column < len(t.keys)) {
   140  			idx.OrderByConsumed = false
   141  			continue
   142  		}
   143  		idxStr.WriteString(sep)
   144  		idxStr.WriteString(t.keys[ord.Column])
   145  		idxStr.WriteString(" COLLATE ")
   146  		idxStr.WriteString(idx.Collation(ord.Column))
   147  		if ord.Desc {
   148  			idxStr.WriteString(" DESC")
   149  		}
   150  		sep = ","
   151  	}
   152  
   153  	idx.EstimatedCost = 1e9 / float64(argvIndex)
   154  	idx.IdxStr = idxStr.String()
   155  	return nil
   156  }
   157  
   158  func (t *table) Open() (sqlite3.VTabCursor, error) {
   159  	return &cursor{table: t}, nil
   160  }
   161  
   162  func (t *table) Rename(new string) error {
   163  	return nil
   164  }
   165  
   166  type cursor struct {
   167  	table *table
   168  	scan  *sqlite3.Stmt
   169  	cell  *sqlite3.Stmt
   170  	rowID int64
   171  }
   172  
   173  func (c *cursor) Close() error {
   174  	return errors.Join(c.scan.Close(), c.cell.Close())
   175  }
   176  
   177  func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
   178  	err := c.scan.Close()
   179  	if err != nil {
   180  		return err
   181  	}
   182  
   183  	c.scan, _, err = c.table.db.Prepare(idxStr)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	for i, arg := range arg {
   188  		err := c.scan.BindValue(i+1, arg)
   189  		if err != nil {
   190  			return err
   191  		}
   192  	}
   193  
   194  	if c.cell == nil {
   195  		c.cell, _, err = c.table.db.Prepare(c.table.cell)
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  
   201  	c.rowID = 0
   202  	return c.Next()
   203  }
   204  
   205  func (c *cursor) Next() error {
   206  	if c.scan.Step() {
   207  		count := c.scan.ColumnCount()
   208  		for i := 0; i < count; i++ {
   209  			err := c.cell.BindValue(i+1, c.scan.ColumnValue(i))
   210  			if err != nil {
   211  				return err
   212  			}
   213  		}
   214  		c.rowID++
   215  	}
   216  	return c.scan.Err()
   217  }
   218  
   219  func (c *cursor) EOF() bool {
   220  	return !c.scan.Busy()
   221  }
   222  
   223  func (c *cursor) RowID() (int64, error) {
   224  	return c.rowID, nil
   225  }
   226  
   227  func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
   228  	count := c.scan.ColumnCount()
   229  	if col < count {
   230  		ctx.ResultValue(c.scan.ColumnValue(col))
   231  		return nil
   232  	}
   233  
   234  	err := c.cell.BindValue(count+1, *c.table.cols[col-count])
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	if c.cell.Step() {
   240  		ctx.ResultValue(c.cell.ColumnValue(0))
   241  	}
   242  	return c.cell.Reset()
   243  }
   244  
   245  func operator(op sqlite3.IndexConstraintOp) string {
   246  	switch op {
   247  	case sqlite3.INDEX_CONSTRAINT_EQ:
   248  		return "="
   249  	case sqlite3.INDEX_CONSTRAINT_LT:
   250  		return "<"
   251  	case sqlite3.INDEX_CONSTRAINT_GT:
   252  		return ">"
   253  	case sqlite3.INDEX_CONSTRAINT_LE:
   254  		return "<="
   255  	case sqlite3.INDEX_CONSTRAINT_GE:
   256  		return ">="
   257  	case sqlite3.INDEX_CONSTRAINT_NE:
   258  		return "<>"
   259  	case sqlite3.INDEX_CONSTRAINT_MATCH:
   260  		return "MATCH"
   261  	case sqlite3.INDEX_CONSTRAINT_LIKE:
   262  		return "LIKE"
   263  	case sqlite3.INDEX_CONSTRAINT_GLOB:
   264  		return "GLOB"
   265  	case sqlite3.INDEX_CONSTRAINT_REGEXP:
   266  		return "REGEXP"
   267  	case sqlite3.INDEX_CONSTRAINT_IS, sqlite3.INDEX_CONSTRAINT_ISNULL:
   268  		return "IS"
   269  	case sqlite3.INDEX_CONSTRAINT_ISNOT, sqlite3.INDEX_CONSTRAINT_ISNOTNULL:
   270  		return "IS NOT"
   271  	default:
   272  		return ""
   273  	}
   274  }