github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/sqlitepool_test.go (about)

     1  package sqlitepool
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/tailscale/sqlite/sqliteh"
     9  	"github.com/tailscale/sqlite/sqlstats"
    10  )
    11  
    12  func TestPool(t *testing.T) {
    13  	ctx := context.Background()
    14  	initFn := func(db sqliteh.DB) error {
    15  		err := ExecScript(db, `
    16  			PRAGMA synchronous=OFF;
    17  			PRAGMA journal_mode=WAL;
    18  			`)
    19  		return err
    20  	}
    21  	tracer := &sqlstats.Tracer{}
    22  	tempDir := t.TempDir()
    23  	p, err := NewPool("file:"+tempDir+"/sqlitepool_test", 3, initFn, tracer)
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  
    28  	tx, err := p.BeginTx(ctx, "insert-1")
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  	if err := tx.Exec("CREATE TABLE t (c);"); err != nil {
    33  		t.Fatal(err)
    34  	}
    35  	stmt := tx.Prepare("INSERT INTO t (c) VALUES (?);")
    36  	stmt.BindInt64(1, 1)
    37  	if _, _, _, _, err := stmt.StepResult(); err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	var onCommitCalled, onRollbackCalled bool
    41  	tx.OnCommit = func() { onCommitCalled = true }
    42  	tx.OnRollback = func() { onRollbackCalled = true }
    43  	if err := tx.Commit(); err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	tx.Rollback() // no-op, does not call OnRollback
    47  	if !onCommitCalled {
    48  		t.Fatal("onCommit not called")
    49  	}
    50  	if onRollbackCalled {
    51  		t.Fatal("onRollback called")
    52  	}
    53  	if err := tx.Commit(); err == nil {
    54  		t.Fatalf("want error on second commit, got: %v", err)
    55  	}
    56  
    57  	tx, err = p.BeginTx(ctx, "insert-2")
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	stmt2 := tx.Prepare("INSERT INTO t (c) VALUES (?);")
    62  	if stmt != stmt2 {
    63  		t.Fatalf("second call to prepare returned a different stmt: %p vs. %p", stmt, stmt2)
    64  	}
    65  	stmt = stmt2
    66  	stmt.BindInt64(1, 2)
    67  	if _, _, _, _, err := stmt.StepResult(); err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	func() {
    71  		defer func() {
    72  			const want = `SQLITE_ERROR: near "INVALID": syntax error`
    73  			if r := recover(); r == nil {
    74  				t.Fatal("no panic from invalid prepare")
    75  			} else if r != want {
    76  				t.Fatalf("invalid sql recover: %q, want %q", r, want)
    77  			}
    78  		}()
    79  		tx.Prepare("INVALID SQL")
    80  	}()
    81  	onCommitCalled = false
    82  	onRollbackCalled = false
    83  	tx.OnCommit = func() { onCommitCalled = true }
    84  	tx.OnRollback = func() { onRollbackCalled = true }
    85  	tx.Rollback()
    86  	if onCommitCalled {
    87  		t.Fatal("onCommit called")
    88  	}
    89  	if !onRollbackCalled {
    90  		t.Fatal("onRollback not called")
    91  	}
    92  	if err := tx.Commit(); err == nil {
    93  		t.Fatalf("want error on commit after rollback, got: %v", err)
    94  	}
    95  	tx.Rollback() // no-op
    96  
    97  	rx1, err := p.BeginRx(ctx, "read-1")
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  	defer rx1.Rollback()
   102  	rx2, err := p.BeginRx(ctx, "read-2")
   103  	if err != nil {
   104  		t.Fatal(err)
   105  	}
   106  	defer rx2.Rollback()
   107  
   108  	ctxCancel, cancel := context.WithCancel(ctx)
   109  	rx3Err := make(chan error, 1)
   110  	go func() {
   111  		rx3, err := p.BeginRx(ctxCancel, "read-3")
   112  		if err != nil {
   113  			rx3Err <- err
   114  			return
   115  		}
   116  		rx3.Rollback()
   117  		rx3Err <- errors.New("BeginRx(read-3) did not fail")
   118  	}()
   119  	cancel()
   120  	if err := <-rx3Err; err != context.Canceled {
   121  		t.Fatalf("read-3, not context canceled: %v", err)
   122  	}
   123  
   124  	stmt = rx1.Prepare("SELECT count(*) FROM t")
   125  	if row, err := stmt.Step(nil); err != nil {
   126  		t.Fatal(err)
   127  	} else if !row {
   128  		t.Fatal("no row from select count")
   129  	}
   130  	if got, want := int(stmt.ColumnInt64(0)), 1; got != want {
   131  		t.Fatalf("got=%d, want %d", got, want)
   132  	}
   133  	rx1.Rollback()
   134  	rx1.Rollback() // no-op
   135  
   136  	rx1, err = p.BeginRx(ctx, "read-1") // now another rx is available
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  	rx1.Rollback()
   141  	rx2.Rollback()
   142  
   143  	tx, err = p.BeginTx(ctx, "insert-3")
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  	if err := ExecScript(tx.DB(), "PRAGMA user_version=5"); err != nil {
   148  		t.Fatal(err)
   149  	}
   150  	func() {
   151  		defer func() {
   152  			if r := recover(); r != "Tx.Rx.Rollback called, only call Rollback on the Tx object" {
   153  				t.Fatalf("expected panic from Tx.Rx.Rollback, got: %q", r)
   154  			}
   155  		}()
   156  		tx.Rx.Rollback()
   157  	}()
   158  	if err := tx.Commit(); err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	if err := tx.Commit(); err == nil {
   162  		t.Fatalf("second commit did not fail, want 'already done'")
   163  	}
   164  
   165  	if err := p.Close(); err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	p.Close() // no-op
   169  
   170  	if _, err := p.BeginTx(ctx, "after-close"); err == nil {
   171  		t.Fatal("tx-after-close did not fail")
   172  	}
   173  	if _, err := p.BeginRx(ctx, "after-close"); err == nil {
   174  		t.Fatal("rx-after-close did not fail")
   175  	}
   176  }