github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlite_test.go (about) 1 // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sqlite 6 7 import ( 8 "context" 9 "database/sql" 10 "database/sql/driver" 11 "expvar" 12 "fmt" 13 "os" 14 "reflect" 15 "runtime" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 21 "github.com/tailscale/sqlite/sqliteh" 22 ) 23 24 func TestOpenDB(t *testing.T) { 25 db := openTestDB(t) 26 var journalMode string 27 if err := db.QueryRow("PRAGMA journal_mode;").Scan(&journalMode); err != nil { 28 t.Fatal(err) 29 } 30 if want := "wal"; journalMode != want { 31 t.Errorf("journal_mode=%q, want %q", journalMode, want) 32 } 33 var synchronous string 34 if err := db.QueryRow("PRAGMA synchronous;").Scan(&synchronous); err != nil { 35 t.Fatal(err) 36 } 37 if want := "0"; synchronous != want { 38 t.Errorf("synchronous=%q, want %q", synchronous, want) 39 } 40 if err := db.Close(); err != nil { 41 t.Fatal(err) 42 } 43 } 44 45 func configDB(t testing.TB, db *sql.DB) { 46 if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { 47 t.Fatal(err) 48 } 49 if _, err := db.Exec("PRAGMA synchronous=OFF"); err != nil { 50 t.Fatal(err) 51 } 52 numConns := runtime.GOMAXPROCS(0) 53 db.SetMaxOpenConns(numConns) 54 db.SetMaxIdleConns(numConns) 55 db.SetConnMaxLifetime(0) 56 db.SetConnMaxIdleTime(0) 57 t.Cleanup(func() { db.Close() }) 58 } 59 60 func getUsesAfterClose() (ret int64) { 61 UsesAfterClose.Do(func(kv expvar.KeyValue) { 62 ret += kv.Value.(*expvar.Int).Value() 63 }) 64 return ret 65 } 66 67 func checkBadUsageDB(t testing.TB, db *sql.DB) { 68 initial := getUsesAfterClose() 69 t.Cleanup(func() { 70 final := getUsesAfterClose() 71 if initial != final { 72 t.Errorf("%d uses after finalization != %d final value", initial, final) 73 } 74 }) 75 } 76 77 func openTestDB(t testing.TB) *sql.DB { 78 t.Helper() 79 db, err := sql.Open("sqlite3", "file:"+t.TempDir()+"/test.db") 80 if err != nil { 81 t.Fatal(err) 82 } 83 configDB(t, db) 84 checkBadUsageDB(t, db) 85 return db 86 } 87 88 func openTestDBTrace(t testing.TB, tracer sqliteh.Tracer) *sql.DB { 89 t.Helper() 90 db := sql.OpenDB(Connector("file:"+t.TempDir()+"/test.db", nil, tracer)) 91 configDB(t, db) 92 return db 93 } 94 95 // execContexter is an *sql.DB or an *sql.Tx. 96 type execContexter interface { 97 ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 98 } 99 100 func exec(t *testing.T, db execContexter, query string, args ...any) sql.Result { 101 t.Helper() 102 ctx := context.Background() 103 res, err := db.ExecContext(ctx, query, args...) 104 if err != nil { 105 t.Fatal(err) 106 } 107 return res 108 } 109 110 func TestTrailingTextError(t *testing.T) { 111 db := openTestDB(t) 112 _, err := db.Exec("PRAGMA journal_mode=WAL; PRAGMA synchronous=OFF;") 113 if err == nil { 114 t.Error("missing error from trailing command") 115 } 116 if !strings.Contains(err.Error(), "trailing text") { 117 t.Errorf("error does not mention 'trailing text': %v", err) 118 } 119 } 120 121 func TestInsertResults(t *testing.T) { 122 db := openTestDB(t) 123 exec(t, db, "CREATE TABLE t (c)") 124 res := exec(t, db, "INSERT INTO t VALUES ('a')") 125 if id, err := res.LastInsertId(); err != nil { 126 t.Fatal(err) 127 } else if id != 1 { 128 t.Errorf("LastInsertId=%d, want 1", id) 129 } 130 if rows, err := res.RowsAffected(); err != nil { 131 t.Fatal(err) 132 } else if rows != 1 { 133 t.Errorf("RowsAffected=%d, want 1", rows) 134 } 135 136 res = exec(t, db, "INSERT INTO t VALUES ('b')") 137 if id, err := res.LastInsertId(); err != nil { 138 t.Fatal(err) 139 } else if id != 2 { 140 t.Errorf("LastInsertId=%d, want 1", id) 141 } 142 143 exec(t, db, "INSERT INTO t VALUES ('c')") 144 exec(t, db, "CREATE TABLE t2 (c)") 145 res = exec(t, db, "INSERT INTO t2 SELECT c from t;") 146 if id, err := res.LastInsertId(); err != nil { 147 t.Fatal(err) 148 } else if id != 3 { 149 t.Errorf("LastInsertId=%d, want 1", id) 150 } 151 if rows, err := res.RowsAffected(); err != nil { 152 t.Fatal(err) 153 } else if rows != 3 { 154 t.Errorf("RowsAffected=%d, want 1", rows) 155 } 156 } 157 158 func TestExecAndScanSequence(t *testing.T) { 159 db := openTestDB(t) 160 exec(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)") 161 exec(t, db, "INSERT INTO t VALUES (?, ?)", 10, "skip") 162 exec(t, db, "INSERT INTO t VALUES (?, ?)", 100, "a") 163 exec(t, db, "INSERT INTO t VALUES (?, ?)", 200, "b") 164 exec(t, db, "INSERT INTO t VALUES (?, ?)", 300, "c") 165 exec(t, db, "INSERT INTO t VALUES (?, ?)", 400, "d") 166 exec(t, db, "INSERT INTO t VALUES (?, ?)", 401, "skip") 167 168 rows, err := db.Query("SELECT * FROM t WHERE id >= ? AND id <= :max", 100, sql.Named("max", 400)) 169 if err != nil { 170 t.Fatal(err) 171 } 172 for i := 0; i < 4; i++ { 173 if !rows.Next() { 174 t.Fatalf("pass %d: Next=false", i) 175 } 176 var id int64 177 var val string 178 if err := rows.Scan(&id, &val); err != nil { 179 t.Fatalf("pass %d: Scan: %v", i, err) 180 } 181 if want := int64(i+1) * 100; id != want { 182 t.Fatalf("pass %d: id=%d, want %d", i, id, want) 183 } 184 if want := string([]byte{'a' + byte(i)}); val != want { 185 t.Fatalf("pass %d: val=%q want %q", i, val, want) 186 } 187 } 188 if rows.Next() { 189 t.Fatal("too many rows") 190 } 191 if err := rows.Err(); err != nil { 192 t.Fatal(err) 193 } 194 if err := rows.Close(); err != nil { 195 t.Fatal(err) 196 } 197 } 198 199 func TestTx(t *testing.T) { 200 ctx := context.Background() 201 db := openTestDB(t) 202 203 tx, err := db.BeginTx(ctx, nil) 204 if err != nil { 205 t.Fatal(err) 206 } 207 exec(t, tx, "CREATE TABLE t (c);") 208 exec(t, tx, "INSERT INTO t VALUES (1);") 209 if err := tx.Commit(); err != nil { 210 t.Fatal(err) 211 } 212 if err := tx.Rollback(); err == nil { 213 t.Fatal("rollback of committed Tx did not error") 214 } 215 216 tx, err = db.BeginTx(ctx, nil) 217 if err != nil { 218 t.Fatal(err) 219 } 220 exec(t, tx, "INSERT INTO t VALUES (2);") 221 if err := tx.Rollback(); err != nil { 222 t.Fatal(err) 223 } 224 225 var count int 226 if err := db.QueryRowContext(ctx, "SELECT count(*) FROM t").Scan(&count); err != nil { 227 t.Fatal(err) 228 } 229 if count != 1 { 230 t.Errorf("count=%d, want 1", count) 231 } 232 } 233 234 func TestValueConversion(t *testing.T) { 235 db := openTestDB(t) 236 var cInt int64 237 var cFloat float64 238 var cText string 239 var cBlob []byte 240 var cNull *string 241 err := db.QueryRowContext(context.Background(), `SELECT 242 CAST(4 AS INTEGER), 243 CAST(4.0 AS FLOAT), 244 CAST('txt' AS TEXT), 245 CAST('txt' AS BLOB), 246 NULL`).Scan(&cInt, &cFloat, &cText, &cBlob, &cNull) 247 if err != nil { 248 t.Fatal(err) 249 } 250 if cInt != 4 { 251 t.Errorf("cInt=%d, want 4", cInt) 252 } 253 if cFloat != 4.0 { 254 t.Errorf("cFloat=%v, want 4.0", cFloat) 255 } 256 if cText != "txt" { 257 t.Errorf("cText=%v, want 'txt'", cText) 258 } 259 if string(cBlob) != "txt" { 260 t.Errorf("cBlob=%v, want 'txt'", cBlob) 261 } 262 if cNull != nil { 263 t.Errorf("cNull=%v, want nil", cNull) 264 } 265 } 266 267 func TestTime(t *testing.T) { 268 t1Str := "2021-06-08 11:36:52.444-0700" 269 t1, err := time.Parse(TimeFormat, t1Str) 270 if err != nil { 271 t.Fatal(err) 272 } 273 var t2 time.Time 274 275 db := openTestDB(t) 276 exec(t, db, "CREATE TABLE t (c DATETIME)") 277 exec(t, db, "INSERT INTO t VALUES (?)", t1) 278 err = db.QueryRowContext(context.Background(), "SELECT c FROM t").Scan(&t2) 279 if err != nil { 280 t.Fatal(err) 281 } 282 var txt string 283 err = db.QueryRowContext(context.Background(), "SELECT CAST(c AS TEXT) FROM t").Scan(&txt) 284 if err != nil { 285 t.Fatal(err) 286 } 287 if want := t1Str; txt != want { 288 t.Errorf("time stored as %q, want %q", txt, want) 289 } 290 291 exec(t, db, "CREATE TABLE t2 (c FOOD)") 292 exec(t, db, "INSERT INTO t2 VALUES (?)", t1) 293 if err := db.QueryRowContext(context.Background(), "SELECT c FROM t2").Scan(&t2); err == nil { 294 t.Fatal("expect an error trying to interpet FOOD as Time") 295 } 296 } 297 298 func TestShortTimes(t *testing.T) { 299 var tests = []struct { 300 in string 301 want time.Time 302 }{ 303 {in: "2021-06-08 11:36:52.128+0000", want: time.Date(2021, 6, 8, 11, 36, 52, 128*1e6, time.UTC)}, 304 {in: "2021-06-08 11:36:52", want: time.Date(2021, 6, 8, 11, 36, 52, 0, time.UTC)}, 305 {in: "2021-06-08 11:36", want: time.Date(2021, 6, 8, 11, 36, 0, 0, time.UTC)}, 306 } 307 308 for _, tt := range tests { 309 db := openTestDB(t) 310 exec(t, db, "CREATE TABLE t (c DATETIME)") 311 exec(t, db, "INSERT INTO t VALUES (?)", tt.in) 312 var got time.Time 313 if err := db.QueryRowContext(context.Background(), "SELECT c FROM t").Scan(&got); err != nil { 314 t.Fatal(err) 315 } 316 if !got.Equal(tt.want) { 317 t.Errorf("in=%v, want=%v, got=%v", tt.in, tt.want, got) 318 } 319 } 320 } 321 322 func TestEmptyString(t *testing.T) { 323 db := openTestDB(t) 324 exec(t, db, "CREATE TABLE t (c)") 325 exec(t, db, "INSERT INTO t VALUES (?)", "") 326 exec(t, db, "INSERT INTO t VALUES (?)", "") 327 var count int 328 if err := db.QueryRowContext(context.Background(), "SELECT count(*) FROM t").Scan(&count); err != nil { 329 t.Fatal(err) 330 } 331 if count != 2 { 332 t.Fatalf("count=%d, want 2", count) 333 } 334 } 335 336 func TestExecScript(t *testing.T) { 337 db := openTestDB(t) 338 conn, err := db.Conn(context.Background()) 339 if err != nil { 340 t.Fatal(err) 341 } 342 defer conn.Close() 343 err = ExecScript(conn, `BEGIN; 344 CREATE TABLE t (c); 345 INSERT INTO t VALUES ('a'); 346 INSERT INTO t VALUES ('b'); 347 COMMIT;`) 348 if err != nil { 349 t.Fatal(err) 350 } 351 var count int 352 if err := db.QueryRowContext(context.Background(), "SELECT count(*) FROM t").Scan(&count); err != nil { 353 t.Fatal(err) 354 } 355 if count != 2 { 356 t.Fatalf("count=%d, want 2", count) 357 } 358 } 359 360 func TestWithPersist(t *testing.T) { 361 db := openTestDB(t) 362 exec(t, db, "CREATE TABLE t (c)") 363 364 ctx := context.Background() 365 sqlConn, err := db.Conn(ctx) 366 if err != nil { 367 t.Fatal(err) 368 } 369 defer sqlConn.Close() 370 371 ins := "INSERT INTO t VALUES (?)" 372 if _, err := sqlConn.ExecContext(ctx, ins, 1); err != nil { 373 t.Fatal(err) 374 } 375 376 err = sqlConn.Raw(func(driverConn any) error { 377 c := driverConn.(*conn) 378 if c.stmts[ins] != nil { 379 return fmt.Errorf("query %q was persisted", ins) 380 } 381 return nil 382 }) 383 if err != nil { 384 t.Fatal(err) 385 } 386 387 if _, err := sqlConn.ExecContext(WithPersist(ctx), ins, 2); err != nil { 388 t.Fatal(err) 389 } 390 err = sqlConn.Raw(func(driverConn any) error { 391 c := driverConn.(*conn) 392 if c.stmts[ins] == nil { 393 return fmt.Errorf("query %q was not persisted", ins) 394 } 395 return nil 396 }) 397 if err != nil { 398 t.Fatal(err) 399 } 400 } 401 402 func TestWithQueryCancel(t *testing.T) { 403 // This test query runs forever until interrupted. 404 const testQuery = `WITH RECURSIVE inf(n) AS ( 405 SELECT 1 406 UNION ALL 407 SELECT n+1 FROM inf 408 ) SELECT * FROM inf WHERE n = 0` 409 410 db := openTestDB(t) 411 412 done := make(chan struct{}) 413 go func() { 414 defer close(done) 415 416 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 417 defer cancel() 418 419 rows, err := db.QueryContext(WithQueryCancel(ctx), testQuery) 420 if err != nil { 421 t.Fatalf("QueryContext: unexpected error: %v", err) 422 } 423 for rows.Next() { 424 t.Error("Next result available before timeout") 425 } 426 if err := rows.Err(); err == nil { 427 t.Error("Rows did not report an error") 428 } else if !strings.Contains(err.Error(), "SQLITE_INTERRUPT") { 429 t.Errorf("Rows err=%v, want SQLITE_INTERRUPT", err) 430 } 431 }() 432 433 select { 434 case <-done: 435 // OK 436 case <-time.After(30 * time.Second): 437 t.Fatal("Timeout waiting for query to end") 438 } 439 } 440 441 func TestErrors(t *testing.T) { 442 db := openTestDB(t) 443 exec(t, db, "CREATE TABLE t (c)") 444 exec(t, db, "INSERT INTO t (c) VALUES (1)") 445 exec(t, db, "INSERT INTO t (c) VALUES (2)") 446 447 ctx := context.Background() 448 rows, err := db.QueryContext(ctx, "SELECT c FROM t;") 449 if err != nil { 450 t.Fatal(err) 451 } 452 exec(t, db, "DROP TABLE t") 453 if rows.Next() { 454 t.Error("rows") 455 } 456 err = rows.Err() 457 if err == nil { 458 t.Fatal("no error") 459 } 460 // Test use of ErrMsg to elaborate on the error. 461 if want := "no such table: t"; !strings.Contains(err.Error(), want) { 462 t.Errorf("err=%v, want %q", err, want) 463 } 464 465 conn, err := db.Conn(ctx) 466 if err != nil { 467 t.Fatal(err) 468 } 469 defer conn.Close() 470 err = ExecScript(conn, `BEGIN; NOT VALID SQL;`) 471 if err == nil { 472 t.Fatal("no error") 473 } 474 if want := `near "NOT": syntax error`; !strings.Contains(err.Error(), want) { 475 t.Errorf("err=%v, want %q", err, want) 476 } 477 if err := ExecScript(conn, "ROLLBACK;"); err != nil { // TODO: make unnecessary? 478 t.Fatal(err) 479 } 480 481 err = ExecScript(conn, `CREATE TABLE t (c INTEGER PRIMARY KEY); 482 INSERT INTO t (c) VALUES (1); 483 INSERT INTO t (c) VALUES (1);`) 484 if err == nil { 485 t.Fatal("no error") 486 } 487 if want := `UNIQUE constraint failed: t.c`; !strings.Contains(err.Error(), want) { 488 t.Errorf("err=%v, want %q", err, want) 489 } 490 491 _, err = conn.ExecContext(ctx, "INSERT INTO t (c) VALUES (1);") 492 if err == nil { 493 t.Fatal("no error") 494 } 495 if want := `Stmt.Exec: SQLITE_CONSTRAINT: UNIQUE constraint failed: t.c`; !strings.Contains(err.Error(), want) { 496 t.Errorf("err=%v, want %q", err, want) 497 } 498 } 499 500 func TestCheckpoint(t *testing.T) { 501 dbFile := t.TempDir() + "/test.db" 502 db, err := sql.Open("sqlite3", "file:"+dbFile) 503 if err != nil { 504 t.Fatal(err) 505 } 506 defer db.Close() 507 if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { 508 t.Fatal(err) 509 } 510 511 ctx := context.Background() 512 conn, err := db.Conn(ctx) 513 if err != nil { 514 t.Fatal(err) 515 } 516 defer conn.Close() 517 err = ExecScript(conn, `CREATE TABLE t (c); 518 INSERT INTO t (c) VALUES (1); 519 INSERT INTO t (c) VALUES (1);`) 520 if err != nil { 521 t.Fatal(err) 522 } 523 524 if fi, err := os.Stat(dbFile + "-wal"); err != nil { 525 t.Fatal(err) 526 } else if fi.Size() == 0 { 527 t.Fatal("WAL is empty") 528 } else { 529 t.Logf("WAL is %d bytes", fi.Size()) 530 } 531 532 if _, _, err := Checkpoint(conn, "", sqliteh.SQLITE_CHECKPOINT_TRUNCATE); err != nil { 533 t.Fatal(err) 534 } 535 536 if fi, err := os.Stat(dbFile + "-wal"); err != nil { 537 t.Fatal(err) 538 } else if fi.Size() != 0 { 539 t.Fatal("WAL is not empty") 540 } 541 } 542 543 type queryTraceEvent struct { 544 prepCtx context.Context 545 query string 546 duration time.Duration 547 err error 548 } 549 550 type queryTracer struct { 551 evCh chan queryTraceEvent 552 } 553 554 func (t *queryTracer) Query(prepCtx context.Context, id sqliteh.TraceConnID, query string, duration time.Duration, err error) { 555 t.evCh <- queryTraceEvent{prepCtx, query, duration, err} 556 } 557 func (t *queryTracer) BeginTx(_ context.Context, _ sqliteh.TraceConnID, _ string, _ bool, _ error) {} 558 func (t *queryTracer) Commit(_ sqliteh.TraceConnID, _ error) { 559 } 560 func (t *queryTracer) Rollback(_ sqliteh.TraceConnID, _ error) { 561 } 562 563 func TestTraceQuery(t *testing.T) { 564 tracer := &queryTracer{ 565 evCh: make(chan queryTraceEvent, 16), 566 } 567 type ctxKey struct{} 568 expectEv := func(srcCtx context.Context, query string, errSubstr string) { 569 t.Helper() 570 ev, ok := <-tracer.evCh 571 if !ok { 572 t.Fatal("trace: no event") 573 } 574 if ev.prepCtx == nil { 575 t.Errorf("trace: prepCtx==nil") 576 } else if want, got := srcCtx.Value(ctxKey{}), ev.prepCtx.Value(ctxKey{}); want != got { 577 t.Errorf("trace: prepCtx value=%v, want %v", got, want) 578 } 579 if ev.query != query { 580 t.Errorf("trace: query=%q, want %q", ev.query, query) 581 } 582 switch { 583 case ev.err == nil && errSubstr != "": 584 t.Errorf("trace: err=nil, want %q", errSubstr) 585 case ev.err != nil && errSubstr == "": 586 t.Errorf("trace: err=%v, want nil", ev.err) 587 case ev.err != nil && !strings.Contains(ev.err.Error(), errSubstr): 588 t.Errorf("trace: err=%v, want %v", ev.err, errSubstr) 589 } 590 if ev.duration <= 0 || ev.duration > 10*time.Minute { 591 // The macOS clock appears to low resolution and so 592 // it's common to get a duration of exactly 0s. 593 if runtime.GOOS != "darwin" || ev.duration != 0 { 594 t.Errorf("trace: improbable duration: %v", ev.duration) 595 } 596 } 597 } 598 db := openTestDBTrace(t, tracer) 599 noErr := "" 600 expectEv(context.Background(), "PRAGMA journal_mode=WAL", noErr) // from configDB 601 expectEv(context.Background(), "PRAGMA synchronous=OFF", noErr) 602 603 execCtx := func(ctx context.Context, query string, args ...any) { 604 t.Helper() 605 if _, err := db.ExecContext(ctx, query, args...); err != nil { 606 t.Fatal(err) 607 } 608 expectEv(ctx, query, noErr) 609 } 610 ctx := WithPersist(context.Background()) 611 ctx = context.WithValue(ctx, ctxKey{}, 7) 612 execCtx(ctx, "CREATE TABLE t (c)") 613 614 ins := "INSERT INTO t VALUES (?)" 615 execCtx(ctx, ins, 1) 616 execCtx(WithPersist(ctx), ins, 2) 617 618 execCtx(ctx, "SELECT null LIMIT 0;") 619 620 rows, err := db.QueryContext(ctx, "SELECT * FROM t") 621 if err != nil { 622 t.Fatal(err) 623 } 624 for rows.Next() { 625 var val int64 626 if err := rows.Scan(&val); err != nil { 627 t.Fatal(err) 628 } 629 } 630 if err := rows.Err(); err != nil { 631 t.Fatal(err) 632 } 633 if err := rows.Close(); err != nil { 634 t.Fatal(err) 635 } 636 expectEv(ctx, "SELECT * FROM t", noErr) 637 638 _, err = db.ExecContext(ctx, "DELETOR") 639 if err == nil { 640 t.Fatal(err) 641 } 642 expectEv(ctx, "DELETOR", err.Error()) 643 644 execCtx(context.WithValue(ctx, ctxKey{}, 9), "CREATE TABLE t2 (c INTEGER PRIMARY KEY)") 645 execCtx(ctx, "INSERT INTO t2 (c) VALUES (1)") 646 _, err = db.ExecContext(ctx, "INSERT INTO t2 (c) VALUES (1)") 647 if err == nil { 648 t.Fatal(err) 649 } 650 expectEv(ctx, "INSERT INTO t2 (c) VALUES (1)", "UNIQUE constraint failed") 651 } 652 653 func TestTxnState(t *testing.T) { 654 dbFile := t.TempDir() + "/test.db" 655 db, err := sql.Open("sqlite3", "file:"+dbFile) 656 if err != nil { 657 t.Fatal(err) 658 } 659 defer db.Close() 660 if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { 661 t.Fatal(err) 662 } 663 664 ctx := context.Background() 665 sqlConn, err := db.Conn(ctx) 666 if err != nil { 667 t.Fatal(err) 668 } 669 defer sqlConn.Close() 670 if state, err := TxnState(sqlConn, ""); err != nil { 671 t.Fatal(err) 672 } else if state != sqliteh.SQLITE_TXN_NONE { 673 t.Errorf("state=%v, want SQLITE_TXN_NONE", state) 674 } 675 if err := ExecScript(sqlConn, "BEGIN; CREATE TABLE t (c);"); err != nil { 676 t.Fatal(err) 677 } 678 if state, err := TxnState(sqlConn, ""); err != nil { 679 t.Fatal(err) 680 } else if state != sqliteh.SQLITE_TXN_WRITE { 681 t.Errorf("state=%v, want SQLITE_TXN_WRITE", state) 682 } 683 if err := ExecScript(sqlConn, "COMMIT; BEGIN; SELECT * FROM (t);"); err != nil { 684 t.Fatal(err) 685 } 686 if state, err := TxnState(sqlConn, ""); err != nil { 687 t.Fatal(err) 688 } else if state != sqliteh.SQLITE_TXN_READ { 689 t.Errorf("state=%v, want SQLITE_TXN_READ", state) 690 } 691 } 692 693 func TestConnInit(t *testing.T) { 694 called := 0 695 uri := "file:" + t.TempDir() + "/test.db" 696 connInitFunc := func(ctx context.Context, conn driver.ConnPrepareContext) error { 697 called++ 698 return ExecScript(conn.(SQLConn), "PRAGMA journal_mode=WAL;") 699 } 700 db := sql.OpenDB(Connector(uri, connInitFunc, nil)) 701 conn, err := db.Conn(context.Background()) 702 if err != nil { 703 t.Fatal(err) 704 } 705 if called == 0 { 706 t.Fatal("called=0, want non-zero") 707 } 708 conn.Close() 709 db.Close() 710 } 711 712 func TestTxReadOnly(t *testing.T) { 713 ctx := context.Background() 714 db := openTestDB(t) 715 716 tx, err := db.BeginTx(ctx, nil) 717 if err != nil { 718 t.Fatal(err) 719 } 720 if _, err := tx.Exec("create table t (c)"); err != nil { 721 t.Fatal(err) 722 } 723 if _, err := tx.Exec("insert into t (c) values (1)"); err != nil { 724 t.Fatal(err) 725 } 726 if err := tx.Commit(); err != nil { 727 t.Fatal(err) 728 } 729 730 tx, err = db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) 731 if err != nil { 732 t.Fatal(err) 733 } 734 var count int 735 if err := tx.QueryRow("select count(*) from t").Scan(&count); err != nil { 736 t.Fatal(err) 737 } 738 if count != 1 { 739 t.Errorf("count=%d, want 1", count) 740 } 741 if _, err := tx.Exec("insert into t (c) values (1)"); err == nil { 742 t.Fatal("no error on read-only insert") 743 } else if !strings.Contains(err.Error(), "SQLITE_READONLY") { 744 t.Errorf("err does not reference SQLITE_READONLY: %v", err) 745 } 746 if err := tx.Rollback(); err != nil { 747 t.Fatal(err) 748 } 749 750 tx, err = db.BeginTx(ctx, nil) 751 if err != nil { 752 t.Fatal(err) 753 } 754 if _, err := tx.Exec("insert into t (c) values (1)"); err != nil { 755 t.Fatal(err) 756 } 757 if err := tx.Commit(); err != nil { 758 t.Fatal(err) 759 } 760 } 761 762 // TestAttachOrderingDeadlock fails if transactions use SQLite's default 763 // BEGIN DEFERRED semantics, as the two databases locks are acquired in 764 // the wrong order. This tests that BEGIN IMMEDIATE resolves this. 765 func TestAttachOrderingDeadlock(t *testing.T) { 766 ctx := context.Background() 767 tmpdir := t.TempDir() 768 db := sql.OpenDB(Connector("file:"+tmpdir+"/test.db", func(ctx context.Context, conn driver.ConnPrepareContext) error { 769 c := conn.(SQLConn) 770 err := ExecScript(c, ` 771 ATTACH DATABASE "file:`+tmpdir+`/test2.db" AS attached; 772 PRAGMA busy_timeout=10000; 773 PRAGMA main.journal_mode=WAL; 774 PRAGMA attached.journal_mode=WAL; 775 CREATE TABLE IF NOT EXISTS main.m1 (c); 776 CREATE TABLE IF NOT EXISTS main.m2 (c); 777 CREATE TABLE IF NOT EXISTS attached.a1 (c); 778 CREATE TABLE IF NOT EXISTS attached.a2 (c); 779 780 `) 781 if err != nil { 782 return err 783 } 784 return nil 785 }, nil)) 786 defer db.Close() 787 788 // Prime the connections. 789 const numConcurrent = 10 790 db.SetMaxOpenConns(numConcurrent) 791 db.SetMaxIdleConns(numConcurrent) 792 db.SetConnMaxLifetime(0) 793 db.SetConnMaxIdleTime(0) 794 var conns []*sql.Conn 795 for i := 0; i < numConcurrent; i++ { 796 c, err := db.Conn(ctx) 797 if err != nil { 798 t.Fatal(err) 799 } 800 conns = append(conns, c) 801 } 802 for _, c := range conns { 803 c.Close() 804 } 805 806 lockTables := func(name string, tables ...string) { 807 tx, err := db.BeginTx(ctx, nil) 808 if err != nil { 809 t.Error(err) 810 return 811 } 812 defer tx.Rollback() 813 814 for _, table := range tables { 815 // Read from and write to the same table to lock it. 816 _, err := tx.ExecContext(WithPersist(ctx), `INSERT INTO `+table+` SELECT * FROM `+table+` LIMIT 0`) 817 if err != nil { 818 t.Error(err) 819 return 820 } 821 } 822 } 823 824 var wg sync.WaitGroup 825 defer wg.Wait() 826 827 // The following goroutines write to the main and noise databases in 828 // a different order. This should not result in a deadlock. 829 for i := 0; i < numConcurrent; i++ { 830 wg.Add(2) 831 go func() { 832 defer wg.Done() 833 lockTables("main-then-attached", "main.m1", "attached.a1") 834 }() 835 go func() { 836 defer wg.Done() 837 lockTables("attached-then-main", "attached.a2", "main.m2") 838 }() 839 } 840 } 841 842 func TestSetWALHook(t *testing.T) { 843 ctx := context.Background() 844 db := openTestDB(t) 845 846 var conns []*sql.Conn 847 for i := 1; i <= 2; i++ { 848 conn, err := db.Conn(ctx) 849 if err != nil { 850 t.Fatal(err) 851 } 852 defer conn.Close() 853 conns = append(conns, conn) 854 } 855 856 got := []string{} 857 for i := 1; i <= 2; i++ { 858 hookGen := i 859 for connNum, conn := range conns { 860 connNum := connNum 861 err := SetWALHook(conn, func(dbName string, pages int) { 862 s := fmt.Sprintf("conn=%d, db=%s, pages=%v", connNum, dbName, pages) 863 if hookGen == 2 { // verify our hook replacement worked 864 got = append(got, s) 865 } 866 }) 867 if err != nil { 868 t.Fatal(err) 869 } 870 } 871 } 872 873 if _, err := conns[0].ExecContext(ctx, "CREATE TABLE foo (k INT, v INT)"); err != nil { 874 t.Fatal(err) 875 } 876 if _, err := conns[1].ExecContext(ctx, "INSERT INTO foo VALUES (1, 2)"); err != nil { 877 t.Fatal(err) 878 } 879 880 // Disable the hook. 881 for _, conn := range conns { 882 if err := SetWALHook(conn, nil); err != nil { 883 t.Fatal(err) 884 } 885 } 886 // And do another write that we shouldn't get a callback for. 887 if _, err := conns[1].ExecContext(ctx, "INSERT INTO foo VALUES (2, 3)"); err != nil { 888 t.Fatal(err) 889 } 890 891 want := []string{ 892 "conn=0, db=main, pages=2", 893 "conn=1, db=main, pages=3", 894 } 895 if !reflect.DeepEqual(got, want) { 896 t.Errorf("wrong\n got: %q\nwant: %q", got, want) 897 } 898 899 // Check allocs 900 if err := SetWALHook(conns[0], func(dbName string, pages int) {}); err != nil { 901 t.Fatal(err) 902 } 903 904 stmt, err := conns[0].PrepareContext(WithPersist(ctx), "UPDATE foo SET v = v + 1 WHERE k in (SELECT k FROM foo LIMIT 1)") 905 if err != nil { 906 t.Fatal(err) 907 } 908 n := testing.AllocsPerRun(10000, func() { 909 if _, err := stmt.Exec(); err != nil { 910 t.Fatal(err) 911 } 912 }) 913 const maxAllocs = 3 // as of Go 1.20 914 if n > maxAllocs { 915 t.Errorf("allocs = %v; want no more than %v", n, maxAllocs) 916 } 917 } 918 919 // Tests that we don't remember the SQLite column types of the first row of the 920 // result set (notably the "NULL" type) and re-use it for all subsequent rows 921 // like we used to. 922 func TestNoStickyColumnTypes(t *testing.T) { 923 db := openTestDB(t) 924 exec(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, v1 ANY, v2 ANY)") 925 926 type row []any 927 r := func(v ...any) row { return v } 928 rs := func(v ...row) []row { return v } 929 tests := []struct { 930 name string 931 rows []row 932 }{ 933 {"no-null", rs( 934 r("a", "b"), 935 r("foo", "bar"))}, 936 {"only-null", rs( 937 r(nil, nil))}, 938 {"null-after-string", rs( 939 r("a", "b"), 940 r(nil, "bar"))}, 941 {"string-after-null", rs( 942 r(nil, "b"), 943 r("foo", "bar"))}, 944 {"null-after-int", rs( 945 r(101, 102), 946 r(nil, 202))}, 947 {"int-after-null", rs( 948 r(nil, 102), 949 r(201, 202))}, 950 {"changing-types-within-a-column-between-rows", rs( 951 r("foo", nil), 952 r(nil, 2), 953 r(3, "bar"))}, 954 } 955 956 // canonical maps from types we get back out from sqlite 957 // to the types we provided in the test cases above. 958 canonical := func(v any) any { 959 switch v := v.(type) { 960 default: 961 return v 962 case []byte: 963 return string(v) 964 case int64: 965 return int(v) 966 } 967 } 968 969 for _, tt := range tests { 970 t.Run(tt.name, func(t *testing.T) { 971 exec(t, db, "DELETE FROM t") 972 for primaryKey, r := range tt.rows { 973 exec(t, db, "INSERT INTO t VALUES (?, ?, ?)", append([]any{primaryKey}, r...)...) 974 } 975 rows, err := db.Query("SELECT id, v1, v2 FROM t ORDER BY id") 976 if err != nil { 977 t.Fatal(err) 978 } 979 for rows.Next() { 980 var id int 981 var v1, v2 any 982 err := rows.Scan(&id, &v1, &v2) 983 if err != nil { 984 t.Fatal(err) 985 } 986 v1, v2 = canonical(v1), canonical(v2) 987 want := tt.rows[id] 988 got := row{v1, v2} 989 t.Logf("[%v]: %T, %T", id, v1, v2) 990 if !reflect.DeepEqual(got, want) { 991 t.Errorf("row %d got %v; want %v", id, got, want) 992 } 993 } 994 if err := rows.Err(); err != nil { 995 t.Fatal(err) 996 } 997 }) 998 } 999 } 1000 1001 func TestUsesAfterClose(t *testing.T) { 1002 ctx := context.Background() 1003 1004 // Clean up metric after we're done testing it. 1005 t.Cleanup(func() { 1006 UsesAfterClose.Init() 1007 }) 1008 1009 connector := Connector("file:"+t.TempDir()+"/test.db", nil, nil) 1010 sqlConn, err := connector.Connect(ctx) 1011 if err != nil { 1012 t.Fatal(err) 1013 } 1014 conn := sqlConn.(*conn) 1015 1016 initial := getUsesAfterClose() 1017 1018 // Close the conn, then use something from the conn which triggers our 1019 // "used after close" logic. 1020 conn.Close() 1021 if _, err = conn.PrepareContext(ctx, "SELECT 1;"); err == nil { 1022 t.Error("expected error, got nil") 1023 } 1024 1025 final := getUsesAfterClose() 1026 if final != initial+1 { 1027 t.Errorf("got UsesAfterClose=%d, want %d", final, initial+1) 1028 } 1029 } 1030 1031 func BenchmarkWALHookAndExec(b *testing.B) { 1032 ctx := context.Background() 1033 db := openTestDB(b) 1034 conn, err := db.Conn(ctx) 1035 if err != nil { 1036 b.Fatal(err) 1037 } 1038 defer conn.Close() 1039 if err := SetWALHook(conn, func(dbName string, pages int) {}); err != nil { 1040 b.Fatal(err) 1041 } 1042 if _, err := conn.ExecContext(ctx, "CREATE TABLE foo (k INT, v INT)"); err != nil { 1043 b.Fatal(err) 1044 } 1045 1046 b.ReportAllocs() 1047 b.ResetTimer() 1048 1049 stmt, err := conn.PrepareContext(WithPersist(ctx), "UPDATE foo SET v=123") // will match no rows 1050 if err != nil { 1051 b.Fatal(err) 1052 } 1053 for i := 0; i < b.N; i++ { 1054 if _, err := stmt.Exec(); err != nil { 1055 b.Fatal(err) 1056 } 1057 } 1058 } 1059 1060 func BenchmarkPersist(b *testing.B) { 1061 b.ReportAllocs() 1062 ctx := context.Background() 1063 db := openTestDB(b) 1064 conn, err := db.Conn(ctx) 1065 if err != nil { 1066 b.Fatal(err) 1067 } 1068 err = ExecScript(conn, `BEGIN; 1069 CREATE TABLE t (c); 1070 INSERT INTO t VALUES ('a'); 1071 INSERT INTO t VALUES ('b'); 1072 COMMIT;`) 1073 if err != nil { 1074 b.Fatal(err) 1075 } 1076 1077 for i := 0; i < b.N; i++ { 1078 var str string 1079 if err := db.QueryRowContext(WithPersist(ctx), "SELECT c FROM t LIMIT 1").Scan(&str); err != nil { 1080 b.Fatal(err) 1081 } 1082 } 1083 } 1084 1085 func BenchmarkQueryRows100MixedTypes(b *testing.B) { 1086 b.ReportAllocs() 1087 ctx := context.Background() 1088 db := openTestDB(b) 1089 conn, err := db.Conn(ctx) 1090 if err != nil { 1091 b.Fatal(err) 1092 } 1093 err = ExecScript(conn, `BEGIN; 1094 CREATE TABLE t (id INTEGER); 1095 COMMIT;`) 1096 if err != nil { 1097 b.Fatal(err) 1098 } 1099 for i := 0; i < 100; i++ { 1100 if _, err := db.Exec("INSERT INTO t (id) VALUES (?)", i); err != nil { 1101 b.Fatal(err) 1102 } 1103 } 1104 b.ResetTimer() 1105 1106 ctx = WithPersist(ctx) 1107 1108 var id int 1109 var raw sql.RawBytes 1110 for i := 0; i < b.N; i++ { 1111 rows, err := db.QueryContext(ctx, "SELECT id, 'hello world some string' FROM t") 1112 if err != nil { 1113 b.Fatal(err) 1114 } 1115 for rows.Next() { 1116 if err := rows.Scan(&id, &raw); err != nil { 1117 b.Fatal(err) 1118 } 1119 } 1120 if err := rows.Err(); err != nil { 1121 b.Fatal(err) 1122 } 1123 } 1124 } 1125 1126 func BenchmarkEmptyExec(b *testing.B) { 1127 b.ReportAllocs() 1128 ctx := context.Background() 1129 db := openTestDB(b) 1130 ctx = WithPersist(ctx) 1131 for i := 0; i < b.N; i++ { 1132 if _, err := db.ExecContext(ctx, "SELECT null LIMIT 0;"); err != nil { 1133 b.Fatal(err) 1134 } 1135 } 1136 } 1137 1138 func BenchmarkBeginTxNoop(b *testing.B) { 1139 b.ReportAllocs() 1140 ctx := context.Background() 1141 db := openTestDB(b) 1142 for i := 0; i < b.N; i++ { 1143 tx, err := db.BeginTx(ctx, nil) 1144 if err != nil { 1145 b.Fatal(err) 1146 } 1147 if err := tx.Rollback(); err != nil { 1148 b.Fatal(err) 1149 } 1150 1151 tx, err = db.BeginTx(ctx, nil) 1152 if err != nil { 1153 b.Fatal(err) 1154 } 1155 if err := tx.Commit(); err != nil { 1156 b.Fatal(err) 1157 } 1158 } 1159 } 1160 1161 // TODO(crawshaw): test TextMarshaler 1162 // TODO(crawshaw): test named types 1163 // TODO(crawshaw): check coverage 1164 1165 // This tests that we don't give the same *stmt to two different callers that 1166 // prepare the same persistent query. See: 1167 // 1168 // https://github.com/tailscale/sqlite/issues/73 1169 func TestPrepareReuse(t *testing.T) { 1170 db := openTestDB(t) 1171 ctx := context.Background() 1172 sqlConn, err := db.Conn(ctx) 1173 if err != nil { 1174 t.Fatal(err) 1175 } 1176 defer sqlConn.Close() 1177 1178 // Insert a bunch of values into a table that we'll query to get 1179 // multiple rows back. 1180 err = ExecScript(sqlConn, 1181 `BEGIN; 1182 CREATE TABLE t (c); 1183 INSERT INTO t VALUES (1), (2), (3), (4); 1184 COMMIT;`) 1185 if err != nil { 1186 t.Fatal(err) 1187 } 1188 1189 ctx = WithPersist(ctx) 1190 1191 // Calling PrepareContext twice in a row used to return the same 1192 // statement to both callers. 1193 const query = "SELECT c FROM t ORDER BY c;" 1194 stmt1, err := sqlConn.PrepareContext(ctx, query) 1195 if err != nil { 1196 t.Fatal(err) 1197 } 1198 defer stmt1.Close() 1199 stmt2, err := sqlConn.PrepareContext(ctx, query) 1200 if err != nil { 1201 t.Fatal(err) 1202 } 1203 defer stmt2.Close() 1204 1205 rows1, err := stmt1.QueryContext(ctx) 1206 if err != nil { 1207 t.Fatal(err) 1208 } 1209 defer rows1.Close() 1210 rows2, err := stmt2.QueryContext(ctx) 1211 if err != nil { 1212 t.Fatal(err) 1213 } 1214 defer rows2.Close() 1215 1216 assertResult := func(rows *sql.Rows, want int) { 1217 t.Helper() 1218 var num int 1219 if err := rows.Scan(&num); err != nil { 1220 t.Fatalf("Scan: %v", err) 1221 } 1222 if num != want { 1223 t.Fatalf("num=%d, want %d", num, want) 1224 } 1225 } 1226 1227 // Each set of rows should get a full copy of the query results; if 1228 // these are incorrectly shared, then advancing one Rows will change 1229 // the results from the other. 1230 for i := 0; i < 4; i++ { 1231 if !rows1.Next() { 1232 t.Fatalf("[1] pass %d: Next=false", i) 1233 } 1234 if !rows2.Next() { 1235 t.Fatalf("[2] pass %d: Next=false", i) 1236 } 1237 1238 // rows2 should be different from row1 and should return a 1239 // different set of values. 1240 assertResult(rows1, i+1) 1241 assertResult(rows2, i+1) 1242 } 1243 } 1244 1245 func TestRegression(t *testing.T) { 1246 // A regression test for a query comprising only comments, which caused the 1247 // driver to panic in statement cleanup. 1248 t.Run("CommentOnlyQuery", func(t *testing.T) { 1249 db := openTestDB(t) 1250 1251 rows, err := db.Query("-- comments only\n-- nothing else") 1252 if err != nil { 1253 t.Fatalf("Query failed: %v", err) 1254 } 1255 rows.Next() 1256 t.Log("OK") // Reaching here at all means we didn't panic. 1257 }) 1258 }