github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/ql/ql.go (about)

     1  package ql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"strings"
     9  
    10  	"github.com/hashicorp/go-multierror"
    11  	"go.uber.org/atomic"
    12  
    13  	nurl "net/url"
    14  
    15  	"github.com/getsynq/migrate/v4"
    16  	"github.com/getsynq/migrate/v4/database"
    17  	_ "modernc.org/ql/driver"
    18  )
    19  
    20  func init() {
    21  	database.Register("ql", &Ql{})
    22  }
    23  
    24  var DefaultMigrationsTable = "schema_migrations"
    25  var (
    26  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  	ErrNoDatabaseName = fmt.Errorf("no database name")
    29  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    30  )
    31  
    32  type Config struct {
    33  	MigrationsTable string
    34  	DatabaseName    string
    35  }
    36  
    37  type Ql struct {
    38  	db       *sql.DB
    39  	isLocked atomic.Bool
    40  
    41  	config *Config
    42  }
    43  
    44  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    45  	if config == nil {
    46  		return nil, ErrNilConfig
    47  	}
    48  
    49  	if err := instance.Ping(); err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if len(config.MigrationsTable) == 0 {
    54  		config.MigrationsTable = DefaultMigrationsTable
    55  	}
    56  
    57  	mx := &Ql{
    58  		db:     instance,
    59  		config: config,
    60  	}
    61  	if err := mx.ensureVersionTable(); err != nil {
    62  		return nil, err
    63  	}
    64  	return mx, nil
    65  }
    66  
    67  // ensureVersionTable checks if versions table exists and, if not, creates it.
    68  // Note that this function locks the database, which deviates from the usual
    69  // convention of "caller locks" in the Ql type.
    70  func (m *Ql) ensureVersionTable() (err error) {
    71  	if err = m.Lock(); err != nil {
    72  		return err
    73  	}
    74  
    75  	defer func() {
    76  		if e := m.Unlock(); e != nil {
    77  			if err == nil {
    78  				err = e
    79  			} else {
    80  				err = multierror.Append(err, e)
    81  			}
    82  		}
    83  	}()
    84  
    85  	tx, err := m.db.Begin()
    86  	if err != nil {
    87  		return err
    88  	}
    89  	if _, err := tx.Exec(fmt.Sprintf(`
    90  	CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool);
    91  	CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    92  `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
    93  		if err := tx.Rollback(); err != nil {
    94  			return err
    95  		}
    96  		return err
    97  	}
    98  	if err := tx.Commit(); err != nil {
    99  		return err
   100  	}
   101  	return nil
   102  }
   103  
   104  func (m *Ql) Open(url string) (database.Driver, error) {
   105  	purl, err := nurl.Parse(url)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
   110  	db, err := sql.Open("ql", dbfile)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	migrationsTable := purl.Query().Get("x-migrations-table")
   115  	if len(migrationsTable) == 0 {
   116  		migrationsTable = DefaultMigrationsTable
   117  	}
   118  	mx, err := WithInstance(db, &Config{
   119  		DatabaseName:    purl.Path,
   120  		MigrationsTable: migrationsTable,
   121  	})
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	return mx, nil
   126  }
   127  func (m *Ql) Close() error {
   128  	return m.db.Close()
   129  }
   130  func (m *Ql) Drop() (err error) {
   131  	query := `SELECT Name FROM __Table`
   132  	tables, err := m.db.Query(query)
   133  	if err != nil {
   134  		return &database.Error{OrigErr: err, Query: []byte(query)}
   135  	}
   136  	defer func() {
   137  		if errClose := tables.Close(); errClose != nil {
   138  			err = multierror.Append(err, errClose)
   139  		}
   140  	}()
   141  
   142  	tableNames := make([]string, 0)
   143  	for tables.Next() {
   144  		var tableName string
   145  		if err := tables.Scan(&tableName); err != nil {
   146  			return err
   147  		}
   148  		if len(tableName) > 0 {
   149  			if !strings.HasPrefix(tableName, "__") {
   150  				tableNames = append(tableNames, tableName)
   151  			}
   152  		}
   153  	}
   154  	if err := tables.Err(); err != nil {
   155  		return &database.Error{OrigErr: err, Query: []byte(query)}
   156  	}
   157  
   158  	if len(tableNames) > 0 {
   159  		for _, t := range tableNames {
   160  			query := "DROP TABLE " + t
   161  			err = m.executeQuery(query)
   162  			if err != nil {
   163  				return &database.Error{OrigErr: err, Query: []byte(query)}
   164  			}
   165  		}
   166  	}
   167  
   168  	return nil
   169  }
   170  func (m *Ql) Lock() error {
   171  	if !m.isLocked.CAS(false, true) {
   172  		return database.ErrLocked
   173  	}
   174  	return nil
   175  }
   176  func (m *Ql) Unlock() error {
   177  	if !m.isLocked.CAS(true, false) {
   178  		return database.ErrNotLocked
   179  	}
   180  	return nil
   181  }
   182  func (m *Ql) Run(migration io.Reader) error {
   183  	migr, err := ioutil.ReadAll(migration)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	query := string(migr[:])
   188  
   189  	return m.executeQuery(query)
   190  }
   191  func (m *Ql) executeQuery(query string) error {
   192  	tx, err := m.db.Begin()
   193  	if err != nil {
   194  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   195  	}
   196  	if _, err := tx.Exec(query); err != nil {
   197  		if errRollback := tx.Rollback(); errRollback != nil {
   198  			err = multierror.Append(err, errRollback)
   199  		}
   200  		return &database.Error{OrigErr: err, Query: []byte(query)}
   201  	}
   202  	if err := tx.Commit(); err != nil {
   203  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   204  	}
   205  	return nil
   206  }
   207  func (m *Ql) SetVersion(version int, dirty bool) error {
   208  	tx, err := m.db.Begin()
   209  	if err != nil {
   210  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   211  	}
   212  
   213  	query := "TRUNCATE TABLE " + m.config.MigrationsTable
   214  	if _, err := tx.Exec(query); err != nil {
   215  		return &database.Error{OrigErr: err, Query: []byte(query)}
   216  	}
   217  
   218  	// Also re-write the schema version for nil dirty versions to prevent
   219  	// empty schema version for failed down migration on the first migration
   220  	// See: https://github.com/getsynq/migrate/issues/330
   221  	if version >= 0 || (version == database.NilVersion && dirty) {
   222  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (uint64(?1), ?2)`,
   223  			m.config.MigrationsTable)
   224  		if _, err := tx.Exec(query, version, dirty); err != nil {
   225  			if errRollback := tx.Rollback(); errRollback != nil {
   226  				err = multierror.Append(err, errRollback)
   227  			}
   228  			return &database.Error{OrigErr: err, Query: []byte(query)}
   229  		}
   230  	}
   231  
   232  	if err := tx.Commit(); err != nil {
   233  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   234  	}
   235  
   236  	return nil
   237  }
   238  
   239  func (m *Ql) Version() (version int, dirty bool, err error) {
   240  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   241  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   242  	if err != nil {
   243  		return database.NilVersion, false, nil
   244  	}
   245  	return version, dirty, nil
   246  }