github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/database/rqlite/rqlite.go (about)

     1  package rqlite
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	nurl "net/url"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"go.uber.org/atomic"
    11  
    12  	"github.com/bdollma-te/migrate/v4"
    13  	"github.com/bdollma-te/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	"github.com/pkg/errors"
    16  	"github.com/rqlite/gorqlite"
    17  )
    18  
    19  func init() {
    20  	database.Register("rqlite", &Rqlite{})
    21  }
    22  
    23  const (
    24  	// DefaultMigrationsTable defines the default rqlite migrations table
    25  	DefaultMigrationsTable = "schema_migrations"
    26  
    27  	// DefaultConnectInsecure defines the default setting for connect insecure
    28  	DefaultConnectInsecure = false
    29  )
    30  
    31  // ErrNilConfig is returned if no configuration was passed to WithInstance
    32  var ErrNilConfig = fmt.Errorf("no config")
    33  
    34  // ErrBadConfig is returned if configuration was invalid
    35  var ErrBadConfig = fmt.Errorf("bad parameter")
    36  
    37  // Config defines the driver configuration
    38  type Config struct {
    39  	// ConnectInsecure sets whether the connection uses TLS. Ineffectual when using WithInstance
    40  	ConnectInsecure bool
    41  	// MigrationsTable configures the migrations table name
    42  	MigrationsTable string
    43  }
    44  
    45  type Rqlite struct {
    46  	db       *gorqlite.Connection
    47  	isLocked atomic.Bool
    48  
    49  	config *Config
    50  }
    51  
    52  // WithInstance creates a rqlite database driver with an existing gorqlite database connection
    53  // and a Config struct
    54  func WithInstance(instance *gorqlite.Connection, config *Config) (database.Driver, error) {
    55  	if config == nil {
    56  		return nil, ErrNilConfig
    57  	}
    58  
    59  	// we use the consistency level check as a database ping
    60  	if _, err := instance.ConsistencyLevel(); err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	if len(config.MigrationsTable) == 0 {
    65  		config.MigrationsTable = DefaultMigrationsTable
    66  	}
    67  
    68  	driver := &Rqlite{
    69  		db:     instance,
    70  		config: config,
    71  	}
    72  
    73  	if err := driver.ensureVersionTable(); err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	return driver, nil
    78  }
    79  
    80  // OpenURL creates a rqlite database driver from a connect URL
    81  func OpenURL(url string) (database.Driver, error) {
    82  	d := &Rqlite{}
    83  	return d.Open(url)
    84  }
    85  
    86  func (r *Rqlite) ensureVersionTable() (err error) {
    87  	if err = r.Lock(); err != nil {
    88  		return err
    89  	}
    90  
    91  	defer func() {
    92  		if e := r.Unlock(); e != nil {
    93  			if err == nil {
    94  				err = e
    95  			} else {
    96  				err = multierror.Append(err, e)
    97  			}
    98  		}
    99  	}()
   100  
   101  	stmts := []string{
   102  		fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool)`, r.config.MigrationsTable),
   103  		fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version)`, r.config.MigrationsTable),
   104  	}
   105  
   106  	if _, err := r.db.Write(stmts); err != nil {
   107  		return err
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  // Open returns a new driver instance configured with parameters
   114  // coming from the URL string. Migrate will call this function
   115  // only once per instance.
   116  func (r *Rqlite) Open(url string) (database.Driver, error) {
   117  	dburl, config, err := parseUrl(url)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	r.config = config
   122  
   123  	r.db, err = gorqlite.Open(dburl.String())
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	if err := r.ensureVersionTable(); err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	return r, nil
   133  }
   134  
   135  // Close closes the underlying database instance managed by the driver.
   136  // Migrate will call this function only once per instance.
   137  func (r *Rqlite) Close() error {
   138  	r.db.Close()
   139  	return nil
   140  }
   141  
   142  // Lock should acquire a database lock so that only one migration process
   143  // can run at a time. Migrate will call this function before Run is called.
   144  // If the implementation can't provide this functionality, return nil.
   145  // Return database.ErrLocked if database is already locked.
   146  func (r *Rqlite) Lock() error {
   147  	if !r.isLocked.CAS(false, true) {
   148  		return database.ErrLocked
   149  	}
   150  	return nil
   151  }
   152  
   153  // Unlock should release the lock. Migrate will call this function after
   154  // all migrations have been run.
   155  func (r *Rqlite) Unlock() error {
   156  	if !r.isLocked.CAS(true, false) {
   157  		return database.ErrNotLocked
   158  	}
   159  	return nil
   160  }
   161  
   162  // Run applies a migration to the database. migration is guaranteed to be not nil.
   163  func (r *Rqlite) Run(migration io.Reader) error {
   164  	migr, err := io.ReadAll(migration)
   165  	if err != nil {
   166  		return err
   167  	}
   168  
   169  	query := string(migr[:])
   170  	if _, err := r.db.WriteOne(query); err != nil {
   171  		return &database.Error{OrigErr: err, Query: []byte(query)}
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  // SetVersion saves version and dirty state.
   178  // Migrate will call this function before and after each call to Run.
   179  // version must be >= -1. -1 means NilVersion.
   180  func (r *Rqlite) SetVersion(version int, dirty bool) error {
   181  	deleteQuery := fmt.Sprintf(`DELETE FROM %s`, r.config.MigrationsTable)
   182  	statements := []gorqlite.ParameterizedStatement{
   183  		{
   184  			Query: deleteQuery,
   185  		},
   186  	}
   187  
   188  	// Also re-write the schema version for nil dirty versions to prevent
   189  	// empty schema version for failed down migration on the first migration
   190  	// See: https://github.com/golang-migrate/migrate/issues/330
   191  	insertQuery := fmt.Sprintf(`INSERT INTO %s (version, dirty) VALUES (?, ?)`, r.config.MigrationsTable)
   192  	if version >= 0 || (version == database.NilVersion && dirty) {
   193  		statements = append(statements, gorqlite.ParameterizedStatement{
   194  			Query: insertQuery,
   195  			Arguments: []interface{}{
   196  				version,
   197  				dirty,
   198  			},
   199  		})
   200  	}
   201  
   202  	wr, err := r.db.WriteParameterized(statements)
   203  	if err != nil {
   204  		for i, res := range wr {
   205  			if res.Err != nil {
   206  				return &database.Error{OrigErr: err, Query: []byte(statements[i].Query)}
   207  			}
   208  		}
   209  
   210  		// if somehow we're still here, return the original error with combined queries
   211  		return &database.Error{OrigErr: err, Query: []byte(deleteQuery + "\n" + insertQuery)}
   212  	}
   213  
   214  	return nil
   215  }
   216  
   217  // Version returns the currently active version and if the database is dirty.
   218  // When no migration has been applied, it must return version -1.
   219  // Dirty means, a previous migration failed and user interaction is required.
   220  func (r *Rqlite) Version() (version int, dirty bool, err error) {
   221  	query := "SELECT version, dirty FROM " + r.config.MigrationsTable + " LIMIT 1"
   222  
   223  	qr, err := r.db.QueryOne(query)
   224  	if err != nil {
   225  		return database.NilVersion, false, nil
   226  	}
   227  
   228  	if !qr.Next() {
   229  		return database.NilVersion, false, nil
   230  	}
   231  
   232  	if err := qr.Scan(&version, &dirty); err != nil {
   233  		return database.NilVersion, false, &database.Error{OrigErr: err, Query: []byte(query)}
   234  	}
   235  
   236  	return version, dirty, nil
   237  }
   238  
   239  // Drop deletes everything in the database.
   240  // Note that this is a breaking action, a new call to Open() is necessary to
   241  // ensure subsequent calls work as expected.
   242  func (r *Rqlite) Drop() error {
   243  	query := `SELECT name FROM sqlite_master WHERE type = 'table'`
   244  
   245  	tables, err := r.db.QueryOne(query)
   246  	if err != nil {
   247  		return &database.Error{OrigErr: err, Query: []byte(query)}
   248  	}
   249  
   250  	statements := make([]string, 0)
   251  	for tables.Next() {
   252  		var tableName string
   253  		if err := tables.Scan(&tableName); err != nil {
   254  			return err
   255  		}
   256  
   257  		if len(tableName) > 0 {
   258  			statement := fmt.Sprintf(`DROP TABLE %s`, tableName)
   259  			statements = append(statements, statement)
   260  		}
   261  	}
   262  
   263  	// return if nothing to do
   264  	if len(statements) <= 0 {
   265  		return nil
   266  	}
   267  
   268  	wr, err := r.db.Write(statements)
   269  	if err != nil {
   270  		for i, res := range wr {
   271  			if res.Err != nil {
   272  				return &database.Error{OrigErr: err, Query: []byte(statements[i])}
   273  			}
   274  		}
   275  
   276  		// if somehow we're still here, return the original error with combined queries
   277  		return &database.Error{OrigErr: err, Query: []byte(strings.Join(statements, "\n"))}
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  func parseUrl(url string) (*nurl.URL, *Config, error) {
   284  	parsedUrl, err := nurl.Parse(url)
   285  	if err != nil {
   286  		return nil, nil, err
   287  	}
   288  
   289  	config, err := parseConfigFromQuery(parsedUrl.Query())
   290  	if err != nil {
   291  		return nil, nil, err
   292  	}
   293  
   294  	if parsedUrl.Scheme != "rqlite" {
   295  		return nil, nil, errors.Wrap(ErrBadConfig, "bad scheme")
   296  	}
   297  
   298  	// adapt from rqlite to http/https schemes
   299  	if config.ConnectInsecure {
   300  		parsedUrl.Scheme = "http"
   301  	} else {
   302  		parsedUrl.Scheme = "https"
   303  	}
   304  
   305  	filteredUrl := migrate.FilterCustomQuery(parsedUrl)
   306  
   307  	return filteredUrl, config, nil
   308  }
   309  
   310  func parseConfigFromQuery(queryVals nurl.Values) (*Config, error) {
   311  	c := Config{
   312  		ConnectInsecure: DefaultConnectInsecure,
   313  		MigrationsTable: DefaultMigrationsTable,
   314  	}
   315  
   316  	migrationsTable := queryVals.Get("x-migrations-table")
   317  	if migrationsTable != "" {
   318  		if strings.HasPrefix(migrationsTable, "sqlite_") {
   319  			return nil, errors.Wrap(ErrBadConfig, "invalid value for x-migrations-table")
   320  		}
   321  		c.MigrationsTable = migrationsTable
   322  	}
   323  
   324  	connectInsecureStr := queryVals.Get("x-connect-insecure")
   325  	if connectInsecureStr != "" {
   326  		connectInsecure, err := strconv.ParseBool(connectInsecureStr)
   327  		if err != nil {
   328  			return nil, errors.Wrap(ErrBadConfig, "invalid value for x-connect-insecure")
   329  		}
   330  		c.ConnectInsecure = connectInsecure
   331  	}
   332  
   333  	return &c, nil
   334  }