github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/sqlcipher/sqlcipher.go (about) 1 package sqlcipher 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 nurl "net/url" 8 "strconv" 9 "strings" 10 11 "go.uber.org/atomic" 12 13 "github.com/golang-migrate/migrate/v4" 14 "github.com/golang-migrate/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 _ "github.com/mutecomm/go-sqlcipher/v4" 17 ) 18 19 func init() { 20 database.Register("sqlcipher", &Sqlite{}) 21 } 22 23 var DefaultMigrationsTable = "schema_migrations" 24 var ( 25 ErrDatabaseDirty = fmt.Errorf("database is dirty") 26 ErrNilConfig = fmt.Errorf("no config") 27 ErrNoDatabaseName = fmt.Errorf("no database name") 28 ) 29 30 type Config struct { 31 MigrationsTable string 32 DatabaseName string 33 NoTxWrap bool 34 } 35 36 type Sqlite struct { 37 db *sql.DB 38 isLocked atomic.Bool 39 40 config *Config 41 } 42 43 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 44 if config == nil { 45 return nil, ErrNilConfig 46 } 47 48 if err := instance.Ping(); err != nil { 49 return nil, err 50 } 51 52 if len(config.MigrationsTable) == 0 { 53 config.MigrationsTable = DefaultMigrationsTable 54 } 55 56 mx := &Sqlite{ 57 db: instance, 58 config: config, 59 } 60 if err := mx.ensureVersionTable(); err != nil { 61 return nil, err 62 } 63 return mx, nil 64 } 65 66 // ensureVersionTable checks if versions table exists and, if not, creates it. 67 // Note that this function locks the database, which deviates from the usual 68 // convention of "caller locks" in the Sqlite type. 69 func (m *Sqlite) ensureVersionTable() (err error) { 70 if err = m.Lock(); err != nil { 71 return err 72 } 73 74 defer func() { 75 if e := m.Unlock(); e != nil { 76 if err == nil { 77 err = e 78 } else { 79 err = multierror.Append(err, e) 80 } 81 } 82 }() 83 84 query := fmt.Sprintf(` 85 CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); 86 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 87 `, m.config.MigrationsTable, m.config.MigrationsTable) 88 89 if _, err := m.db.Exec(query); err != nil { 90 return err 91 } 92 return nil 93 } 94 95 func (m *Sqlite) Open(url string) (database.Driver, error) { 96 purl, err := nurl.Parse(url) 97 if err != nil { 98 return nil, err 99 } 100 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) 101 db, err := sql.Open("sqlite3", dbfile) 102 if err != nil { 103 return nil, err 104 } 105 106 qv := purl.Query() 107 108 migrationsTable := qv.Get("x-migrations-table") 109 if len(migrationsTable) == 0 { 110 migrationsTable = DefaultMigrationsTable 111 } 112 113 noTxWrap := false 114 if v := qv.Get("x-no-tx-wrap"); v != "" { 115 noTxWrap, err = strconv.ParseBool(v) 116 if err != nil { 117 return nil, fmt.Errorf("x-no-tx-wrap: %s", err) 118 } 119 } 120 121 mx, err := WithInstance(db, &Config{ 122 DatabaseName: purl.Path, 123 MigrationsTable: migrationsTable, 124 NoTxWrap: noTxWrap, 125 }) 126 if err != nil { 127 return nil, err 128 } 129 return mx, nil 130 } 131 132 func (m *Sqlite) Close() error { 133 return m.db.Close() 134 } 135 136 func (m *Sqlite) Drop() (err error) { 137 query := `SELECT name FROM sqlite_master WHERE type = 'table';` 138 tables, err := m.db.Query(query) 139 if err != nil { 140 return &database.Error{OrigErr: err, Query: []byte(query)} 141 } 142 defer func() { 143 if errClose := tables.Close(); errClose != nil { 144 err = multierror.Append(err, errClose) 145 } 146 }() 147 148 tableNames := make([]string, 0) 149 for tables.Next() { 150 var tableName string 151 if err := tables.Scan(&tableName); err != nil { 152 return err 153 } 154 if len(tableName) > 0 { 155 tableNames = append(tableNames, tableName) 156 } 157 } 158 if err := tables.Err(); err != nil { 159 return &database.Error{OrigErr: err, Query: []byte(query)} 160 } 161 162 if len(tableNames) > 0 { 163 for _, t := range tableNames { 164 query := "DROP TABLE " + t 165 err = m.executeQuery(query) 166 if err != nil { 167 return &database.Error{OrigErr: err, Query: []byte(query)} 168 } 169 } 170 query := "VACUUM" 171 _, err = m.db.Query(query) 172 if err != nil { 173 return &database.Error{OrigErr: err, Query: []byte(query)} 174 } 175 } 176 177 return nil 178 } 179 180 func (m *Sqlite) Lock() error { 181 if !m.isLocked.CAS(false, true) { 182 return database.ErrLocked 183 } 184 return nil 185 } 186 187 func (m *Sqlite) Unlock() error { 188 if !m.isLocked.CAS(true, false) { 189 return database.ErrNotLocked 190 } 191 return nil 192 } 193 194 func (m *Sqlite) Run(migration io.Reader) error { 195 migr, err := io.ReadAll(migration) 196 if err != nil { 197 return err 198 } 199 query := string(migr[:]) 200 201 if m.config.NoTxWrap { 202 return m.executeQueryNoTx(query) 203 } 204 return m.executeQuery(query) 205 } 206 207 func (m *Sqlite) executeQuery(query string) error { 208 tx, err := m.db.Begin() 209 if err != nil { 210 return &database.Error{OrigErr: err, Err: "transaction start failed"} 211 } 212 if _, err := tx.Exec(query); err != nil { 213 if errRollback := tx.Rollback(); errRollback != nil { 214 err = multierror.Append(err, errRollback) 215 } 216 return &database.Error{OrigErr: err, Query: []byte(query)} 217 } 218 if err := tx.Commit(); err != nil { 219 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 220 } 221 return nil 222 } 223 224 func (m *Sqlite) executeQueryNoTx(query string) error { 225 if _, err := m.db.Exec(query); err != nil { 226 return &database.Error{OrigErr: err, Query: []byte(query)} 227 } 228 return nil 229 } 230 231 func (m *Sqlite) SetVersion(version int, dirty bool) error { 232 tx, err := m.db.Begin() 233 if err != nil { 234 return &database.Error{OrigErr: err, Err: "transaction start failed"} 235 } 236 237 query := "DELETE FROM " + m.config.MigrationsTable 238 if _, err := tx.Exec(query); err != nil { 239 return &database.Error{OrigErr: err, Query: []byte(query)} 240 } 241 242 // Also re-write the schema version for nil dirty versions to prevent 243 // empty schema version for failed down migration on the first migration 244 // See: https://github.com/golang-migrate/migrate/issues/330 245 if version >= 0 || (version == database.NilVersion && dirty) { 246 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable) 247 if _, err := tx.Exec(query, version, dirty); err != nil { 248 if errRollback := tx.Rollback(); errRollback != nil { 249 err = multierror.Append(err, errRollback) 250 } 251 return &database.Error{OrigErr: err, Query: []byte(query)} 252 } 253 } 254 255 if err := tx.Commit(); err != nil { 256 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 257 } 258 259 return nil 260 } 261 262 func (m *Sqlite) Version() (version int, dirty bool, err error) { 263 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 264 err = m.db.QueryRow(query).Scan(&version, &dirty) 265 if err != nil { 266 return database.NilVersion, false, nil 267 } 268 return version, dirty, nil 269 }