github.com/mrqzzz/migrate@v5.1.7+incompatible/database/sqlite3/sqlite3.go (about) 1 package sqlite3 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 nurl "net/url" 9 "strings" 10 11 "github.com/golang-migrate/migrate/v4" 12 "github.com/golang-migrate/migrate/v4/database" 13 _ "github.com/mattn/go-sqlite3" 14 ) 15 16 func init() { 17 database.Register("sqlite3", &Sqlite{}) 18 } 19 20 var DefaultMigrationsTable = "schema_migrations" 21 var ( 22 ErrDatabaseDirty = fmt.Errorf("database is dirty") 23 ErrNilConfig = fmt.Errorf("no config") 24 ErrNoDatabaseName = fmt.Errorf("no database name") 25 ) 26 27 type Config struct { 28 MigrationsTable string 29 DatabaseName string 30 } 31 32 type Sqlite struct { 33 db *sql.DB 34 isLocked bool 35 36 config *Config 37 } 38 39 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 40 if config == nil { 41 return nil, ErrNilConfig 42 } 43 44 if err := instance.Ping(); err != nil { 45 return nil, err 46 } 47 if len(config.MigrationsTable) == 0 { 48 config.MigrationsTable = DefaultMigrationsTable 49 } 50 51 mx := &Sqlite{ 52 db: instance, 53 config: config, 54 } 55 if err := mx.ensureVersionTable(); err != nil { 56 return nil, err 57 } 58 return mx, nil 59 } 60 61 func (m *Sqlite) ensureVersionTable() error { 62 63 query := fmt.Sprintf(` 64 CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); 65 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 66 `, m.config.MigrationsTable, m.config.MigrationsTable) 67 68 if _, err := m.db.Exec(query); err != nil { 69 return err 70 } 71 return nil 72 } 73 74 func (m *Sqlite) Open(url string) (database.Driver, error) { 75 purl, err := nurl.Parse(url) 76 if err != nil { 77 return nil, err 78 } 79 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1) 80 db, err := sql.Open("sqlite3", dbfile) 81 if err != nil { 82 return nil, err 83 } 84 85 migrationsTable := purl.Query().Get("x-migrations-table") 86 if len(migrationsTable) == 0 { 87 migrationsTable = DefaultMigrationsTable 88 } 89 mx, err := WithInstance(db, &Config{ 90 DatabaseName: purl.Path, 91 MigrationsTable: migrationsTable, 92 }) 93 if err != nil { 94 return nil, err 95 } 96 return mx, nil 97 } 98 99 func (m *Sqlite) Close() error { 100 return m.db.Close() 101 } 102 103 func (m *Sqlite) Drop() error { 104 query := `SELECT name FROM sqlite_master WHERE type = 'table';` 105 tables, err := m.db.Query(query) 106 if err != nil { 107 return &database.Error{OrigErr: err, Query: []byte(query)} 108 } 109 defer tables.Close() 110 tableNames := make([]string, 0) 111 for tables.Next() { 112 var tableName string 113 if err := tables.Scan(&tableName); err != nil { 114 return err 115 } 116 if len(tableName) > 0 { 117 tableNames = append(tableNames, tableName) 118 } 119 } 120 if len(tableNames) > 0 { 121 for _, t := range tableNames { 122 query := "DROP TABLE " + t 123 err = m.executeQuery(query) 124 if err != nil { 125 return &database.Error{OrigErr: err, Query: []byte(query)} 126 } 127 } 128 if err := m.ensureVersionTable(); err != nil { 129 return err 130 } 131 query := "VACUUM" 132 _, err = m.db.Query(query) 133 if err != nil { 134 return &database.Error{OrigErr: err, Query: []byte(query)} 135 } 136 } 137 138 return nil 139 } 140 141 func (m *Sqlite) Lock() error { 142 if m.isLocked { 143 return database.ErrLocked 144 } 145 m.isLocked = true 146 return nil 147 } 148 149 func (m *Sqlite) Unlock() error { 150 if !m.isLocked { 151 return nil 152 } 153 m.isLocked = false 154 return nil 155 } 156 157 func (m *Sqlite) Run(migration io.Reader) error { 158 migr, err := ioutil.ReadAll(migration) 159 if err != nil { 160 return err 161 } 162 query := string(migr[:]) 163 164 return m.executeQuery(query) 165 } 166 167 func (m *Sqlite) executeQuery(query string) error { 168 tx, err := m.db.Begin() 169 if err != nil { 170 return &database.Error{OrigErr: err, Err: "transaction start failed"} 171 } 172 if _, err := tx.Exec(query); err != nil { 173 tx.Rollback() 174 return &database.Error{OrigErr: err, Query: []byte(query)} 175 } 176 if err := tx.Commit(); err != nil { 177 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 178 } 179 return nil 180 } 181 182 func (m *Sqlite) SetVersion(version int, dirty bool) error { 183 tx, err := m.db.Begin() 184 if err != nil { 185 return &database.Error{OrigErr: err, Err: "transaction start failed"} 186 } 187 188 query := "DELETE FROM " + m.config.MigrationsTable 189 if _, err := tx.Exec(query); err != nil { 190 return &database.Error{OrigErr: err, Query: []byte(query)} 191 } 192 193 if version >= 0 { 194 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty) 195 if _, err := tx.Exec(query); err != nil { 196 tx.Rollback() 197 return &database.Error{OrigErr: err, Query: []byte(query)} 198 } 199 } 200 201 if err := tx.Commit(); err != nil { 202 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 203 } 204 205 return nil 206 } 207 208 func (m *Sqlite) Version() (version int, dirty bool, err error) { 209 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 210 err = m.db.QueryRow(query).Scan(&version, &dirty) 211 if err != nil { 212 return database.NilVersion, false, nil 213 } 214 return version, dirty, nil 215 }