github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlite.go (about) 1 // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package sqlite implements a database/sql driver for SQLite3. 6 // 7 // This driver requires a file: URI always be used to open a database. 8 // For details see https://sqlite.org/c3ref/open.html#urifilenames. 9 // 10 // # Initializing connections or tracing 11 // 12 // If you want to do initial configuration of a connection, or enable 13 // tracing, use the Connector function: 14 // 15 // connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error { 16 // return sqlite.ExecScript(conn.(sqlite.SQLConn), "PRAGMA journal_mode=WAL;") 17 // } 18 // db, err = sql.OpenDB(sqlite.Connector(sqliteURI, connInitFunc, nil)) 19 // 20 // # Memory Mode 21 // 22 // In-memory databases are popular for tests. 23 // Use the "memdb" VFS (*not* the legacy in-memory modes) to be compatible 24 // with the database/sql connection pool: 25 // 26 // file:/dbname?vfs=memdb 27 // 28 // Use a different dbname for each memory database opened. 29 // 30 // # Binding Types 31 // 32 // SQLite is flexible about type conversions, and so is this driver. 33 // Almost all "basic" Go types (int, float64, string) are accepted and 34 // directly mapped into SQLite, even if they are named Go types. 35 // The time.Time type is also accepted (described below). 36 // Values that implement encoding.TextMarshaler or json.Marshaler are 37 // stored in SQLite in their marshaled form. 38 // 39 // # Binding Time 40 // 41 // While SQLite3 has no strict time datatype, it does have a series of built-in 42 // functions that operate on timestamps that expect columns to be in one of many 43 // formats: https://sqlite.org/lang_datefunc.html 44 // 45 // When encoding a time.Time into one of SQLite's preferred formats, we use the 46 // shortest timestamp format that can accurately represent the time.Time. 47 // The supported formats are: 48 // 49 // 2. YYYY-MM-DD HH:MM 50 // 3. YYYY-MM-DD HH:MM:SS 51 // 4. YYYY-MM-DD HH:MM:SS.SSS 52 // 53 // If the time.Time is not UTC (strongly consider storing times in UTC!), 54 // we follow SQLite's norm of appending "[+-]HH:MM" to the above formats. 55 // 56 // It is common in SQLite to store "Unix time", seconds-since-epoch in an 57 // INTEGER column. This is understood by the date and time functions documented 58 // in the link above. If you want to do that, pass the result of time.Time.Unix 59 // to the driver. 60 // 61 // # Reading Time 62 // 63 // In general, time is hard to extract from SQLite as a time.Time. 64 // If a column is defined as DATE or DATETIME, then text data is parsed 65 // as TimeFormat and returned as a time.Time. Integer data is parsed as 66 // seconds since epoch and returned as a time.Time. 67 package sqlite 68 69 import ( 70 "context" 71 "database/sql" 72 "database/sql/driver" 73 "encoding" 74 "errors" 75 "expvar" 76 "fmt" 77 "io" 78 "reflect" 79 "strings" 80 "sync/atomic" 81 "time" 82 83 "github.com/tailscale/sqlite/sqliteh" 84 ) 85 86 var Open sqliteh.OpenFunc = func(string, sqliteh.OpenFlags, string) (sqliteh.DB, error) { 87 return nil, fmt.Errorf("cgosqlite.Open is missing") 88 } 89 90 // ConnInitFunc is a function called by the driver on new connections. 91 // 92 // The conn can be used to execute queries, and implements SQLConn. 93 // Any error return closes the conn and passes the error to database/sql. 94 type ConnInitFunc func(ctx context.Context, conn driver.ConnPrepareContext) error 95 96 // TimeFormat is the string format this driver uses to store 97 // microsecond-precision time in SQLite in text format. 98 const TimeFormat = "2006-01-02 15:04:05.000-0700" 99 100 func init() { 101 sql.Register("sqlite3", drv{}) 102 } 103 104 var maxConnID atomic.Int32 105 106 // UsesAfterClose is a metric that is incremented every time an operation is 107 // attempted on a connection after Close has already been called. The keys are 108 // internal identifiers for the code path that incremented a counter. 109 var UsesAfterClose expvar.Map 110 111 // ErrClosed is returned when an operation is attempted on a connection after 112 // Close has already been called. 113 var ErrClosed = errors.New("sqlite3: already closed") 114 115 type drv struct{} 116 117 func (drv) Open(name string) (driver.Conn, error) { panic("deprecated, unused") } 118 func (drv) OpenConnector(name string) (driver.Connector, error) { 119 return &connector{name: name}, nil 120 } 121 122 func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer) driver.Connector { 123 return &connector{ 124 name: sqliteURI, 125 tracer: tracer, 126 connInitFunc: connInitFunc, 127 } 128 } 129 130 type connector struct { 131 name string 132 tracer sqliteh.Tracer 133 connInitFunc ConnInitFunc 134 } 135 136 func (p *connector) Driver() driver.Driver { return drv{} } 137 func (p *connector) Connect(ctx context.Context) (driver.Conn, error) { 138 db, err := Open(p.name, sqliteh.OpenFlagsDefault, "") 139 if err != nil { 140 if ec, ok := err.(sqliteh.ErrCode); ok { 141 e := &Error{ 142 Code: sqliteh.Code(ec), 143 Loc: "Open", 144 } 145 if db != nil { 146 e.Msg = db.ErrMsg() 147 } 148 err = e 149 } 150 if db != nil { 151 db.Close() 152 } 153 return nil, err 154 } 155 156 c := &conn{ 157 db: db, 158 tracer: p.tracer, 159 id: sqliteh.TraceConnID(maxConnID.Add(1)), 160 } 161 if p.connInitFunc != nil { 162 if err := p.connInitFunc(ctx, c); err != nil { 163 db.Close() 164 return nil, fmt.Errorf("sqlite.ConnInitFunc: %w", err) 165 } 166 } 167 return c, nil 168 } 169 170 type txState int 171 172 const ( 173 txStateNone = txState(0) // connection is not connected to a Tx 174 txStateInit = txState(1) // BeginTx called, but "BEGIN;" not yet executed 175 txStateBegun = txState(2) // "BEGIN;" has been executed 176 ) 177 178 type conn struct { 179 db sqliteh.DB 180 id sqliteh.TraceConnID 181 tracer sqliteh.Tracer 182 stmts map[string]*stmt // persisted statements 183 txState txState 184 readOnly bool 185 closed atomic.Bool 186 } 187 188 func (c *conn) Prepare(query string) (driver.Stmt, error) { panic("deprecated, unused") } 189 func (c *conn) Begin() (driver.Tx, error) { panic("deprecated, unused") } 190 func (c *conn) Close() error { 191 // Don't double-close 192 if !c.closed.CompareAndSwap(false, true) { 193 UsesAfterClose.Add("Close", 1) 194 return nil 195 } 196 197 for q, s := range c.stmts { 198 s.stmt.Finalize() 199 s.closed.Store(true) 200 delete(c.stmts, q) 201 } 202 err := reserr(c.db, "Conn.Close", "", c.db.Close()) 203 return err 204 } 205 func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 206 persist := ctx.Value(persistQuery{}) != nil 207 return c.prepare(ctx, query, persist) 208 } 209 210 func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt, err error) { 211 if c.closed.Load() { 212 UsesAfterClose.Add("prepare", 1) 213 return nil, ErrClosed 214 } 215 216 query = strings.TrimSpace(query) 217 if s := c.stmts[query]; s != nil { 218 // don't hand the same statement out twice; this is re-added on s.Close 219 delete(c.stmts, query) 220 221 s.prepCtx = ctx 222 if !s.closed.CompareAndSwap(true, false) { 223 // We'd previously set this to 'false', indicating that 224 // this stmt is in-use. Return an error instead of 225 // reusing the stmt. 226 return nil, ErrClosed 227 } 228 229 return s, nil 230 } 231 if c.tracer != nil { 232 // Not a hot path. Any high-load environment should use 233 // WithPersist so this is rare. 234 start := time.Now() 235 defer func() { 236 if err != nil { 237 c.tracer.Query(ctx, c.id, query, time.Since(start), err) 238 } 239 }() 240 } 241 var flags sqliteh.PrepareFlags 242 if persist { 243 flags = sqliteh.SQLITE_PREPARE_PERSISTENT 244 } 245 cstmt, rem, err := c.db.Prepare(query, flags) 246 if err != nil { 247 return nil, reserr(c.db, "Prepare", query, err) 248 } 249 if rem != "" { 250 cstmt.Finalize() 251 return nil, &Error{ 252 Code: sqliteh.SQLITE_MISUSE, 253 Loc: "Prepare", 254 Query: query, 255 Msg: fmt.Sprintf("query has trailing text: %q", rem), 256 } 257 } 258 s = &stmt{ 259 conn: c, 260 stmt: cstmt, 261 query: query, 262 persist: persist, 263 numInput: -1, 264 prepCtx: ctx, 265 } 266 267 if !persist { 268 return s, nil 269 } 270 271 // NOTE: don't add the statement to c.stmts here, since we could return 272 // it to another caller before Close is called; it's added to the 273 // c.stmts map on Close. 274 if c.stmts == nil { 275 c.stmts = make(map[string]*stmt) 276 } 277 return s, nil 278 } 279 280 func (c *conn) execInternal(ctx context.Context, query string) error { 281 s, err := c.prepare(ctx, query, true) 282 if err != nil { 283 if e, _ := err.(*Error); e != nil { 284 e.Loc = "internal:" + e.Loc 285 } 286 return err 287 } 288 if _, err := s.ExecContext(ctx, nil); err != nil { 289 return err 290 } 291 s.Close() 292 return nil 293 } 294 295 func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 296 if c.closed.Load() { 297 UsesAfterClose.Add("BeginTx", 1) 298 return nil, ErrClosed 299 } 300 301 const LevelSerializable = 6 // matches the sql package constant 302 if opts.Isolation != 0 && opts.Isolation != LevelSerializable { 303 return nil, errors.New("github.com/tailscale/sqlite driver only supports serializable isolation level") 304 } 305 c.readOnly = opts.ReadOnly 306 c.txState = txStateInit 307 if c.tracer != nil { 308 c.tracer.BeginTx(ctx, c.id, "", c.readOnly, nil) 309 } 310 if err := c.txInit(ctx); err != nil { 311 return nil, err 312 } 313 return &connTx{conn: c}, nil 314 } 315 316 // Raw is so ConnInitFunc can cast to SQLConn. 317 func (c *conn) Raw(fn func(any) error) error { return fn(c) } 318 319 type readOnlyKey struct{} 320 321 // ReadOnly applies the query_only pragma to the connection. 322 func ReadOnly(ctx context.Context) context.Context { 323 return context.WithValue(ctx, readOnlyKey{}, true) 324 } 325 326 // IsReadOnly reports whether the context has the ReadOnly key. 327 func IsReadOnly(ctx context.Context) bool { 328 return ctx.Value(readOnlyKey{}) != nil 329 } 330 331 func (c *conn) txInit(ctx context.Context) error { 332 if c.txState != txStateInit { 333 return nil 334 } 335 c.txState = txStateBegun 336 if c.readOnly || IsReadOnly(ctx) { 337 if err := c.execInternal(ctx, "BEGIN"); err != nil { 338 return err 339 } 340 if err := c.execInternal(ctx, "PRAGMA query_only=true"); err != nil { 341 return err 342 } 343 } else { 344 // TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?) 345 // semantics via a context annotation function. 346 if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil { 347 return err 348 } 349 } 350 return nil 351 } 352 353 func (c *conn) txEnd(ctx context.Context, endStmt string) error { 354 state, readOnly := c.txState, c.readOnly 355 c.txState = txStateNone 356 c.readOnly = false 357 if state != txStateBegun { 358 return nil 359 } 360 361 err := c.execInternal(context.Background(), endStmt) 362 if readOnly { 363 if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil { 364 err = err2 365 } 366 } 367 return err 368 } 369 370 type connTx struct { 371 conn *conn 372 } 373 374 func (tx *connTx) Commit() error { 375 if tx.conn.closed.Load() { 376 UsesAfterClose.Add("tx.Commit", 1) 377 return ErrClosed 378 } 379 380 err := tx.conn.txEnd(context.Background(), "COMMIT") 381 if tx.conn.tracer != nil { 382 tx.conn.tracer.Commit(tx.conn.id, err) 383 } 384 return err 385 } 386 387 func (tx *connTx) Rollback() error { 388 if tx.conn.closed.Load() { 389 UsesAfterClose.Add("tx.Rollback", 1) 390 return ErrClosed 391 } 392 393 err := tx.conn.txEnd(context.Background(), "ROLLBACK") 394 if tx.conn.tracer != nil { 395 tx.conn.tracer.Rollback(tx.conn.id, err) 396 } 397 return err 398 } 399 400 func reserr(db sqliteh.DB, loc, query string, err error) error { 401 if err == nil { 402 return nil 403 } 404 e := &Error{ 405 Code: sqliteh.Code(err.(sqliteh.ErrCode)), 406 Loc: loc, 407 Query: query, 408 } 409 // TODO(crawshaw): consider an API to expose this. sqlite.DebugErrMsg(db)? 410 if true { 411 e.Msg = db.ErrMsg() 412 } 413 return e 414 } 415 416 type stmt struct { 417 conn *conn 418 stmt sqliteh.Stmt 419 query string 420 persist bool // true if stmt is cached and lives beyond Close 421 bound bool // true if stmt has parameters bound 422 closed atomic.Bool // true after Close if persist==false 423 424 numInput int // filled on first NumInput only if persist==true 425 426 prepCtx context.Context // the context provided to prepare, for tracing 427 428 // filled on first step only if persist==true 429 colDeclTypes []colDeclType 430 colNames []string 431 } 432 433 func (s *stmt) reserr(loc string, err error) error { return reserr(s.conn.db, loc, s.query, err) } 434 435 func (s *stmt) NumInput() int { 436 if s.closed.Load() { 437 UsesAfterClose.Add("stmt.NumInput", 1) 438 return 0 439 } 440 if s.persist { 441 if s.numInput == -1 { 442 s.numInput = s.stmt.BindParameterCount() 443 } 444 return s.numInput 445 } 446 return s.stmt.BindParameterCount() 447 } 448 449 func (s *stmt) Close() error { 450 // Always set the 'closed' boolean, even for a persisted query; this is 451 // set from false -> true in prepare(), above. 452 if s.conn.closed.Load() { 453 UsesAfterClose.Add("Stmt.Close_conn", 1) 454 return nil 455 } 456 if !s.closed.CompareAndSwap(false, true) { 457 UsesAfterClose.Add("Stmt.Close", 1) 458 return nil 459 } 460 461 // We return this statement to the conn only if it's persistent, and 462 // only if there's not already a statement with the same query already 463 // cached there. 464 shouldPersist := s.persist 465 if shouldPersist { 466 if _, alreadyPersisted := s.conn.stmts[s.query]; alreadyPersisted { 467 shouldPersist = false 468 } 469 } 470 if shouldPersist { 471 err := s.reserr("Stmt.Close", s.resetAndClear()) 472 if err == nil { 473 s.conn.stmts[s.query] = s 474 } 475 return err 476 } 477 return s.reserr("Stmt.Close", s.stmt.Finalize()) 478 } 479 func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { panic("deprecated, unused") } 480 func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { panic("deprecated, unused") } 481 482 func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 483 if s.closed.Load() { 484 UsesAfterClose.Add("stmt.ExecContext", 1) 485 return nil, ErrClosed 486 } 487 if err := s.resetAndClear(); err != nil { 488 return nil, s.reserr("Stmt.Exec(Reset)", err) 489 } 490 if err := s.bindAll(args); err != nil { 491 return nil, s.reserr("Stmt.Exec(Bind)", err) 492 } 493 if ctx.Value(queryCancelKey{}) != nil { 494 var cancel context.CancelFunc 495 ctx, cancel = context.WithCancel(ctx) 496 defer cancel() 497 498 db := s.stmt.DBHandle() 499 go func() { <-ctx.Done(); db.Interrupt() }() 500 } 501 row, lastInsertRowID, changes, duration, err := s.stmt.StepResult() 502 s.bound = false // StepResult resets the query 503 err = s.reserr("Stmt.Exec", err) 504 if s.conn.tracer != nil { 505 s.conn.tracer.Query(s.prepCtx, s.conn.id, s.query, duration, err) 506 } 507 if err != nil { 508 return nil, err 509 } 510 _ = row // TODO: return error if exec on query which returns rows? 511 return getStmtResult(lastInsertRowID, changes), nil 512 } 513 514 var ( 515 stmtResultZeroRows = &stmtResult{} 516 stmtResultOneRow = &stmtResult{rowsAffected: 1} 517 ) 518 519 func getStmtResult(lastInsertID int64, rowsAffected int64) *stmtResult { 520 // Some common cases to avoid allocs: 521 if lastInsertID == 0 { 522 switch rowsAffected { 523 case 0: 524 return stmtResultZeroRows 525 case 1: 526 return stmtResultOneRow 527 } 528 } 529 return &stmtResult{lastInsertID: lastInsertID, rowsAffected: rowsAffected} 530 } 531 532 type stmtResult struct { 533 lastInsertID int64 534 rowsAffected int64 535 } 536 537 func (res *stmtResult) LastInsertId() (int64, error) { return res.lastInsertID, nil } 538 func (res *stmtResult) RowsAffected() (int64, error) { return res.rowsAffected, nil } 539 540 func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 541 if s.closed.Load() { 542 UsesAfterClose.Add("stmt.QueryContext", 1) 543 return nil, ErrClosed 544 } 545 if err := s.resetAndClear(); err != nil { 546 return nil, s.reserr("Stmt.Query(Reset)", err) 547 } 548 if err := s.bindAll(args); err != nil { 549 return nil, err 550 } 551 cancel := func() {} 552 if ctx.Value(queryCancelKey{}) != nil { 553 ctx, cancel = context.WithCancel(ctx) 554 db := s.stmt.DBHandle() 555 go func() { <-ctx.Done(); db.Interrupt() }() 556 } 557 return &rows{stmt: s, cancel: cancel}, nil 558 } 559 560 func (s *stmt) resetAndClear() error { 561 if !s.bound { 562 return nil 563 } 564 s.bound = false 565 duration, err := s.stmt.ResetAndClear() 566 if s.conn.tracer != nil { 567 s.conn.tracer.Query(s.prepCtx, s.conn.id, s.query, duration, err) 568 } 569 return err 570 } 571 572 func (s *stmt) bindAll(args []driver.NamedValue) error { 573 if s.bound { 574 panic("sqlite: impossible state, query already running: " + s.query) 575 } 576 s.bound = true 577 if s.conn.tracer != nil { 578 s.stmt.StartTimer() 579 } 580 for _, arg := range args { 581 if err := s.bind(arg); err != nil { 582 return err 583 } 584 } 585 return nil 586 } 587 588 func (s *stmt) bind(arg driver.NamedValue) error { 589 // TODO(crawshaw): could use a union-ish data type for debugName 590 // to avoid the allocation. 591 var debugName any 592 if arg.Name == "" { 593 debugName = arg.Ordinal 594 } else { 595 debugName = arg.Name 596 index := s.stmt.BindParameterIndexSearch(arg.Name) 597 if index == 0 { 598 return &Error{ 599 Code: sqliteh.SQLITE_MISUSE, 600 Loc: "Bind", 601 Query: s.query, 602 Msg: fmt.Sprintf("unknown parameter name %q", arg.Name), 603 } 604 } 605 arg.Ordinal = index 606 } 607 608 // Start with obvious types, including time.Time before TextMarshaler. 609 found, err := s.bindBasic(debugName, arg.Ordinal, arg.Value) 610 if err != nil { 611 return err 612 } else if found { 613 return nil 614 } 615 616 if m, _ := arg.Value.(encoding.TextMarshaler); m != nil { 617 b, err := m.MarshalText() 618 if err != nil { 619 // TODO: modify Error to carry an error so we can %w? 620 return &Error{ 621 Code: sqliteh.SQLITE_MISUSE, 622 Loc: "Bind", 623 Query: s.query, 624 Msg: fmt.Sprintf("Bind:%v: cannot marshal %T: %v", debugName, arg.Value, err), 625 } 626 } 627 _, err = s.bindBasic(debugName, arg.Ordinal, b) 628 return err 629 } 630 631 // Look for named basic types or other convertible types. 632 val := reflect.ValueOf(arg.Value) 633 typ := reflect.TypeOf(arg.Value) 634 switch typ.Kind() { 635 case reflect.Bool: 636 b := int64(0) 637 if val.Bool() { 638 b = 1 639 } 640 _, err := s.bindBasic(debugName, arg.Ordinal, b) 641 return err 642 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 643 _, err := s.bindBasic(debugName, arg.Ordinal, val.Int()) 644 return err 645 case reflect.Uint, reflect.Uint64: 646 return &Error{ 647 Code: sqliteh.SQLITE_MISUSE, 648 Loc: "Bind", 649 Query: s.query, 650 Msg: fmt.Sprintf("Bind:%v: sqlite does not support uint64 (try a string or TextMarshaler)", debugName), 651 } 652 case reflect.Uint8, reflect.Uint16, reflect.Uint32: 653 _, err := s.bindBasic(debugName, arg.Ordinal, int64(val.Uint())) 654 return err 655 case reflect.Float32, reflect.Float64: 656 _, err := s.bindBasic(debugName, arg.Ordinal, val.Float()) 657 return err 658 case reflect.String: 659 // TODO(crawshaw): decompose bindBasic somehow. 660 // But first: more tests that the errors make sense for each type. 661 _, err := s.bindBasic(debugName, arg.Ordinal, val.String()) 662 return err 663 } 664 665 return &Error{ 666 Code: sqliteh.SQLITE_MISUSE, 667 Loc: "Bind", 668 Query: s.query, 669 Msg: fmt.Sprintf("Bind:%v: unknown value type %T (try a string or TextMarshaler)", debugName, arg.Value), 670 } 671 } 672 673 func (s *stmt) bindBasic(debugName any, ordinal int, v any) (found bool, err error) { 674 defer func() { 675 if err != nil { 676 err = s.reserr(fmt.Sprintf("Bind:%v:%T", debugName, v), err) 677 } 678 }() 679 switch v := v.(type) { 680 case nil: 681 return true, s.stmt.BindNull(ordinal) 682 case string: 683 return true, s.stmt.BindText64(ordinal, v) 684 case int: 685 return true, s.stmt.BindInt64(ordinal, int64(v)) 686 case int64: 687 return true, s.stmt.BindInt64(ordinal, v) 688 case float64: 689 return true, s.stmt.BindDouble(ordinal, v) 690 case []byte: 691 if len(v) == 0 { 692 return true, s.stmt.BindZeroBlob64(ordinal, 0) 693 } else { 694 return true, s.stmt.BindBlob64(ordinal, v) 695 } 696 case time.Time: 697 // Shortest of: 698 // YYYY-MM-DD HH:MM 699 // YYYY-MM-DD HH:MM:SS 700 // YYYY-MM-DD HH:MM:SS.SSS 701 str := v.Format(TimeFormat) 702 str = strings.TrimSuffix(str, "-0000") 703 str = strings.TrimSuffix(str, ".000") 704 str = strings.TrimSuffix(str, ":00") 705 return true, s.stmt.BindText64(ordinal, str) 706 default: 707 return false, nil 708 } 709 } 710 711 // colDeclType is whether and how the declared SQLite column type should 712 // map to any special handling (as a date, or as a boolean, etc). 713 type colDeclType byte 714 715 const ( 716 declTypeUnknown colDeclType = iota 717 declTypeDateOrTime 718 declTypeBoolean 719 ) 720 721 func colDeclTypeFromString(s string) colDeclType { 722 if strings.EqualFold(s, "DATETIME") || strings.EqualFold(s, "DATE") { 723 return declTypeDateOrTime 724 } 725 if strings.EqualFold(s, "BOOLEAN") { 726 return declTypeBoolean 727 } 728 return declTypeUnknown 729 } 730 731 type rows struct { 732 stmt *stmt 733 closed bool 734 cancel context.CancelFunc // call when query ends 735 736 // colType is the column types for Step to fill on each row. We only use 23 737 // as it packs well with the closed bool byte above (24 bytes total, same as 738 // a slice) and it's uncommon for queries to select so many columns. But if 739 // they do, we still work: we just query the column type via cgo on each 740 // row. So a bit slower, but fine. 741 colType [23]sqliteh.ColumnType 742 743 colNames []string // filled on call to Columns 744 745 // Filled on first call to Next. 746 colDeclTypes []colDeclType 747 } 748 749 func (r *rows) Columns() []string { 750 if r.closed { 751 panic("Columns called after Rows was closed") 752 } 753 if r.stmt.closed.Load() { 754 UsesAfterClose.Add("rows.Columns", 1) 755 return nil 756 } 757 if r.colNames == nil { 758 if r.stmt.colNames != nil { 759 r.colNames = r.stmt.colNames 760 } else { 761 r.colNames = make([]string, r.stmt.stmt.ColumnCount()) 762 for i := range r.colNames { 763 r.colNames[i] = r.stmt.stmt.ColumnName(i) 764 } 765 if r.stmt.persist { 766 r.stmt.colNames = r.colNames 767 } 768 } 769 } 770 return append([]string{}, r.colNames...) 771 } 772 773 func (r *rows) Close() error { 774 if r.closed { 775 return errors.New("sqlite rows result already closed") 776 } 777 if r.stmt.closed.Load() { 778 UsesAfterClose.Add("rows.Close", 1) 779 return ErrClosed 780 } 781 r.closed = true 782 defer r.cancel() 783 if err := r.stmt.resetAndClear(); err != nil { 784 return r.stmt.reserr("Rows.Close(Reset)", err) 785 } 786 return nil 787 } 788 789 func (r *rows) Next(dest []driver.Value) error { 790 if r.closed { 791 return errors.New("sqlite rows result already closed") 792 } 793 if r.stmt.closed.Load() { 794 UsesAfterClose.Add("rows.Next", 1) 795 return ErrClosed 796 } 797 hasRow, err := r.stmt.stmt.Step(r.colType[:]) 798 if err != nil { 799 return r.stmt.reserr("Rows.Next", err) 800 } 801 if !hasRow { 802 return io.EOF 803 } 804 805 if r.colDeclTypes == nil { 806 r.colDeclTypes = r.stmt.colDeclTypes 807 } 808 if r.colDeclTypes == nil { 809 colCount := r.stmt.stmt.ColumnCount() 810 r.colDeclTypes = make([]colDeclType, colCount) 811 for i := range r.colDeclTypes { 812 r.colDeclTypes[i] = colDeclTypeFromString(r.stmt.stmt.ColumnDeclType(i)) 813 } 814 if r.stmt.persist { 815 r.stmt.colDeclTypes = r.colDeclTypes 816 } 817 } 818 819 for i := range dest { 820 var colType sqliteh.ColumnType 821 if i < len(r.colType) { 822 // Common case, for the first couple dozen columns. 823 colType = r.colType[i] 824 } else { 825 // If it's a really wide query, then call into 826 // cgo for columns past the length of 827 // r.colType. 828 colType = r.stmt.stmt.ColumnType(i) 829 } 830 831 if r.colDeclTypes[i] == declTypeDateOrTime { 832 switch colType { 833 case sqliteh.SQLITE_INTEGER: 834 v := r.stmt.stmt.ColumnInt64(i) 835 dest[i] = time.Unix(v, 0) 836 case sqliteh.SQLITE_FLOAT: 837 dest[i] = r.stmt.stmt.ColumnDouble(i) 838 // TODO: treat as time? 839 case sqliteh.SQLITE_TEXT: 840 v := r.stmt.stmt.ColumnText(i) 841 format := TimeFormat 842 if len(format) > len(v) { 843 format = strings.TrimSuffix(format, "-0700") 844 } 845 if len(format) > len(v) { 846 format = strings.TrimSuffix(format, ".000") 847 } 848 if len(format) > len(v) { 849 format = strings.TrimSuffix(format, ":05") 850 } 851 t, err := time.Parse(format, v) 852 if err != nil { 853 return fmt.Errorf("cannot parse time from column %d: %v", i, err) 854 } 855 dest[i] = t 856 } 857 continue 858 } 859 switch colType { 860 case sqliteh.SQLITE_INTEGER: 861 val := r.stmt.stmt.ColumnInt64(i) 862 if r.colDeclTypes[i] == declTypeBoolean { 863 dest[i] = val > 0 864 } else { 865 dest[i] = val 866 } 867 case sqliteh.SQLITE_FLOAT: 868 dest[i] = r.stmt.stmt.ColumnDouble(i) 869 case sqliteh.SQLITE_BLOB, sqliteh.SQLITE_TEXT: 870 dest[i] = r.stmt.stmt.ColumnBlob(i) 871 case sqliteh.SQLITE_NULL: 872 dest[i] = nil 873 } 874 } 875 return nil 876 } 877 878 // Error is an error produced by SQLite. 879 type Error struct { 880 Code sqliteh.Code // SQLite extended error code (SQLITE_OK is an invalid value) 881 Loc string // method name that generated the error 882 Query string // original SQL query text 883 Msg string // value of sqlite3_errmsg, set sqlite.ErrMsg = true 884 } 885 886 func (err Error) Error() string { 887 b := new(strings.Builder) 888 b.WriteString("sqlite") 889 if err.Loc != "" { 890 b.WriteByte('.') 891 b.WriteString(err.Loc) 892 } 893 b.WriteString(": ") 894 b.WriteString(err.Code.String()) 895 if err.Msg != "" { 896 b.WriteString(": ") 897 b.WriteString(err.Msg) 898 } 899 if err.Query != "" { 900 b.WriteString(" (") 901 b.WriteString(err.Query) 902 b.WriteByte(')') 903 } 904 return b.String() 905 } 906 907 // SQLConn is a database/sql.Conn. 908 // (We cannot create a circular package dependency here.) 909 type SQLConn interface { 910 Raw(func(driverConn any) error) error 911 } 912 913 // ExecScript executes a set of SQL queries on an sql.Conn. 914 // It stops on the first error. 915 // It is recommended you wrap your script in a BEGIN; ... COMMIT; block. 916 // 917 // Usage: 918 // 919 // c, err := db.Conn(ctx) 920 // if err != nil { 921 // // handle err 922 // } 923 // if err := sqlite.ExecScript(c, queries); err != nil { 924 // // handle err 925 // } 926 // c.Close() // return sql.Conn to pool 927 func ExecScript(sqlconn SQLConn, queries string) error { 928 return sqlconn.Raw(func(driverConn any) error { 929 c, ok := driverConn.(*conn) 930 if !ok { 931 return fmt.Errorf("sqlite.ExecScript: sql.Conn is not the sqlite driver: %T", driverConn) 932 } 933 934 for { 935 queries = strings.TrimSpace(queries) 936 if queries == "" { 937 return nil 938 } 939 cstmt, rem, err := c.db.Prepare(queries, 0) 940 if err != nil { 941 return reserr(c.db, "ExecScript", queries, err) 942 } 943 queries = rem 944 _, err = cstmt.Step(nil) 945 cstmt.Finalize() 946 if err != nil { 947 // TODO(crawshaw): consider checking sqlite3_txn_state 948 // here and issuing a rollback, incase this script was: 949 // BEGIN; BAD-SQL; COMMIT; 950 // So we don't leave the connection open. 951 return reserr(c.db, "ExecScript", queries, err) 952 } 953 } 954 }) 955 } 956 957 // BusyTimeout calls sqlite3_busy_timeout on the underlying connection. 958 func BusyTimeout(sqlconn SQLConn, d time.Duration) error { 959 return sqlconn.Raw(func(driverConn any) error { 960 c, ok := driverConn.(*conn) 961 if !ok { 962 return fmt.Errorf("sqlite.BusyTimeout: sql.Conn is not the sqlite driver: %T", driverConn) 963 } 964 c.db.BusyTimeout(d) 965 return nil 966 }) 967 } 968 969 // SetWALHook calls sqlite3_wal_hook. 970 // 971 // If hook is nil, the hook is removed. 972 func SetWALHook(sqlconn SQLConn, hook func(dbName string, pages int)) error { 973 return sqlconn.Raw(func(driverConn any) error { 974 c, ok := driverConn.(*conn) 975 if !ok { 976 return fmt.Errorf("sqlite.TxnState: sql.Conn is not the sqlite driver: %T", driverConn) 977 } 978 c.db.SetWALHook(hook) 979 return nil 980 }) 981 } 982 983 // TxnState calls sqlite3_txn_state on the underlying connection. 984 func TxnState(sqlconn SQLConn, schema string) (state sqliteh.TxnState, err error) { 985 return state, sqlconn.Raw(func(driverConn any) error { 986 c, ok := driverConn.(*conn) 987 if !ok { 988 return fmt.Errorf("sqlite.TxnState: sql.Conn is not the sqlite driver: %T", driverConn) 989 } 990 state = c.db.TxnState(schema) 991 return nil 992 }) 993 } 994 995 // Checkpoint calls sqlite3_wal_checkpoint_v2 on the underlying connection. 996 func Checkpoint(sqlconn SQLConn, dbName string, mode sqliteh.Checkpoint) (numFrames, numFramesCheckpointed int, err error) { 997 err = sqlconn.Raw(func(driverConn any) error { 998 c, ok := driverConn.(*conn) 999 if !ok { 1000 return fmt.Errorf("sqlite.Checkpoint: sql.Conn is not the sqlite driver: %T", driverConn) 1001 } 1002 numFrames, numFramesCheckpointed, err = c.db.Checkpoint(dbName, mode) 1003 return reserr(c.db, "Checkpoint", dbName, err) 1004 }) 1005 return numFrames, numFramesCheckpointed, err 1006 } 1007 1008 // WithPersist makes a ctx instruct the sqlite driver to persist a prepared query. 1009 // 1010 // This should be used with recurring queries to avoid constant parsing and 1011 // planning of the query by SQLite. 1012 func WithPersist(ctx context.Context) context.Context { 1013 return context.WithValue(ctx, persistQuery{}, persistQuery{}) 1014 } 1015 1016 // persistQuery is used as a context value. 1017 type persistQuery struct{} 1018 1019 // WithQueryCancel makes a ctx that instructs the sqlite driver to explicitly 1020 // interrupt a running query if its argument context ends. By default, without 1021 // this option, queries will only check the context between steps. 1022 func WithQueryCancel(ctx context.Context) context.Context { 1023 return context.WithValue(ctx, queryCancelKey{}, queryCancelKey{}) 1024 } 1025 1026 // queryCancelKey is a context key for query context enforcement. 1027 type queryCancelKey struct{}