github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/driver/driver.go (about) 1 // Package driver provides a database/sql driver for SQLite. 2 // 3 // Importing package driver registers a [database/sql] driver named "sqlite3". 4 // You may also need to import package embed. 5 // 6 // import _ "github.com/ncruces/go-sqlite3/driver" 7 // import _ "github.com/ncruces/go-sqlite3/embed" 8 // 9 // The data source name for "sqlite3" databases can be a filename or a "file:" [URI]. 10 // 11 // The [TRANSACTION] mode can be specified using "_txlock": 12 // 13 // sql.Open("sqlite3", "file:demo.db?_txlock=immediate") 14 // 15 // Possible values are: "deferred", "immediate", "exclusive". 16 // A [read-only] transaction is always "deferred", regardless of "_txlock". 17 // 18 // The time encoding/decoding format can be specified using "_timefmt": 19 // 20 // sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite") 21 // 22 // Possible values are: "auto" (the default), "sqlite", "rfc3339"; 23 // "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite; 24 // "sqlite" encodes as SQLite and decodes any [format] supported by SQLite; 25 // "rfc3339" encodes and decodes RFC 3339 only. 26 // 27 // [PRAGMA] statements can be specified using "_pragma": 28 // 29 // sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)") 30 // 31 // If no PRAGMAs are specified, a busy timeout of 1 minute is set. 32 // 33 // Order matters: 34 // busy timeout and locking mode should be the first PRAGMAs set, in that order. 35 // 36 // [URI]: https://sqlite.org/uri.html 37 // [PRAGMA]: https://sqlite.org/pragma.html 38 // [format]: https://sqlite.org/lang_datefunc.html#time_values 39 // [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions 40 // [read-only]: https://pkg.go.dev/database/sql#TxOptions 41 package driver 42 43 import ( 44 "context" 45 "database/sql" 46 "database/sql/driver" 47 "errors" 48 "fmt" 49 "io" 50 "net/url" 51 "strings" 52 "time" 53 "unsafe" 54 55 "github.com/ncruces/go-sqlite3" 56 "github.com/ncruces/go-sqlite3/internal/util" 57 ) 58 59 // This variable can be replaced with -ldflags: 60 // 61 // go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite" 62 var driverName = "sqlite3" 63 64 func init() { 65 if driverName != "" { 66 sql.Register(driverName, &SQLite{}) 67 } 68 } 69 70 // Open opens the SQLite database specified by dataSourceName as a [database/sql.DB]. 71 // 72 // The init function is called by the driver on new connections. 73 // The [sqlite3.Conn] can be used to execute queries, register functions, etc. 74 // Any error returned closes the connection and is returned to [database/sql]. 75 func Open(dataSourceName string, init func(*sqlite3.Conn) error) (*sql.DB, error) { 76 c, err := (&SQLite{Init: init}).OpenConnector(dataSourceName) 77 if err != nil { 78 return nil, err 79 } 80 return sql.OpenDB(c), nil 81 } 82 83 // SQLite implements [database/sql/driver.Driver]. 84 type SQLite struct { 85 // Init function is called by the driver on new connections. 86 // The [sqlite3.Conn] can be used to execute queries, register functions, etc. 87 // Any error returned closes the connection and is returned to [database/sql]. 88 Init func(*sqlite3.Conn) error 89 } 90 91 // Open implements [database/sql/driver.Driver]. 92 func (d *SQLite) Open(name string) (driver.Conn, error) { 93 c, err := d.newConnector(name) 94 if err != nil { 95 return nil, err 96 } 97 return c.Connect(context.Background()) 98 } 99 100 // OpenConnector implements [database/sql/driver.DriverContext]. 101 func (d *SQLite) OpenConnector(name string) (driver.Connector, error) { 102 return d.newConnector(name) 103 } 104 105 func (d *SQLite) newConnector(name string) (*connector, error) { 106 c := connector{driver: d, name: name} 107 108 var txlock, timefmt string 109 if strings.HasPrefix(name, "file:") { 110 if _, after, ok := strings.Cut(name, "?"); ok { 111 query, err := url.ParseQuery(after) 112 if err != nil { 113 return nil, err 114 } 115 txlock = query.Get("_txlock") 116 timefmt = query.Get("_timefmt") 117 c.pragmas = query.Has("_pragma") 118 } 119 } 120 121 switch txlock { 122 case "": 123 c.txBegin = "BEGIN" 124 case "deferred", "immediate", "exclusive": 125 c.txBegin = "BEGIN " + txlock 126 default: 127 return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock) 128 } 129 130 switch timefmt { 131 case "": 132 c.tmRead = sqlite3.TimeFormatAuto 133 c.tmWrite = sqlite3.TimeFormatDefault 134 case "sqlite": 135 c.tmRead = sqlite3.TimeFormatAuto 136 c.tmWrite = sqlite3.TimeFormat3 137 case "rfc3339": 138 c.tmRead = sqlite3.TimeFormatDefault 139 c.tmWrite = sqlite3.TimeFormatDefault 140 default: 141 c.tmRead = sqlite3.TimeFormat(timefmt) 142 c.tmWrite = sqlite3.TimeFormat(timefmt) 143 } 144 return &c, nil 145 } 146 147 type connector struct { 148 driver *SQLite 149 name string 150 txBegin string 151 tmRead sqlite3.TimeFormat 152 tmWrite sqlite3.TimeFormat 153 pragmas bool 154 } 155 156 func (n *connector) Driver() driver.Driver { 157 return n.driver 158 } 159 160 func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { 161 c := &conn{ 162 txBegin: n.txBegin, 163 tmRead: n.tmRead, 164 tmWrite: n.tmWrite, 165 } 166 167 c.Conn, err = sqlite3.Open(n.name) 168 if err != nil { 169 return nil, err 170 } 171 defer func() { 172 if err != nil { 173 c.Close() 174 } 175 }() 176 177 old := c.Conn.SetInterrupt(ctx) 178 defer c.Conn.SetInterrupt(old) 179 180 if !n.pragmas { 181 err = c.Conn.BusyTimeout(60 * time.Second) 182 if err != nil { 183 return nil, err 184 } 185 } 186 if n.driver.Init != nil { 187 err = n.driver.Init(c.Conn) 188 if err != nil { 189 return nil, err 190 } 191 } 192 if n.pragmas || n.driver.Init != nil { 193 s, _, err := c.Conn.Prepare(`PRAGMA query_only`) 194 if err != nil { 195 return nil, err 196 } 197 if s.Step() && s.ColumnBool(0) { 198 c.readOnly = '1' 199 } else { 200 c.readOnly = '0' 201 } 202 err = s.Close() 203 if err != nil { 204 return nil, err 205 } 206 } 207 return c, nil 208 } 209 210 type conn struct { 211 *sqlite3.Conn 212 txBegin string 213 txCommit string 214 txRollback string 215 tmRead sqlite3.TimeFormat 216 tmWrite sqlite3.TimeFormat 217 readOnly byte 218 } 219 220 var ( 221 // Ensure these interfaces are implemented: 222 _ driver.ConnPrepareContext = &conn{} 223 _ driver.ExecerContext = &conn{} 224 _ driver.ConnBeginTx = &conn{} 225 _ sqlite3.DriverConn = &conn{} 226 ) 227 228 func (c *conn) Raw() *sqlite3.Conn { 229 return c.Conn 230 } 231 232 func (c *conn) Begin() (driver.Tx, error) { 233 return c.BeginTx(context.Background(), driver.TxOptions{}) 234 } 235 236 func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 237 txBegin := c.txBegin 238 c.txCommit = `COMMIT` 239 c.txRollback = `ROLLBACK` 240 241 if opts.ReadOnly { 242 txBegin = ` 243 BEGIN deferred; 244 PRAGMA query_only=on` 245 c.txRollback = ` 246 ROLLBACK; 247 PRAGMA query_only=` + string(c.readOnly) 248 c.txCommit = c.txRollback 249 } 250 251 switch opts.Isolation { 252 default: 253 return nil, util.IsolationErr 254 case 255 driver.IsolationLevel(sql.LevelDefault), 256 driver.IsolationLevel(sql.LevelSerializable): 257 break 258 } 259 260 old := c.Conn.SetInterrupt(ctx) 261 defer c.Conn.SetInterrupt(old) 262 263 err := c.Conn.Exec(txBegin) 264 if err != nil { 265 return nil, err 266 } 267 return c, nil 268 } 269 270 func (c *conn) Commit() error { 271 err := c.Conn.Exec(c.txCommit) 272 if err != nil && !c.Conn.GetAutocommit() { 273 c.Rollback() 274 } 275 return err 276 } 277 278 func (c *conn) Rollback() error { 279 err := c.Conn.Exec(c.txRollback) 280 if errors.Is(err, sqlite3.INTERRUPT) { 281 old := c.Conn.SetInterrupt(context.Background()) 282 defer c.Conn.SetInterrupt(old) 283 err = c.Conn.Exec(c.txRollback) 284 } 285 return err 286 } 287 288 func (c *conn) Prepare(query string) (driver.Stmt, error) { 289 return c.PrepareContext(context.Background(), query) 290 } 291 292 func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 293 old := c.Conn.SetInterrupt(ctx) 294 defer c.Conn.SetInterrupt(old) 295 296 s, tail, err := c.Conn.Prepare(query) 297 if err != nil { 298 return nil, err 299 } 300 if tail != "" { 301 s.Close() 302 return nil, util.TailErr 303 } 304 return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite}, nil 305 } 306 307 func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 308 if len(args) != 0 { 309 // Slow path. 310 return nil, driver.ErrSkip 311 } 312 313 if savept, ok := ctx.(*saveptCtx); ok { 314 // Called from driver.Savepoint. 315 savept.Savepoint = c.Conn.Savepoint() 316 return resultRowsAffected(0), nil 317 } 318 319 old := c.Conn.SetInterrupt(ctx) 320 defer c.Conn.SetInterrupt(old) 321 322 err := c.Conn.Exec(query) 323 if err != nil { 324 return nil, err 325 } 326 327 return newResult(c.Conn), nil 328 } 329 330 func (c *conn) CheckNamedValue(arg *driver.NamedValue) error { 331 return nil 332 } 333 334 type stmt struct { 335 *sqlite3.Stmt 336 tmWrite sqlite3.TimeFormat 337 tmRead sqlite3.TimeFormat 338 } 339 340 var ( 341 // Ensure these interfaces are implemented: 342 _ driver.StmtExecContext = &stmt{} 343 _ driver.StmtQueryContext = &stmt{} 344 _ driver.NamedValueChecker = &stmt{} 345 ) 346 347 func (s *stmt) NumInput() int { 348 n := s.Stmt.BindCount() 349 for i := 1; i <= n; i++ { 350 if s.Stmt.BindName(i) != "" { 351 return -1 352 } 353 } 354 return n 355 } 356 357 // Deprecated: use ExecContext instead. 358 func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 359 return s.ExecContext(context.Background(), namedValues(args)) 360 } 361 362 // Deprecated: use QueryContext instead. 363 func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 364 return s.QueryContext(context.Background(), namedValues(args)) 365 } 366 367 func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 368 err := s.setupBindings(args) 369 if err != nil { 370 return nil, err 371 } 372 373 old := s.Stmt.Conn().SetInterrupt(ctx) 374 defer s.Stmt.Conn().SetInterrupt(old) 375 376 err = s.Stmt.Exec() 377 if err != nil { 378 return nil, err 379 } 380 381 return newResult(s.Stmt.Conn()), nil 382 } 383 384 func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 385 err := s.setupBindings(args) 386 if err != nil { 387 return nil, err 388 } 389 return &rows{ctx: ctx, stmt: s}, nil 390 } 391 392 func (s *stmt) setupBindings(args []driver.NamedValue) error { 393 err := s.Stmt.ClearBindings() 394 if err != nil { 395 return err 396 } 397 398 var ids [3]int 399 for _, arg := range args { 400 ids := ids[:0] 401 if arg.Name == "" { 402 ids = append(ids, arg.Ordinal) 403 } else { 404 for _, prefix := range []string{":", "@", "$"} { 405 if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 { 406 ids = append(ids, id) 407 } 408 } 409 } 410 411 for _, id := range ids { 412 switch a := arg.Value.(type) { 413 case bool: 414 err = s.Stmt.BindBool(id, a) 415 case int: 416 err = s.Stmt.BindInt(id, a) 417 case int64: 418 err = s.Stmt.BindInt64(id, a) 419 case float64: 420 err = s.Stmt.BindFloat(id, a) 421 case string: 422 err = s.Stmt.BindText(id, a) 423 case []byte: 424 err = s.Stmt.BindBlob(id, a) 425 case sqlite3.ZeroBlob: 426 err = s.Stmt.BindZeroBlob(id, int64(a)) 427 case time.Time: 428 err = s.Stmt.BindTime(id, a, s.tmWrite) 429 case util.JSON: 430 err = s.Stmt.BindJSON(id, a.Value) 431 case util.PointerUnwrap: 432 err = s.Stmt.BindPointer(id, util.UnwrapPointer(a)) 433 case nil: 434 err = s.Stmt.BindNull(id) 435 default: 436 panic(util.AssertErr()) 437 } 438 } 439 if err != nil { 440 return err 441 } 442 } 443 return nil 444 } 445 446 func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { 447 switch arg.Value.(type) { 448 case bool, int, int64, float64, string, []byte, 449 time.Time, sqlite3.ZeroBlob, 450 util.JSON, util.PointerUnwrap, 451 nil: 452 return nil 453 default: 454 return driver.ErrSkip 455 } 456 } 457 458 func newResult(c *sqlite3.Conn) driver.Result { 459 rows := c.Changes() 460 if rows != 0 { 461 id := c.LastInsertRowID() 462 if id != 0 { 463 return result{id, rows} 464 } 465 } 466 return resultRowsAffected(rows) 467 } 468 469 type result struct{ lastInsertId, rowsAffected int64 } 470 471 func (r result) LastInsertId() (int64, error) { 472 return r.lastInsertId, nil 473 } 474 475 func (r result) RowsAffected() (int64, error) { 476 return r.rowsAffected, nil 477 } 478 479 type resultRowsAffected int64 480 481 func (r resultRowsAffected) LastInsertId() (int64, error) { 482 return 0, nil 483 } 484 485 func (r resultRowsAffected) RowsAffected() (int64, error) { 486 return int64(r), nil 487 } 488 489 type rows struct { 490 ctx context.Context 491 *stmt 492 names []string 493 types []string 494 } 495 496 func (r *rows) Close() error { 497 r.Stmt.ClearBindings() 498 return r.Stmt.Reset() 499 } 500 501 func (r *rows) Columns() []string { 502 if r.names == nil { 503 count := r.Stmt.ColumnCount() 504 r.names = make([]string, count) 505 for i := range r.names { 506 r.names[i] = r.Stmt.ColumnName(i) 507 } 508 } 509 return r.names 510 } 511 512 func (r *rows) declType(index int) string { 513 if r.types == nil { 514 count := r.Stmt.ColumnCount() 515 r.types = make([]string, count) 516 for i := range r.types { 517 r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) 518 } 519 } 520 return r.types[index] 521 } 522 523 func (r *rows) ColumnTypeDatabaseTypeName(index int) string { 524 decltype := r.declType(index) 525 if len := len(decltype); len > 0 && decltype[len-1] == ')' { 526 if i := strings.LastIndexByte(decltype, '('); i >= 0 { 527 decltype = decltype[:i] 528 } 529 } 530 return strings.TrimSpace(decltype) 531 } 532 533 func (r *rows) Next(dest []driver.Value) error { 534 old := r.Stmt.Conn().SetInterrupt(r.ctx) 535 defer r.Stmt.Conn().SetInterrupt(old) 536 537 if !r.Stmt.Step() { 538 if err := r.Stmt.Err(); err != nil { 539 return err 540 } 541 return io.EOF 542 } 543 544 data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest)) 545 err := r.Stmt.Columns(data) 546 for i := range dest { 547 if t, ok := r.decodeTime(i, dest[i]); ok { 548 dest[i] = t 549 continue 550 } 551 if s, ok := dest[i].(string); ok { 552 t, ok := maybeTime(s) 553 if ok { 554 dest[i] = t 555 } 556 } 557 } 558 return err 559 } 560 561 func (r *rows) decodeTime(i int, v any) (_ time.Time, _ bool) { 562 if r.tmRead == sqlite3.TimeFormatDefault { 563 return 564 } 565 switch r.declType(i) { 566 case "DATE", "TIME", "DATETIME", "TIMESTAMP": 567 // maybe 568 default: 569 return 570 } 571 switch v.(type) { 572 case int64, float64, string: 573 // maybe 574 default: 575 return 576 } 577 t, err := r.tmRead.Decode(v) 578 return t, err == nil 579 }