github.com/nokia/migrate/v4@v4.16.0/database/ql/ql.go (about)

     1  package ql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"strings"
     9  
    10  	"github.com/hashicorp/go-multierror"
    11  	"go.uber.org/atomic"
    12  
    13  	nurl "net/url"
    14  
    15  	"github.com/nokia/migrate/v4"
    16  	"github.com/nokia/migrate/v4/database"
    17  	"github.com/nokia/migrate/v4/source"
    18  	_ "modernc.org/ql/driver"
    19  )
    20  
    21  func init() {
    22  	database.Register("ql", &Ql{})
    23  }
    24  
    25  var DefaultMigrationsTable = "schema_migrations"
    26  var (
    27  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    28  	ErrNilConfig      = fmt.Errorf("no config")
    29  	ErrNoDatabaseName = fmt.Errorf("no database name")
    30  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    31  )
    32  
    33  type Config struct {
    34  	MigrationsTable string
    35  	DatabaseName    string
    36  }
    37  
    38  type Ql struct {
    39  	db       *sql.DB
    40  	isLocked atomic.Bool
    41  
    42  	config *Config
    43  }
    44  
    45  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	}
    49  
    50  	if err := instance.Ping(); err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	if len(config.MigrationsTable) == 0 {
    55  		config.MigrationsTable = DefaultMigrationsTable
    56  	}
    57  
    58  	mx := &Ql{
    59  		db:     instance,
    60  		config: config,
    61  	}
    62  	if err := mx.ensureVersionTable(); err != nil {
    63  		return nil, err
    64  	}
    65  	return mx, nil
    66  }
    67  
    68  // ensureVersionTable checks if versions table exists and, if not, creates it.
    69  // Note that this function locks the database, which deviates from the usual
    70  // convention of "caller locks" in the Ql type.
    71  func (m *Ql) ensureVersionTable() (err error) {
    72  	if err = m.Lock(); err != nil {
    73  		return err
    74  	}
    75  
    76  	defer func() {
    77  		if e := m.Unlock(); e != nil {
    78  			if err == nil {
    79  				err = e
    80  			} else {
    81  				err = multierror.Append(err, e)
    82  			}
    83  		}
    84  	}()
    85  
    86  	tx, err := m.db.Begin()
    87  	if err != nil {
    88  		return err
    89  	}
    90  	if _, err := tx.Exec(fmt.Sprintf(`
    91  	CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);
    92  	CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    93  `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
    94  		if err := tx.Rollback(); err != nil {
    95  			return err
    96  		}
    97  		return err
    98  	}
    99  	if err := tx.Commit(); err != nil {
   100  		return err
   101  	}
   102  	return nil
   103  }
   104  
   105  func (m *Ql) Open(url string) (database.Driver, error) {
   106  	purl, err := nurl.Parse(url)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
   111  	db, err := sql.Open("ql", dbfile)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	migrationsTable := purl.Query().Get("x-migrations-table")
   116  	if len(migrationsTable) == 0 {
   117  		migrationsTable = DefaultMigrationsTable
   118  	}
   119  	mx, err := WithInstance(db, &Config{
   120  		DatabaseName:    purl.Path,
   121  		MigrationsTable: migrationsTable,
   122  	})
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	return mx, nil
   127  }
   128  
   129  func (m *Ql) Close() error {
   130  	return m.db.Close()
   131  }
   132  
   133  func (m *Ql) Drop() (err error) {
   134  	query := `SELECT Name FROM __Table`
   135  	tables, err := m.db.Query(query)
   136  	if err != nil {
   137  		return &database.Error{OrigErr: err, Query: []byte(query)}
   138  	}
   139  	defer func() {
   140  		if errClose := tables.Close(); errClose != nil {
   141  			err = multierror.Append(err, errClose)
   142  		}
   143  	}()
   144  
   145  	tableNames := make([]string, 0)
   146  	for tables.Next() {
   147  		var tableName string
   148  		if err := tables.Scan(&tableName); err != nil {
   149  			return err
   150  		}
   151  		if len(tableName) > 0 {
   152  			if !strings.HasPrefix(tableName, "__") {
   153  				tableNames = append(tableNames, tableName)
   154  			}
   155  		}
   156  	}
   157  	if err := tables.Err(); err != nil {
   158  		return &database.Error{OrigErr: err, Query: []byte(query)}
   159  	}
   160  
   161  	if len(tableNames) > 0 {
   162  		for _, t := range tableNames {
   163  			query := "DROP TABLE " + t
   164  			err = m.executeQuery(query)
   165  			if err != nil {
   166  				return &database.Error{OrigErr: err, Query: []byte(query)}
   167  			}
   168  		}
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (m *Ql) Lock() error {
   175  	if !m.isLocked.CAS(false, true) {
   176  		return database.ErrLocked
   177  	}
   178  	return nil
   179  }
   180  
   181  func (m *Ql) Unlock() error {
   182  	if !m.isLocked.CAS(true, false) {
   183  		return database.ErrNotLocked
   184  	}
   185  	return nil
   186  }
   187  
   188  func (m *Ql) Run(migration io.Reader) error {
   189  	migr, err := ioutil.ReadAll(migration)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	query := string(migr[:])
   194  
   195  	return m.executeQuery(query)
   196  }
   197  
   198  func (m *Ql) RunFunctionMigration(fn source.MigrationFunc) error {
   199  	return database.ErrNotImpl
   200  }
   201  
   202  func (m *Ql) executeQuery(query string) error {
   203  	tx, err := m.db.Begin()
   204  	if err != nil {
   205  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   206  	}
   207  	if _, err := tx.Exec(query); err != nil {
   208  		if errRollback := tx.Rollback(); errRollback != nil {
   209  			err = multierror.Append(err, errRollback)
   210  		}
   211  		return &database.Error{OrigErr: err, Query: []byte(query)}
   212  	}
   213  	if err := tx.Commit(); err != nil {
   214  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   215  	}
   216  	return nil
   217  }
   218  
   219  func (m *Ql) SetVersion(version int, dirty bool) error {
   220  	tx, err := m.db.Begin()
   221  	if err != nil {
   222  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   223  	}
   224  
   225  	query := "TRUNCATE TABLE " + m.config.MigrationsTable
   226  	if _, err := tx.Exec(query); err != nil {
   227  		return &database.Error{OrigErr: err, Query: []byte(query)}
   228  	}
   229  
   230  	// Also re-write the schema version for nil dirty versions to prevent
   231  	// empty schema version for failed down migration on the first migration
   232  	// See: https://github.com/nokia/migrate/issues/330
   233  	if version >= 0 || (version == database.NilVersion && dirty) {
   234  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
   235  			m.config.MigrationsTable)
   236  		if _, err := tx.Exec(query, version, dirty); err != nil {
   237  			if errRollback := tx.Rollback(); errRollback != nil {
   238  				err = multierror.Append(err, errRollback)
   239  			}
   240  			return &database.Error{OrigErr: err, Query: []byte(query)}
   241  		}
   242  	}
   243  
   244  	if err := tx.Commit(); err != nil {
   245  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   246  	}
   247  
   248  	return nil
   249  }
   250  
   251  func (m *Ql) Version() (version int, dirty bool, err error) {
   252  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   253  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   254  	if err != nil {
   255  		return database.NilVersion, false, nil
   256  	}
   257  	return version, dirty, nil
   258  }