github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/sqlite/sqlite.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	nurl "net/url"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"go.uber.org/atomic"
    12  
    13  	"github.com/golang-migrate/migrate/v4"
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  	"github.com/hashicorp/go-multierror"
    16  	_ "modernc.org/sqlite"
    17  )
    18  
    19  func init() {
    20  	database.Register("sqlite", &Sqlite{})
    21  }
    22  
    23  var DefaultMigrationsTable = "schema_migrations"
    24  var (
    25  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    26  	ErrNilConfig      = fmt.Errorf("no config")
    27  	ErrNoDatabaseName = fmt.Errorf("no database name")
    28  )
    29  
    30  type Config struct {
    31  	MigrationsTable string
    32  	DatabaseName    string
    33  	NoTxWrap        bool
    34  }
    35  
    36  type Sqlite struct {
    37  	db       *sql.DB
    38  	isLocked atomic.Bool
    39  
    40  	config *Config
    41  }
    42  
    43  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    44  	if config == nil {
    45  		return nil, ErrNilConfig
    46  	}
    47  
    48  	if err := instance.Ping(); err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	if len(config.MigrationsTable) == 0 {
    53  		config.MigrationsTable = DefaultMigrationsTable
    54  	}
    55  
    56  	mx := &Sqlite{
    57  		db:     instance,
    58  		config: config,
    59  	}
    60  	if err := mx.ensureVersionTable(); err != nil {
    61  		return nil, err
    62  	}
    63  	return mx, nil
    64  }
    65  
    66  // ensureVersionTable checks if versions table exists and, if not, creates it.
    67  // Note that this function locks the database, which deviates from the usual
    68  // convention of "caller locks" in the Sqlite type.
    69  func (m *Sqlite) ensureVersionTable() (err error) {
    70  	if err = m.Lock(); err != nil {
    71  		return err
    72  	}
    73  
    74  	defer func() {
    75  		if e := m.Unlock(); e != nil {
    76  			if err == nil {
    77  				err = e
    78  			} else {
    79  				err = multierror.Append(err, e)
    80  			}
    81  		}
    82  	}()
    83  
    84  	query := fmt.Sprintf(`
    85  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    86    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    87    `, m.config.MigrationsTable, m.config.MigrationsTable)
    88  
    89  	if _, err := m.db.Exec(query); err != nil {
    90  		return err
    91  	}
    92  	return nil
    93  }
    94  
    95  func (m *Sqlite) Open(url string) (database.Driver, error) {
    96  	purl, err := nurl.Parse(url)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite://", "", 1)
   101  	db, err := sql.Open("sqlite", dbfile)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	qv := purl.Query()
   107  
   108  	migrationsTable := qv.Get("x-migrations-table")
   109  	if len(migrationsTable) == 0 {
   110  		migrationsTable = DefaultMigrationsTable
   111  	}
   112  
   113  	noTxWrap := false
   114  	if v := qv.Get("x-no-tx-wrap"); v != "" {
   115  		noTxWrap, err = strconv.ParseBool(v)
   116  		if err != nil {
   117  			return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
   118  		}
   119  	}
   120  
   121  	mx, err := WithInstance(db, &Config{
   122  		DatabaseName:    purl.Path,
   123  		MigrationsTable: migrationsTable,
   124  		NoTxWrap:        noTxWrap,
   125  	})
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	return mx, nil
   130  }
   131  
   132  func (m *Sqlite) Close() error {
   133  	return m.db.Close()
   134  }
   135  
   136  func (m *Sqlite) Drop() (err error) {
   137  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   138  	tables, err := m.db.Query(query)
   139  	if err != nil {
   140  		return &database.Error{OrigErr: err, Query: []byte(query)}
   141  	}
   142  	defer func() {
   143  		if errClose := tables.Close(); errClose != nil {
   144  			err = multierror.Append(err, errClose)
   145  		}
   146  	}()
   147  
   148  	tableNames := make([]string, 0)
   149  	for tables.Next() {
   150  		var tableName string
   151  		if err := tables.Scan(&tableName); err != nil {
   152  			return err
   153  		}
   154  		if len(tableName) > 0 {
   155  			tableNames = append(tableNames, tableName)
   156  		}
   157  	}
   158  	if err := tables.Err(); err != nil {
   159  		return &database.Error{OrigErr: err, Query: []byte(query)}
   160  	}
   161  
   162  	if len(tableNames) > 0 {
   163  		for _, t := range tableNames {
   164  			query := "DROP TABLE " + t
   165  			err = m.executeQuery(query)
   166  			if err != nil {
   167  				return &database.Error{OrigErr: err, Query: []byte(query)}
   168  			}
   169  		}
   170  		query := "VACUUM"
   171  		_, err = m.db.Query(query)
   172  		if err != nil {
   173  			return &database.Error{OrigErr: err, Query: []byte(query)}
   174  		}
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func (m *Sqlite) Lock() error {
   181  	if !m.isLocked.CAS(false, true) {
   182  		return database.ErrLocked
   183  	}
   184  	return nil
   185  }
   186  
   187  func (m *Sqlite) Unlock() error {
   188  	if !m.isLocked.CAS(true, false) {
   189  		return database.ErrNotLocked
   190  	}
   191  	return nil
   192  }
   193  
   194  func (m *Sqlite) Run(migration io.Reader) error {
   195  	migr, err := io.ReadAll(migration)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	query := string(migr[:])
   200  
   201  	if m.config.NoTxWrap {
   202  		return m.executeQueryNoTx(query)
   203  	}
   204  	return m.executeQuery(query)
   205  }
   206  
   207  func (m *Sqlite) executeQuery(query string) error {
   208  	tx, err := m.db.Begin()
   209  	if err != nil {
   210  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   211  	}
   212  	if _, err := tx.Exec(query); err != nil {
   213  		if errRollback := tx.Rollback(); errRollback != nil {
   214  			err = multierror.Append(err, errRollback)
   215  		}
   216  		return &database.Error{OrigErr: err, Query: []byte(query)}
   217  	}
   218  	if err := tx.Commit(); err != nil {
   219  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   220  	}
   221  	return nil
   222  }
   223  
   224  func (m *Sqlite) executeQueryNoTx(query string) error {
   225  	if _, err := m.db.Exec(query); err != nil {
   226  		return &database.Error{OrigErr: err, Query: []byte(query)}
   227  	}
   228  	return nil
   229  }
   230  
   231  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   232  	tx, err := m.db.Begin()
   233  	if err != nil {
   234  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   235  	}
   236  
   237  	query := "DELETE FROM " + m.config.MigrationsTable
   238  	if _, err := tx.Exec(query); err != nil {
   239  		return &database.Error{OrigErr: err, Query: []byte(query)}
   240  	}
   241  
   242  	// Also re-write the schema version for nil dirty versions to prevent
   243  	// empty schema version for failed down migration on the first migration
   244  	// See: https://github.com/golang-migrate/migrate/issues/330
   245  	if version >= 0 || (version == database.NilVersion && dirty) {
   246  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
   247  		if _, err := tx.Exec(query, version, dirty); err != nil {
   248  			if errRollback := tx.Rollback(); errRollback != nil {
   249  				err = multierror.Append(err, errRollback)
   250  			}
   251  			return &database.Error{OrigErr: err, Query: []byte(query)}
   252  		}
   253  	}
   254  
   255  	if err := tx.Commit(); err != nil {
   256  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   263  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   264  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   265  	if err != nil {
   266  		return database.NilVersion, false, nil
   267  	}
   268  	return version, dirty, nil
   269  }