github.com/neugram/ng@v0.0.0-20180309130942-d472ff93d872/frame/sqlframe/sqlframe.go (about)

     1  // Copyright 2015 The Neugram Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sqlframe
     6  
     7  import (
     8  	"bytes"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"strings"
    13  
    14  	"neugram.io/ng/frame"
    15  )
    16  
    17  /*
    18  TODO composition of filters
    19  
    20  Slice, Filter, and Accumulate interact oddly.
    21  
    22  Given a filter, f.Filter("term1 < 1808"), we may get the query
    23  	select name, term1 from presidents where term1 < 1808;
    24  which gives us
    25  	{1, "George Washington", 1789, 1792},
    26  	{2, "John Adams", 1797, 0},
    27  	{3, "Thomas Jefferson", 1800, 1804},
    28  this could be sliced: f.Filter("term1 < 1808").Slice(0, 2, 0, 2) into
    29  	{1, "George Washington", 1789, 1792},
    30  	{2, "John Adams", 1797, 0},
    31  by adding to the query:
    32  	select name, term1 from presidents where term1 < 1808 limit 2;
    33  so far so good.
    34  
    35  However, if we first applied an offset slice, then the filter cannot
    36  simply be added. That is,
    37  	f.Slice(0, 2, 2, 5).Filter("term1 < 1808")
    38  needs to produce
    39  	{3, "Thomas Jefferson", 1800, 1804},
    40  which is the query:
    41  	select name, term1 from (
    42  		select name, term1 from presidents offset 2 limit 5;
    43  	) where term1 < 1808;
    44  .
    45  
    46  So we need to introduce a new kind of subFrame that can correctly
    47  compose these restrictions. Or at the very least realize when they
    48  don't compose, and punt to the default impl.
    49  */
    50  
    51  // TODO: Set always returns an error on an accumulation
    52  
    53  func Load(db *sql.DB, table string) (frame.Frame, error) {
    54  	// TODO: if sqlite. find out by lookiing at db.Driver()?
    55  	return sqliteLoad(db, table)
    56  }
    57  
    58  func NewFromFrame(db *sql.DB, table string, src frame.Frame) (frame.Frame, error) {
    59  	f := &sqlFrame{
    60  		db:        db,
    61  		table:     table,
    62  		sliceCols: append([]string{}, src.Cols()...),
    63  	}
    64  	if _, err := db.Exec(f.createStmt()); err != nil {
    65  		return nil, err
    66  	}
    67  	return f, nil
    68  }
    69  
    70  type sqlFrame struct {
    71  	db         *sql.DB
    72  	table      string
    73  	sliceCols  []string // table columns that are part of the frame
    74  	primaryKey []string // primary key columns
    75  
    76  	// TODO colExpr    []parser.Expr
    77  	// TODO where      []parser.Expr
    78  	// TODO groupBy    []string
    79  	offset int
    80  	limit  int // -1 for no limit
    81  
    82  	// TODO colType
    83  
    84  	insert   *sql.Stmt
    85  	count    *sql.Stmt
    86  	rowForPK *sql.Stmt
    87  
    88  	cache struct {
    89  		rowPKs [][]interface{} // rowPKs[i], primary key for row i
    90  		curGet *sql.Rows       // current forward cursor, call Next for row len(rowPKs)
    91  	}
    92  }
    93  
    94  func (f *sqlFrame) Get(x, y int, dst ...interface{}) (err error) {
    95  	// Frame argument types don't quite line up with sql database types,
    96  	// so we do a per-driver transformation. In particular, a *big.Int
    97  	// and *big.Float are perfectly valid dst arguments, which
    98  	// database/sql does not understand.
    99  	frameDst := dst
   100  	sqlDst, err := sqliteScanBegin(frameDst)
   101  	if err != nil {
   102  		return err
   103  	}
   104  	dst = sqlDst
   105  	defer func() {
   106  		if err == nil {
   107  			sqliteScanEnd(frameDst, sqlDst)
   108  		}
   109  	}()
   110  
   111  	// Pad dst to handle slicing.
   112  	var empty interface{}
   113  	if x > 0 {
   114  		dst = append(make([]interface{}, x), dst...)
   115  		for i := 0; i < x; i++ {
   116  			dst[i] = &empty
   117  		}
   118  	}
   119  	if w := len(dst); w < len(f.sliceCols) {
   120  		dst = append(dst, make([]interface{}, len(f.sliceCols)-len(dst))...)
   121  		for i := w; i < len(dst); i++ {
   122  			dst[i] = &empty
   123  		}
   124  	}
   125  
   126  	if y < len(f.cache.rowPKs) {
   127  		// Previously visited row.
   128  		// Extract it from the DB using the primary key.
   129  		if f.rowForPK == nil {
   130  			buf := new(bytes.Buffer)
   131  			fmt.Fprint(buf, "SELECT ")
   132  			fmt.Fprint(buf, strings.Join(f.sliceCols, ", "))
   133  			fmt.Fprintf(buf, " FROM %s WHERE ", f.table)
   134  			for i, key := range f.primaryKey {
   135  				if i > 0 {
   136  					fmt.Fprintf(buf, " AND ")
   137  				}
   138  				fmt.Fprintf(buf, "%s=?", key)
   139  			}
   140  			fmt.Fprintf(buf, ";")
   141  			f.rowForPK, err = f.db.Prepare(buf.String())
   142  			if err != nil {
   143  				return fmt.Errorf("sqlframe: %v", err)
   144  			}
   145  		}
   146  		row := f.rowForPK.QueryRow(f.cache.rowPKs[y]...)
   147  		return row.Scan(dst...)
   148  	}
   149  	if f.cache.curGet == nil {
   150  		f.cache.rowPKs = nil
   151  		f.cache.curGet, err = f.db.Query(f.queryForGet())
   152  		if err != nil {
   153  			return fmt.Errorf("sqlframe: %v", err)
   154  		}
   155  	}
   156  	for y >= len(f.cache.rowPKs) {
   157  		if !f.cache.curGet.Next() {
   158  			f.cache.curGet = nil
   159  			return io.EOF
   160  		}
   161  		pk := make([]interface{}, len(f.primaryKey))
   162  		pkp := make([]interface{}, len(f.primaryKey))
   163  		for i := range pk {
   164  			pkp[i] = &pk[i]
   165  		}
   166  		err = f.cache.curGet.Scan(append(dst, pkp...)...)
   167  		if err != nil {
   168  			f.cache.curGet = nil
   169  			return fmt.Errorf("sqlframe: %v", err)
   170  		}
   171  		f.cache.rowPKs = append(f.cache.rowPKs, pk)
   172  	}
   173  	return nil
   174  }
   175  
   176  func (f *sqlFrame) Len() (int, error) {
   177  	if f.count == nil {
   178  		var err error
   179  		f.count, err = f.db.Prepare("SELECT COUNT(*) FROM " + f.table + ";")
   180  		if err != nil {
   181  			return 0, fmt.Errorf("sqlframe: %v", err)
   182  		}
   183  	}
   184  	row := f.count.QueryRow()
   185  	count := 0
   186  	if err := row.Scan(&count); err != nil {
   187  		return 0, fmt.Errorf("sqlframe: count %v", err)
   188  	}
   189  	count -= f.offset
   190  	if f.limit >= 0 && count > f.limit {
   191  		count = f.limit
   192  	}
   193  	return count, nil
   194  }
   195  
   196  func (f *sqlFrame) CopyFrom(src frame.Frame) (n int, err error) {
   197  	if f.insert == nil {
   198  		buf := new(bytes.Buffer)
   199  		fmt.Fprintf(buf, "INSERT INTO %s (", f.table)
   200  		fmt.Fprintf(buf, strings.Join(f.sliceCols, ", "))
   201  		fmt.Fprintf(buf, ") VALUES (")
   202  		for i := range f.sliceCols {
   203  			if i > 0 {
   204  				fmt.Fprintf(buf, ", ")
   205  			}
   206  			fmt.Fprintf(buf, "?")
   207  		}
   208  		fmt.Fprintf(buf, ");")
   209  		var err error
   210  		f.insert, err = f.db.Prepare(buf.String())
   211  		if err != nil {
   212  			return 0, fmt.Errorf("sqlframe: %v", err)
   213  		}
   214  	}
   215  
   216  	// TODO: fast path for src.(*sqlFrame): insert from select
   217  
   218  	row := make([]interface{}, len(f.sliceCols))
   219  	rowp := make([]interface{}, len(row))
   220  	for i := range row {
   221  		rowp[i] = &row[i]
   222  	}
   223  	y := 0
   224  	for {
   225  		err := src.Get(0, y, rowp...)
   226  		if err == io.EOF {
   227  			break // last row, all is good
   228  		}
   229  		if err != nil {
   230  			return y, err
   231  		}
   232  		if _, err := f.insert.Exec(row...); err != nil {
   233  			return y, fmt.Errorf("sqlframe: %v", err)
   234  		}
   235  		y++
   236  	}
   237  	return y, nil
   238  }
   239  
   240  func (f *sqlFrame) Cols() []string { return f.sliceCols }
   241  
   242  func (d *sqlFrame) Slice(x, xlen, y, ylen int) frame.Frame {
   243  	n := &sqlFrame{
   244  		db:         d.db,
   245  		table:      d.table,
   246  		sliceCols:  d.sliceCols[x : x+xlen],
   247  		primaryKey: d.primaryKey,
   248  		count:      d.count,
   249  		offset:     d.offset + y,
   250  		limit:      ylen,
   251  	}
   252  	if len(d.cache.rowPKs) > y {
   253  		n.cache.rowPKs = d.cache.rowPKs[y:]
   254  		if len(n.cache.rowPKs) > ylen {
   255  			n.cache.rowPKs = n.cache.rowPKs[:ylen]
   256  		}
   257  	}
   258  	return n
   259  }
   260  
   261  func (f *sqlFrame) Accumulate(g frame.Grouping) (frame.Frame, error) {
   262  	panic("TODO")
   263  }
   264  
   265  func (f *sqlFrame) validate() {
   266  	// TODO: check names match a strict format, mostly to avoid SQL injection
   267  }
   268  
   269  func (f *sqlFrame) createStmt() string {
   270  	f.validate()
   271  	buf := new(bytes.Buffer)
   272  	fmt.Fprintf(buf, "CREATE TABLE %s (\n", f.table)
   273  	for _, name := range f.sliceCols {
   274  		fmt.Fprintf(buf, "\t%s TODO_type,\n", name)
   275  	}
   276  	fmt.Fprintf(buf, ");")
   277  	return buf.String()
   278  }
   279  
   280  func (f *sqlFrame) queryForGet() string {
   281  	f.validate()
   282  	buf := new(bytes.Buffer)
   283  	fmt.Fprintf(buf, "SELECT ")
   284  	col := 0
   285  	for _, c := range f.sliceCols {
   286  		if col > 0 {
   287  			fmt.Fprintf(buf, ", ")
   288  		}
   289  		col++
   290  		fmt.Fprintf(buf, c)
   291  	}
   292  	for i, c := range f.primaryKey {
   293  		if col > 0 {
   294  			fmt.Fprintf(buf, ", ")
   295  		}
   296  		col++
   297  		fmt.Fprintf(buf, "%s as _pk%d", c, i)
   298  	}
   299  	fmt.Fprintf(buf, " FROM %s", f.table)
   300  	if f.limit >= 0 {
   301  		fmt.Fprintf(buf, " LIMIT %d", f.limit)
   302  	}
   303  	if f.offset > 0 {
   304  		fmt.Fprintf(buf, " OFFSET %d", f.offset)
   305  	}
   306  	fmt.Fprintf(buf, ";")
   307  	// TODO where
   308  	// TODO groupBy
   309  	// TODO offset
   310  	// TODO limit
   311  	// TODO colExpr
   312  	return buf.String()
   313  }