github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/sqlcipher/sqlcipher.go (about)

     1  package sqlcipher
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	nurl "net/url"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"go.uber.org/atomic"
    13  
    14  	"github.com/hashicorp/go-multierror"
    15  	_ "github.com/mutecomm/go-sqlcipher/v4"
    16  	migrate "github.com/seashell-org/golang-migrate/v4"
    17  	"github.com/seashell-org/golang-migrate/v4/database"
    18  )
    19  
    20  func init() {
    21  	database.Register("sqlcipher", &Sqlite{})
    22  }
    23  
    24  var DefaultMigrationsTable = "schema_migrations"
    25  var (
    26  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  	ErrNoDatabaseName = fmt.Errorf("no database name")
    29  )
    30  
    31  type Config struct {
    32  	MigrationsTable string
    33  	DatabaseName    string
    34  	NoTxWrap        bool
    35  }
    36  
    37  type Sqlite struct {
    38  	db       *sql.DB
    39  	isLocked atomic.Bool
    40  
    41  	config *Config
    42  }
    43  
    44  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    45  	if config == nil {
    46  		return nil, ErrNilConfig
    47  	}
    48  
    49  	if err := instance.Ping(); err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if len(config.MigrationsTable) == 0 {
    54  		config.MigrationsTable = DefaultMigrationsTable
    55  	}
    56  
    57  	mx := &Sqlite{
    58  		db:     instance,
    59  		config: config,
    60  	}
    61  	if err := mx.ensureVersionTable(); err != nil {
    62  		return nil, err
    63  	}
    64  	return mx, nil
    65  }
    66  
    67  // ensureVersionTable checks if versions table exists and, if not, creates it.
    68  // Note that this function locks the database, which deviates from the usual
    69  // convention of "caller locks" in the Sqlite type.
    70  func (m *Sqlite) ensureVersionTable() (err error) {
    71  	if err = m.Lock(); err != nil {
    72  		return err
    73  	}
    74  
    75  	defer func() {
    76  		if e := m.Unlock(); e != nil {
    77  			if err == nil {
    78  				err = e
    79  			} else {
    80  				err = multierror.Append(err, e)
    81  			}
    82  		}
    83  	}()
    84  
    85  	query := fmt.Sprintf(`
    86  	CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool);
    87    CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version);
    88    `, m.config.MigrationsTable, m.config.MigrationsTable)
    89  
    90  	if _, err := m.db.Exec(query); err != nil {
    91  		return err
    92  	}
    93  	return nil
    94  }
    95  
    96  func (m *Sqlite) Open(url string) (database.Driver, error) {
    97  	purl, err := nurl.Parse(url)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	dbfile := strings.Replace(migrate.FilterCustomQuery(purl).String(), "sqlite3://", "", 1)
   102  	db, err := sql.Open("sqlite3", dbfile)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	qv := purl.Query()
   108  
   109  	migrationsTable := qv.Get("x-migrations-table")
   110  	if len(migrationsTable) == 0 {
   111  		migrationsTable = DefaultMigrationsTable
   112  	}
   113  
   114  	noTxWrap := false
   115  	if v := qv.Get("x-no-tx-wrap"); v != "" {
   116  		noTxWrap, err = strconv.ParseBool(v)
   117  		if err != nil {
   118  			return nil, fmt.Errorf("x-no-tx-wrap: %s", err)
   119  		}
   120  	}
   121  
   122  	mx, err := WithInstance(db, &Config{
   123  		DatabaseName:    purl.Path,
   124  		MigrationsTable: migrationsTable,
   125  		NoTxWrap:        noTxWrap,
   126  	})
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	return mx, nil
   131  }
   132  
   133  func (m *Sqlite) Close() error {
   134  	return m.db.Close()
   135  }
   136  
   137  func (m *Sqlite) Drop() (err error) {
   138  	query := `SELECT name FROM sqlite_master WHERE type = 'table';`
   139  	tables, err := m.db.Query(query)
   140  	if err != nil {
   141  		return &database.Error{OrigErr: err, Query: []byte(query)}
   142  	}
   143  	defer func() {
   144  		if errClose := tables.Close(); errClose != nil {
   145  			err = multierror.Append(err, errClose)
   146  		}
   147  	}()
   148  
   149  	tableNames := make([]string, 0)
   150  	for tables.Next() {
   151  		var tableName string
   152  		if err := tables.Scan(&tableName); err != nil {
   153  			return err
   154  		}
   155  		if len(tableName) > 0 {
   156  			tableNames = append(tableNames, tableName)
   157  		}
   158  	}
   159  	if err := tables.Err(); err != nil {
   160  		return &database.Error{OrigErr: err, Query: []byte(query)}
   161  	}
   162  
   163  	if len(tableNames) > 0 {
   164  		for _, t := range tableNames {
   165  			query := "DROP TABLE " + t
   166  			err = m.executeQuery(query)
   167  			if err != nil {
   168  				return &database.Error{OrigErr: err, Query: []byte(query)}
   169  			}
   170  		}
   171  		query := "VACUUM"
   172  		_, err = m.db.Query(query)
   173  		if err != nil {
   174  			return &database.Error{OrigErr: err, Query: []byte(query)}
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  func (m *Sqlite) Lock() error {
   182  	if !m.isLocked.CAS(false, true) {
   183  		return database.ErrLocked
   184  	}
   185  	return nil
   186  }
   187  
   188  func (m *Sqlite) Unlock() error {
   189  	if !m.isLocked.CAS(true, false) {
   190  		return database.ErrNotLocked
   191  	}
   192  	return nil
   193  }
   194  
   195  func (m *Sqlite) Run(migration io.Reader) error {
   196  	migr, err := ioutil.ReadAll(migration)
   197  	if err != nil {
   198  		return err
   199  	}
   200  	query := string(migr[:])
   201  
   202  	if m.config.NoTxWrap {
   203  		return m.executeQueryNoTx(query)
   204  	}
   205  	return m.executeQuery(query)
   206  }
   207  
   208  func (m *Sqlite) executeQuery(query string) error {
   209  	tx, err := m.db.Begin()
   210  	if err != nil {
   211  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   212  	}
   213  	if _, err := tx.Exec(query); err != nil {
   214  		if errRollback := tx.Rollback(); errRollback != nil {
   215  			err = multierror.Append(err, errRollback)
   216  		}
   217  		return &database.Error{OrigErr: err, Query: []byte(query)}
   218  	}
   219  	if err := tx.Commit(); err != nil {
   220  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   221  	}
   222  	return nil
   223  }
   224  
   225  func (m *Sqlite) executeQueryNoTx(query string) error {
   226  	if _, err := m.db.Exec(query); err != nil {
   227  		return &database.Error{OrigErr: err, Query: []byte(query)}
   228  	}
   229  	return nil
   230  }
   231  
   232  func (m *Sqlite) SetVersion(version int, dirty bool) error {
   233  	tx, err := m.db.Begin()
   234  	if err != nil {
   235  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   236  	}
   237  
   238  	query := "DELETE FROM " + m.config.MigrationsTable
   239  	if _, err := tx.Exec(query); err != nil {
   240  		return &database.Error{OrigErr: err, Query: []byte(query)}
   241  	}
   242  
   243  	// Also re-write the schema version for nil dirty versions to prevent
   244  	// empty schema version for failed down migration on the first migration
   245  	// See: https://github.com/seashell-org/golang-migrate/issues/330
   246  	if version >= 0 || (version == database.NilVersion && dirty) {
   247  		query := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, m.config.MigrationsTable)
   248  		if _, err := tx.Exec(query, version, dirty); err != nil {
   249  			if errRollback := tx.Rollback(); errRollback != nil {
   250  				err = multierror.Append(err, errRollback)
   251  			}
   252  			return &database.Error{OrigErr: err, Query: []byte(query)}
   253  		}
   254  	}
   255  
   256  	if err := tx.Commit(); err != nil {
   257  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  func (m *Sqlite) Version() (version int, dirty bool, err error) {
   264  	query := "SELECT version, dirty FROM " + m.config.MigrationsTable + " LIMIT 1"
   265  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   266  	if err != nil {
   267  		return database.NilVersion, false, nil
   268  	}
   269  	return version, dirty, nil
   270  }