github.com/fr-nvriep/migrate/v4@v4.3.2/database/sqlite3/sqlite3.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strings"
    10  
    11  	"github.com/fr-nvriep/migrate/v4"
    12  	"github.com/fr-nvriep/migrate/v4/database"
    13  	"github.com/hashicorp/go-multierror"
    14  	_ "github.com/mattn/go-sqlite3"
    15  )
    16  
    17  func init() {
    18  	database.Register("sqlite3", &Sqlite{})
    19  }
    20  
    21  var DefaultMigrationsTable = "schema_migrations"
    22  var (
    23  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    24  	ErrNilConfig      = fmt.Errorf("no config")
    25  	ErrNoDatabaseName = fmt.Errorf("no database name")
    26  )
    27  
    28  type Config struct {
    29  	MigrationsTable string
    30  	DatabaseName    string
    31  }
    32  
    33  type Sqlite struct {
    34  	db       *sql.DB
    35  	isLocked bool
    36  
    37  	config *Config
    38  }
    39  
    40  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    41  	if config == nil {
    42  		return nil, ErrNilConfig
    43  	}
    44  
    45  	if err := instance.Ping(); err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	if len(config.MigrationsTable) == 0 {
    50  		config.MigrationsTable = DefaultMigrationsTable
    51  	}
    52  
    53  	mx := &Sqlite{
    54  		db:     instance,
    55  		config: config,
    56  	}
    57  	if err := mx.ensureVersionTable(); err != nil {
    58  		return nil, err
    59  	}
    60  	return mx, nil
    61  }
    62  
    63  // ensureVersionTable checks if versions table exists and, if not, creates it.
    64  // Note that this function locks the database, which deviates from the usual
    65  // convention of "caller locks" in the Sqlite type.
    66  func (m *Sqlite) ensureVersionTable() (err error) {
    67  	if err = m.Lock(); err != nil {
    68  		return err
    69  	}
    70  
    71  	defer func() {
    72  		if e := m.Unlock(); e != nil {
    73  			if err == nil {
    74  				err = e
    75  			} else {
    76  				err = multierror.Append(err, e)
    77  			}
    78  		}
    79  	}()
    80  
    81  	query := fmt.Sprintf(`
    82  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    83    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    84    `, m.config.MigrationsTable, m.config.MigrationsTable)
    85  
    86  	if _, err := m.db.Exec(query); err != nil {
    87  		return err
    88  	}
    89  	return nil
    90  }
    91  
    92  func (m *Sqlite) Open(url string) (database.Driver, error) {
    93  	purl, err := nurl.Parse(url)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
    98  	db, err := sql.Open("sqlite3", dbfile)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	migrationsTable := purl.Query().Get("x-migrations-table")
   104  	if len(migrationsTable) == 0 {
   105  		migrationsTable = DefaultMigrationsTable
   106  	}
   107  	mx, err := WithInstance(db, &Config{
   108  		DatabaseName:    purl.Path,
   109  		MigrationsTable: migrationsTable,
   110  	})
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	return mx, nil
   115  }
   116  
   117  func (m *Sqlite) Close() error {
   118  	return m.db.Close()
   119  }
   120  
   121  func (m *Sqlite) Drop() (err error) {
   122  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   123  	tables, err := m.db.Query(query)
   124  	if err != nil {
   125  		return &database.Error{OrigErr: err, Query: []byte(query)}
   126  	}
   127  	defer func() {
   128  		if errClose := tables.Close(); errClose != nil {
   129  			err = multierror.Append(err, errClose)
   130  		}
   131  	}()
   132  	tableNames := make([]string, 0)
   133  	for tables.Next() {
   134  		var tableName string
   135  		if err := tables.Scan(&tableName); err != nil {
   136  			return err
   137  		}
   138  		if len(tableName) > 0 {
   139  			tableNames = append(tableNames, tableName)
   140  		}
   141  	}
   142  	if len(tableNames) > 0 {
   143  		for _, t := range tableNames {
   144  			query := "DROP TABLE " + t
   145  			err = m.executeQuery(query)
   146  			if err != nil {
   147  				return &database.Error{OrigErr: err, Query: []byte(query)}
   148  			}
   149  		}
   150  		query := "VACUUM"
   151  		_, err = m.db.Query(query)
   152  		if err != nil {
   153  			return &database.Error{OrigErr: err, Query: []byte(query)}
   154  		}
   155  	}
   156  
   157  	return nil
   158  }
   159  
   160  func (m *Sqlite) Lock() error {
   161  	if m.isLocked {
   162  		return database.ErrLocked
   163  	}
   164  	m.isLocked = true
   165  	return nil
   166  }
   167  
   168  func (m *Sqlite) Unlock() error {
   169  	if !m.isLocked {
   170  		return nil
   171  	}
   172  	m.isLocked = false
   173  	return nil
   174  }
   175  
   176  func (m *Sqlite) Run(migration io.Reader) error {
   177  	migr, err := ioutil.ReadAll(migration)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	query := string(migr[:])
   182  
   183  	return m.executeQuery(query)
   184  }
   185  
   186  func (m *Sqlite) executeQuery(query string) error {
   187  	tx, err := m.db.Begin()
   188  	if err != nil {
   189  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   190  	}
   191  	if _, err := tx.Exec(query); err != nil {
   192  		if errRollback := tx.Rollback(); errRollback != nil {
   193  			err = multierror.Append(err, errRollback)
   194  		}
   195  		return &database.Error{OrigErr: err, Query: []byte(query)}
   196  	}
   197  	if err := tx.Commit(); err != nil {
   198  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   199  	}
   200  	return nil
   201  }
   202  
   203  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   204  	tx, err := m.db.Begin()
   205  	if err != nil {
   206  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   207  	}
   208  
   209  	query := "DELETE FROM " + m.config.MigrationsTable
   210  	if _, err := tx.Exec(query); err != nil {
   211  		return &database.Error{OrigErr: err, Query: []byte(query)}
   212  	}
   213  
   214  	if version >= 0 {
   215  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (%d, '%t')`, m.config.MigrationsTable, version, dirty)
   216  		if _, err := tx.Exec(query); err != nil {
   217  			if errRollback := tx.Rollback(); errRollback != nil {
   218  				err = multierror.Append(err, errRollback)
   219  			}
   220  			return &database.Error{OrigErr: err, Query: []byte(query)}
   221  		}
   222  	}
   223  
   224  	if err := tx.Commit(); err != nil {
   225  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   232  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   233  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   234  	if err != nil {
   235  		return database.NilVersion, false, nil
   236  	}
   237  	return version, dirty, nil
   238  }