github.com/neugram/ng@v0.0.0-20180309130942-d472ff93d872/frame/sqlframe/sqlframe.go (about) 1 // Copyright 2015 The Neugram Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sqlframe 6 7 import ( 8 "bytes" 9 "database/sql" 10 "fmt" 11 "io" 12 "strings" 13 14 "neugram.io/ng/frame" 15 ) 16 17 /* 18 TODO composition of filters 19 20 Slice, Filter, and Accumulate interact oddly. 21 22 Given a filter, f.Filter("term1 < 1808"), we may get the query 23 select name, term1 from presidents where term1 < 1808; 24 which gives us 25 {1, "George Washington", 1789, 1792}, 26 {2, "John Adams", 1797, 0}, 27 {3, "Thomas Jefferson", 1800, 1804}, 28 this could be sliced: f.Filter("term1 < 1808").Slice(0, 2, 0, 2) into 29 {1, "George Washington", 1789, 1792}, 30 {2, "John Adams", 1797, 0}, 31 by adding to the query: 32 select name, term1 from presidents where term1 < 1808 limit 2; 33 so far so good. 34 35 However, if we first applied an offset slice, then the filter cannot 36 simply be added. That is, 37 f.Slice(0, 2, 2, 5).Filter("term1 < 1808") 38 needs to produce 39 {3, "Thomas Jefferson", 1800, 1804}, 40 which is the query: 41 select name, term1 from ( 42 select name, term1 from presidents offset 2 limit 5; 43 ) where term1 < 1808; 44 . 45 46 So we need to introduce a new kind of subFrame that can correctly 47 compose these restrictions. Or at the very least realize when they 48 don't compose, and punt to the default impl. 49 */ 50 51 // TODO: Set always returns an error on an accumulation 52 53 func Load(db *sql.DB, table string) (frame.Frame, error) { 54 // TODO: if sqlite. find out by lookiing at db.Driver()? 55 return sqliteLoad(db, table) 56 } 57 58 func NewFromFrame(db *sql.DB, table string, src frame.Frame) (frame.Frame, error) { 59 f := &sqlFrame{ 60 db: db, 61 table: table, 62 sliceCols: append([]string{}, src.Cols()...), 63 } 64 if _, err := db.Exec(f.createStmt()); err != nil { 65 return nil, err 66 } 67 return f, nil 68 } 69 70 type sqlFrame struct { 71 db *sql.DB 72 table string 73 sliceCols []string // table columns that are part of the frame 74 primaryKey []string // primary key columns 75 76 // TODO colExpr []parser.Expr 77 // TODO where []parser.Expr 78 // TODO groupBy []string 79 offset int 80 limit int // -1 for no limit 81 82 // TODO colType 83 84 insert *sql.Stmt 85 count *sql.Stmt 86 rowForPK *sql.Stmt 87 88 cache struct { 89 rowPKs [][]interface{} // rowPKs[i], primary key for row i 90 curGet *sql.Rows // current forward cursor, call Next for row len(rowPKs) 91 } 92 } 93 94 func (f *sqlFrame) Get(x, y int, dst ...interface{}) (err error) { 95 // Frame argument types don't quite line up with sql database types, 96 // so we do a per-driver transformation. In particular, a *big.Int 97 // and *big.Float are perfectly valid dst arguments, which 98 // database/sql does not understand. 99 frameDst := dst 100 sqlDst, err := sqliteScanBegin(frameDst) 101 if err != nil { 102 return err 103 } 104 dst = sqlDst 105 defer func() { 106 if err == nil { 107 sqliteScanEnd(frameDst, sqlDst) 108 } 109 }() 110 111 // Pad dst to handle slicing. 112 var empty interface{} 113 if x > 0 { 114 dst = append(make([]interface{}, x), dst...) 115 for i := 0; i < x; i++ { 116 dst[i] = &empty 117 } 118 } 119 if w := len(dst); w < len(f.sliceCols) { 120 dst = append(dst, make([]interface{}, len(f.sliceCols)-len(dst))...) 121 for i := w; i < len(dst); i++ { 122 dst[i] = &empty 123 } 124 } 125 126 if y < len(f.cache.rowPKs) { 127 // Previously visited row. 128 // Extract it from the DB using the primary key. 129 if f.rowForPK == nil { 130 buf := new(bytes.Buffer) 131 fmt.Fprint(buf, "SELECT ") 132 fmt.Fprint(buf, strings.Join(f.sliceCols, ", ")) 133 fmt.Fprintf(buf, " FROM %s WHERE ", f.table) 134 for i, key := range f.primaryKey { 135 if i > 0 { 136 fmt.Fprintf(buf, " AND ") 137 } 138 fmt.Fprintf(buf, "%s=?", key) 139 } 140 fmt.Fprintf(buf, ";") 141 f.rowForPK, err = f.db.Prepare(buf.String()) 142 if err != nil { 143 return fmt.Errorf("sqlframe: %v", err) 144 } 145 } 146 row := f.rowForPK.QueryRow(f.cache.rowPKs[y]...) 147 return row.Scan(dst...) 148 } 149 if f.cache.curGet == nil { 150 f.cache.rowPKs = nil 151 f.cache.curGet, err = f.db.Query(f.queryForGet()) 152 if err != nil { 153 return fmt.Errorf("sqlframe: %v", err) 154 } 155 } 156 for y >= len(f.cache.rowPKs) { 157 if !f.cache.curGet.Next() { 158 f.cache.curGet = nil 159 return io.EOF 160 } 161 pk := make([]interface{}, len(f.primaryKey)) 162 pkp := make([]interface{}, len(f.primaryKey)) 163 for i := range pk { 164 pkp[i] = &pk[i] 165 } 166 err = f.cache.curGet.Scan(append(dst, pkp...)...) 167 if err != nil { 168 f.cache.curGet = nil 169 return fmt.Errorf("sqlframe: %v", err) 170 } 171 f.cache.rowPKs = append(f.cache.rowPKs, pk) 172 } 173 return nil 174 } 175 176 func (f *sqlFrame) Len() (int, error) { 177 if f.count == nil { 178 var err error 179 f.count, err = f.db.Prepare("SELECT COUNT(*) FROM " + f.table + ";") 180 if err != nil { 181 return 0, fmt.Errorf("sqlframe: %v", err) 182 } 183 } 184 row := f.count.QueryRow() 185 count := 0 186 if err := row.Scan(&count); err != nil { 187 return 0, fmt.Errorf("sqlframe: count %v", err) 188 } 189 count -= f.offset 190 if f.limit >= 0 && count > f.limit { 191 count = f.limit 192 } 193 return count, nil 194 } 195 196 func (f *sqlFrame) CopyFrom(src frame.Frame) (n int, err error) { 197 if f.insert == nil { 198 buf := new(bytes.Buffer) 199 fmt.Fprintf(buf, "INSERT INTO %s (", f.table) 200 fmt.Fprintf(buf, strings.Join(f.sliceCols, ", ")) 201 fmt.Fprintf(buf, ") VALUES (") 202 for i := range f.sliceCols { 203 if i > 0 { 204 fmt.Fprintf(buf, ", ") 205 } 206 fmt.Fprintf(buf, "?") 207 } 208 fmt.Fprintf(buf, ");") 209 var err error 210 f.insert, err = f.db.Prepare(buf.String()) 211 if err != nil { 212 return 0, fmt.Errorf("sqlframe: %v", err) 213 } 214 } 215 216 // TODO: fast path for src.(*sqlFrame): insert from select 217 218 row := make([]interface{}, len(f.sliceCols)) 219 rowp := make([]interface{}, len(row)) 220 for i := range row { 221 rowp[i] = &row[i] 222 } 223 y := 0 224 for { 225 err := src.Get(0, y, rowp...) 226 if err == io.EOF { 227 break // last row, all is good 228 } 229 if err != nil { 230 return y, err 231 } 232 if _, err := f.insert.Exec(row...); err != nil { 233 return y, fmt.Errorf("sqlframe: %v", err) 234 } 235 y++ 236 } 237 return y, nil 238 } 239 240 func (f *sqlFrame) Cols() []string { return f.sliceCols } 241 242 func (d *sqlFrame) Slice(x, xlen, y, ylen int) frame.Frame { 243 n := &sqlFrame{ 244 db: d.db, 245 table: d.table, 246 sliceCols: d.sliceCols[x : x+xlen], 247 primaryKey: d.primaryKey, 248 count: d.count, 249 offset: d.offset + y, 250 limit: ylen, 251 } 252 if len(d.cache.rowPKs) > y { 253 n.cache.rowPKs = d.cache.rowPKs[y:] 254 if len(n.cache.rowPKs) > ylen { 255 n.cache.rowPKs = n.cache.rowPKs[:ylen] 256 } 257 } 258 return n 259 } 260 261 func (f *sqlFrame) Accumulate(g frame.Grouping) (frame.Frame, error) { 262 panic("TODO") 263 } 264 265 func (f *sqlFrame) validate() { 266 // TODO: check names match a strict format, mostly to avoid SQL injection 267 } 268 269 func (f *sqlFrame) createStmt() string { 270 f.validate() 271 buf := new(bytes.Buffer) 272 fmt.Fprintf(buf, "CREATE TABLE %s (\n", f.table) 273 for _, name := range f.sliceCols { 274 fmt.Fprintf(buf, "\t%s TODO_type,\n", name) 275 } 276 fmt.Fprintf(buf, ");") 277 return buf.String() 278 } 279 280 func (f *sqlFrame) queryForGet() string { 281 f.validate() 282 buf := new(bytes.Buffer) 283 fmt.Fprintf(buf, "SELECT ") 284 col := 0 285 for _, c := range f.sliceCols { 286 if col > 0 { 287 fmt.Fprintf(buf, ", ") 288 } 289 col++ 290 fmt.Fprintf(buf, c) 291 } 292 for i, c := range f.primaryKey { 293 if col > 0 { 294 fmt.Fprintf(buf, ", ") 295 } 296 col++ 297 fmt.Fprintf(buf, "%s as _pk%d", c, i) 298 } 299 fmt.Fprintf(buf, " FROM %s", f.table) 300 if f.limit >= 0 { 301 fmt.Fprintf(buf, " LIMIT %d", f.limit) 302 } 303 if f.offset > 0 { 304 fmt.Fprintf(buf, " OFFSET %d", f.offset) 305 } 306 fmt.Fprintf(buf, ";") 307 // TODO where 308 // TODO groupBy 309 // TODO offset 310 // TODO limit 311 // TODO colExpr 312 return buf.String() 313 }