github.com/mrqzzz/migrate@v5.1.7+incompatible/database/sqlite3/sqlite3.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strings"
    10  
    11  	"github.com/golang-migrate/migrate/v4"
    12  	"github.com/golang-migrate/migrate/v4/database"
    13  	_ "github.com/mattn/go-sqlite3"
    14  )
    15  
    16  func init() {
    17  	database.Register("sqlite3", &Sqlite{})
    18  }
    19  
    20  var DefaultMigrationsTable = "schema_migrations"
    21  var (
    22  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    23  	ErrNilConfig      = fmt.Errorf("no config")
    24  	ErrNoDatabaseName = fmt.Errorf("no database name")
    25  )
    26  
    27  type Config struct {
    28  	MigrationsTable string
    29  	DatabaseName    string
    30  }
    31  
    32  type Sqlite struct {
    33  	db       *sql.DB
    34  	isLocked bool
    35  
    36  	config *Config
    37  }
    38  
    39  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    40  	if config == nil {
    41  		return nil, ErrNilConfig
    42  	}
    43  
    44  	if err := instance.Ping(); err != nil {
    45  		return nil, err
    46  	}
    47  	if len(config.MigrationsTable) == 0 {
    48  		config.MigrationsTable = DefaultMigrationsTable
    49  	}
    50  
    51  	mx := &Sqlite{
    52  		db:     instance,
    53  		config: config,
    54  	}
    55  	if err := mx.ensureVersionTable(); err != nil {
    56  		return nil, err
    57  	}
    58  	return mx, nil
    59  }
    60  
    61  func (m *Sqlite) ensureVersionTable() error {
    62  
    63  	query := fmt.Sprintf(`
    64  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    65    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    66    `, m.config.MigrationsTable, m.config.MigrationsTable)
    67  
    68  	if _, err := m.db.Exec(query); err != nil {
    69  		return err
    70  	}
    71  	return nil
    72  }
    73  
    74  func (m *Sqlite) Open(url string) (database.Driver, error) {
    75  	purl, err := nurl.Parse(url)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
    80  	db, err := sql.Open("sqlite3", dbfile)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	migrationsTable := purl.Query().Get("x-migrations-table")
    86  	if len(migrationsTable) == 0 {
    87  		migrationsTable = DefaultMigrationsTable
    88  	}
    89  	mx, err := WithInstance(db, &Config{
    90  		DatabaseName:    purl.Path,
    91  		MigrationsTable: migrationsTable,
    92  	})
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	return mx, nil
    97  }
    98  
    99  func (m *Sqlite) Close() error {
   100  	return m.db.Close()
   101  }
   102  
   103  func (m *Sqlite) Drop() error {
   104  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   105  	tables, err := m.db.Query(query)
   106  	if err != nil {
   107  		return &database.Error{OrigErr: err, Query: []byte(query)}
   108  	}
   109  	defer tables.Close()
   110  	tableNames := make([]string, 0)
   111  	for tables.Next() {
   112  		var tableName string
   113  		if err := tables.Scan(&tableName); err != nil {
   114  			return err
   115  		}
   116  		if len(tableName) > 0 {
   117  			tableNames = append(tableNames, tableName)
   118  		}
   119  	}
   120  	if len(tableNames) > 0 {
   121  		for _, t := range tableNames {
   122  			query := "DROP TABLE " + t
   123  			err = m.executeQuery(query)
   124  			if err != nil {
   125  				return &database.Error{OrigErr: err, Query: []byte(query)}
   126  			}
   127  		}
   128  		if err := m.ensureVersionTable(); err != nil {
   129  			return err
   130  		}
   131  		query := "VACUUM"
   132  		_, err = m.db.Query(query)
   133  		if err != nil {
   134  			return &database.Error{OrigErr: err, Query: []byte(query)}
   135  		}
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func (m *Sqlite) Lock() error {
   142  	if m.isLocked {
   143  		return database.ErrLocked
   144  	}
   145  	m.isLocked = true
   146  	return nil
   147  }
   148  
   149  func (m *Sqlite) Unlock() error {
   150  	if !m.isLocked {
   151  		return nil
   152  	}
   153  	m.isLocked = false
   154  	return nil
   155  }
   156  
   157  func (m *Sqlite) Run(migration io.Reader) error {
   158  	migr, err := ioutil.ReadAll(migration)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	query := string(migr[:])
   163  
   164  	return m.executeQuery(query)
   165  }
   166  
   167  func (m *Sqlite) executeQuery(query string) error {
   168  	tx, err := m.db.Begin()
   169  	if err != nil {
   170  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   171  	}
   172  	if _, err := tx.Exec(query); err != nil {
   173  		tx.Rollback()
   174  		return &database.Error{OrigErr: err, Query: []byte(query)}
   175  	}
   176  	if err := tx.Commit(); err != nil {
   177  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   178  	}
   179  	return nil
   180  }
   181  
   182  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   183  	tx, err := m.db.Begin()
   184  	if err != nil {
   185  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   186  	}
   187  
   188  	query := "DELETE FROM " + m.config.MigrationsTable
   189  	if _, err := tx.Exec(query); err != nil {
   190  		return &database.Error{OrigErr: err, Query: []byte(query)}
   191  	}
   192  
   193  	if version >= 0 {
   194  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty)
   195  		if _, err := tx.Exec(query); err != nil {
   196  			tx.Rollback()
   197  			return &database.Error{OrigErr: err, Query: []byte(query)}
   198  		}
   199  	}
   200  
   201  	if err := tx.Commit(); err != nil {
   202  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   203  	}
   204  
   205  	return nil
   206  }
   207  
   208  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   209  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   210  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   211  	if err != nil {
   212  		return database.NilVersion, false, nil
   213  	}
   214  	return version, dirty, nil
   215  }