github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/tests/txn_test.go (about)

     1  package tests
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  
     8  	"github.com/ncruces/go-sqlite3"
     9  	_ "github.com/ncruces/go-sqlite3/embed"
    10  	_ "github.com/ncruces/go-sqlite3/tests/testcfg"
    11  	_ "github.com/ncruces/go-sqlite3/vfs/memdb"
    12  )
    13  
    14  func TestConn_Transaction_exec(t *testing.T) {
    15  	t.Parallel()
    16  
    17  	db, err := sqlite3.Open(":memory:")
    18  	if err != nil {
    19  		t.Fatal(err)
    20  	}
    21  	defer db.Close()
    22  
    23  	db.RollbackHook(func() {})
    24  	db.CommitHook(func() bool { return true })
    25  	db.UpdateHook(func(sqlite3.AuthorizerActionCode, string, string, int64) {})
    26  
    27  	err = db.Exec(`CREATE TABLE test (col)`)
    28  	if err != nil {
    29  		t.Fatal(err)
    30  	}
    31  
    32  	errFailed := errors.New("failed")
    33  
    34  	count := func() int {
    35  		stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
    36  		if err != nil {
    37  			t.Fatal(err)
    38  		}
    39  		defer stmt.Close()
    40  		if stmt.Step() {
    41  			return stmt.ColumnInt(0)
    42  		}
    43  		t.Fatal(stmt.Err())
    44  		return 0
    45  	}
    46  
    47  	insert := func(succeed bool) (err error) {
    48  		tx := db.Begin()
    49  		defer tx.End(&err)
    50  
    51  		err = db.Exec(`INSERT INTO test VALUES ('hello')`)
    52  		if err != nil {
    53  			t.Fatal(err)
    54  		}
    55  
    56  		if s := db.TxnState("main"); s != sqlite3.TXN_WRITE {
    57  			t.Errorf("got %d", s)
    58  		}
    59  
    60  		if succeed {
    61  			return nil
    62  		}
    63  		return errFailed
    64  	}
    65  
    66  	err = insert(true)
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	if got := count(); got != 1 {
    71  		t.Errorf("got %d, want 1", got)
    72  	}
    73  
    74  	err = insert(true)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	if got := count(); got != 2 {
    79  		t.Errorf("got %d, want 2", got)
    80  	}
    81  
    82  	err = insert(false)
    83  	if err != errFailed {
    84  		t.Errorf("got %v, want errFailed", err)
    85  	}
    86  	if got := count(); got != 2 {
    87  		t.Errorf("got %d, want 2", got)
    88  	}
    89  }
    90  
    91  func TestConn_Transaction_panic(t *testing.T) {
    92  	t.Parallel()
    93  
    94  	db, err := sqlite3.Open(":memory:")
    95  	if err != nil {
    96  		t.Fatal(err)
    97  	}
    98  	defer db.Close()
    99  
   100  	err = db.Exec(`CREATE TABLE test (col)`)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  
   105  	err = db.Exec(`INSERT INTO test VALUES ('one');`)
   106  	if err != nil {
   107  		t.Fatal(err)
   108  	}
   109  
   110  	panics := func() (err error) {
   111  		tx := db.Begin()
   112  		defer tx.End(&err)
   113  
   114  		err = db.Exec(`INSERT INTO test VALUES ('hello')`)
   115  		if err != nil {
   116  			return err
   117  		}
   118  
   119  		panic("omg!")
   120  	}
   121  
   122  	defer func() {
   123  		p := recover()
   124  		if p != "omg!" {
   125  			t.Errorf("got %v, want panic", p)
   126  		}
   127  
   128  		stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   129  		if err != nil {
   130  			t.Fatal(err)
   131  		}
   132  		defer stmt.Close()
   133  		if stmt.Step() {
   134  			got := stmt.ColumnInt(0)
   135  			if got != 1 {
   136  				t.Errorf("got %d, want 1", got)
   137  			}
   138  			return
   139  		}
   140  		t.Fatal(stmt.Err())
   141  	}()
   142  
   143  	err = panics()
   144  	if err != nil {
   145  		t.Error(err)
   146  	}
   147  }
   148  
   149  func TestConn_Transaction_interrupt(t *testing.T) {
   150  	t.Parallel()
   151  
   152  	db, err := sqlite3.Open(":memory:")
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	defer db.Close()
   157  
   158  	err = db.Exec(`CREATE TABLE test (col)`)
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  
   163  	tx, err := db.BeginImmediate()
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  	err = db.Exec(`INSERT INTO test VALUES (1)`)
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  	tx.End(&err)
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  
   176  	ctx, cancel := context.WithCancel(context.Background())
   177  	db.SetInterrupt(ctx)
   178  
   179  	tx, err = db.BeginExclusive()
   180  	if err != nil {
   181  		t.Fatal(err)
   182  	}
   183  	err = db.Exec(`INSERT INTO test VALUES (2)`)
   184  	if err != nil {
   185  		t.Fatal(err)
   186  	}
   187  
   188  	cancel()
   189  	_, err = db.BeginImmediate()
   190  	if !errors.Is(err, sqlite3.INTERRUPT) {
   191  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   192  	}
   193  
   194  	err = db.Exec(`INSERT INTO test VALUES (3)`)
   195  	if !errors.Is(err, sqlite3.INTERRUPT) {
   196  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   197  	}
   198  
   199  	err = nil
   200  	tx.End(&err)
   201  	if !errors.Is(err, sqlite3.INTERRUPT) {
   202  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   203  	}
   204  
   205  	db.SetInterrupt(context.Background())
   206  	stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   207  	if err != nil {
   208  		t.Fatal(err)
   209  	}
   210  	defer stmt.Close()
   211  
   212  	if stmt.Step() {
   213  		got := stmt.ColumnInt(0)
   214  		if got != 1 {
   215  			t.Errorf("got %d, want 1", got)
   216  		}
   217  	}
   218  	err = stmt.Err()
   219  	if err != nil {
   220  		t.Error(err)
   221  	}
   222  }
   223  
   224  func TestConn_Transaction_interrupted(t *testing.T) {
   225  	t.Parallel()
   226  
   227  	db, err := sqlite3.Open(":memory:")
   228  	if err != nil {
   229  		t.Fatal(err)
   230  	}
   231  	defer db.Close()
   232  
   233  	ctx, cancel := context.WithCancel(context.Background())
   234  	db.SetInterrupt(ctx)
   235  	cancel()
   236  
   237  	tx := db.Begin()
   238  
   239  	err = tx.Commit()
   240  	if !errors.Is(err, sqlite3.INTERRUPT) {
   241  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   242  	}
   243  
   244  	err = nil
   245  	tx.End(&err)
   246  	if !errors.Is(err, sqlite3.INTERRUPT) {
   247  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   248  	}
   249  }
   250  
   251  func TestConn_Transaction_busy(t *testing.T) {
   252  	t.Parallel()
   253  
   254  	db1, err := sqlite3.Open("file:/test.db?vfs=memdb")
   255  	if err != nil {
   256  		t.Fatal(err)
   257  	}
   258  	defer db1.Close()
   259  
   260  	db2, err := sqlite3.Open("file:/test.db?vfs=memdb&_pragma=busy_timeout(10000)")
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	defer db2.Close()
   265  
   266  	err = db1.Exec(`CREATE TABLE test (col)`)
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  
   271  	tx, err := db1.BeginImmediate()
   272  	if err != nil {
   273  		t.Fatal(err)
   274  	}
   275  	err = db1.Exec(`INSERT INTO test VALUES (1)`)
   276  	if err != nil {
   277  		t.Fatal(err)
   278  	}
   279  
   280  	ctx, cancel := context.WithCancel(context.Background())
   281  	db2.SetInterrupt(ctx)
   282  	go cancel()
   283  
   284  	_, err = db2.BeginExclusive()
   285  	if !errors.Is(err, sqlite3.BUSY) && !errors.Is(err, sqlite3.INTERRUPT) {
   286  		t.Errorf("got %v, want sqlite3.BUSY or sqlite3.INTERRUPT", err)
   287  	}
   288  
   289  	err = nil
   290  	tx.End(&err)
   291  	if err != nil {
   292  		t.Fatal(err)
   293  	}
   294  }
   295  
   296  func TestConn_Transaction_rollback(t *testing.T) {
   297  	t.Parallel()
   298  
   299  	db, err := sqlite3.Open(":memory:")
   300  	if err != nil {
   301  		t.Fatal(err)
   302  	}
   303  	defer db.Close()
   304  
   305  	err = db.Exec(`CREATE TABLE test (col)`)
   306  	if err != nil {
   307  		t.Fatal(err)
   308  	}
   309  
   310  	tx := db.Begin()
   311  	err = db.Exec(`INSERT INTO test VALUES (1)`)
   312  	if err != nil {
   313  		t.Fatal(err)
   314  	}
   315  	err = db.Exec(`COMMIT`)
   316  	if err != nil {
   317  		t.Fatal(err)
   318  	}
   319  	tx.End(&err)
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  
   324  	stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   325  	if err != nil {
   326  		t.Fatal(err)
   327  	}
   328  	defer stmt.Close()
   329  
   330  	if stmt.Step() {
   331  		got := stmt.ColumnInt(0)
   332  		if got != 1 {
   333  			t.Errorf("got %d, want 1", got)
   334  		}
   335  	}
   336  	err = stmt.Err()
   337  	if err != nil {
   338  		t.Error(err)
   339  	}
   340  }
   341  
   342  func TestConn_Savepoint_exec(t *testing.T) {
   343  	t.Parallel()
   344  
   345  	db, err := sqlite3.Open(":memory:")
   346  	if err != nil {
   347  		t.Fatal(err)
   348  	}
   349  	defer db.Close()
   350  
   351  	err = db.Exec(`CREATE TABLE test (col)`)
   352  	if err != nil {
   353  		t.Fatal(err)
   354  	}
   355  
   356  	errFailed := errors.New("failed")
   357  
   358  	count := func() int {
   359  		stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   360  		if err != nil {
   361  			t.Fatal(err)
   362  		}
   363  		defer stmt.Close()
   364  		if stmt.Step() {
   365  			return stmt.ColumnInt(0)
   366  		}
   367  		t.Fatal(stmt.Err())
   368  		return 0
   369  	}
   370  
   371  	insert := func(succeed bool) (err error) {
   372  		defer db.Savepoint().Release(&err)
   373  
   374  		err = db.Exec(`INSERT INTO test VALUES ('hello')`)
   375  		if err != nil {
   376  			t.Fatal(err)
   377  		}
   378  
   379  		if succeed {
   380  			return nil
   381  		}
   382  		return errFailed
   383  	}
   384  
   385  	err = insert(true)
   386  	if err != nil {
   387  		t.Fatal(err)
   388  	}
   389  	if got := count(); got != 1 {
   390  		t.Errorf("got %d, want 1", got)
   391  	}
   392  
   393  	err = insert(true)
   394  	if err != nil {
   395  		t.Fatal(err)
   396  	}
   397  	if got := count(); got != 2 {
   398  		t.Errorf("got %d, want 2", got)
   399  	}
   400  
   401  	err = insert(false)
   402  	if err != errFailed {
   403  		t.Errorf("got %v, want errFailed", err)
   404  	}
   405  	if got := count(); got != 2 {
   406  		t.Errorf("got %d, want 2", got)
   407  	}
   408  }
   409  
   410  func TestConn_Savepoint_panic(t *testing.T) {
   411  	t.Parallel()
   412  
   413  	db, err := sqlite3.Open(":memory:")
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  	defer db.Close()
   418  
   419  	err = db.Exec(`CREATE TABLE test (col)`)
   420  	if err != nil {
   421  		t.Fatal(err)
   422  	}
   423  
   424  	err = db.Exec(`INSERT INTO test VALUES ('one');`)
   425  	if err != nil {
   426  		t.Fatal(err)
   427  	}
   428  
   429  	panics := func() (err error) {
   430  		defer db.Savepoint().Release(&err)
   431  
   432  		err = db.Exec(`INSERT INTO test VALUES ('hello')`)
   433  		if err != nil {
   434  			return err
   435  		}
   436  
   437  		panic("omg!")
   438  	}
   439  
   440  	defer func() {
   441  		p := recover()
   442  		if p != "omg!" {
   443  			t.Errorf("got %v, want panic", p)
   444  		}
   445  
   446  		stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   447  		if err != nil {
   448  			t.Fatal(err)
   449  		}
   450  		defer stmt.Close()
   451  		if stmt.Step() {
   452  			got := stmt.ColumnInt(0)
   453  			if got != 1 {
   454  				t.Errorf("got %d, want 1", got)
   455  			}
   456  			return
   457  		}
   458  		t.Fatal(stmt.Err())
   459  	}()
   460  
   461  	err = panics()
   462  	if err != nil {
   463  		t.Error(err)
   464  	}
   465  }
   466  
   467  func TestConn_Savepoint_interrupt(t *testing.T) {
   468  	t.Parallel()
   469  
   470  	db, err := sqlite3.Open(":memory:")
   471  	if err != nil {
   472  		t.Fatal(err)
   473  	}
   474  	defer db.Close()
   475  
   476  	err = db.Exec(`CREATE TABLE test (col)`)
   477  	if err != nil {
   478  		t.Fatal(err)
   479  	}
   480  
   481  	savept := db.Savepoint()
   482  	err = db.Exec(`INSERT INTO test VALUES (1)`)
   483  	if err != nil {
   484  		t.Fatal(err)
   485  	}
   486  	savept.Release(&err)
   487  	if err != nil {
   488  		t.Fatal(err)
   489  	}
   490  
   491  	ctx, cancel := context.WithCancel(context.Background())
   492  	db.SetInterrupt(ctx)
   493  
   494  	savept1 := db.Savepoint()
   495  	err = db.Exec(`INSERT INTO test VALUES (2)`)
   496  	if err != nil {
   497  		t.Fatal(err)
   498  	}
   499  	savept2 := db.Savepoint()
   500  	err = db.Exec(`INSERT INTO test VALUES (3)`)
   501  	if err != nil {
   502  		t.Fatal(err)
   503  	}
   504  
   505  	cancel()
   506  	db.Savepoint().Release(&err)
   507  	if !errors.Is(err, sqlite3.INTERRUPT) {
   508  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   509  	}
   510  
   511  	err = db.Exec(`INSERT INTO test VALUES (4)`)
   512  	if !errors.Is(err, sqlite3.INTERRUPT) {
   513  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   514  	}
   515  
   516  	err = context.Canceled
   517  	savept2.Release(&err)
   518  	if err != context.Canceled {
   519  		t.Fatal(err)
   520  	}
   521  
   522  	err = nil
   523  	savept1.Release(&err)
   524  	if !errors.Is(err, sqlite3.INTERRUPT) {
   525  		t.Errorf("got %v, want sqlite3.INTERRUPT", err)
   526  	}
   527  
   528  	db.SetInterrupt(context.Background())
   529  	stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   530  	if err != nil {
   531  		t.Fatal(err)
   532  	}
   533  	defer stmt.Close()
   534  
   535  	if stmt.Step() {
   536  		got := stmt.ColumnInt(0)
   537  		if got != 1 {
   538  			t.Errorf("got %d, want 1", got)
   539  		}
   540  	}
   541  	err = stmt.Err()
   542  	if err != nil {
   543  		t.Error(err)
   544  	}
   545  }
   546  
   547  func TestConn_Savepoint_rollback(t *testing.T) {
   548  	t.Parallel()
   549  
   550  	db, err := sqlite3.Open(":memory:")
   551  	if err != nil {
   552  		t.Fatal(err)
   553  	}
   554  	defer db.Close()
   555  
   556  	err = db.Exec(`CREATE TABLE test (col)`)
   557  	if err != nil {
   558  		t.Fatal(err)
   559  	}
   560  
   561  	savept := db.Savepoint()
   562  	err = db.Exec(`INSERT INTO test VALUES (1)`)
   563  	if err != nil {
   564  		t.Fatal(err)
   565  	}
   566  	err = db.Exec(`COMMIT`)
   567  	if err != nil {
   568  		t.Fatal(err)
   569  	}
   570  	savept.Release(&err)
   571  	if err != nil {
   572  		t.Fatal(err)
   573  	}
   574  
   575  	stmt, _, err := db.Prepare(`SELECT count(*) FROM test`)
   576  	if err != nil {
   577  		t.Fatal(err)
   578  	}
   579  	defer stmt.Close()
   580  
   581  	if stmt.Step() {
   582  		got := stmt.ColumnInt(0)
   583  		if got != 1 {
   584  			t.Errorf("got %d, want 1", got)
   585  		}
   586  	}
   587  	err = stmt.Err()
   588  	if err != nil {
   589  		t.Error(err)
   590  	}
   591  }