github.com/dhui/migrate@v3.4.0+incompatible/database/sqlite3/sqlite3.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"github.com/golang-migrate/migrate"
     7  	"github.com/golang-migrate/migrate/database"
     8  	_ "github.com/mattn/go-sqlite3"
     9  	"io"
    10  	"io/ioutil"
    11  	nurl "net/url"
    12  	"strings"
    13  )
    14  
    15  func init() {
    16  	database.Register("sqlite3", &Sqlite{})
    17  }
    18  
    19  var DefaultMigrationsTable = "schema_migrations"
    20  var (
    21  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    22  	ErrNilConfig      = fmt.Errorf("no config")
    23  	ErrNoDatabaseName = fmt.Errorf("no database name")
    24  )
    25  
    26  type Config struct {
    27  	MigrationsTable string
    28  	DatabaseName    string
    29  }
    30  
    31  type Sqlite struct {
    32  	db       *sql.DB
    33  	isLocked bool
    34  
    35  	config *Config
    36  }
    37  
    38  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    39  	if config == nil {
    40  		return nil, ErrNilConfig
    41  	}
    42  
    43  	if err := instance.Ping(); err != nil {
    44  		return nil, err
    45  	}
    46  	if len(config.MigrationsTable) == 0 {
    47  		config.MigrationsTable = DefaultMigrationsTable
    48  	}
    49  
    50  	mx := &Sqlite{
    51  		db:     instance,
    52  		config: config,
    53  	}
    54  	if err := mx.ensureVersionTable(); err != nil {
    55  		return nil, err
    56  	}
    57  	return mx, nil
    58  }
    59  
    60  func (m *Sqlite) ensureVersionTable() error {
    61  
    62  	query := fmt.Sprintf(`
    63  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    64    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    65    `, DefaultMigrationsTable, DefaultMigrationsTable)
    66  
    67  	if _, err := m.db.Exec(query); err != nil {
    68  		return err
    69  	}
    70  	return nil
    71  }
    72  
    73  func (m *Sqlite) Open(url string) (database.Driver, error) {
    74  	purl, err := nurl.Parse(url)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
    79  	db, err := sql.Open("sqlite3", dbfile)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	migrationsTable := purl.Query().Get("x-migrations-table")
    85  	if len(migrationsTable) == 0 {
    86  		migrationsTable = DefaultMigrationsTable
    87  	}
    88  	mx, err := WithInstance(db, &Config{
    89  		DatabaseName:    purl.Path,
    90  		MigrationsTable: migrationsTable,
    91  	})
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	return mx, nil
    96  }
    97  
    98  func (m *Sqlite) Close() error {
    99  	return m.db.Close()
   100  }
   101  
   102  func (m *Sqlite) Drop() error {
   103  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   104  	tables, err := m.db.Query(query)
   105  	if err != nil {
   106  		return &database.Error{OrigErr: err, Query: []byte(query)}
   107  	}
   108  	defer tables.Close()
   109  	tableNames := make([]string, 0)
   110  	for tables.Next() {
   111  		var tableName string
   112  		if err := tables.Scan(&tableName); err != nil {
   113  			return err
   114  		}
   115  		if len(tableName) > 0 {
   116  			tableNames = append(tableNames, tableName)
   117  		}
   118  	}
   119  	if len(tableNames) > 0 {
   120  		for _, t := range tableNames {
   121  			query := "DROP TABLE " + t
   122  			err = m.executeQuery(query)
   123  			if err != nil {
   124  				return &database.Error{OrigErr: err, Query: []byte(query)}
   125  			}
   126  		}
   127  		if err := m.ensureVersionTable(); err != nil {
   128  			return err
   129  		}
   130  		query := "VACUUM"
   131  		_, err = m.db.Query(query)
   132  		if err != nil {
   133  			return &database.Error{OrigErr: err, Query: []byte(query)}
   134  		}
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  func (m *Sqlite) Lock() error {
   141  	if m.isLocked {
   142  		return database.ErrLocked
   143  	}
   144  	m.isLocked = true
   145  	return nil
   146  }
   147  
   148  func (m *Sqlite) Unlock() error {
   149  	if !m.isLocked {
   150  		return nil
   151  	}
   152  	m.isLocked = false
   153  	return nil
   154  }
   155  
   156  func (m *Sqlite) Run(migration io.Reader) error {
   157  	migr, err := ioutil.ReadAll(migration)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	query := string(migr[:])
   162  
   163  	return m.executeQuery(query)
   164  }
   165  
   166  func (m *Sqlite) executeQuery(query string) error {
   167  	tx, err := m.db.Begin()
   168  	if err != nil {
   169  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   170  	}
   171  	if _, err := tx.Exec(query); err != nil {
   172  		tx.Rollback()
   173  		return &database.Error{OrigErr: err, Query: []byte(query)}
   174  	}
   175  	if err := tx.Commit(); err != nil {
   176  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   177  	}
   178  	return nil
   179  }
   180  
   181  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   182  	tx, err := m.db.Begin()
   183  	if err != nil {
   184  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   185  	}
   186  
   187  	query := "DELETE FROM " + m.config.MigrationsTable
   188  	if _, err := tx.Exec(query); err != nil {
   189  		return &database.Error{OrigErr: err, Query: []byte(query)}
   190  	}
   191  
   192  	if version >= 0 {
   193  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty)
   194  		if _, err := tx.Exec(query); err != nil {
   195  			tx.Rollback()
   196  			return &database.Error{OrigErr: err, Query: []byte(query)}
   197  		}
   198  	}
   199  
   200  	if err := tx.Commit(); err != nil {
   201  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   208  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   209  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   210  	if err != nil {
   211  		return database.NilVersion, false, nil
   212  	}
   213  	return version, dirty, nil
   214  }