github.com/nagyistzcons/migrate/v4@v4.14.5/database/sqlcipher/sqlcipher.go (about)

     1  package sqlcipher
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/golang-migrate/migrate/v4"
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	_ "github.com/mutecomm/go-sqlcipher/v4"
    16  )
    17  
    18  func init() {
    19  	database.Register("sqlcipher", &Sqlite{})
    20  }
    21  
    22  var DefaultMigrationsTable = "schema_migrations"
    23  var (
    24  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    25  	ErrNilConfig      = fmt.Errorf("no config")
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  )
    28  
    29  type Config struct {
    30  	MigrationsTable string
    31  	DatabaseName    string
    32  	NoTxWrap        bool
    33  }
    34  
    35  type Sqlite struct {
    36  	db       *sql.DB
    37  	isLocked bool
    38  
    39  	config *Config
    40  }
    41  
    42  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    43  	if config == nil {
    44  		return nil, ErrNilConfig
    45  	}
    46  
    47  	if err := instance.Ping(); err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	if len(config.MigrationsTable) == 0 {
    52  		config.MigrationsTable = DefaultMigrationsTable
    53  	}
    54  
    55  	mx := &Sqlite{
    56  		db:     instance,
    57  		config: config,
    58  	}
    59  	if err := mx.ensureVersionTable(); err != nil {
    60  		return nil, err
    61  	}
    62  	return mx, nil
    63  }
    64  
    65  // ensureVersionTable checks if versions table exists and, if not, creates it.
    66  // Note that this function locks the database, which deviates from the usual
    67  // convention of "caller locks" in the Sqlite type.
    68  func (m *Sqlite) ensureVersionTable() (err error) {
    69  	if err = m.Lock(); err != nil {
    70  		return err
    71  	}
    72  
    73  	defer func() {
    74  		if e := m.Unlock(); e != nil {
    75  			if err == nil {
    76  				err = e
    77  			} else {
    78  				err = multierror.Append(err, e)
    79  			}
    80  		}
    81  	}()
    82  
    83  	query := fmt.Sprintf(`
    84  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    85    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    86    `, m.config.MigrationsTable, m.config.MigrationsTable)
    87  
    88  	if _, err := m.db.Exec(query); err != nil {
    89  		return err
    90  	}
    91  	return nil
    92  }
    93  
    94  func (m *Sqlite) Open(url string) (database.Driver, error) {
    95  	purl, err := nurl.Parse(url)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
   100  	db, err := sql.Open("sqlite3", dbfile)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	qv := purl.Query()
   106  
   107  	migrationsTable := qv.Get("x-migrations-table")
   108  	if len(migrationsTable) == 0 {
   109  		migrationsTable = DefaultMigrationsTable
   110  	}
   111  
   112  	noTxWrap := false
   113  	if v := qv.Get("x-no-tx-wrap"); v != "" {
   114  		noTxWrap, err = strconv.ParseBool(v)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
   117  		}
   118  	}
   119  
   120  	mx, err := WithInstance(db, &Config{
   121  		DatabaseName:    purl.Path,
   122  		MigrationsTable: migrationsTable,
   123  		NoTxWrap:        noTxWrap,
   124  	})
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	return mx, nil
   129  }
   130  
   131  func (m *Sqlite) Close() error {
   132  	return m.db.Close()
   133  }
   134  
   135  func (m *Sqlite) Drop() (err error) {
   136  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   137  	tables, err := m.db.Query(query)
   138  	if err != nil {
   139  		return &database.Error{OrigErr: err, Query: []byte(query)}
   140  	}
   141  	defer func() {
   142  		if errClose := tables.Close(); errClose != nil {
   143  			err = multierror.Append(err, errClose)
   144  		}
   145  	}()
   146  
   147  	tableNames := make([]string, 0)
   148  	for tables.Next() {
   149  		var tableName string
   150  		if err := tables.Scan(&tableName); err != nil {
   151  			return err
   152  		}
   153  		if len(tableName) > 0 {
   154  			tableNames = append(tableNames, tableName)
   155  		}
   156  	}
   157  	if err := tables.Err(); err != nil {
   158  		return &database.Error{OrigErr: err, Query: []byte(query)}
   159  	}
   160  
   161  	if len(tableNames) > 0 {
   162  		for _, t := range tableNames {
   163  			query := "DROP TABLE " + t
   164  			err = m.executeQuery(query)
   165  			if err != nil {
   166  				return &database.Error{OrigErr: err, Query: []byte(query)}
   167  			}
   168  		}
   169  		query := "VACUUM"
   170  		_, err = m.db.Query(query)
   171  		if err != nil {
   172  			return &database.Error{OrigErr: err, Query: []byte(query)}
   173  		}
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  func (m *Sqlite) Lock() error {
   180  	if m.isLocked {
   181  		return database.ErrLocked
   182  	}
   183  	m.isLocked = true
   184  	return nil
   185  }
   186  
   187  func (m *Sqlite) Unlock() error {
   188  	if !m.isLocked {
   189  		return nil
   190  	}
   191  	m.isLocked = false
   192  	return nil
   193  }
   194  
   195  func (m *Sqlite) Run(migration io.Reader) error {
   196  	migr, err := ioutil.ReadAll(migration)
   197  	if err != nil {
   198  		return err
   199  	}
   200  	query := string(migr[:])
   201  
   202  	if m.config.NoTxWrap {
   203  		return m.executeQueryNoTx(query)
   204  	}
   205  	return m.executeQuery(query)
   206  }
   207  
   208  func (m *Sqlite) executeQuery(query string) error {
   209  	tx, err := m.db.Begin()
   210  	if err != nil {
   211  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   212  	}
   213  	if _, err := tx.Exec(query); err != nil {
   214  		if errRollback := tx.Rollback(); errRollback != nil {
   215  			err = multierror.Append(err, errRollback)
   216  		}
   217  		return &database.Error{OrigErr: err, Query: []byte(query)}
   218  	}
   219  	if err := tx.Commit(); err != nil {
   220  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   221  	}
   222  	return nil
   223  }
   224  
   225  func (m *Sqlite) executeQueryNoTx(query string) error {
   226  	if _, err := m.db.Exec(query); err != nil {
   227  		return &database.Error{OrigErr: err, Query: []byte(query)}
   228  	}
   229  	return nil
   230  }
   231  
   232  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   233  	tx, err := m.db.Begin()
   234  	if err != nil {
   235  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   236  	}
   237  
   238  	query := "DELETE FROM " + m.config.MigrationsTable
   239  	if _, err := tx.Exec(query); err != nil {
   240  		return &database.Error{OrigErr: err, Query: []byte(query)}
   241  	}
   242  
   243  	// Also re-write the schema version for nil dirty versions to prevent
   244  	// empty schema version for failed down migration on the first migration
   245  	// See: https://github.com/golang-migrate/migrate/issues/330
   246  	if version >= 0 || (version == database.NilVersion && dirty) {
   247  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
   248  		if _, err := tx.Exec(query, version, dirty); err != nil {
   249  			if errRollback := tx.Rollback(); errRollback != nil {
   250  				err = multierror.Append(err, errRollback)
   251  			}
   252  			return &database.Error{OrigErr: err, Query: []byte(query)}
   253  		}
   254  	}
   255  
   256  	if err := tx.Commit(); err != nil {
   257  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   264  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   265  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   266  	if err != nil {
   267  		return database.NilVersion, false, nil
   268  	}
   269  	return version, dirty, nil
   270  }