github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/ql/ql.go (about) 1 package ql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "strings" 9 10 "github.com/hashicorp/go-multierror" 11 "go.uber.org/atomic" 12 13 nurl "net/url" 14 15 "github.com/getsynq/migrate/v4" 16 "github.com/getsynq/migrate/v4/database" 17 _ "modernc.org/ql/driver" 18 ) 19 20 func init() { 21 database.Register("ql", &Ql{}) 22 } 23 24 var DefaultMigrationsTable = "schema_migrations" 25 var ( 26 ErrDatabaseDirty = fmt.Errorf("database is dirty") 27 ErrNilConfig = fmt.Errorf("no config") 28 ErrNoDatabaseName = fmt.Errorf("no database name") 29 ErrAppendPEM = fmt.Errorf("failed to append PEM") 30 ) 31 32 type Config struct { 33 MigrationsTable string 34 DatabaseName string 35 } 36 37 type Ql struct { 38 db *sql.DB 39 isLocked atomic.Bool 40 41 config *Config 42 } 43 44 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 45 if config == nil { 46 return nil, ErrNilConfig 47 } 48 49 if err := instance.Ping(); err != nil { 50 return nil, err 51 } 52 53 if len(config.MigrationsTable) == 0 { 54 config.MigrationsTable = DefaultMigrationsTable 55 } 56 57 mx := &Ql{ 58 db: instance, 59 config: config, 60 } 61 if err := mx.ensureVersionTable(); err != nil { 62 return nil, err 63 } 64 return mx, nil 65 } 66 67 // ensureVersionTable checks if versions table exists and, if not, creates it. 68 // Note that this function locks the database, which deviates from the usual 69 // convention of "caller locks" in the Ql type. 70 func (m *Ql) ensureVersionTable() (err error) { 71 if err = m.Lock(); err != nil { 72 return err 73 } 74 75 defer func() { 76 if e := m.Unlock(); e != nil { 77 if err == nil { 78 err = e 79 } else { 80 err = multierror.Append(err, e) 81 } 82 } 83 }() 84 85 tx, err := m.db.Begin() 86 if err != nil { 87 return err 88 } 89 if _, err := tx.Exec(fmt.Sprintf(` 90 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); 91 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 92 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil { 93 if err := tx.Rollback(); err != nil { 94 return err 95 } 96 return err 97 } 98 if err := tx.Commit(); err != nil { 99 return err 100 } 101 return nil 102 } 103 104 func (m *Ql) Open(url string) (database.Driver, error) { 105 purl, err := nurl.Parse(url) 106 if err != nil { 107 return nil, err 108 } 109 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1) 110 db, err := sql.Open("ql", dbfile) 111 if err != nil { 112 return nil, err 113 } 114 migrationsTable := purl.Query().Get("x-migrations-table") 115 if len(migrationsTable) == 0 { 116 migrationsTable = DefaultMigrationsTable 117 } 118 mx, err := WithInstance(db, &Config{ 119 DatabaseName: purl.Path, 120 MigrationsTable: migrationsTable, 121 }) 122 if err != nil { 123 return nil, err 124 } 125 return mx, nil 126 } 127 func (m *Ql) Close() error { 128 return m.db.Close() 129 } 130 func (m *Ql) Drop() (err error) { 131 query := `SELECT Name FROM __Table` 132 tables, err := m.db.Query(query) 133 if err != nil { 134 return &database.Error{OrigErr: err, Query: []byte(query)} 135 } 136 defer func() { 137 if errClose := tables.Close(); errClose != nil { 138 err = multierror.Append(err, errClose) 139 } 140 }() 141 142 tableNames := make([]string, 0) 143 for tables.Next() { 144 var tableName string 145 if err := tables.Scan(&tableName); err != nil { 146 return err 147 } 148 if len(tableName) > 0 { 149 if !strings.HasPrefix(tableName, "__") { 150 tableNames = append(tableNames, tableName) 151 } 152 } 153 } 154 if err := tables.Err(); err != nil { 155 return &database.Error{OrigErr: err, Query: []byte(query)} 156 } 157 158 if len(tableNames) > 0 { 159 for _, t := range tableNames { 160 query := "DROP TABLE " + t 161 err = m.executeQuery(query) 162 if err != nil { 163 return &database.Error{OrigErr: err, Query: []byte(query)} 164 } 165 } 166 } 167 168 return nil 169 } 170 func (m *Ql) Lock() error { 171 if !m.isLocked.CAS(false, true) { 172 return database.ErrLocked 173 } 174 return nil 175 } 176 func (m *Ql) Unlock() error { 177 if !m.isLocked.CAS(true, false) { 178 return database.ErrNotLocked 179 } 180 return nil 181 } 182 func (m *Ql) Run(migration io.Reader) error { 183 migr, err := ioutil.ReadAll(migration) 184 if err != nil { 185 return err 186 } 187 query := string(migr[:]) 188 189 return m.executeQuery(query) 190 } 191 func (m *Ql) executeQuery(query string) error { 192 tx, err := m.db.Begin() 193 if err != nil { 194 return &database.Error{OrigErr: err, Err: "transaction start failed"} 195 } 196 if _, err := tx.Exec(query); err != nil { 197 if errRollback := tx.Rollback(); errRollback != nil { 198 err = multierror.Append(err, errRollback) 199 } 200 return &database.Error{OrigErr: err, Query: []byte(query)} 201 } 202 if err := tx.Commit(); err != nil { 203 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 204 } 205 return nil 206 } 207 func (m *Ql) SetVersion(version int, dirty bool) error { 208 tx, err := m.db.Begin() 209 if err != nil { 210 return &database.Error{OrigErr: err, Err: "transaction start failed"} 211 } 212 213 query := "TRUNCATE TABLE " + m.config.MigrationsTable 214 if _, err := tx.Exec(query); err != nil { 215 return &database.Error{OrigErr: err, Query: []byte(query)} 216 } 217 218 // Also re-write the schema version for nil dirty versions to prevent 219 // empty schema version for failed down migration on the first migration 220 // See: https://github.com/getsynq/migrate/issues/330 221 if version >= 0 || (version == database.NilVersion && dirty) { 222 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`, 223 m.config.MigrationsTable) 224 if _, err := tx.Exec(query, version, dirty); err != nil { 225 if errRollback := tx.Rollback(); errRollback != nil { 226 err = multierror.Append(err, errRollback) 227 } 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 } 231 232 if err := tx.Commit(); err != nil { 233 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 234 } 235 236 return nil 237 } 238 239 func (m *Ql) Version() (version int, dirty bool, err error) { 240 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 241 err = m.db.QueryRow(query).Scan(&version, &dirty) 242 if err != nil { 243 return database.NilVersion, false, nil 244 } 245 return version, dirty, nil 246 }