github.com/Elate-DevOps/migrate/v4@v4.0.12/database/ql/ql.go (about) 1 package ql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 nurl "net/url" 8 "strings" 9 10 "github.com/hashicorp/go-multierror" 11 "go.uber.org/atomic" 12 13 "github.com/Elate-DevOps/migrate/v4" 14 "github.com/Elate-DevOps/migrate/v4/database" 15 _ "modernc.org/ql/driver" 16 ) 17 18 func init() { 19 database.Register("ql", &Ql{}) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 var ( 24 ErrDatabaseDirty = fmt.Errorf("database is dirty") 25 ErrNilConfig = fmt.Errorf("no config") 26 ErrNoDatabaseName = fmt.Errorf("no database name") 27 ErrAppendPEM = fmt.Errorf("failed to append PEM") 28 ) 29 30 type Config struct { 31 MigrationsTable string 32 DatabaseName string 33 } 34 35 type Ql struct { 36 db *sql.DB 37 isLocked atomic.Bool 38 39 config *Config 40 } 41 42 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 43 if config == nil { 44 return nil, ErrNilConfig 45 } 46 47 if err := instance.Ping(); err != nil { 48 return nil, err 49 } 50 51 if len(config.MigrationsTable) == 0 { 52 config.MigrationsTable = DefaultMigrationsTable 53 } 54 55 mx := &Ql{ 56 db: instance, 57 config: config, 58 } 59 if err := mx.ensureVersionTable(); err != nil { 60 return nil, err 61 } 62 return mx, nil 63 } 64 65 // ensureVersionTable checks if versions table exists and, if not, creates it. 66 // Note that this function locks the database, which deviates from the usual 67 // convention of "caller locks" in the Ql type. 68 func (m *Ql) ensureVersionTable() (err error) { 69 if err = m.Lock(); err != nil { 70 return err 71 } 72 73 defer func() { 74 if e := m.Unlock(); e != nil { 75 if err == nil { 76 err = e 77 } else { 78 err = multierror.Append(err, e) 79 } 80 } 81 }() 82 83 tx, err := m.db.Begin() 84 if err != nil { 85 return err 86 } 87 if _, err := tx.Exec(fmt.Sprintf(` 88 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); 89 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 90 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil { 91 if err := tx.Rollback(); err != nil { 92 return err 93 } 94 return err 95 } 96 if err := tx.Commit(); err != nil { 97 return err 98 } 99 return nil 100 } 101 102 func (m *Ql) Open(url string) (database.Driver, error) { 103 purl, err := nurl.Parse(url) 104 if err != nil { 105 return nil, err 106 } 107 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1) 108 db, err := sql.Open("ql", dbfile) 109 if err != nil { 110 return nil, err 111 } 112 migrationsTable := purl.Query().Get("x-migrations-table") 113 if len(migrationsTable) == 0 { 114 migrationsTable = DefaultMigrationsTable 115 } 116 mx, err := WithInstance(db, &Config{ 117 DatabaseName: purl.Path, 118 MigrationsTable: migrationsTable, 119 }) 120 if err != nil { 121 return nil, err 122 } 123 return mx, nil 124 } 125 126 func (m *Ql) Close() error { 127 return m.db.Close() 128 } 129 130 func (m *Ql) Drop() (err error) { 131 query := `SELECT Name FROM __Table` 132 tables, err := m.db.Query(query) 133 if err != nil { 134 return &database.Error{OrigErr: err, Query: []byte(query)} 135 } 136 defer func() { 137 if errClose := tables.Close(); errClose != nil { 138 err = multierror.Append(err, errClose) 139 } 140 }() 141 142 tableNames := make([]string, 0) 143 for tables.Next() { 144 var tableName string 145 if err := tables.Scan(&tableName); err != nil { 146 return err 147 } 148 if len(tableName) > 0 { 149 if !strings.HasPrefix(tableName, "__") { 150 tableNames = append(tableNames, tableName) 151 } 152 } 153 } 154 if err := tables.Err(); err != nil { 155 return &database.Error{OrigErr: err, Query: []byte(query)} 156 } 157 158 if len(tableNames) > 0 { 159 for _, t := range tableNames { 160 query := "DROP TABLE " + t 161 err = m.executeQuery(query) 162 if err != nil { 163 return &database.Error{OrigErr: err, Query: []byte(query)} 164 } 165 } 166 } 167 168 return nil 169 } 170 171 func (m *Ql) Lock() error { 172 if !m.isLocked.CAS(false, true) { 173 return database.ErrLocked 174 } 175 return nil 176 } 177 178 func (m *Ql) Unlock() error { 179 if !m.isLocked.CAS(true, false) { 180 return database.ErrNotLocked 181 } 182 return nil 183 } 184 185 func (m *Ql) Run(migration io.Reader) error { 186 migr, err := io.ReadAll(migration) 187 if err != nil { 188 return err 189 } 190 query := string(migr[:]) 191 192 return m.executeQuery(query) 193 } 194 195 func (m *Ql) executeQuery(query string) error { 196 tx, err := m.db.Begin() 197 if err != nil { 198 return &database.Error{OrigErr: err, Err: "transaction start failed"} 199 } 200 if _, err := tx.Exec(query); err != nil { 201 if errRollback := tx.Rollback(); errRollback != nil { 202 err = multierror.Append(err, errRollback) 203 } 204 return &database.Error{OrigErr: err, Query: []byte(query)} 205 } 206 if err := tx.Commit(); err != nil { 207 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 208 } 209 return nil 210 } 211 212 func (m *Ql) SetVersion(version int, dirty bool) error { 213 tx, err := m.db.Begin() 214 if err != nil { 215 return &database.Error{OrigErr: err, Err: "transaction start failed"} 216 } 217 218 query := "TRUNCATE TABLE " + m.config.MigrationsTable 219 if _, err := tx.Exec(query); err != nil { 220 return &database.Error{OrigErr: err, Query: []byte(query)} 221 } 222 223 // Also re-write the schema version for nil dirty versions to prevent 224 // empty schema version for failed down migration on the first migration 225 // See: https://github.com/golang-migrate/migrate/issues/330 226 if version >= 0 || (version == database.NilVersion && dirty) { 227 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`, 228 m.config.MigrationsTable) 229 if _, err := tx.Exec(query, version, dirty); err != nil { 230 if errRollback := tx.Rollback(); errRollback != nil { 231 err = multierror.Append(err, errRollback) 232 } 233 return &database.Error{OrigErr: err, Query: []byte(query)} 234 } 235 } 236 237 if err := tx.Commit(); err != nil { 238 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 239 } 240 241 return nil 242 } 243 244 func (m *Ql) Version() (version int, dirty bool, err error) { 245 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 246 err = m.db.QueryRow(query).Scan(&version, &dirty) 247 if err != nil { 248 return database.NilVersion, false, nil 249 } 250 return version, dirty, nil 251 }