github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/ql/ql.go (about) 1 package ql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io" 7 nurl "net/url" 8 "strings" 9 10 "github.com/hashicorp/go-multierror" 11 "go.uber.org/atomic" 12 13 "github.com/golang-migrate/migrate/v4" 14 "github.com/golang-migrate/migrate/v4/database" 15 _ "modernc.org/ql/driver" 16 ) 17 18 func init() { 19 database.Register("ql", &Ql{}) 20 } 21 22 var DefaultMigrationsTable = "schema_migrations" 23 var ( 24 ErrDatabaseDirty = fmt.Errorf("database is dirty") 25 ErrNilConfig = fmt.Errorf("no config") 26 ErrNoDatabaseName = fmt.Errorf("no database name") 27 ErrAppendPEM = fmt.Errorf("failed to append PEM") 28 ) 29 30 type Config struct { 31 MigrationsTable string 32 DatabaseName string 33 } 34 35 type Ql struct { 36 db *sql.DB 37 isLocked atomic.Bool 38 39 config *Config 40 } 41 42 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 43 if config == nil { 44 return nil, ErrNilConfig 45 } 46 47 if err := instance.Ping(); err != nil { 48 return nil, err 49 } 50 51 if len(config.MigrationsTable) == 0 { 52 config.MigrationsTable = DefaultMigrationsTable 53 } 54 55 mx := &Ql{ 56 db: instance, 57 config: config, 58 } 59 if err := mx.ensureVersionTable(); err != nil { 60 return nil, err 61 } 62 return mx, nil 63 } 64 65 // ensureVersionTable checks if versions table exists and, if not, creates it. 66 // Note that this function locks the database, which deviates from the usual 67 // convention of "caller locks" in the Ql type. 68 func (m *Ql) ensureVersionTable() (err error) { 69 if err = m.Lock(); err != nil { 70 return err 71 } 72 73 defer func() { 74 if e := m.Unlock(); e != nil { 75 if err == nil { 76 err = e 77 } else { 78 err = multierror.Append(err, e) 79 } 80 } 81 }() 82 83 tx, err := m.db.Begin() 84 if err != nil { 85 return err 86 } 87 if _, err := tx.Exec(fmt.Sprintf(` 88 CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); 89 CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); 90 `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil { 91 if err := tx.Rollback(); err != nil { 92 return err 93 } 94 return err 95 } 96 if err := tx.Commit(); err != nil { 97 return err 98 } 99 return nil 100 } 101 102 func (m *Ql) Open(url string) (database.Driver, error) { 103 purl, err := nurl.Parse(url) 104 if err != nil { 105 return nil, err 106 } 107 dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1) 108 db, err := sql.Open("ql", dbfile) 109 if err != nil { 110 return nil, err 111 } 112 migrationsTable := purl.Query().Get("x-migrations-table") 113 if len(migrationsTable) == 0 { 114 migrationsTable = DefaultMigrationsTable 115 } 116 mx, err := WithInstance(db, &Config{ 117 DatabaseName: purl.Path, 118 MigrationsTable: migrationsTable, 119 }) 120 if err != nil { 121 return nil, err 122 } 123 return mx, nil 124 } 125 func (m *Ql) Close() error { 126 return m.db.Close() 127 } 128 func (m *Ql) Drop() (err error) { 129 query := `SELECT Name FROM __Table` 130 tables, err := m.db.Query(query) 131 if err != nil { 132 return &database.Error{OrigErr: err, Query: []byte(query)} 133 } 134 defer func() { 135 if errClose := tables.Close(); errClose != nil { 136 err = multierror.Append(err, errClose) 137 } 138 }() 139 140 tableNames := make([]string, 0) 141 for tables.Next() { 142 var tableName string 143 if err := tables.Scan(&tableName); err != nil { 144 return err 145 } 146 if len(tableName) > 0 { 147 if !strings.HasPrefix(tableName, "__") { 148 tableNames = append(tableNames, tableName) 149 } 150 } 151 } 152 if err := tables.Err(); err != nil { 153 return &database.Error{OrigErr: err, Query: []byte(query)} 154 } 155 156 if len(tableNames) > 0 { 157 for _, t := range tableNames { 158 query := "DROP TABLE " + t 159 err = m.executeQuery(query) 160 if err != nil { 161 return &database.Error{OrigErr: err, Query: []byte(query)} 162 } 163 } 164 } 165 166 return nil 167 } 168 func (m *Ql) Lock() error { 169 if !m.isLocked.CAS(false, true) { 170 return database.ErrLocked 171 } 172 return nil 173 } 174 func (m *Ql) Unlock() error { 175 if !m.isLocked.CAS(true, false) { 176 return database.ErrNotLocked 177 } 178 return nil 179 } 180 func (m *Ql) Run(migration io.Reader) error { 181 migr, err := io.ReadAll(migration) 182 if err != nil { 183 return err 184 } 185 query := string(migr[:]) 186 187 return m.executeQuery(query) 188 } 189 func (m *Ql) executeQuery(query string) error { 190 tx, err := m.db.Begin() 191 if err != nil { 192 return &database.Error{OrigErr: err, Err: "transaction start failed"} 193 } 194 if _, err := tx.Exec(query); err != nil { 195 if errRollback := tx.Rollback(); errRollback != nil { 196 err = multierror.Append(err, errRollback) 197 } 198 return &database.Error{OrigErr: err, Query: []byte(query)} 199 } 200 if err := tx.Commit(); err != nil { 201 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 202 } 203 return nil 204 } 205 func (m *Ql) SetVersion(version int, dirty bool) error { 206 tx, err := m.db.Begin() 207 if err != nil { 208 return &database.Error{OrigErr: err, Err: "transaction start failed"} 209 } 210 211 query := "TRUNCATE TABLE " + m.config.MigrationsTable 212 if _, err := tx.Exec(query); err != nil { 213 return &database.Error{OrigErr: err, Query: []byte(query)} 214 } 215 216 // Also re-write the schema version for nil dirty versions to prevent 217 // empty schema version for failed down migration on the first migration 218 // See: https://github.com/golang-migrate/migrate/issues/330 219 if version >= 0 || (version == database.NilVersion && dirty) { 220 query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`, 221 m.config.MigrationsTable) 222 if _, err := tx.Exec(query, version, dirty); err != nil { 223 if errRollback := tx.Rollback(); errRollback != nil { 224 err = multierror.Append(err, errRollback) 225 } 226 return &database.Error{OrigErr: err, Query: []byte(query)} 227 } 228 } 229 230 if err := tx.Commit(); err != nil { 231 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 232 } 233 234 return nil 235 } 236 237 func (m *Ql) Version() (version int, dirty bool, err error) { 238 query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1" 239 err = m.db.QueryRow(query).Scan(&version, &dirty) 240 if err != nil { 241 return database.NilVersion, false, nil 242 } 243 return version, dirty, nil 244 }