github.com/nokia/migrate/v4@v4.16.0/database/ql/ql.go (about) 1 package ql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "strings" 9 10 "github.com/hashicorp/go-multierror" 11 "go.uber.org/atomic" 12 13 nurl "net/url" 14 15 "github.com/nokia/migrate/v4" 16 "github.com/nokia/migrate/v4/database" 17 "github.com/nokia/migrate/v4/source" 18 _ "modernc.org/ql/driver" 19 ) 20 21 func init() { 22 database.Register("ql", &Ql{}) 23 } 24 25 var DefaultMigrationsTable = "schema_migrations" 26 var ( 27 ErrDatabaseDirty = fmt.Errorf("database is dirty") 28 ErrNilConfig = fmt.Errorf("no config") 29 ErrNoDatabaseName = fmt.Errorf("no database name") 30 ErrAppendPEM = fmt.Errorf("failed to append PEM") 31 ) 32 33 type Config struct { 34 MigrationsTable string 35 DatabaseName string 36 } 37 38 type Ql struct { 39 db *sql.DB 40 isLocked atomic.Bool 41 42 config *Config 43 } 44 45 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } 49 50 if err := instance.Ping(); err != nil { 51 return nil, err 52 } 53 54 if len(config.MigrationsTable) == 0 { 55 config.MigrationsTable = DefaultMigrationsTable 56 } 57 58 mx := &Ql{ 59 db: instance, 60 config: config, 61 } 62 if err := mx.ensureVersionTable(); err != nil { 63 return nil, err 64 } 65 return mx, nil 66 } 67 68 // ensureVersionTable checks if versions table exists and, if not, creates it. 69 // Note that this function locks the database, which deviates from the usual 70 // convention of "caller locks" in the Ql type. 71 func (m *Ql) ensureVersionTable() (err error) { 72 if err = m.Lock(); err != nil { 73 return err 74 } 75 76 defer func() { 77 if e := m.Unlock(); e != nil { 78 if err == nil { 79 err = e 80 } else { 81 err = multierror.Append(err, e) 82 } 83 } 84 }() 85 86 tx, err := m.db.Begin() 87 if err != nil { 88 return err 89 } 90 if _, err := tx.Exec(fmt.Sprintf(` 91 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); 92 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 93 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil { 94 if err := tx.Rollback(); err != nil { 95 return err 96 } 97 return err 98 } 99 if err := tx.Commit(); err != nil { 100 return err 101 } 102 return nil 103 } 104 105 func (m *Ql) Open(url string) (database.Driver, error) { 106 purl, err := nurl.Parse(url) 107 if err != nil { 108 return nil, err 109 } 110 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1) 111 db, err := sql.Open("ql", dbfile) 112 if err != nil { 113 return nil, err 114 } 115 migrationsTable := purl.Query().Get("x-migrations-table") 116 if len(migrationsTable) == 0 { 117 migrationsTable = DefaultMigrationsTable 118 } 119 mx, err := WithInstance(db, &Config{ 120 DatabaseName: purl.Path, 121 MigrationsTable: migrationsTable, 122 }) 123 if err != nil { 124 return nil, err 125 } 126 return mx, nil 127 } 128 129 func (m *Ql) Close() error { 130 return m.db.Close() 131 } 132 133 func (m *Ql) Drop() (err error) { 134 query := `SELECT Name FROM __Table` 135 tables, err := m.db.Query(query) 136 if err != nil { 137 return &database.Error{OrigErr: err, Query: []byte(query)} 138 } 139 defer func() { 140 if errClose := tables.Close(); errClose != nil { 141 err = multierror.Append(err, errClose) 142 } 143 }() 144 145 tableNames := make([]string, 0) 146 for tables.Next() { 147 var tableName string 148 if err := tables.Scan(&tableName); err != nil { 149 return err 150 } 151 if len(tableName) > 0 { 152 if !strings.HasPrefix(tableName, "__") { 153 tableNames = append(tableNames, tableName) 154 } 155 } 156 } 157 if err := tables.Err(); err != nil { 158 return &database.Error{OrigErr: err, Query: []byte(query)} 159 } 160 161 if len(tableNames) > 0 { 162 for _, t := range tableNames { 163 query := "DROP TABLE " + t 164 err = m.executeQuery(query) 165 if err != nil { 166 return &database.Error{OrigErr: err, Query: []byte(query)} 167 } 168 } 169 } 170 171 return nil 172 } 173 174 func (m *Ql) Lock() error { 175 if !m.isLocked.CAS(false, true) { 176 return database.ErrLocked 177 } 178 return nil 179 } 180 181 func (m *Ql) Unlock() error { 182 if !m.isLocked.CAS(true, false) { 183 return database.ErrNotLocked 184 } 185 return nil 186 } 187 188 func (m *Ql) Run(migration io.Reader) error { 189 migr, err := ioutil.ReadAll(migration) 190 if err != nil { 191 return err 192 } 193 query := string(migr[:]) 194 195 return m.executeQuery(query) 196 } 197 198 func (m *Ql) RunFunctionMigration(fn source.MigrationFunc) error { 199 return database.ErrNotImpl 200 } 201 202 func (m *Ql) executeQuery(query string) error { 203 tx, err := m.db.Begin() 204 if err != nil { 205 return &database.Error{OrigErr: err, Err: "transaction start failed"} 206 } 207 if _, err := tx.Exec(query); err != nil { 208 if errRollback := tx.Rollback(); errRollback != nil { 209 err = multierror.Append(err, errRollback) 210 } 211 return &database.Error{OrigErr: err, Query: []byte(query)} 212 } 213 if err := tx.Commit(); err != nil { 214 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 215 } 216 return nil 217 } 218 219 func (m *Ql) SetVersion(version int, dirty bool) error { 220 tx, err := m.db.Begin() 221 if err != nil { 222 return &database.Error{OrigErr: err, Err: "transaction start failed"} 223 } 224 225 query := "TRUNCATE TABLE " + m.config.MigrationsTable 226 if _, err := tx.Exec(query); err != nil { 227 return &database.Error{OrigErr: err, Query: []byte(query)} 228 } 229 230 // Also re-write the schema version for nil dirty versions to prevent 231 // empty schema version for failed down migration on the first migration 232 // See: https://github.com/nokia/migrate/issues/330 233 if version >= 0 || (version == database.NilVersion && dirty) { 234 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`, 235 m.config.MigrationsTable) 236 if _, err := tx.Exec(query, version, dirty); err != nil { 237 if errRollback := tx.Rollback(); errRollback != nil { 238 err = multierror.Append(err, errRollback) 239 } 240 return &database.Error{OrigErr: err, Query: []byte(query)} 241 } 242 } 243 244 if err := tx.Commit(); err != nil { 245 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 246 } 247 248 return nil 249 } 250 251 func (m *Ql) Version() (version int, dirty bool, err error) { 252 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 253 err = m.db.QueryRow(query).Scan(&version, &dirty) 254 if err != nil { 255 return database.NilVersion, false, nil 256 } 257 return version, dirty, nil 258 }