github.com/nokia/migrate/v4@v4.16.0/database/firebird/firebird.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package firebird
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	nurl "net/url"
    13  
    14  	"github.com/hashicorp/go-multierror"
    15  	_ "github.com/nakagami/firebirdsql"
    16  	"github.com/nokia/migrate/v4"
    17  	"github.com/nokia/migrate/v4/database"
    18  	"github.com/nokia/migrate/v4/source"
    19  	"go.uber.org/atomic"
    20  )
    21  
    22  func init() {
    23  	db := Firebird{}
    24  	database.Register("firebird", &db)
    25  	database.Register("firebirdsql", &db)
    26  }
    27  
    28  var DefaultMigrationsTable = "schema_migrations"
    29  
    30  var ErrNilConfig = fmt.Errorf("no config")
    31  
    32  type Config struct {
    33  	DatabaseName    string
    34  	MigrationsTable string
    35  }
    36  
    37  type Firebird struct {
    38  	// Locking and unlocking need to use the same connection
    39  	conn     *sql.Conn
    40  	db       *sql.DB
    41  	isLocked atomic.Bool
    42  
    43  	// Open and WithInstance need to guarantee that config is never nil
    44  	config *Config
    45  }
    46  
    47  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  
    52  	if err := instance.Ping(); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	if len(config.MigrationsTable) == 0 {
    57  		config.MigrationsTable = DefaultMigrationsTable
    58  	}
    59  
    60  	conn, err := instance.Conn(context.Background())
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	fb := &Firebird{
    66  		conn:   conn,
    67  		db:     instance,
    68  		config: config,
    69  	}
    70  
    71  	if err := fb.ensureVersionTable(); err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	return fb, nil
    76  }
    77  
    78  func (f *Firebird) Open(dsn string) (database.Driver, error) {
    79  	purl, err := nurl.Parse(dsn)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	px, err := WithInstance(db, &Config{
    90  		MigrationsTable: purl.Query().Get("x-migrations-table"),
    91  		DatabaseName:    purl.Path,
    92  	})
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return px, nil
    98  }
    99  
   100  func (f *Firebird) Close() error {
   101  	connErr := f.conn.Close()
   102  	dbErr := f.db.Close()
   103  	if connErr != nil || dbErr != nil {
   104  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   105  	}
   106  	return nil
   107  }
   108  
   109  func (f *Firebird) Lock() error {
   110  	if !f.isLocked.CAS(false, true) {
   111  		return database.ErrLocked
   112  	}
   113  	return nil
   114  }
   115  
   116  func (f *Firebird) Unlock() error {
   117  	if !f.isLocked.CAS(true, false) {
   118  		return database.ErrNotLocked
   119  	}
   120  	return nil
   121  }
   122  
   123  func (f *Firebird) Run(migration io.Reader) error {
   124  	migr, err := ioutil.ReadAll(migration)
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	// run migration
   130  	query := string(migr[:])
   131  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   132  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  func (f *Firebird) RunFunctionMigration(fn source.MigrationFunc) error {
   139  	return database.ErrNotImpl
   140  }
   141  
   142  func (f *Firebird) SetVersion(version int, dirty bool) error {
   143  	// Always re-write the schema version to prevent empty schema version
   144  	// for failed down migration on the first migration
   145  	// See: https://github.com/nokia/migrate/issues/330
   146  
   147  	// TODO: parameterize this SQL statement
   148  	//       https://firebirdsql.org/refdocs/langrefupd20-execblock.html
   149  	//       VALUES (?, ?) doesn't work
   150  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   151  					DELETE FROM "%v";
   152  					INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
   153  				END;`,
   154  		f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
   155  
   156  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   157  		return &database.Error{OrigErr: err, Query: []byte(query)}
   158  	}
   159  
   160  	return nil
   161  }
   162  
   163  func (f *Firebird) Version() (version int, dirty bool, err error) {
   164  	var d int
   165  	query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
   166  	err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
   167  	switch {
   168  	case err == sql.ErrNoRows:
   169  		return database.NilVersion, false, nil
   170  	case err != nil:
   171  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   172  
   173  	default:
   174  		return version, itob(d), nil
   175  	}
   176  }
   177  
   178  func (f *Firebird) Drop() (err error) {
   179  	// select all tables
   180  	query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
   181  	tables, err := f.conn.QueryContext(context.Background(), query)
   182  	if err != nil {
   183  		return &database.Error{OrigErr: err, Query: []byte(query)}
   184  	}
   185  	defer func() {
   186  		if errClose := tables.Close(); errClose != nil {
   187  			err = multierror.Append(err, errClose)
   188  		}
   189  	}()
   190  
   191  	// delete one table after another
   192  	tableNames := make([]string, 0)
   193  	for tables.Next() {
   194  		var tableName string
   195  		if err := tables.Scan(&tableName); err != nil {
   196  			return err
   197  		}
   198  		if len(tableName) > 0 {
   199  			tableNames = append(tableNames, tableName)
   200  		}
   201  	}
   202  	if err := tables.Err(); err != nil {
   203  		return &database.Error{OrigErr: err, Query: []byte(query)}
   204  	}
   205  
   206  	// delete one by one ...
   207  	for _, t := range tableNames {
   208  		query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   209  						if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   210  						execute statement 'drop table "%v"';
   211  					END;`,
   212  			t, t)
   213  
   214  		if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   215  			return &database.Error{OrigErr: err, Query: []byte(query)}
   216  		}
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  // ensureVersionTable checks if versions table exists and, if not, creates it.
   223  func (f *Firebird) ensureVersionTable() (err error) {
   224  	if err = f.Lock(); err != nil {
   225  		return err
   226  	}
   227  
   228  	defer func() {
   229  		if e := f.Unlock(); e != nil {
   230  			if err == nil {
   231  				err = e
   232  			} else {
   233  				err = multierror.Append(err, e)
   234  			}
   235  		}
   236  	}()
   237  
   238  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   239  			if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   240  			execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
   241  		END;`,
   242  		f.config.MigrationsTable, f.config.MigrationsTable)
   243  
   244  	if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
   245  		return &database.Error{OrigErr: err, Query: []byte(query)}
   246  	}
   247  
   248  	return nil
   249  }
   250  
   251  // btoi converts bool to int
   252  func btoi(v bool) int {
   253  	if v {
   254  		return 1
   255  	}
   256  	return 0
   257  }
   258  
   259  // itob converts int to bool
   260  func itob(v int) bool {
   261  	return v != 0
   262  }