github.com/bishtawi/migrate/v4@v4.8.11/database/firebird/firebird.go (about)

     1  // +build go1.9
     2  
     3  package firebird
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"github.com/bishtawi/migrate/v4"
    10  	"github.com/bishtawi/migrate/v4/database"
    11  	"github.com/hashicorp/go-multierror"
    12  	_ "github.com/nakagami/firebirdsql"
    13  	"io"
    14  	"io/ioutil"
    15  	nurl "net/url"
    16  )
    17  
    18  func init() {
    19  	db := Firebird{}
    20  	database.Register("firebird", &db)
    21  	database.Register("firebirdsql", &db)
    22  }
    23  
    24  var DefaultMigrationsTable = "schema_migrations"
    25  
    26  var (
    27  	ErrNilConfig = fmt.Errorf("no config")
    28  )
    29  
    30  type Config struct {
    31  	DatabaseName    string
    32  	MigrationsTable string
    33  }
    34  
    35  type Firebird struct {
    36  	// Locking and unlocking need to use the same connection
    37  	conn     *sql.Conn
    38  	db       *sql.DB
    39  	isLocked bool
    40  
    41  	// Open and WithInstance need to guarantee that config is never nil
    42  	config *Config
    43  }
    44  
    45  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	}
    49  
    50  	if err := instance.Ping(); err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	if len(config.MigrationsTable) == 0 {
    55  		config.MigrationsTable = DefaultMigrationsTable
    56  	}
    57  
    58  	conn, err := instance.Conn(context.Background())
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	fb := &Firebird{
    64  		conn:   conn,
    65  		db:     instance,
    66  		config: config,
    67  	}
    68  
    69  	if err := fb.ensureVersionTable(); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	return fb, nil
    74  }
    75  
    76  func (f *Firebird) Open(dsn string) (database.Driver, error) {
    77  	purl, err := nurl.Parse(dsn)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	px, err := WithInstance(db, &Config{
    88  		MigrationsTable: purl.Query().Get("x-migrations-table"),
    89  		DatabaseName:    purl.Path,
    90  	})
    91  
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	return px, nil
    97  }
    98  
    99  func (f *Firebird) Close() error {
   100  	connErr := f.conn.Close()
   101  	dbErr := f.db.Close()
   102  	if connErr != nil || dbErr != nil {
   103  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   104  	}
   105  	return nil
   106  }
   107  
   108  func (f *Firebird) Lock() error {
   109  	if f.isLocked {
   110  		return database.ErrLocked
   111  	}
   112  	f.isLocked = true
   113  	return nil
   114  }
   115  
   116  func (f *Firebird) Unlock() error {
   117  	f.isLocked = false
   118  	return nil
   119  }
   120  
   121  func (f *Firebird) Run(migration io.Reader) error {
   122  	migr, err := ioutil.ReadAll(migration)
   123  	if err != nil {
   124  		return err
   125  	}
   126  
   127  	// run migration
   128  	query := string(migr[:])
   129  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   130  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   131  	}
   132  
   133  	return nil
   134  }
   135  
   136  func (f *Firebird) SetVersion(version int, dirty bool) error {
   137  	if version < 0 {
   138  		return nil
   139  	}
   140  
   141  	// TODO: parameterize this SQL statement
   142  	//       https://firebirdsql.org/refdocs/langrefupd20-execblock.html
   143  	//       VALUES (?, ?) doesn't work
   144  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   145  					DELETE FROM "%v";
   146  					INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
   147  				END;`,
   148  		f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
   149  
   150  	if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   151  		return &database.Error{OrigErr: err, Query: []byte(query)}
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func (f *Firebird) Version() (version int, dirty bool, err error) {
   158  	var d int
   159  	query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
   160  	err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
   161  	switch {
   162  	case err == sql.ErrNoRows:
   163  		return database.NilVersion, false, nil
   164  	case err != nil:
   165  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   166  
   167  	default:
   168  		return version, itob(d), nil
   169  	}
   170  }
   171  
   172  func (f *Firebird) Drop() (err error) {
   173  	// select all tables
   174  	query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
   175  	tables, err := f.conn.QueryContext(context.Background(), query)
   176  	if err != nil {
   177  		return &database.Error{OrigErr: err, Query: []byte(query)}
   178  	}
   179  	defer func() {
   180  		if errClose := tables.Close(); errClose != nil {
   181  			err = multierror.Append(err, errClose)
   182  		}
   183  	}()
   184  
   185  	// delete one table after another
   186  	tableNames := make([]string, 0)
   187  	for tables.Next() {
   188  		var tableName string
   189  		if err := tables.Scan(&tableName); err != nil {
   190  			return err
   191  		}
   192  		if len(tableName) > 0 {
   193  			tableNames = append(tableNames, tableName)
   194  		}
   195  	}
   196  
   197  	// delete one by one ...
   198  	for _, t := range tableNames {
   199  		query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   200  						if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   201  						execute statement 'drop table "%v"';
   202  					END;`,
   203  			t, t)
   204  
   205  		if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
   206  			return &database.Error{OrigErr: err, Query: []byte(query)}
   207  		}
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  // ensureVersionTable checks if versions table exists and, if not, creates it.
   214  func (f *Firebird) ensureVersionTable() (err error) {
   215  	if err = f.Lock(); err != nil {
   216  		return err
   217  	}
   218  
   219  	defer func() {
   220  		if e := f.Unlock(); e != nil {
   221  			if err == nil {
   222  				err = e
   223  			} else {
   224  				err = multierror.Append(err, e)
   225  			}
   226  		}
   227  	}()
   228  
   229  	query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
   230  			if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
   231  			execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
   232  		END;`,
   233  		f.config.MigrationsTable, f.config.MigrationsTable)
   234  
   235  	if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
   236  		return &database.Error{OrigErr: err, Query: []byte(query)}
   237  	}
   238  
   239  	return nil
   240  }
   241  
   242  // btoi converts bool to int
   243  func btoi(v bool) int {
   244  	if v {
   245  		return 1
   246  	}
   247  	return 0
   248  }
   249  
   250  // itob converts int to bool
   251  func itob(v int) bool {
   252  	return v != 0
   253  }