github.com/ldej/migrate@v3.5.4+incompatible/database/ql/ql.go (about)

     1  package ql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"strings"
     9  
    10  	nurl "net/url"
    11  
    12  	_ "github.com/cznic/ql/driver"
    13  	"github.com/golang-migrate/migrate"
    14  	"github.com/golang-migrate/migrate/database"
    15  )
    16  
    17  func init() {
    18  	database.Register("ql", &Ql{})
    19  }
    20  
    21  var DefaultMigrationsTable = "schema_migrations"
    22  var (
    23  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    24  	ErrNilConfig      = fmt.Errorf("no config")
    25  	ErrNoDatabaseName = fmt.Errorf("no database name")
    26  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    27  )
    28  
    29  type Config struct {
    30  	MigrationsTable string
    31  	DatabaseName    string
    32  }
    33  
    34  type Ql struct {
    35  	db       *sql.DB
    36  	isLocked bool
    37  
    38  	config *Config
    39  }
    40  
    41  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    42  	if config == nil {
    43  		return nil, ErrNilConfig
    44  	}
    45  
    46  	if err := instance.Ping(); err != nil {
    47  		return nil, err
    48  	}
    49  	if len(config.MigrationsTable) == 0 {
    50  		config.MigrationsTable = DefaultMigrationsTable
    51  	}
    52  
    53  	mx := &Ql{
    54  		db:     instance,
    55  		config: config,
    56  	}
    57  	if err := mx.ensureVersionTable(); err != nil {
    58  		return nil, err
    59  	}
    60  	return mx, nil
    61  }
    62  func (m *Ql) ensureVersionTable() error {
    63  	tx, err := m.db.Begin()
    64  	if err != nil {
    65  		return err
    66  	}
    67  	if _, err := tx.Exec(fmt.Sprintf(`
    68  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    69  	CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    70  `, m.config.MigrationsTable, m.config.MigrationsTable)); err != nil {
    71  		if err := tx.Rollback(); err != nil {
    72  			return err
    73  		}
    74  		return err
    75  	}
    76  	if err := tx.Commit(); err != nil {
    77  		return err
    78  	}
    79  	return nil
    80  }
    81  
    82  func (m *Ql) Open(url string) (database.Driver, error) {
    83  	purl, err := nurl.Parse(url)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "ql://", "", 1)
    88  	db, err := sql.Open("ql", dbfile)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	migrationsTable := purl.Query().Get("x-migrations-table")
    93  	if len(migrationsTable) == 0 {
    94  		migrationsTable = DefaultMigrationsTable
    95  	}
    96  	mx, err := WithInstance(db, &Config{
    97  		DatabaseName:    purl.Path,
    98  		MigrationsTable: migrationsTable,
    99  	})
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	return mx, nil
   104  }
   105  func (m *Ql) Close() error {
   106  	return m.db.Close()
   107  }
   108  func (m *Ql) Drop() error {
   109  	query := `SELECT Name FROM __Table`
   110  	tables, err := m.db.Query(query)
   111  	if err != nil {
   112  		return &database.Error{OrigErr: err, Query: []byte(query)}
   113  	}
   114  	defer tables.Close()
   115  	tableNames := make([]string, 0)
   116  	for tables.Next() {
   117  		var tableName string
   118  		if err := tables.Scan(&tableName); err != nil {
   119  			return err
   120  		}
   121  		if len(tableName) > 0 {
   122  			if strings.HasPrefix(tableName, "__") == false {
   123  				tableNames = append(tableNames, tableName)
   124  			}
   125  		}
   126  	}
   127  	if len(tableNames) > 0 {
   128  		for _, t := range tableNames {
   129  			query := "DROP TABLE " + t
   130  			err = m.executeQuery(query)
   131  			if err != nil {
   132  				return &database.Error{OrigErr: err, Query: []byte(query)}
   133  			}
   134  		}
   135  		if err := m.ensureVersionTable(); err != nil {
   136  			return err
   137  		}
   138  	}
   139  
   140  	return nil
   141  }
   142  func (m *Ql) Lock() error {
   143  	if m.isLocked {
   144  		return database.ErrLocked
   145  	}
   146  	m.isLocked = true
   147  	return nil
   148  }
   149  func (m *Ql) Unlock() error {
   150  	if !m.isLocked {
   151  		return nil
   152  	}
   153  	m.isLocked = false
   154  	return nil
   155  }
   156  func (m *Ql) Run(migration io.Reader) error {
   157  	migr, err := ioutil.ReadAll(migration)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	query := string(migr[:])
   162  
   163  	return m.executeQuery(query)
   164  }
   165  func (m *Ql) executeQuery(query string) error {
   166  	tx, err := m.db.Begin()
   167  	if err != nil {
   168  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   169  	}
   170  	if _, err := tx.Exec(query); err != nil {
   171  		tx.Rollback()
   172  		return &database.Error{OrigErr: err, Query: []byte(query)}
   173  	}
   174  	if err := tx.Commit(); err != nil {
   175  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   176  	}
   177  	return nil
   178  }
   179  func (m *Ql) SetVersion(version int, dirty bool) error {
   180  	tx, err := m.db.Begin()
   181  	if err != nil {
   182  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   183  	}
   184  
   185  	query := "TRUNCATE TABLE " + m.config.MigrationsTable
   186  	if _, err := tx.Exec(query); err != nil {
   187  		return &database.Error{OrigErr: err, Query: []byte(query)}
   188  	}
   189  
   190  	if version >= 0 {
   191  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, %t)`, m.config.MigrationsTable, version, dirty)
   192  		if _, err := tx.Exec(query); err != nil {
   193  			tx.Rollback()
   194  			return &database.Error{OrigErr: err, Query: []byte(query)}
   195  		}
   196  	}
   197  
   198  	if err := tx.Commit(); err != nil {
   199  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   200  	}
   201  
   202  	return nil
   203  }
   204  
   205  func (m *Ql) Version() (version int, dirty bool, err error) {
   206  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   207  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   208  	if err != nil {
   209  		return database.NilVersion, false, nil
   210  	}
   211  	return version, dirty, nil
   212  }