github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/txn.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"runtime"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/ncruces/go-sqlite3/internal/util"
    13  	"github.com/tetratelabs/wazero/api"
    14  )
    15  
    16  // Txn is an in-progress database transaction.
    17  //
    18  // https://sqlite.org/lang_transaction.html
    19  type Txn struct {
    20  	c *Conn
    21  }
    22  
    23  // Begin starts a deferred transaction.
    24  //
    25  // https://sqlite.org/lang_transaction.html
    26  func (c *Conn) Begin() Txn {
    27  	// BEGIN even if interrupted.
    28  	err := c.txnExecInterrupted(`BEGIN DEFERRED`)
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  	return Txn{c}
    33  }
    34  
    35  // BeginImmediate starts an immediate transaction.
    36  //
    37  // https://sqlite.org/lang_transaction.html
    38  func (c *Conn) BeginImmediate() (Txn, error) {
    39  	err := c.Exec(`BEGIN IMMEDIATE`)
    40  	if err != nil {
    41  		return Txn{}, err
    42  	}
    43  	return Txn{c}, nil
    44  }
    45  
    46  // BeginExclusive starts an exclusive transaction.
    47  //
    48  // https://sqlite.org/lang_transaction.html
    49  func (c *Conn) BeginExclusive() (Txn, error) {
    50  	err := c.Exec(`BEGIN EXCLUSIVE`)
    51  	if err != nil {
    52  		return Txn{}, err
    53  	}
    54  	return Txn{c}, nil
    55  }
    56  
    57  // End calls either [Txn.Commit] or [Txn.Rollback]
    58  // depending on whether *error points to a nil or non-nil error.
    59  //
    60  // This is meant to be deferred:
    61  //
    62  //	func doWork(db *sqlite3.Conn) (err error) {
    63  //		tx := db.Begin()
    64  //		defer tx.End(&err)
    65  //
    66  //		// ... do work in the transaction
    67  //	}
    68  //
    69  // https://sqlite.org/lang_transaction.html
    70  func (tx Txn) End(errp *error) {
    71  	recovered := recover()
    72  	if recovered != nil {
    73  		defer panic(recovered)
    74  	}
    75  
    76  	if *errp == nil && recovered == nil {
    77  		// Success path.
    78  		if tx.c.GetAutocommit() { // There is nothing to commit.
    79  			return
    80  		}
    81  		*errp = tx.Commit()
    82  		if *errp == nil {
    83  			return
    84  		}
    85  		// Fall through to the error path.
    86  	}
    87  
    88  	// Error path.
    89  	if tx.c.GetAutocommit() { // There is nothing to rollback.
    90  		return
    91  	}
    92  	err := tx.Rollback()
    93  	if err != nil {
    94  		panic(err)
    95  	}
    96  }
    97  
    98  // Commit commits the transaction.
    99  //
   100  // https://sqlite.org/lang_transaction.html
   101  func (tx Txn) Commit() error {
   102  	return tx.c.Exec(`COMMIT`)
   103  }
   104  
   105  // Rollback rolls back the transaction,
   106  // even if the connection has been interrupted.
   107  //
   108  // https://sqlite.org/lang_transaction.html
   109  func (tx Txn) Rollback() error {
   110  	return tx.c.txnExecInterrupted(`ROLLBACK`)
   111  }
   112  
   113  // Savepoint is a marker within a transaction
   114  // that allows for partial rollback.
   115  //
   116  // https://sqlite.org/lang_savepoint.html
   117  type Savepoint struct {
   118  	c    *Conn
   119  	name string
   120  }
   121  
   122  // Savepoint establishes a new transaction savepoint.
   123  //
   124  // https://sqlite.org/lang_savepoint.html
   125  func (c *Conn) Savepoint() Savepoint {
   126  	// Names can be reused; this makes catching bugs more likely.
   127  	name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
   128  
   129  	err := c.txnExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
   130  	if err != nil {
   131  		panic(err)
   132  	}
   133  	return Savepoint{c: c, name: name}
   134  }
   135  
   136  func saveptName() (name string) {
   137  	defer func() {
   138  		if name == "" {
   139  			name = "sqlite3.Savepoint"
   140  		}
   141  	}()
   142  
   143  	var pc [8]uintptr
   144  	n := runtime.Callers(3, pc[:])
   145  	if n <= 0 {
   146  		return ""
   147  	}
   148  	frames := runtime.CallersFrames(pc[:n])
   149  	frame, more := frames.Next()
   150  	for more && (strings.HasPrefix(frame.Function, "database/sql.") ||
   151  		strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) {
   152  		frame, more = frames.Next()
   153  	}
   154  	return frame.Function
   155  }
   156  
   157  // Release releases the savepoint rolling back any changes
   158  // if *error points to a non-nil error.
   159  //
   160  // This is meant to be deferred:
   161  //
   162  //	func doWork(db *sqlite3.Conn) (err error) {
   163  //		savept := db.Savepoint()
   164  //		defer savept.Release(&err)
   165  //
   166  //		// ... do work in the transaction
   167  //	}
   168  func (s Savepoint) Release(errp *error) {
   169  	recovered := recover()
   170  	if recovered != nil {
   171  		defer panic(recovered)
   172  	}
   173  
   174  	if *errp == nil && recovered == nil {
   175  		// Success path.
   176  		if s.c.GetAutocommit() { // There is nothing to commit.
   177  			return
   178  		}
   179  		*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
   180  		if *errp == nil {
   181  			return
   182  		}
   183  		// Fall through to the error path.
   184  	}
   185  
   186  	// Error path.
   187  	if s.c.GetAutocommit() { // There is nothing to rollback.
   188  		return
   189  	}
   190  	// ROLLBACK and RELEASE even if interrupted.
   191  	err := s.c.txnExecInterrupted(fmt.Sprintf(`
   192  		ROLLBACK TO %[1]q;
   193  		RELEASE %[1]q;
   194  	`, s.name))
   195  	if err != nil {
   196  		panic(err)
   197  	}
   198  }
   199  
   200  // Rollback rolls the transaction back to the savepoint,
   201  // even if the connection has been interrupted.
   202  // Rollback does not release the savepoint.
   203  //
   204  // https://sqlite.org/lang_transaction.html
   205  func (s Savepoint) Rollback() error {
   206  	// ROLLBACK even if interrupted.
   207  	return s.c.txnExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
   208  }
   209  
   210  func (c *Conn) txnExecInterrupted(sql string) error {
   211  	err := c.Exec(sql)
   212  	if errors.Is(err, INTERRUPT) {
   213  		old := c.SetInterrupt(context.Background())
   214  		defer c.SetInterrupt(old)
   215  		err = c.Exec(sql)
   216  	}
   217  	return err
   218  }
   219  
   220  // TxnState starts a deferred transaction.
   221  //
   222  // https://sqlite.org/c3ref/txn_state.html
   223  func (c *Conn) TxnState(schema string) TxnState {
   224  	var ptr uint32
   225  	if schema != "" {
   226  		defer c.arena.mark()()
   227  		ptr = c.arena.string(schema)
   228  	}
   229  	r := c.call("sqlite3_txn_state", uint64(c.handle), uint64(ptr))
   230  	return TxnState(r)
   231  }
   232  
   233  // CommitHook registers a callback function to be invoked
   234  // whenever a transaction is committed.
   235  // Return true to allow the commit operation to continue normally.
   236  //
   237  // https://sqlite.org/c3ref/commit_hook.html
   238  func (c *Conn) CommitHook(cb func() (ok bool)) {
   239  	var enable uint64
   240  	if cb != nil {
   241  		enable = 1
   242  	}
   243  	c.call("sqlite3_commit_hook_go", uint64(c.handle), enable)
   244  	c.commit = cb
   245  }
   246  
   247  // RollbackHook registers a callback function to be invoked
   248  // whenever a transaction is rolled back.
   249  //
   250  // https://sqlite.org/c3ref/commit_hook.html
   251  func (c *Conn) RollbackHook(cb func()) {
   252  	var enable uint64
   253  	if cb != nil {
   254  		enable = 1
   255  	}
   256  	c.call("sqlite3_rollback_hook_go", uint64(c.handle), enable)
   257  	c.rollback = cb
   258  }
   259  
   260  // UpdateHook registers a callback function to be invoked
   261  // whenever a row is updated, inserted or deleted in a rowid table.
   262  //
   263  // https://sqlite.org/c3ref/update_hook.html
   264  func (c *Conn) UpdateHook(cb func(action AuthorizerActionCode, schema, table string, rowid int64)) {
   265  	var enable uint64
   266  	if cb != nil {
   267  		enable = 1
   268  	}
   269  	c.call("sqlite3_update_hook_go", uint64(c.handle), enable)
   270  	c.update = cb
   271  }
   272  
   273  func commitCallback(ctx context.Context, mod api.Module, pDB uint32) (rollback uint32) {
   274  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.commit != nil {
   275  		if !c.commit() {
   276  			rollback = 1
   277  		}
   278  	}
   279  	return rollback
   280  }
   281  
   282  func rollbackCallback(ctx context.Context, mod api.Module, pDB uint32) {
   283  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.rollback != nil {
   284  		c.rollback()
   285  	}
   286  }
   287  
   288  func updateCallback(ctx context.Context, mod api.Module, pDB uint32, action AuthorizerActionCode, zSchema, zTabName uint32, rowid uint64) {
   289  	if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.update != nil {
   290  		schema := util.ReadString(mod, zSchema, _MAX_NAME)
   291  		table := util.ReadString(mod, zTabName, _MAX_NAME)
   292  		c.update(action, schema, table, int64(rowid))
   293  	}
   294  }