github.com/nagyistzcons/migrate/v4@v4.14.5/database/sqlcipher/sqlcipher.go (about) 1 package sqlcipher 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strconv" 10 "strings" 11 12 "github.com/golang-migrate/migrate/v4" 13 "github.com/golang-migrate/migrate/v4/database" 14 "github.com/hashicorp/go-multierror" 15 _ "github.com/mutecomm/go-sqlcipher/v4" 16 ) 17 18 func init() { 19 database.Register("sqlcipher", &Sqlite{}) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 var ( 24 ErrDatabaseDirty = fmt.Errorf("database is dirty") 25 ErrNilConfig = fmt.Errorf("no config") 26 ErrNoDatabaseName = fmt.Errorf("no database name") 27 ) 28 29 type Config struct { 30 MigrationsTable string 31 DatabaseName string 32 NoTxWrap bool 33 } 34 35 type Sqlite struct { 36 db *sql.DB 37 isLocked bool 38 39 config *Config 40 } 41 42 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 43 if config == nil { 44 return nil, ErrNilConfig 45 } 46 47 if err := instance.Ping(); err != nil { 48 return nil, err 49 } 50 51 if len(config.MigrationsTable) == 0 { 52 config.MigrationsTable = DefaultMigrationsTable 53 } 54 55 mx := &Sqlite{ 56 db: instance, 57 config: config, 58 } 59 if err := mx.ensureVersionTable(); err != nil { 60 return nil, err 61 } 62 return mx, nil 63 } 64 65 // ensureVersionTable checks if versions table exists and, if not, creates it. 66 // Note that this function locks the database, which deviates from the usual 67 // convention of "caller locks" in the Sqlite type. 68 func (m *Sqlite) ensureVersionTable() (err error) { 69 if err = m.Lock(); err != nil { 70 return err 71 } 72 73 defer func() { 74 if e := m.Unlock(); e != nil { 75 if err == nil { 76 err = e 77 } else { 78 err = multierror.Append(err, e) 79 } 80 } 81 }() 82 83 query := fmt.Sprintf(` 84 CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); 85 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 86 `, m.config.MigrationsTable, m.config.MigrationsTable) 87 88 if _, err := m.db.Exec(query); err != nil { 89 return err 90 } 91 return nil 92 } 93 94 func (m *Sqlite) Open(url string) (database.Driver, error) { 95 purl, err := nurl.Parse(url) 96 if err != nil { 97 return nil, err 98 } 99 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) 100 db, err := sql.Open("sqlite3", dbfile) 101 if err != nil { 102 return nil, err 103 } 104 105 qv := purl.Query() 106 107 migrationsTable := qv.Get("x-migrations-table") 108 if len(migrationsTable) == 0 { 109 migrationsTable = DefaultMigrationsTable 110 } 111 112 noTxWrap := false 113 if v := qv.Get("x-no-tx-wrap"); v != "" { 114 noTxWrap, err = strconv.ParseBool(v) 115 if err != nil { 116 return nil, fmt.Errorf("x-no-tx-wrap: %s", err) 117 } 118 } 119 120 mx, err := WithInstance(db, &Config{ 121 DatabaseName: purl.Path, 122 MigrationsTable: migrationsTable, 123 NoTxWrap: noTxWrap, 124 }) 125 if err != nil { 126 return nil, err 127 } 128 return mx, nil 129 } 130 131 func (m *Sqlite) Close() error { 132 return m.db.Close() 133 } 134 135 func (m *Sqlite) Drop() (err error) { 136 query := `SELECT name FROM sqlite_master WHERE type = 'table';` 137 tables, err := m.db.Query(query) 138 if err != nil { 139 return &database.Error{OrigErr: err, Query: []byte(query)} 140 } 141 defer func() { 142 if errClose := tables.Close(); errClose != nil { 143 err = multierror.Append(err, errClose) 144 } 145 }() 146 147 tableNames := make([]string, 0) 148 for tables.Next() { 149 var tableName string 150 if err := tables.Scan(&tableName); err != nil { 151 return err 152 } 153 if len(tableName) > 0 { 154 tableNames = append(tableNames, tableName) 155 } 156 } 157 if err := tables.Err(); err != nil { 158 return &database.Error{OrigErr: err, Query: []byte(query)} 159 } 160 161 if len(tableNames) > 0 { 162 for _, t := range tableNames { 163 query := "DROP TABLE " + t 164 err = m.executeQuery(query) 165 if err != nil { 166 return &database.Error{OrigErr: err, Query: []byte(query)} 167 } 168 } 169 query := "VACUUM" 170 _, err = m.db.Query(query) 171 if err != nil { 172 return &database.Error{OrigErr: err, Query: []byte(query)} 173 } 174 } 175 176 return nil 177 } 178 179 func (m *Sqlite) Lock() error { 180 if m.isLocked { 181 return database.ErrLocked 182 } 183 m.isLocked = true 184 return nil 185 } 186 187 func (m *Sqlite) Unlock() error { 188 if !m.isLocked { 189 return nil 190 } 191 m.isLocked = false 192 return nil 193 } 194 195 func (m *Sqlite) Run(migration io.Reader) error { 196 migr, err := ioutil.ReadAll(migration) 197 if err != nil { 198 return err 199 } 200 query := string(migr[:]) 201 202 if m.config.NoTxWrap { 203 return m.executeQueryNoTx(query) 204 } 205 return m.executeQuery(query) 206 } 207 208 func (m *Sqlite) executeQuery(query string) error { 209 tx, err := m.db.Begin() 210 if err != nil { 211 return &database.Error{OrigErr: err, Err: "transaction start failed"} 212 } 213 if _, err := tx.Exec(query); err != nil { 214 if errRollback := tx.Rollback(); errRollback != nil { 215 err = multierror.Append(err, errRollback) 216 } 217 return &database.Error{OrigErr: err, Query: []byte(query)} 218 } 219 if err := tx.Commit(); err != nil { 220 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 221 } 222 return nil 223 } 224 225 func (m *Sqlite) executeQueryNoTx(query string) error { 226 if _, err := m.db.Exec(query); err != nil { 227 return &database.Error{OrigErr: err, Query: []byte(query)} 228 } 229 return nil 230 } 231 232 func (m *Sqlite) SetVersion(version int, dirty bool) error { 233 tx, err := m.db.Begin() 234 if err != nil { 235 return &database.Error{OrigErr: err, Err: "transaction start failed"} 236 } 237 238 query := "DELETE FROM " + m.config.MigrationsTable 239 if _, err := tx.Exec(query); err != nil { 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 243 // Also re-write the schema version for nil dirty versions to prevent 244 // empty schema version for failed down migration on the first migration 245 // See: https://github.com/golang-migrate/migrate/issues/330 246 if version >= 0 || (version == database.NilVersion && dirty) { 247 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable) 248 if _, err := tx.Exec(query, version, dirty); err != nil { 249 if errRollback := tx.Rollback(); errRollback != nil { 250 err = multierror.Append(err, errRollback) 251 } 252 return &database.Error{OrigErr: err, Query: []byte(query)} 253 } 254 } 255 256 if err := tx.Commit(); err != nil { 257 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 258 } 259 260 return nil 261 } 262 263 func (m *Sqlite) Version() (version int, dirty bool, err error) { 264 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 265 err = m.db.QueryRow(query).Scan(&version, &dirty) 266 if err != nil { 267 return database.NilVersion, false, nil 268 } 269 return version, dirty, nil 270 }