github.com/dynastymasra/migrate/v4@v4.11.0/database/sqlite3/sqlite3.go (about)

     1  package sqlite3
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/golang-migrate/migrate/v4"
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	_ "github.com/mattn/go-sqlite3"
    16  )
    17  
    18  func init() {
    19  	database.Register("sqlite3", &Sqlite{})
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  var (
    24  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    25  	ErrNilConfig      = fmt.Errorf("no config")
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  )
    28  
    29  type Config struct {
    30  	MigrationsTable string
    31  	DatabaseName    string
    32  	NoTxWrap        bool
    33  }
    34  
    35  type Sqlite struct {
    36  	db       *sql.DB
    37  	isLocked bool
    38  
    39  	config *Config
    40  }
    41  
    42  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    43  	if config == nil {
    44  		return nil, ErrNilConfig
    45  	}
    46  
    47  	if err := instance.Ping(); err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	if len(config.MigrationsTable) == 0 {
    52  		config.MigrationsTable = DefaultMigrationsTable
    53  	}
    54  
    55  	mx := &Sqlite{
    56  		db:     instance,
    57  		config: config,
    58  	}
    59  	if err := mx.ensureVersionTable(); err != nil {
    60  		return nil, err
    61  	}
    62  	return mx, nil
    63  }
    64  
    65  // ensureVersionTable checks if versions table exists and, if not, creates it.
    66  // Note that this function locks the database, which deviates from the usual
    67  // convention of "caller locks" in the Sqlite type.
    68  func (m *Sqlite) ensureVersionTable() (err error) {
    69  	if err = m.Lock(); err != nil {
    70  		return err
    71  	}
    72  
    73  	defer func() {
    74  		if e := m.Unlock(); e != nil {
    75  			if err == nil {
    76  				err = e
    77  			} else {
    78  				err = multierror.Append(err, e)
    79  			}
    80  		}
    81  	}()
    82  
    83  	query := fmt.Sprintf(`
    84  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    85    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    86    `, m.config.MigrationsTable, m.config.MigrationsTable)
    87  
    88  	if _, err := m.db.Exec(query); err != nil {
    89  		return err
    90  	}
    91  	return nil
    92  }
    93  
    94  func (m *Sqlite) Open(url string) (database.Driver, error) {
    95  	purl, err := nurl.Parse(url)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
   100  	db, err := sql.Open("sqlite3", dbfile)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	qv := purl.Query()
   106  
   107  	migrationsTable := qv.Get("x-migrations-table")
   108  	if len(migrationsTable) == 0 {
   109  		migrationsTable = DefaultMigrationsTable
   110  	}
   111  
   112  	noTxWrap := false
   113  	if v := qv.Get("x-no-tx-wrap"); v != "" {
   114  		noTxWrap, err = strconv.ParseBool(v)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
   117  		}
   118  	}
   119  
   120  	mx, err := WithInstance(db, &Config{
   121  		DatabaseName:    purl.Path,
   122  		MigrationsTable: migrationsTable,
   123  		NoTxWrap:        noTxWrap,
   124  	})
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	return mx, nil
   129  }
   130  
   131  func (m *Sqlite) Close() error {
   132  	return m.db.Close()
   133  }
   134  
   135  func (m *Sqlite) Drop() (err error) {
   136  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   137  	tables, err := m.db.Query(query)
   138  	if err != nil {
   139  		return &database.Error{OrigErr: err, Query: []byte(query)}
   140  	}
   141  	defer func() {
   142  		if errClose := tables.Close(); errClose != nil {
   143  			err = multierror.Append(err, errClose)
   144  		}
   145  	}()
   146  	tableNames := make([]string, 0)
   147  	for tables.Next() {
   148  		var tableName string
   149  		if err := tables.Scan(&tableName); err != nil {
   150  			return err
   151  		}
   152  		if len(tableName) > 0 {
   153  			tableNames = append(tableNames, tableName)
   154  		}
   155  	}
   156  	if len(tableNames) > 0 {
   157  		for _, t := range tableNames {
   158  			query := "DROP TABLE " + t
   159  			err = m.executeQuery(query)
   160  			if err != nil {
   161  				return &database.Error{OrigErr: err, Query: []byte(query)}
   162  			}
   163  		}
   164  		query := "VACUUM"
   165  		_, err = m.db.Query(query)
   166  		if err != nil {
   167  			return &database.Error{OrigErr: err, Query: []byte(query)}
   168  		}
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (m *Sqlite) Lock() error {
   175  	if m.isLocked {
   176  		return database.ErrLocked
   177  	}
   178  	m.isLocked = true
   179  	return nil
   180  }
   181  
   182  func (m *Sqlite) Unlock() error {
   183  	if !m.isLocked {
   184  		return nil
   185  	}
   186  	m.isLocked = false
   187  	return nil
   188  }
   189  
   190  func (m *Sqlite) Run(migration io.Reader) error {
   191  	migr, err := ioutil.ReadAll(migration)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	query := string(migr[:])
   196  
   197  	if m.config.NoTxWrap {
   198  		return m.executeQueryNoTx(query)
   199  	}
   200  	return m.executeQuery(query)
   201  }
   202  
   203  func (m *Sqlite) executeQuery(query string) error {
   204  	tx, err := m.db.Begin()
   205  	if err != nil {
   206  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   207  	}
   208  	if _, err := tx.Exec(query); err != nil {
   209  		if errRollback := tx.Rollback(); errRollback != nil {
   210  			err = multierror.Append(err, errRollback)
   211  		}
   212  		return &database.Error{OrigErr: err, Query: []byte(query)}
   213  	}
   214  	if err := tx.Commit(); err != nil {
   215  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   216  	}
   217  	return nil
   218  }
   219  
   220  func (m *Sqlite) executeQueryNoTx(query string) error {
   221  	if _, err := m.db.Exec(query); err != nil {
   222  		return &database.Error{OrigErr: err, Query: []byte(query)}
   223  	}
   224  	return nil
   225  }
   226  
   227  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   228  	tx, err := m.db.Begin()
   229  	if err != nil {
   230  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   231  	}
   232  
   233  	query := "DELETE FROM " + m.config.MigrationsTable
   234  	if _, err := tx.Exec(query); err != nil {
   235  		return &database.Error{OrigErr: err, Query: []byte(query)}
   236  	}
   237  
   238  	// Also re-write the schema version for nil dirty versions to prevent
   239  	// empty schema version for failed down migration on the first migration
   240  	// See: https://github.com/golang-migrate/migrate/issues/330
   241  	if version >= 0 || (version == database.NilVersion && dirty) {
   242  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
   243  		if _, err := tx.Exec(query, version, dirty); err != nil {
   244  			if errRollback := tx.Rollback(); errRollback != nil {
   245  				err = multierror.Append(err, errRollback)
   246  			}
   247  			return &database.Error{OrigErr: err, Query: []byte(query)}
   248  		}
   249  	}
   250  
   251  	if err := tx.Commit(); err != nil {
   252  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   259  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   260  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   261  	if err != nil {
   262  		return database.NilVersion, false, nil
   263  	}
   264  	return version, dirty, nil
   265  }