github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/firebird/firebird.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package firebird
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	nurl "net/url"
    13  
    14  	"github.com/getsynq/migrate/v4"
    15  	"github.com/getsynq/migrate/v4/database"
    16  	"github.com/hashicorp/go-multierror"
    17  	_ "github.com/nakagami/firebirdsql"
    18  	"go.uber.org/atomic"
    19  )
    20  
    21  func init() {
    22  	db := Firebird{}
    23  	database.Register("firebird", &db)
    24  	database.Register("firebirdsql", &db)
    25  }
    26  
    27  var DefaultMigrationsTable = "schema_migrations"
    28  
    29  var (
    30  	ErrNilConfig = fmt.Errorf("no config")
    31  )
    32  
    33  type Config struct {
    34  	DatabaseName    string
    35  	MigrationsTable string
    36  }
    37  
    38  type Firebird struct {
    39  	// Locking and unlocking need to use the same connection
    40  	conn     *sql.Conn
    41  	db       *sql.DB
    42  	isLocked atomic.Bool
    43  
    44  	// Open and WithInstance need to guarantee that config is never nil
    45  	config *Config
    46  }
    47  
    48  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    49  	if config == nil {
    50  		return nil, ErrNilConfig
    51  	}
    52  
    53  	if err := instance.Ping(); err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	if len(config.MigrationsTable) == 0 {
    58  		config.MigrationsTable = DefaultMigrationsTable
    59  	}
    60  
    61  	conn, err := instance.Conn(context.Background())
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	fb := &Firebird{
    67  		conn:   conn,
    68  		db:     instance,
    69  		config: config,
    70  	}
    71  
    72  	if err := fb.ensureVersionTable(); err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return fb, nil
    77  }
    78  
    79  func (f *Firebird) Open(dsn string) (database.Driver, error) {
    80  	purl, err := nurl.Parse(dsn)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	px, err := WithInstance(db, &Config{
    91  		MigrationsTable: purl.Query().Get("x-migrations-table"),
    92  		DatabaseName:    purl.Path,
    93  	})
    94  
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	return px, nil
   100  }
   101  
   102  func (f *Firebird) Close() error {
   103  	connErr := f.conn.Close()
   104  	dbErr := f.db.Close()
   105  	if connErr != nil || dbErr != nil {
   106  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   107  	}
   108  	return nil
   109  }
   110  
   111  func (f *Firebird) Lock() error {
   112  	if !f.isLocked.CAS(false, true) {
   113  		return database.ErrLocked
   114  	}
   115  	return nil
   116  }
   117  
   118  func (f *Firebird) Unlock() error {
   119  	if !f.isLocked.CAS(true, false) {
   120  		return database.ErrNotLocked
   121  	}
   122  	return nil
   123  }
   124  
   125  func (f *Firebird) Run(migration io.Reader) error {
   126  	migr, err := ioutil.ReadAll(migration)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	// run migration
   132  	query := string(migr[:])
   133  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   134  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  func (f *Firebird) SetVersion(version int, dirty bool) error {
   141  	// Always re-write the schema version to prevent empty schema version
   142  	// for failed down migration on the first migration
   143  	// See: https://github.com/getsynq/migrate/issues/330
   144  
   145  	// TODO: parameterize this SQL statement
   146  	//       https://firebirdsql.org/refdocs/langrefupd20-execblock.html
   147  	//       VALUES (?, ?) doesn't work
   148  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   149  					DELETE FROM "%v";
   150  					INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
   151  				END;`,
   152  		f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
   153  
   154  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   155  		return &database.Error{OrigErr: err, Query: []byte(query)}
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func (f *Firebird) Version() (version int, dirty bool, err error) {
   162  	var d int
   163  	query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
   164  	err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
   165  	switch {
   166  	case err == sql.ErrNoRows:
   167  		return database.NilVersion, false, nil
   168  	case err != nil:
   169  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   170  
   171  	default:
   172  		return version, itob(d), nil
   173  	}
   174  }
   175  
   176  func (f *Firebird) Drop() (err error) {
   177  	// select all tables
   178  	query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
   179  	tables, err := f.conn.QueryContext(context.Background(), query)
   180  	if err != nil {
   181  		return &database.Error{OrigErr: err, Query: []byte(query)}
   182  	}
   183  	defer func() {
   184  		if errClose := tables.Close(); errClose != nil {
   185  			err = multierror.Append(err, errClose)
   186  		}
   187  	}()
   188  
   189  	// delete one table after another
   190  	tableNames := make([]string, 0)
   191  	for tables.Next() {
   192  		var tableName string
   193  		if err := tables.Scan(&tableName); err != nil {
   194  			return err
   195  		}
   196  		if len(tableName) > 0 {
   197  			tableNames = append(tableNames, tableName)
   198  		}
   199  	}
   200  	if err := tables.Err(); err != nil {
   201  		return &database.Error{OrigErr: err, Query: []byte(query)}
   202  	}
   203  
   204  	// delete one by one ...
   205  	for _, t := range tableNames {
   206  		query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   207  						if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   208  						execute statement 'drop table "%v"';
   209  					END;`,
   210  			t, t)
   211  
   212  		if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   213  			return &database.Error{OrigErr: err, Query: []byte(query)}
   214  		}
   215  	}
   216  
   217  	return nil
   218  }
   219  
   220  // ensureVersionTable checks if versions table exists and, if not, creates it.
   221  func (f *Firebird) ensureVersionTable() (err error) {
   222  	if err = f.Lock(); err != nil {
   223  		return err
   224  	}
   225  
   226  	defer func() {
   227  		if e := f.Unlock(); e != nil {
   228  			if err == nil {
   229  				err = e
   230  			} else {
   231  				err = multierror.Append(err, e)
   232  			}
   233  		}
   234  	}()
   235  
   236  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   237  			if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   238  			execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
   239  		END;`,
   240  		f.config.MigrationsTable, f.config.MigrationsTable)
   241  
   242  	if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
   243  		return &database.Error{OrigErr: err, Query: []byte(query)}
   244  	}
   245  
   246  	return nil
   247  }
   248  
   249  // btoi converts bool to int
   250  func btoi(v bool) int {
   251  	if v {
   252  		return 1
   253  	}
   254  	return 0
   255  }
   256  
   257  // itob converts int to bool
   258  func itob(v int) bool {
   259  	return v != 0
   260  }