github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/firebird/firebird.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package firebird 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "io" 11 nurl "net/url" 12 13 "github.com/golang-migrate/migrate/v4" 14 "github.com/golang-migrate/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 _ "github.com/nakagami/firebirdsql" 17 "go.uber.org/atomic" 18 ) 19 20 func init() { 21 db := Firebird{} 22 database.Register("firebird", &db) 23 database.Register("firebirdsql", &db) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrNilConfig = fmt.Errorf("no config") 30 ) 31 32 type Config struct { 33 DatabaseName string 34 MigrationsTable string 35 } 36 37 type Firebird struct { 38 // Locking and unlocking need to use the same connection 39 conn *sql.Conn 40 db *sql.DB 41 isLocked atomic.Bool 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 if err := instance.Ping(); err != nil { 53 return nil, err 54 } 55 56 if len(config.MigrationsTable) == 0 { 57 config.MigrationsTable = DefaultMigrationsTable 58 } 59 60 conn, err := instance.Conn(context.Background()) 61 if err != nil { 62 return nil, err 63 } 64 65 fb := &Firebird{ 66 conn: conn, 67 db: instance, 68 config: config, 69 } 70 71 if err := fb.ensureVersionTable(); err != nil { 72 return nil, err 73 } 74 75 return fb, nil 76 } 77 78 func (f *Firebird) Open(dsn string) (database.Driver, error) { 79 purl, err := nurl.Parse(dsn) 80 if err != nil { 81 return nil, err 82 } 83 84 db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String()) 85 if err != nil { 86 return nil, err 87 } 88 89 px, err := WithInstance(db, &Config{ 90 MigrationsTable: purl.Query().Get("x-migrations-table"), 91 DatabaseName: purl.Path, 92 }) 93 94 if err != nil { 95 return nil, err 96 } 97 98 return px, nil 99 } 100 101 func (f *Firebird) Close() error { 102 connErr := f.conn.Close() 103 dbErr := f.db.Close() 104 if connErr != nil || dbErr != nil { 105 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 106 } 107 return nil 108 } 109 110 func (f *Firebird) Lock() error { 111 if !f.isLocked.CAS(false, true) { 112 return database.ErrLocked 113 } 114 return nil 115 } 116 117 func (f *Firebird) Unlock() error { 118 if !f.isLocked.CAS(true, false) { 119 return database.ErrNotLocked 120 } 121 return nil 122 } 123 124 func (f *Firebird) Run(migration io.Reader) error { 125 migr, err := io.ReadAll(migration) 126 if err != nil { 127 return err 128 } 129 130 // run migration 131 query := string(migr[:]) 132 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 133 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 134 } 135 136 return nil 137 } 138 139 func (f *Firebird) SetVersion(version int, dirty bool) error { 140 // Always re-write the schema version to prevent empty schema version 141 // for failed down migration on the first migration 142 // See: https://github.com/golang-migrate/migrate/issues/330 143 144 // TODO: parameterize this SQL statement 145 // https://firebirdsql.org/refdocs/langrefupd20-execblock.html 146 // VALUES (?, ?) doesn't work 147 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 148 DELETE FROM "%v"; 149 INSERT INTO "%v" (version, dirty) VALUES (%v, %v); 150 END;`, 151 f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty)) 152 153 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 154 return &database.Error{OrigErr: err, Query: []byte(query)} 155 } 156 157 return nil 158 } 159 160 func (f *Firebird) Version() (version int, dirty bool, err error) { 161 var d int 162 query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable) 163 err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d) 164 switch { 165 case err == sql.ErrNoRows: 166 return database.NilVersion, false, nil 167 case err != nil: 168 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 169 170 default: 171 return version, itob(d), nil 172 } 173 } 174 175 func (f *Firebird) Drop() (err error) { 176 // select all tables 177 query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);` 178 tables, err := f.conn.QueryContext(context.Background(), query) 179 if err != nil { 180 return &database.Error{OrigErr: err, Query: []byte(query)} 181 } 182 defer func() { 183 if errClose := tables.Close(); errClose != nil { 184 err = multierror.Append(err, errClose) 185 } 186 }() 187 188 // delete one table after another 189 tableNames := make([]string, 0) 190 for tables.Next() { 191 var tableName string 192 if err := tables.Scan(&tableName); err != nil { 193 return err 194 } 195 if len(tableName) > 0 { 196 tableNames = append(tableNames, tableName) 197 } 198 } 199 if err := tables.Err(); err != nil { 200 return &database.Error{OrigErr: err, Query: []byte(query)} 201 } 202 203 // delete one by one ... 204 for _, t := range tableNames { 205 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 206 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 207 execute statement 'drop table "%v"'; 208 END;`, 209 t, t) 210 211 if _, err := f.conn.ExecContext(context.Background(), query); err != nil { 212 return &database.Error{OrigErr: err, Query: []byte(query)} 213 } 214 } 215 216 return nil 217 } 218 219 // ensureVersionTable checks if versions table exists and, if not, creates it. 220 func (f *Firebird) ensureVersionTable() (err error) { 221 if err = f.Lock(); err != nil { 222 return err 223 } 224 225 defer func() { 226 if e := f.Unlock(); e != nil { 227 if err == nil { 228 err = e 229 } else { 230 err = multierror.Append(err, e) 231 } 232 } 233 }() 234 235 query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN 236 if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then 237 execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)'; 238 END;`, 239 f.config.MigrationsTable, f.config.MigrationsTable) 240 241 if _, err = f.conn.ExecContext(context.Background(), query); err != nil { 242 return &database.Error{OrigErr: err, Query: []byte(query)} 243 } 244 245 return nil 246 } 247 248 // btoi converts bool to int 249 func btoi(v bool) int { 250 if v { 251 return 1 252 } 253 return 0 254 } 255 256 // itob converts int to bool 257 func itob(v int) bool { 258 return v != 0 259 }