github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/firebird/firebird.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package firebird 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 14 "github.com/getsynq/migrate/v4" 15 "github.com/getsynq/migrate/v4/database" 16 "github.com/hashicorp/go-multierror" 17 _ "github.com/nakagami/firebirdsql" 18 "go.uber.org/atomic" 19 ) 20 21 func init() { 22 db := Firebird{} 23 database.Register("firebird", &db) 24 database.Register("firebirdsql", &db) 25 } 26 27 var DefaultMigrationsTable = "schema_migrations" 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ) 32 33 type Config struct { 34 DatabaseName string 35 MigrationsTable string 36 } 37 38 type Firebird struct { 39 // Locking and unlocking need to use the same connection 40 conn *sql.Conn 41 db *sql.DB 42 isLocked atomic.Bool 43 44 // Open and WithInstance need to guarantee that config is never nil 45 config *Config 46 } 47 48 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 49 if config == nil { 50 return nil, ErrNilConfig 51 } 52 53 if err := instance.Ping(); err != nil { 54 return nil, err 55 } 56 57 if len(config.MigrationsTable) == 0 { 58 config.MigrationsTable = DefaultMigrationsTable 59 } 60 61 conn, err := instance.Conn(context.Background()) 62 if err != nil { 63 return nil, err 64 } 65 66 fb := &Firebird{ 67 conn: conn, 68 db: instance, 69 config: config, 70 } 71 72 if err := fb.ensureVersionTable(); err != nil { 73 return nil, err 74 } 75 76 return fb, nil 77 } 78 79 func (f *Firebird) Open(dsn string) (database.Driver, error) { 80 purl, err := nurl.Parse(dsn) 81 if err != nil { 82 return nil, err 83 } 84 85 db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String()) 86 if err != nil { 87 return nil, err 88 } 89 90 px, err := WithInstance(db, &Config{ 91 MigrationsTable: purl.Query().Get("x-migrations-table"), 92 DatabaseName: purl.Path, 93 }) 94 95 if err != nil { 96 return nil, err 97 } 98 99 return px, nil 100 } 101 102 func (f *Firebird) Close() error { 103 connErr := f.conn.Close() 104 dbErr := f.db.Close() 105 if connErr != nil || dbErr != nil { 106 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 107 } 108 return nil 109 } 110 111 func (f *Firebird) Lock() error { 112 if !f.isLocked.CAS(false, true) { 113 return database.ErrLocked 114 } 115 return nil 116 } 117 118 func (f *Firebird) Unlock() error { 119 if !f.isLocked.CAS(true, false) { 120 return database.ErrNotLocked 121 } 122 return nil 123 } 124 125 func (f *Firebird) Run(migration io.Reader) error { 126 migr, err := ioutil.ReadAll(migration) 127 if err != nil { 128 return err 129 } 130 131 // run migration 132 query := string(migr[:]) 133 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 134 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 135 } 136 137 return nil 138 } 139 140 func (f *Firebird) SetVersion(version int, dirty bool) error { 141 // Always re-write the schema version to prevent empty schema version 142 // for failed down migration on the first migration 143 // See: https://github.com/getsynq/migrate/issues/330 144 145 // TODO: parameterize this SQL statement 146 // https://firebirdsql.org/refdocs/langrefupd20-execblock.html 147 // VALUES (?, ?) doesn't work 148 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 149 DELETE FROM "%v"; 150 INSERT INTO "%v" (version, dirty) VALUES (%v, %v); 151 END;`, 152 f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty)) 153 154 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 155 return &database.Error{OrigErr: err, Query: []byte(query)} 156 } 157 158 return nil 159 } 160 161 func (f *Firebird) Version() (version int, dirty bool, err error) { 162 var d int 163 query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable) 164 err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d) 165 switch { 166 case err == sql.ErrNoRows: 167 return database.NilVersion, false, nil 168 case err != nil: 169 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 170 171 default: 172 return version, itob(d), nil 173 } 174 } 175 176 func (f *Firebird) Drop() (err error) { 177 // select all tables 178 query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);` 179 tables, err := f.conn.QueryContext(context.Background(), query) 180 if err != nil { 181 return &database.Error{OrigErr: err, Query: []byte(query)} 182 } 183 defer func() { 184 if errClose := tables.Close(); errClose != nil { 185 err = multierror.Append(err, errClose) 186 } 187 }() 188 189 // delete one table after another 190 tableNames := make([]string, 0) 191 for tables.Next() { 192 var tableName string 193 if err := tables.Scan(&tableName); err != nil { 194 return err 195 } 196 if len(tableName) > 0 { 197 tableNames = append(tableNames, tableName) 198 } 199 } 200 if err := tables.Err(); err != nil { 201 return &database.Error{OrigErr: err, Query: []byte(query)} 202 } 203 204 // delete one by one ... 205 for _, t := range tableNames { 206 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 207 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 208 execute statement 'drop table "%v"'; 209 END;`, 210 t, t) 211 212 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 213 return &database.Error{OrigErr: err, Query: []byte(query)} 214 } 215 } 216 217 return nil 218 } 219 220 // ensureVersionTable checks if versions table exists and, if not, creates it. 221 func (f *Firebird) ensureVersionTable() (err error) { 222 if err = f.Lock(); err != nil { 223 return err 224 } 225 226 defer func() { 227 if e := f.Unlock(); e != nil { 228 if err == nil { 229 err = e 230 } else { 231 err = multierror.Append(err, e) 232 } 233 } 234 }() 235 236 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 237 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 238 execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)'; 239 END;`, 240 f.config.MigrationsTable, f.config.MigrationsTable) 241 242 if _, err = f.conn.ExecContext(context.Background(), query); err != nil { 243 return &database.Error{OrigErr: err, Query: []byte(query)} 244 } 245 246 return nil 247 } 248 249 // btoi converts bool to int 250 func btoi(v bool) int { 251 if v { 252 return 1 253 } 254 return 0 255 } 256 257 // itob converts int to bool 258 func itob(v int) bool { 259 return v != 0 260 }