github.com/nokia/migrate/v4@v4.16.0/database/firebird/firebird.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package firebird 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 "io/ioutil" 12 nurl "net/url" 13 14 "github.com/hashicorp/go-multierror" 15 _ "github.com/nakagami/firebirdsql" 16 "github.com/nokia/migrate/v4" 17 "github.com/nokia/migrate/v4/database" 18 "github.com/nokia/migrate/v4/source" 19 "go.uber.org/atomic" 20 ) 21 22 func init() { 23 db := Firebird{} 24 database.Register("firebird", &db) 25 database.Register("firebirdsql", &db) 26 } 27 28 var DefaultMigrationsTable = "schema_migrations" 29 30 var ErrNilConfig = fmt.Errorf("no config") 31 32 type Config struct { 33 DatabaseName string 34 MigrationsTable string 35 } 36 37 type Firebird struct { 38 // Locking and unlocking need to use the same connection 39 conn *sql.Conn 40 db *sql.DB 41 isLocked atomic.Bool 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 if err := instance.Ping(); err != nil { 53 return nil, err 54 } 55 56 if len(config.MigrationsTable) == 0 { 57 config.MigrationsTable = DefaultMigrationsTable 58 } 59 60 conn, err := instance.Conn(context.Background()) 61 if err != nil { 62 return nil, err 63 } 64 65 fb := &Firebird{ 66 conn: conn, 67 db: instance, 68 config: config, 69 } 70 71 if err := fb.ensureVersionTable(); err != nil { 72 return nil, err 73 } 74 75 return fb, nil 76 } 77 78 func (f *Firebird) Open(dsn string) (database.Driver, error) { 79 purl, err := nurl.Parse(dsn) 80 if err != nil { 81 return nil, err 82 } 83 84 db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String()) 85 if err != nil { 86 return nil, err 87 } 88 89 px, err := WithInstance(db, &Config{ 90 MigrationsTable: purl.Query().Get("x-migrations-table"), 91 DatabaseName: purl.Path, 92 }) 93 if err != nil { 94 return nil, err 95 } 96 97 return px, nil 98 } 99 100 func (f *Firebird) Close() error { 101 connErr := f.conn.Close() 102 dbErr := f.db.Close() 103 if connErr != nil || dbErr != nil { 104 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 105 } 106 return nil 107 } 108 109 func (f *Firebird) Lock() error { 110 if !f.isLocked.CAS(false, true) { 111 return database.ErrLocked 112 } 113 return nil 114 } 115 116 func (f *Firebird) Unlock() error { 117 if !f.isLocked.CAS(true, false) { 118 return database.ErrNotLocked 119 } 120 return nil 121 } 122 123 func (f *Firebird) Run(migration io.Reader) error { 124 migr, err := ioutil.ReadAll(migration) 125 if err != nil { 126 return err 127 } 128 129 // run migration 130 query := string(migr[:]) 131 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 132 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 133 } 134 135 return nil 136 } 137 138 func (f *Firebird) RunFunctionMigration(fn source.MigrationFunc) error { 139 return database.ErrNotImpl 140 } 141 142 func (f *Firebird) SetVersion(version int, dirty bool) error { 143 // Always re-write the schema version to prevent empty schema version 144 // for failed down migration on the first migration 145 // See: https://github.com/nokia/migrate/issues/330 146 147 // TODO: parameterize this SQL statement 148 // https://firebirdsql.org/refdocs/langrefupd20-execblock.html 149 // VALUES (?, ?) doesn't work 150 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 151 DELETE FROM "%v"; 152 INSERT INTO "%v" (version, dirty) VALUES (%v, %v); 153 END;`, 154 f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty)) 155 156 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 157 return &database.Error{OrigErr: err, Query: []byte(query)} 158 } 159 160 return nil 161 } 162 163 func (f *Firebird) Version() (version int, dirty bool, err error) { 164 var d int 165 query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable) 166 err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d) 167 switch { 168 case err == sql.ErrNoRows: 169 return database.NilVersion, false, nil 170 case err != nil: 171 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 172 173 default: 174 return version, itob(d), nil 175 } 176 } 177 178 func (f *Firebird) Drop() (err error) { 179 // select all tables 180 query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);` 181 tables, err := f.conn.QueryContext(context.Background(), query) 182 if err != nil { 183 return &database.Error{OrigErr: err, Query: []byte(query)} 184 } 185 defer func() { 186 if errClose := tables.Close(); errClose != nil { 187 err = multierror.Append(err, errClose) 188 } 189 }() 190 191 // delete one table after another 192 tableNames := make([]string, 0) 193 for tables.Next() { 194 var tableName string 195 if err := tables.Scan(&tableName); err != nil { 196 return err 197 } 198 if len(tableName) > 0 { 199 tableNames = append(tableNames, tableName) 200 } 201 } 202 if err := tables.Err(); err != nil { 203 return &database.Error{OrigErr: err, Query: []byte(query)} 204 } 205 206 // delete one by one ... 207 for _, t := range tableNames { 208 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 209 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 210 execute statement 'drop table "%v"'; 211 END;`, 212 t, t) 213 214 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 215 return &database.Error{OrigErr: err, Query: []byte(query)} 216 } 217 } 218 219 return nil 220 } 221 222 // ensureVersionTable checks if versions table exists and, if not, creates it. 223 func (f *Firebird) ensureVersionTable() (err error) { 224 if err = f.Lock(); err != nil { 225 return err 226 } 227 228 defer func() { 229 if e := f.Unlock(); e != nil { 230 if err == nil { 231 err = e 232 } else { 233 err = multierror.Append(err, e) 234 } 235 } 236 }() 237 238 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 239 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 240 execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)'; 241 END;`, 242 f.config.MigrationsTable, f.config.MigrationsTable) 243 244 if _, err = f.conn.ExecContext(context.Background(), query); err != nil { 245 return &database.Error{OrigErr: err, Query: []byte(query)} 246 } 247 248 return nil 249 } 250 251 // btoi converts bool to int 252 func btoi(v bool) int { 253 if v { 254 return 1 255 } 256 return 0 257 } 258 259 // itob converts int to bool 260 func itob(v int) bool { 261 return v != 0 262 }