github.com/fr-nvriep/migrate/v4@v4.3.2/database/sqlite3/sqlite3.go (about) 1 package sqlite3 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strings" 10 11 "github.com/fr-nvriep/migrate/v4" 12 "github.com/fr-nvriep/migrate/v4/database" 13 "github.com/hashicorp/go-multierror" 14 _ "github.com/mattn/go-sqlite3" 15 ) 16 17 func init() { 18 database.Register("sqlite3", &Sqlite{}) 19 } 20 21 var DefaultMigrationsTable = "schema_migrations" 22 var ( 23 ErrDatabaseDirty = fmt.Errorf("database is dirty") 24 ErrNilConfig = fmt.Errorf("no config") 25 ErrNoDatabaseName = fmt.Errorf("no database name") 26 ) 27 28 type Config struct { 29 MigrationsTable string 30 DatabaseName string 31 } 32 33 type Sqlite struct { 34 db *sql.DB 35 isLocked bool 36 37 config *Config 38 } 39 40 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 41 if config == nil { 42 return nil, ErrNilConfig 43 } 44 45 if err := instance.Ping(); err != nil { 46 return nil, err 47 } 48 49 if len(config.MigrationsTable) == 0 { 50 config.MigrationsTable = DefaultMigrationsTable 51 } 52 53 mx := &Sqlite{ 54 db: instance, 55 config: config, 56 } 57 if err := mx.ensureVersionTable(); err != nil { 58 return nil, err 59 } 60 return mx, nil 61 } 62 63 // ensureVersionTable checks if versions table exists and, if not, creates it. 64 // Note that this function locks the database, which deviates from the usual 65 // convention of "caller locks" in the Sqlite type. 66 func (m *Sqlite) ensureVersionTable() (err error) { 67 if err = m.Lock(); err != nil { 68 return err 69 } 70 71 defer func() { 72 if e := m.Unlock(); e != nil { 73 if err == nil { 74 err = e 75 } else { 76 err = multierror.Append(err, e) 77 } 78 } 79 }() 80 81 query := fmt.Sprintf(` 82 CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); 83 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 84 `, m.config.MigrationsTable, m.config.MigrationsTable) 85 86 if _, err := m.db.Exec(query); err != nil { 87 return err 88 } 89 return nil 90 } 91 92 func (m *Sqlite) Open(url string) (database.Driver, error) { 93 purl, err := nurl.Parse(url) 94 if err != nil { 95 return nil, err 96 } 97 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) 98 db, err := sql.Open("sqlite3", dbfile) 99 if err != nil { 100 return nil, err 101 } 102 103 migrationsTable := purl.Query().Get("x-migrations-table") 104 if len(migrationsTable) == 0 { 105 migrationsTable = DefaultMigrationsTable 106 } 107 mx, err := WithInstance(db, &Config{ 108 DatabaseName: purl.Path, 109 MigrationsTable: migrationsTable, 110 }) 111 if err != nil { 112 return nil, err 113 } 114 return mx, nil 115 } 116 117 func (m *Sqlite) Close() error { 118 return m.db.Close() 119 } 120 121 func (m *Sqlite) Drop() (err error) { 122 query := `SELECT name FROM sqlite_master WHERE type = 'table';` 123 tables, err := m.db.Query(query) 124 if err != nil { 125 return &database.Error{OrigErr: err, Query: []byte(query)} 126 } 127 defer func() { 128 if errClose := tables.Close(); errClose != nil { 129 err = multierror.Append(err, errClose) 130 } 131 }() 132 tableNames := make([]string, 0) 133 for tables.Next() { 134 var tableName string 135 if err := tables.Scan(&tableName); err != nil { 136 return err 137 } 138 if len(tableName) > 0 { 139 tableNames = append(tableNames, tableName) 140 } 141 } 142 if len(tableNames) > 0 { 143 for _, t := range tableNames { 144 query := "DROP TABLE " + t 145 err = m.executeQuery(query) 146 if err != nil { 147 return &database.Error{OrigErr: err, Query: []byte(query)} 148 } 149 } 150 query := "VACUUM" 151 _, err = m.db.Query(query) 152 if err != nil { 153 return &database.Error{OrigErr: err, Query: []byte(query)} 154 } 155 } 156 157 return nil 158 } 159 160 func (m *Sqlite) Lock() error { 161 if m.isLocked { 162 return database.ErrLocked 163 } 164 m.isLocked = true 165 return nil 166 } 167 168 func (m *Sqlite) Unlock() error { 169 if !m.isLocked { 170 return nil 171 } 172 m.isLocked = false 173 return nil 174 } 175 176 func (m *Sqlite) Run(migration io.Reader) error { 177 migr, err := ioutil.ReadAll(migration) 178 if err != nil { 179 return err 180 } 181 query := string(migr[:]) 182 183 return m.executeQuery(query) 184 } 185 186 func (m *Sqlite) executeQuery(query string) error { 187 tx, err := m.db.Begin() 188 if err != nil { 189 return &database.Error{OrigErr: err, Err: "transaction start failed"} 190 } 191 if _, err := tx.Exec(query); err != nil { 192 if errRollback := tx.Rollback(); errRollback != nil { 193 err = multierror.Append(err, errRollback) 194 } 195 return &database.Error{OrigErr: err, Query: []byte(query)} 196 } 197 if err := tx.Commit(); err != nil { 198 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 199 } 200 return nil 201 } 202 203 func (m *Sqlite) SetVersion(version int, dirty bool) error { 204 tx, err := m.db.Begin() 205 if err != nil { 206 return &database.Error{OrigErr: err, Err: "transaction start failed"} 207 } 208 209 query := "DELETE FROM " + m.config.MigrationsTable 210 if _, err := tx.Exec(query); err != nil { 211 return &database.Error{OrigErr: err, Query: []byte(query)} 212 } 213 214 if version >= 0 { 215 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty) 216 if _, err := tx.Exec(query); err != nil { 217 if errRollback := tx.Rollback(); errRollback != nil { 218 err = multierror.Append(err, errRollback) 219 } 220 return &database.Error{OrigErr: err, Query: []byte(query)} 221 } 222 } 223 224 if err := tx.Commit(); err != nil { 225 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 226 } 227 228 return nil 229 } 230 231 func (m *Sqlite) Version() (version int, dirty bool, err error) { 232 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 233 err = m.db.QueryRow(query).Scan(&version, &dirty) 234 if err != nil { 235 return database.NilVersion, false, nil 236 } 237 return version, dirty, nil 238 }