github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/queryglue.go (about) 1 package sqlitepool 2 3 // This file contains bridging functions designed to let users of 4 // database/sql move to sqlitepool without changing the semantics 5 // of their code. 6 // 7 // Eventually users should piece-wise migrate to another interface. 8 // (Or we should invest in this interface? Seems suboptimal.) 9 10 import ( 11 sqlpkg "database/sql" 12 "database/sql/driver" 13 "encoding" 14 "fmt" 15 "reflect" 16 "strings" 17 "time" 18 19 "github.com/tailscale/sqlite/sqliteh" 20 ) 21 22 // Exec is like database/sql.Tx.Exec. 23 // Only use this for one-off/rare queries. 24 // For normal queries, see the Exec method on Tx. 25 func Exec(db sqliteh.DB, sql string, args ...any) error { 26 stmt, _, err := db.Prepare(sql, 0) 27 if err != nil { 28 return err 29 } 30 if err := bindAll(db, stmt, args...); err != nil { 31 return fmt.Errorf("Exec: %w", err) 32 } 33 _, _, _, _, err = stmt.StepResult() 34 if err != nil { 35 err = fmt.Errorf("%w: %v", err, db.ErrMsg()) 36 } 37 stmt.Finalize() 38 return err 39 } 40 41 // QueryRow is like database/sql.Tx.QueryRow. 42 // Only use this for one-off/rare queries. 43 // For normal queries, see the methods on Rx. 44 func QueryRow(db sqliteh.DB, sql string, args ...any) *Row { 45 stmt, _, err := db.Prepare(sql, 0) 46 if err != nil { 47 return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, db.ErrMsg())} 48 } 49 if err := bindAll(db, stmt, args...); err != nil { 50 return &Row{err: fmt.Errorf("QueryRow: %w", err)} 51 } 52 row, err := stmt.Step(nil) 53 if err != nil { 54 msg := db.ErrMsg() 55 stmt.Finalize() 56 return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} 57 } 58 if !row { 59 stmt.Finalize() 60 return &Row{err: sqlpkg.ErrNoRows} 61 } 62 return &Row{stmt: stmt, oneOff: true} 63 } 64 65 // Query is like database/sql.Tx.Query. 66 // Only use this for one-off/rare queries. 67 // For normal queries, see the methods on Rx. 68 func Query(db sqliteh.DB, sql string, args ...any) (*Rows, error) { 69 stmt, _, err := db.Prepare(sql, 0) 70 if err != nil { 71 return nil, fmt.Errorf("Query: %w: %v", err, db.ErrMsg()) 72 } 73 if err := bindAll(db, stmt, args...); err != nil { 74 return nil, err 75 } 76 return &Rows{stmt: stmt, oneOff: true}, nil 77 } 78 79 // Exec is like database/sql.Tx.Exec. 80 func (tx *Tx) Exec(sql string, args ...any) error { 81 stmt := tx.Prepare(sql) 82 if err := bindAll(tx.conn.db, stmt, args...); err != nil { 83 return err 84 } 85 _, _, _, _, err := stmt.StepResult() 86 if err != nil { 87 return fmt.Errorf("%w: %v", err, tx.conn.db.ErrMsg()) 88 } 89 return nil 90 } 91 92 func (tx *Tx) ExecRes(sql string, args ...any) (rowsAffected int64, err error) { 93 stmt := tx.Prepare(sql) 94 if err := bindAll(tx.conn.db, stmt, args...); err != nil { 95 return 0, err 96 } 97 _, _, rowsAffected, _, err = stmt.StepResult() 98 return rowsAffected, err 99 } 100 101 // QueryRow is like database/sql.Tx.QueryRow. 102 func (rx *Rx) QueryRow(sql string, args ...any) *Row { 103 stmt := rx.Prepare(sql) 104 if err := bindAll(rx.conn.db, stmt, args...); err != nil { 105 return &Row{err: fmt.Errorf("QueryRow: %w", err)} 106 } 107 row, err := stmt.Step(nil) 108 if err != nil { 109 msg := rx.DB().ErrMsg() 110 stmt.ResetAndClear() 111 return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} 112 } 113 if !row { 114 stmt.ResetAndClear() 115 return &Row{err: sqlpkg.ErrNoRows} 116 } 117 return &Row{stmt: stmt} 118 } 119 120 // Query is like database/sql.Tx.Query. 121 func (rx *Rx) Query(sql string, args ...any) (*Rows, error) { 122 stmt := rx.Prepare(sql) 123 if err := bindAll(rx.conn.db, stmt, args...); err != nil { 124 return nil, fmt.Errorf("Query: %w", err) 125 } 126 return &Rows{stmt: stmt}, nil 127 } 128 129 // Rows is like database/sql.Tx.Rows. 130 type Rows struct { 131 stmt sqliteh.Stmt 132 err error 133 oneOff bool 134 } 135 136 func (rs *Rows) Next() bool { 137 if rs.err != nil { 138 return false 139 } 140 row, err := rs.stmt.Step(nil) 141 if err != nil { 142 rs.err = fmt.Errorf("QueryRow.Next: %w: %v", err, rs.stmt.DBHandle().ErrMsg()) 143 return false 144 } 145 if !row { 146 rs.stmt.ResetAndClear() 147 } 148 return row 149 } 150 151 func (rs *Rows) Err() error { 152 return rs.err 153 } 154 155 func (rs *Rows) Scan(dest ...any) error { 156 if rs.err != nil { 157 return rs.err 158 } 159 return scanAll(rs.stmt, dest...) 160 } 161 162 func (rs *Rows) Close() error { 163 if rs.stmt == nil { 164 return nil 165 } 166 _, err := rs.stmt.ResetAndClear() 167 msg := rs.stmt.DBHandle().ErrMsg() 168 var err2 error 169 if rs.oneOff { 170 err2 = rs.stmt.Finalize() 171 } 172 rs.stmt = nil 173 if err != nil { 174 return fmt.Errorf("Rows.ResetAndClear: %w: %v", err, msg) 175 } 176 if err2 != nil { 177 return fmt.Errorf("Rows.ResetAndClear: %w: %v", err2, rs.stmt.DBHandle().ErrMsg()) 178 } 179 return nil 180 } 181 182 // Row is like database/sql.Tx.Row. 183 type Row struct { 184 stmt sqliteh.Stmt 185 err error 186 oneOff bool 187 } 188 189 func (r *Row) Err() error { 190 return r.err 191 } 192 193 func (r *Row) Scan(dest ...any) error { 194 if r.err != nil { 195 return r.err 196 } 197 err := scanAll(r.stmt, dest...) 198 r.stmt.ResetAndClear() 199 if r.oneOff { 200 r.stmt.Finalize() 201 } 202 return err 203 } 204 205 type scanner interface { 206 Scan(value any) error 207 } 208 209 // scanAll mimics (some of) the sqlite driver's scanning logic, which is 210 // split across the driver and the database/sql package. 211 func scanAll(stmt sqliteh.Stmt, dest ...any) error { 212 for i := 0; i < len(dest); i++ { 213 if s, ok := dest[i].(scanner); ok { 214 // We have a handful of *sql.NullInt64 objects in 215 // our tree, so we implement minimal support for 216 // them here. TODO: remove some time. 217 var v any 218 switch stmt.ColumnType(i) { 219 case sqliteh.SQLITE_INTEGER: 220 v = stmt.ColumnInt64(i) 221 case sqliteh.SQLITE_FLOAT: 222 v = stmt.ColumnDouble(i) 223 case sqliteh.SQLITE_TEXT: 224 v = stmt.ColumnText(i) 225 case sqliteh.SQLITE_BLOB: 226 v = stmt.ColumnText(i) 227 case sqliteh.SQLITE_NULL: 228 v = nil 229 } 230 if err := s.Scan(v); err != nil { 231 return err 232 } 233 continue 234 } 235 v := reflect.ValueOf(dest[i]) 236 if v.Elem().Kind() == reflect.Slice && v.Elem().Type().Elem().Kind() == reflect.Uint8 { 237 b := append([]byte(nil), stmt.ColumnBlob(i)...) 238 v.Elem().SetBytes(b) 239 continue 240 } 241 switch v.Elem().Kind() { 242 case reflect.Bool: 243 v.Elem().SetBool(stmt.ColumnInt64(i) != 0) 244 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 245 v.Elem().SetInt(stmt.ColumnInt64(i)) 246 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 247 v.Elem().SetUint(uint64(stmt.ColumnInt64(i))) 248 case reflect.Float32, reflect.Float64: 249 v.Elem().SetFloat(stmt.ColumnDouble(i)) 250 case reflect.String: 251 v.Elem().SetString(stmt.ColumnText(i)) 252 default: 253 return fmt.Errorf("sqlitepool.scan:%d: cannot handle destination kind %v (%T)", i, v.Kind(), dest[i]) 254 } 255 } 256 return nil 257 } 258 259 func bindAll(db sqliteh.DB, stmt sqliteh.Stmt, args ...any) error { 260 for i, arg := range args { 261 if err := bind(db, stmt, i+1, arg); err != nil { 262 stmt.ResetAndClear() 263 return fmt.Errorf("bind: %d, %q: %w", i, arg, err) 264 } 265 } 266 return nil 267 } 268 269 type driverValue interface { 270 Value() (driver.Value, error) 271 } 272 273 // bind, from the driver in sqlite.go. 274 func bind(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) error { 275 // Start with obvious types, including time.Time before TextMarshaler. 276 found, err := bindBasic(db, s, ordinal, v) 277 if err != nil { 278 return err 279 } else if found { 280 return nil 281 } 282 283 if m, _ := v.(driverValue); m != nil { 284 // We have a few NullInt64s we need to handle. 285 // TODO: remove or rethink in the future. 286 var err error 287 v, err = m.Value() 288 if err != nil { 289 return fmt.Errorf("sqlitepool.bind:%d: bad driver.Value: %w", ordinal, err) 290 } 291 if v == nil { 292 _, err := bindBasic(db, s, ordinal, nil) 293 return err 294 } 295 } 296 297 if m, _ := v.(encoding.TextMarshaler); m != nil { 298 b, err := m.MarshalText() 299 if err != nil { 300 return fmt.Errorf("sqlitepool.bind:%d: cannot marshal %T: %w", ordinal, v, err) 301 } 302 _, err = bindBasic(db, s, ordinal, b) 303 return err 304 } 305 306 // Look for named basic types or other convertible types. 307 val := reflect.ValueOf(v) 308 if val.Kind() == reflect.Pointer { 309 if val.IsNil() { 310 _, err := bindBasic(db, s, ordinal, nil) 311 return err 312 } 313 val = val.Elem() 314 } 315 typ := reflect.TypeOf(v) 316 if typ.Kind() == reflect.Pointer { 317 typ = typ.Elem() 318 } 319 switch typ.Kind() { 320 case reflect.Bool: 321 b := int64(0) 322 if val.Bool() { 323 b = 1 324 } 325 _, err := bindBasic(db, s, ordinal, b) 326 return err 327 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 328 var i int64 329 if !val.IsZero() { 330 i = val.Int() 331 } 332 _, err := bindBasic(db, s, ordinal, i) 333 return err 334 case reflect.Uint, reflect.Uint64: 335 return fmt.Errorf("sqlitepool.bind:%d: sqlite does not support uint64 (try a string or TextMarshaler)", ordinal) 336 case reflect.Uint8, reflect.Uint16, reflect.Uint32: 337 _, err := bindBasic(db, s, ordinal, int64(val.Uint())) 338 return err 339 case reflect.Float32, reflect.Float64: 340 _, err := bindBasic(db, s, ordinal, val.Float()) 341 return err 342 case reflect.String: 343 _, err := bindBasic(db, s, ordinal, val.String()) 344 return err 345 } 346 347 return fmt.Errorf("sqlitepool.bind:%d: unknown value type %T (try a string or TextMarshaler)", ordinal, v) 348 } 349 350 // bindBasic, from the driver in sqlite.go. 351 func bindBasic(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) (found bool, err error) { 352 defer func() { 353 if err != nil { 354 err = fmt.Errorf("sqlitepool.bind:%d:%T: %w: %v", ordinal, v, err, db.ErrMsg()) 355 } 356 }() 357 switch v := v.(type) { 358 case nil: 359 return true, s.BindNull(ordinal) 360 case string: 361 return true, s.BindText64(ordinal, v) 362 case int: 363 return true, s.BindInt64(ordinal, int64(v)) 364 case int64: 365 return true, s.BindInt64(ordinal, v) 366 case float64: 367 return true, s.BindDouble(ordinal, v) 368 case []byte: 369 if len(v) == 0 { 370 return true, s.BindZeroBlob64(ordinal, 0) 371 } else { 372 return true, s.BindBlob64(ordinal, v) 373 } 374 case time.Time: 375 // Shortest of: 376 // YYYY-MM-DD HH:MM 377 // YYYY-MM-DD HH:MM:SS 378 // YYYY-MM-DD HH:MM:SS.SSS 379 str := v.Format(timeFormat) 380 str = strings.TrimSuffix(str, "-0000") 381 str = strings.TrimSuffix(str, ".000") 382 str = strings.TrimSuffix(str, ":00") 383 return true, s.BindText64(ordinal, str) 384 default: 385 return false, nil 386 } 387 } 388 389 // timeFormat from the driver in sqlite.go. 390 const timeFormat = "2006-01-02 15:04:05.000-0700"