github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/sqlitepool.go (about)

     1  // Package sqlitepool implements a pool of SQLite database connections.
     2  package sqlitepool
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"strings"
     9  
    10  	"github.com/tailscale/sqlite/cgosqlite"
    11  	"github.com/tailscale/sqlite/sqliteh"
    12  )
    13  
    14  // A Pool is a fixed-size pool of SQLite database connections.
    15  // One is reserved for writable transactions, the others are
    16  // used for read-only transactions.
    17  type Pool struct {
    18  	poolSize    int
    19  	rwConnFree  chan *conn // cap == 1
    20  	roConnsFree chan *conn // cap == poolSize-1
    21  	tracer      sqliteh.Tracer
    22  	closed      chan struct{}
    23  }
    24  
    25  type conn struct {
    26  	pool  *Pool
    27  	db    sqliteh.DB
    28  	stmts map[string]sqliteh.Stmt // persistent statements on db
    29  	id    sqliteh.TraceConnID
    30  }
    31  
    32  // NewPool creates a Pool of poolSize database connections.
    33  //
    34  // For each connection, initFn is called to initialize the connection.
    35  // Tracer is used to report statistics about the use of the Pool.
    36  func NewPool(filename string, poolSize int, initFn func(sqliteh.DB) error, tracer sqliteh.Tracer) (_ *Pool, err error) {
    37  	p := &Pool{
    38  		poolSize:    poolSize,
    39  		rwConnFree:  make(chan *conn, 1),
    40  		roConnsFree: make(chan *conn, poolSize-1),
    41  		tracer:      tracer,
    42  		closed:      make(chan struct{}),
    43  	}
    44  	defer func() {
    45  		if err != nil {
    46  			err = fmt.Errorf("sqlitepool.NewPool: %w", err)
    47  			select {
    48  			case conn := <-p.rwConnFree:
    49  				conn.db.Close()
    50  			default:
    51  			}
    52  			close(p.roConnsFree)
    53  			for conn := range p.roConnsFree {
    54  				conn.db.Close()
    55  			}
    56  		}
    57  	}()
    58  	if poolSize < 2 {
    59  		return nil, fmt.Errorf("poolSize=%d is too small", poolSize)
    60  	}
    61  	for i := 0; i < poolSize; i++ {
    62  		db, err := cgosqlite.Open(filename, sqliteh.OpenFlagsDefault, "")
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		if err := initFn(db); err != nil {
    67  			return nil, err
    68  		}
    69  		c := &conn{
    70  			pool:  p,
    71  			db:    db,
    72  			stmts: make(map[string]sqliteh.Stmt),
    73  			id:    sqliteh.TraceConnID(i),
    74  		}
    75  		if i == 0 {
    76  			p.rwConnFree <- c
    77  		} else {
    78  			if err := ExecScript(c.db, "PRAGMA query_only=true"); err != nil {
    79  				return nil, err
    80  			}
    81  			p.roConnsFree <- c
    82  		}
    83  	}
    84  
    85  	return p, nil
    86  }
    87  
    88  func (c *conn) close() error {
    89  	if c.db == nil {
    90  		return errors.New("sqlitepool conn already closed")
    91  	}
    92  	for _, stmt := range c.stmts {
    93  		stmt.Finalize()
    94  	}
    95  	c.stmts = nil
    96  	err := c.db.Close()
    97  	c.db = nil
    98  	return err
    99  }
   100  
   101  func (p *Pool) Close() error {
   102  	select {
   103  	case <-p.closed:
   104  		return errors.New("pool already closed")
   105  	default:
   106  	}
   107  	close(p.closed)
   108  
   109  	c := <-p.rwConnFree
   110  	err := c.close()
   111  
   112  	for i := 0; i < p.poolSize-1; i++ {
   113  		c := <-p.roConnsFree
   114  		err2 := c.close()
   115  		if err == nil {
   116  			err = err2
   117  		}
   118  	}
   119  	return err
   120  }
   121  
   122  var errPoolClosed = fmt.Errorf("%w: sqlitepool closed", context.Canceled)
   123  
   124  // BeginTx creates a writable transaction using BEGIN IMMEDIATE.
   125  // The parameter why is passed to the Tracer for debugging.
   126  func (p *Pool) BeginTx(ctx context.Context, why string) (*Tx, error) {
   127  	select {
   128  	case <-p.closed:
   129  		return nil, errPoolClosed
   130  	case <-ctx.Done():
   131  		return nil, ctx.Err()
   132  	case conn := <-p.rwConnFree:
   133  		tx := &Tx{Rx: &Rx{conn: conn, inTx: true}}
   134  		err := tx.Exec("BEGIN IMMEDIATE;")
   135  		if p.tracer != nil {
   136  			p.tracer.BeginTx(ctx, conn.id, why, false, err)
   137  		}
   138  		if err != nil {
   139  			p.rwConnFree <- conn // can't block, buffer is big enough
   140  			return nil, err
   141  		}
   142  		return tx, nil
   143  	}
   144  }
   145  
   146  // BeginRx creates a read-only transaction.
   147  // The parameter why is passed to the Tracer for debugging.
   148  func (p *Pool) BeginRx(ctx context.Context, why string) (*Rx, error) {
   149  	select {
   150  	case <-p.closed:
   151  		return nil, errPoolClosed
   152  	case <-ctx.Done():
   153  		return nil, ctx.Err()
   154  	case conn := <-p.roConnsFree:
   155  		rx := &Rx{conn: conn}
   156  		err := rx.Exec("BEGIN;")
   157  		if p.tracer != nil {
   158  			p.tracer.BeginTx(ctx, conn.id, why, true, err)
   159  		}
   160  		if err != nil {
   161  			p.roConnsFree <- conn // can't block, buffer is big enough
   162  			return nil, err
   163  		}
   164  		return &Rx{conn: conn}, nil
   165  	}
   166  }
   167  
   168  // Rx is a read-only transaction.
   169  //
   170  // It is *not* safe for concurrent use.
   171  type Rx struct {
   172  	conn *conn
   173  	inTx bool // true if this Rx is embedded in a writable Tx
   174  
   175  	// OnRollback is an optional function called after rollback.
   176  	// If Rx is part of a Tx and it is committed, then OnRollback
   177  	// is not called.
   178  	OnRollback func()
   179  }
   180  
   181  // Exec executes an SQL statement with no result.
   182  func (rx *Rx) Exec(sql string) error {
   183  	_, _, _, _, err := rx.Prepare(sql).StepResult()
   184  	if err != nil {
   185  		return fmt.Errorf("%w: %v", err, rx.conn.db.ErrMsg())
   186  	}
   187  	return nil
   188  }
   189  
   190  // Prepare prepares an SQL statement.
   191  // The Stmt is cached on the connection, so subsequent calls are fast.
   192  func (rx *Rx) Prepare(sql string) sqliteh.Stmt {
   193  	stmt := rx.conn.stmts[sql]
   194  	if stmt != nil {
   195  		return stmt
   196  	}
   197  	stmt, _, err := rx.conn.db.Prepare(sql, sqliteh.SQLITE_PREPARE_PERSISTENT)
   198  	if err != nil {
   199  		// Persistent statements are constant strings hardcoded into
   200  		// programs. Failing to prepare one means the string is bad.
   201  		// Ideally we would detect this at compile time, but barring
   202  		// that, there is no point returning the error because this
   203  		// is not something the program can recover from or handle.
   204  		panic(fmt.Sprintf("%v: %v", err, rx.conn.db.ErrMsg()))
   205  	}
   206  	rx.conn.stmts[sql] = stmt
   207  	return stmt
   208  }
   209  
   210  // DB returns the underlying database connection.
   211  //
   212  // Be careful: a transaction is in progress. Any use of BEGIN/COMMIT/ROLLBACK
   213  // should be modelled as a nested transaction, and when done the original
   214  // outer transaction should be left in-progress.
   215  func (rx *Rx) DB() sqliteh.DB {
   216  	return rx.conn.db
   217  }
   218  
   219  // ExecScript executes a series of SQL statements against a database connection.
   220  // It is intended for one-off scripts, so the prepared Stmt objects are not
   221  // cached for future calls.
   222  func ExecScript(db sqliteh.DB, queries string) error {
   223  	for {
   224  		queries = strings.TrimSpace(queries)
   225  		if queries == "" {
   226  			return nil
   227  		}
   228  		stmt, rem, err := db.Prepare(queries, 0)
   229  		if err != nil {
   230  			return fmt.Errorf("ExecScript: %w: %v, in remaining script: %s", err, db.ErrMsg(), queries)
   231  		}
   232  		queries = rem
   233  		_, err = stmt.Step(nil)
   234  		if err != nil {
   235  			err = fmt.Errorf("ExecScript: %w: %s: %v", err, stmt.SQL(), db.ErrMsg())
   236  		}
   237  		stmt.Finalize()
   238  		if err != nil {
   239  			return err
   240  		}
   241  	}
   242  }
   243  
   244  // Rollback executes ROLLBACK and cleans up the Rx.
   245  // It is a no-op if Rx is already rolled back.
   246  func (rx *Rx) Rollback() {
   247  	if rx.conn == nil {
   248  		return
   249  	}
   250  	if rx.inTx {
   251  		panic("Tx.Rx.Rollback called, only call Rollback on the Tx object")
   252  	}
   253  	err := rx.Exec("ROLLBACK;")
   254  	if rx.conn.pool.tracer != nil {
   255  		rx.conn.pool.tracer.Rollback(rx.conn.id, err)
   256  	}
   257  	rx.conn.pool.roConnsFree <- rx.conn
   258  	rx.conn = nil
   259  	if rx.OnRollback != nil {
   260  		rx.OnRollback()
   261  		rx.OnRollback = nil
   262  	}
   263  	if err != nil {
   264  		panic(err)
   265  	}
   266  }
   267  
   268  // Tx is a writable SQLite database transaction.
   269  //
   270  // It is *not* safe for concurrent use.
   271  //
   272  // A Tx contains an embedded Rx, which can be used to pass to functions
   273  // that want to perform read-only queries on the writable Tx.
   274  type Tx struct {
   275  	*Rx
   276  
   277  	// OnCommit is an optional function called after successful commit.
   278  	OnCommit func()
   279  }
   280  
   281  // Rollback executes ROLLBACK and cleans up the Tx.
   282  // It is a no-op if the Tx is already rolled back or committed.
   283  func (tx *Tx) Rollback() {
   284  	if tx.conn == nil {
   285  		return
   286  	}
   287  	err := tx.Exec("ROLLBACK;")
   288  	if tx.conn.pool.tracer != nil {
   289  		tx.conn.pool.tracer.Rollback(tx.conn.id, err)
   290  	}
   291  	tx.conn.pool.rwConnFree <- tx.conn
   292  	tx.conn = nil
   293  	if tx.OnRollback != nil {
   294  		tx.OnRollback()
   295  		tx.OnRollback = nil
   296  		tx.OnCommit = nil
   297  	}
   298  	if err != nil {
   299  		panic(err)
   300  	}
   301  }
   302  
   303  // Commit executes COMMIT and cleans up the Tx.
   304  // It is an error to call if the Tx is already rolled back or committed.
   305  func (tx *Tx) Commit() error {
   306  	if tx.conn == nil {
   307  		return errors.New("tx already done")
   308  	}
   309  	err := tx.Exec("COMMIT;")
   310  	if tx.conn.pool.tracer != nil {
   311  		tx.conn.pool.tracer.Commit(tx.conn.id, err)
   312  	}
   313  	tx.conn.pool.rwConnFree <- tx.conn
   314  	tx.conn = nil
   315  	if tx.OnCommit != nil {
   316  		tx.OnCommit()
   317  		tx.OnCommit = nil
   318  		tx.OnRollback = nil
   319  	}
   320  	return err
   321  }