github.com/tailscale/sqlite@v0.0.0-20240515181108-c667cbe57c66/sqlitepool/sqlitepool.go (about) 1 // Package sqlitepool implements a pool of SQLite database connections. 2 package sqlitepool 3 4 import ( 5 "context" 6 "errors" 7 "fmt" 8 "strings" 9 10 "github.com/tailscale/sqlite/cgosqlite" 11 "github.com/tailscale/sqlite/sqliteh" 12 ) 13 14 // A Pool is a fixed-size pool of SQLite database connections. 15 // One is reserved for writable transactions, the others are 16 // used for read-only transactions. 17 type Pool struct { 18 poolSize int 19 rwConnFree chan *conn // cap == 1 20 roConnsFree chan *conn // cap == poolSize-1 21 tracer sqliteh.Tracer 22 closed chan struct{} 23 } 24 25 type conn struct { 26 pool *Pool 27 db sqliteh.DB 28 stmts map[string]sqliteh.Stmt // persistent statements on db 29 id sqliteh.TraceConnID 30 } 31 32 // NewPool creates a Pool of poolSize database connections. 33 // 34 // For each connection, initFn is called to initialize the connection. 35 // Tracer is used to report statistics about the use of the Pool. 36 func NewPool(filename string, poolSize int, initFn func(sqliteh.DB) error, tracer sqliteh.Tracer) (_ *Pool, err error) { 37 p := &Pool{ 38 poolSize: poolSize, 39 rwConnFree: make(chan *conn, 1), 40 roConnsFree: make(chan *conn, poolSize-1), 41 tracer: tracer, 42 closed: make(chan struct{}), 43 } 44 defer func() { 45 if err != nil { 46 err = fmt.Errorf("sqlitepool.NewPool: %w", err) 47 select { 48 case conn := <-p.rwConnFree: 49 conn.db.Close() 50 default: 51 } 52 close(p.roConnsFree) 53 for conn := range p.roConnsFree { 54 conn.db.Close() 55 } 56 } 57 }() 58 if poolSize < 2 { 59 return nil, fmt.Errorf("poolSize=%d is too small", poolSize) 60 } 61 for i := 0; i < poolSize; i++ { 62 db, err := cgosqlite.Open(filename, sqliteh.OpenFlagsDefault, "") 63 if err != nil { 64 return nil, err 65 } 66 if err := initFn(db); err != nil { 67 return nil, err 68 } 69 c := &conn{ 70 pool: p, 71 db: db, 72 stmts: make(map[string]sqliteh.Stmt), 73 id: sqliteh.TraceConnID(i), 74 } 75 if i == 0 { 76 p.rwConnFree <- c 77 } else { 78 if err := ExecScript(c.db, "PRAGMA query_only=true"); err != nil { 79 return nil, err 80 } 81 p.roConnsFree <- c 82 } 83 } 84 85 return p, nil 86 } 87 88 func (c *conn) close() error { 89 if c.db == nil { 90 return errors.New("sqlitepool conn already closed") 91 } 92 for _, stmt := range c.stmts { 93 stmt.Finalize() 94 } 95 c.stmts = nil 96 err := c.db.Close() 97 c.db = nil 98 return err 99 } 100 101 func (p *Pool) Close() error { 102 select { 103 case <-p.closed: 104 return errors.New("pool already closed") 105 default: 106 } 107 close(p.closed) 108 109 c := <-p.rwConnFree 110 err := c.close() 111 112 for i := 0; i < p.poolSize-1; i++ { 113 c := <-p.roConnsFree 114 err2 := c.close() 115 if err == nil { 116 err = err2 117 } 118 } 119 return err 120 } 121 122 var errPoolClosed = fmt.Errorf("%w: sqlitepool closed", context.Canceled) 123 124 // BeginTx creates a writable transaction using BEGIN IMMEDIATE. 125 // The parameter why is passed to the Tracer for debugging. 126 func (p *Pool) BeginTx(ctx context.Context, why string) (*Tx, error) { 127 select { 128 case <-p.closed: 129 return nil, errPoolClosed 130 case <-ctx.Done(): 131 return nil, ctx.Err() 132 case conn := <-p.rwConnFree: 133 tx := &Tx{Rx: &Rx{conn: conn, inTx: true}} 134 err := tx.Exec("BEGIN IMMEDIATE;") 135 if p.tracer != nil { 136 p.tracer.BeginTx(ctx, conn.id, why, false, err) 137 } 138 if err != nil { 139 p.rwConnFree <- conn // can't block, buffer is big enough 140 return nil, err 141 } 142 return tx, nil 143 } 144 } 145 146 // BeginRx creates a read-only transaction. 147 // The parameter why is passed to the Tracer for debugging. 148 func (p *Pool) BeginRx(ctx context.Context, why string) (*Rx, error) { 149 select { 150 case <-p.closed: 151 return nil, errPoolClosed 152 case <-ctx.Done(): 153 return nil, ctx.Err() 154 case conn := <-p.roConnsFree: 155 rx := &Rx{conn: conn} 156 err := rx.Exec("BEGIN;") 157 if p.tracer != nil { 158 p.tracer.BeginTx(ctx, conn.id, why, true, err) 159 } 160 if err != nil { 161 p.roConnsFree <- conn // can't block, buffer is big enough 162 return nil, err 163 } 164 return &Rx{conn: conn}, nil 165 } 166 } 167 168 // Rx is a read-only transaction. 169 // 170 // It is *not* safe for concurrent use. 171 type Rx struct { 172 conn *conn 173 inTx bool // true if this Rx is embedded in a writable Tx 174 175 // OnRollback is an optional function called after rollback. 176 // If Rx is part of a Tx and it is committed, then OnRollback 177 // is not called. 178 OnRollback func() 179 } 180 181 // Exec executes an SQL statement with no result. 182 func (rx *Rx) Exec(sql string) error { 183 _, _, _, _, err := rx.Prepare(sql).StepResult() 184 if err != nil { 185 return fmt.Errorf("%w: %v", err, rx.conn.db.ErrMsg()) 186 } 187 return nil 188 } 189 190 // Prepare prepares an SQL statement. 191 // The Stmt is cached on the connection, so subsequent calls are fast. 192 func (rx *Rx) Prepare(sql string) sqliteh.Stmt { 193 stmt := rx.conn.stmts[sql] 194 if stmt != nil { 195 return stmt 196 } 197 stmt, _, err := rx.conn.db.Prepare(sql, sqliteh.SQLITE_PREPARE_PERSISTENT) 198 if err != nil { 199 // Persistent statements are constant strings hardcoded into 200 // programs. Failing to prepare one means the string is bad. 201 // Ideally we would detect this at compile time, but barring 202 // that, there is no point returning the error because this 203 // is not something the program can recover from or handle. 204 panic(fmt.Sprintf("%v: %v", err, rx.conn.db.ErrMsg())) 205 } 206 rx.conn.stmts[sql] = stmt 207 return stmt 208 } 209 210 // DB returns the underlying database connection. 211 // 212 // Be careful: a transaction is in progress. Any use of BEGIN/COMMIT/ROLLBACK 213 // should be modelled as a nested transaction, and when done the original 214 // outer transaction should be left in-progress. 215 func (rx *Rx) DB() sqliteh.DB { 216 return rx.conn.db 217 } 218 219 // ExecScript executes a series of SQL statements against a database connection. 220 // It is intended for one-off scripts, so the prepared Stmt objects are not 221 // cached for future calls. 222 func ExecScript(db sqliteh.DB, queries string) error { 223 for { 224 queries = strings.TrimSpace(queries) 225 if queries == "" { 226 return nil 227 } 228 stmt, rem, err := db.Prepare(queries, 0) 229 if err != nil { 230 return fmt.Errorf("ExecScript: %w: %v, in remaining script: %s", err, db.ErrMsg(), queries) 231 } 232 queries = rem 233 _, err = stmt.Step(nil) 234 if err != nil { 235 err = fmt.Errorf("ExecScript: %w: %s: %v", err, stmt.SQL(), db.ErrMsg()) 236 } 237 stmt.Finalize() 238 if err != nil { 239 return err 240 } 241 } 242 } 243 244 // Rollback executes ROLLBACK and cleans up the Rx. 245 // It is a no-op if Rx is already rolled back. 246 func (rx *Rx) Rollback() { 247 if rx.conn == nil { 248 return 249 } 250 if rx.inTx { 251 panic("Tx.Rx.Rollback called, only call Rollback on the Tx object") 252 } 253 err := rx.Exec("ROLLBACK;") 254 if rx.conn.pool.tracer != nil { 255 rx.conn.pool.tracer.Rollback(rx.conn.id, err) 256 } 257 rx.conn.pool.roConnsFree <- rx.conn 258 rx.conn = nil 259 if rx.OnRollback != nil { 260 rx.OnRollback() 261 rx.OnRollback = nil 262 } 263 if err != nil { 264 panic(err) 265 } 266 } 267 268 // Tx is a writable SQLite database transaction. 269 // 270 // It is *not* safe for concurrent use. 271 // 272 // A Tx contains an embedded Rx, which can be used to pass to functions 273 // that want to perform read-only queries on the writable Tx. 274 type Tx struct { 275 *Rx 276 277 // OnCommit is an optional function called after successful commit. 278 OnCommit func() 279 } 280 281 // Rollback executes ROLLBACK and cleans up the Tx. 282 // It is a no-op if the Tx is already rolled back or committed. 283 func (tx *Tx) Rollback() { 284 if tx.conn == nil { 285 return 286 } 287 err := tx.Exec("ROLLBACK;") 288 if tx.conn.pool.tracer != nil { 289 tx.conn.pool.tracer.Rollback(tx.conn.id, err) 290 } 291 tx.conn.pool.rwConnFree <- tx.conn 292 tx.conn = nil 293 if tx.OnRollback != nil { 294 tx.OnRollback() 295 tx.OnRollback = nil 296 tx.OnCommit = nil 297 } 298 if err != nil { 299 panic(err) 300 } 301 } 302 303 // Commit executes COMMIT and cleans up the Tx. 304 // It is an error to call if the Tx is already rolled back or committed. 305 func (tx *Tx) Commit() error { 306 if tx.conn == nil { 307 return errors.New("tx already done") 308 } 309 err := tx.Exec("COMMIT;") 310 if tx.conn.pool.tracer != nil { 311 tx.conn.pool.tracer.Commit(tx.conn.id, err) 312 } 313 tx.conn.pool.rwConnFree <- tx.conn 314 tx.conn = nil 315 if tx.OnCommit != nil { 316 tx.OnCommit() 317 tx.OnCommit = nil 318 tx.OnRollback = nil 319 } 320 return err 321 }