github.com/nagyist/migrate/v4@v4.14.6/database/ql/ql.go (about) 1 package ql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "github.com/hashicorp/go-multierror" 7 "io" 8 "io/ioutil" 9 "strings" 10 11 nurl "net/url" 12 13 "github.com/golang-migrate/migrate/v4" 14 "github.com/golang-migrate/migrate/v4/database" 15 _ "modernc.org/ql/driver" 16 ) 17 18 func init() { 19 database.Register("ql", &Ql{}) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 var ( 24 ErrDatabaseDirty = fmt.Errorf("database is dirty") 25 ErrNilConfig = fmt.Errorf("no config") 26 ErrNoDatabaseName = fmt.Errorf("no database name") 27 ErrAppendPEM = fmt.Errorf("failed to append PEM") 28 ) 29 30 type Config struct { 31 MigrationsTable string 32 DatabaseName string 33 } 34 35 type Ql struct { 36 db *sql.DB 37 isLocked bool 38 39 config *Config 40 } 41 42 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 43 if config == nil { 44 return nil, ErrNilConfig 45 } 46 47 if err := instance.Ping(); err != nil { 48 return nil, err 49 } 50 51 if len(config.MigrationsTable) == 0 { 52 config.MigrationsTable = DefaultMigrationsTable 53 } 54 55 mx := &Ql{ 56 db: instance, 57 config: config, 58 } 59 if err := mx.ensureVersionTable(); err != nil { 60 return nil, err 61 } 62 return mx, nil 63 } 64 65 // ensureVersionTable checks if versions table exists and, if not, creates it. 66 // Note that this function locks the database, which deviates from the usual 67 // convention of "caller locks" in the Ql type. 68 func (m *Ql) ensureVersionTable() (err error) { 69 if err = m.Lock(); err != nil { 70 return err 71 } 72 73 defer func() { 74 if e := m.Unlock(); e != nil { 75 if err == nil { 76 err = e 77 } else { 78 err = multierror.Append(err, e) 79 } 80 } 81 }() 82 83 tx, err := m.db.Begin() 84 if err != nil { 85 return err 86 } 87 if _, err := tx.Exec(fmt.Sprintf(` 88 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); 89 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 90 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil { 91 if err := tx.Rollback(); err != nil { 92 return err 93 } 94 return err 95 } 96 if err := tx.Commit(); err != nil { 97 return err 98 } 99 return nil 100 } 101 102 func (m *Ql) Open(url string) (database.Driver, error) { 103 purl, err := nurl.Parse(url) 104 if err != nil { 105 return nil, err 106 } 107 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1) 108 db, err := sql.Open("ql", dbfile) 109 if err != nil { 110 return nil, err 111 } 112 migrationsTable := purl.Query().Get("x-migrations-table") 113 if len(migrationsTable) == 0 { 114 migrationsTable = DefaultMigrationsTable 115 } 116 mx, err := WithInstance(db, &Config{ 117 DatabaseName: purl.Path, 118 MigrationsTable: migrationsTable, 119 }) 120 if err != nil { 121 return nil, err 122 } 123 return mx, nil 124 } 125 func (m *Ql) Close() error { 126 return m.db.Close() 127 } 128 func (m *Ql) Drop() (err error) { 129 query := `SELECT Name FROM __Table` 130 tables, err := m.db.Query(query) 131 if err != nil { 132 return &database.Error{OrigErr: err, Query: []byte(query)} 133 } 134 defer func() { 135 if errClose := tables.Close(); errClose != nil { 136 err = multierror.Append(err, errClose) 137 } 138 }() 139 140 tableNames := make([]string, 0) 141 for tables.Next() { 142 var tableName string 143 if err := tables.Scan(&tableName); err != nil { 144 return err 145 } 146 if len(tableName) > 0 { 147 if !strings.HasPrefix(tableName, "__") { 148 tableNames = append(tableNames, tableName) 149 } 150 } 151 } 152 if err := tables.Err(); err != nil { 153 return &database.Error{OrigErr: err, Query: []byte(query)} 154 } 155 156 if len(tableNames) > 0 { 157 for _, t := range tableNames { 158 query := "DROP TABLE " + t 159 err = m.executeQuery(query) 160 if err != nil { 161 return &database.Error{OrigErr: err, Query: []byte(query)} 162 } 163 } 164 } 165 166 return nil 167 } 168 func (m *Ql) Lock() error { 169 if m.isLocked { 170 return database.ErrLocked 171 } 172 m.isLocked = true 173 return nil 174 } 175 func (m *Ql) Unlock() error { 176 if !m.isLocked { 177 return nil 178 } 179 m.isLocked = false 180 return nil 181 } 182 func (m *Ql) Run(migration io.Reader) error { 183 migr, err := ioutil.ReadAll(migration) 184 if err != nil { 185 return err 186 } 187 query := string(migr[:]) 188 189 return m.executeQuery(query) 190 } 191 func (m *Ql) executeQuery(query string) error { 192 tx, err := m.db.Begin() 193 if err != nil { 194 return &database.Error{OrigErr: err, Err: "transaction start failed"} 195 } 196 if _, err := tx.Exec(query); err != nil { 197 if errRollback := tx.Rollback(); errRollback != nil { 198 err = multierror.Append(err, errRollback) 199 } 200 return &database.Error{OrigErr: err, Query: []byte(query)} 201 } 202 if err := tx.Commit(); err != nil { 203 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 204 } 205 return nil 206 } 207 func (m *Ql) SetVersion(version int, dirty bool) error { 208 tx, err := m.db.Begin() 209 if err != nil { 210 return &database.Error{OrigErr: err, Err: "transaction start failed"} 211 } 212 213 query := "TRUNCATE TABLE " + m.config.MigrationsTable 214 if _, err := tx.Exec(query); err != nil { 215 return &database.Error{OrigErr: err, Query: []byte(query)} 216 } 217 218 // Also re-write the schema version for nil dirty versions to prevent 219 // empty schema version for failed down migration on the first migration 220 // See: https://github.com/golang-migrate/migrate/issues/330 221 if version >= 0 || (version == database.NilVersion && dirty) { 222 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`, 223 m.config.MigrationsTable) 224 if _, err := tx.Exec(query, version, dirty); err != nil { 225 if errRollback := tx.Rollback(); errRollback != nil { 226 err = multierror.Append(err, errRollback) 227 } 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 } 231 232 if err := tx.Commit(); err != nil { 233 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 234 } 235 236 return nil 237 } 238 239 func (m *Ql) Version() (version int, dirty bool, err error) { 240 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 241 err = m.db.QueryRow(query).Scan(&version, &dirty) 242 if err != nil { 243 return database.NilVersion, false, nil 244 } 245 return version, dirty, nil 246 }