github.com/fr-nvriep/migrate/v4@v4.3.2/database/cockroachdb/cockroachdb.go (about)

     1  package cockroachdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  	"regexp"
    11  	"strconv"
    12  )
    13  
    14  import (
    15  	"github.com/cockroachdb/cockroach-go/crdb"
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/lib/pq"
    18  )
    19  
    20  import (
    21  	"github.com/fr-nvriep/migrate/v4"
    22  	"github.com/fr-nvriep/migrate/v4/database"
    23  )
    24  
    25  func init() {
    26  	db := CockroachDb{}
    27  	database.Register("cockroach", &db)
    28  	database.Register("cockroachdb", &db)
    29  	database.Register("crdb-postgres", &db)
    30  }
    31  
    32  var DefaultMigrationsTable = "schema_migrations"
    33  var DefaultLockTable = "schema_lock"
    34  
    35  var (
    36  	ErrNilConfig      = fmt.Errorf("no config")
    37  	ErrNoDatabaseName = fmt.Errorf("no database name")
    38  )
    39  
    40  type Config struct {
    41  	MigrationsTable string
    42  	LockTable       string
    43  	ForceLock       bool
    44  	DatabaseName    string
    45  }
    46  
    47  type CockroachDb struct {
    48  	db       *sql.DB
    49  	isLocked bool
    50  
    51  	// Open and WithInstance need to guarantee that config is never nil
    52  	config *Config
    53  }
    54  
    55  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    56  	if config == nil {
    57  		return nil, ErrNilConfig
    58  	}
    59  
    60  	if err := instance.Ping(); err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	query := `SELECT current_database()`
    65  	var databaseName string
    66  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    67  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    68  	}
    69  
    70  	if len(databaseName) == 0 {
    71  		return nil, ErrNoDatabaseName
    72  	}
    73  
    74  	config.DatabaseName = databaseName
    75  
    76  	if len(config.MigrationsTable) == 0 {
    77  		config.MigrationsTable = DefaultMigrationsTable
    78  	}
    79  
    80  	if len(config.LockTable) == 0 {
    81  		config.LockTable = DefaultLockTable
    82  	}
    83  
    84  	px := &CockroachDb{
    85  		db:     instance,
    86  		config: config,
    87  	}
    88  
    89  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
    90  	if err := px.ensureLockTable(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	if err := px.ensureVersionTable(); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return px, nil
    99  }
   100  
   101  func (c *CockroachDb) Open(url string) (database.Driver, error) {
   102  	purl, err := nurl.Parse(url)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   108  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   109  	re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
   110  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   111  
   112  	db, err := sql.Open("postgres", connectString)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	migrationsTable := purl.Query().Get("x-migrations-table")
   118  	if len(migrationsTable) == 0 {
   119  		migrationsTable = DefaultMigrationsTable
   120  	}
   121  
   122  	lockTable := purl.Query().Get("x-lock-table")
   123  	if len(lockTable) == 0 {
   124  		lockTable = DefaultLockTable
   125  	}
   126  
   127  	forceLockQuery := purl.Query().Get("x-force-lock")
   128  	forceLock, err := strconv.ParseBool(forceLockQuery)
   129  	if err != nil {
   130  		forceLock = false
   131  	}
   132  
   133  	px, err := WithInstance(db, &Config{
   134  		DatabaseName:    purl.Path,
   135  		MigrationsTable: migrationsTable,
   136  		LockTable:       lockTable,
   137  		ForceLock:       forceLock,
   138  	})
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return px, nil
   144  }
   145  
   146  func (c *CockroachDb) Close() error {
   147  	return c.db.Close()
   148  }
   149  
   150  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   151  // See: https://github.com/cockroachdb/cockroach/issues/13546
   152  func (c *CockroachDb) Lock() error {
   153  	err := crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
   154  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   155  		if err != nil {
   156  			return err
   157  		}
   158  
   159  		query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   160  		rows, err := tx.Query(query, aid)
   161  		if err != nil {
   162  			return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   163  		}
   164  		defer func() {
   165  			if errClose := rows.Close(); errClose != nil {
   166  				err = multierror.Append(err, errClose)
   167  			}
   168  		}()
   169  
   170  		// If row exists at all, lock is present
   171  		locked := rows.Next()
   172  		if locked && !c.config.ForceLock {
   173  			return database.ErrLocked
   174  		}
   175  
   176  		query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   177  		if _, err := tx.Exec(query, aid); err != nil {
   178  			return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   179  		}
   180  
   181  		return nil
   182  	})
   183  
   184  	if err != nil {
   185  		return err
   186  	} else {
   187  		c.isLocked = true
   188  		return nil
   189  	}
   190  }
   191  
   192  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   193  // See: https://github.com/cockroachdb/cockroach/issues/13546
   194  func (c *CockroachDb) Unlock() error {
   195  	aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	// In the event of an implementation (non-migration) error, it is possible for the lock to not be released.  Until
   201  	// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   202  	query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   203  	if _, err := c.db.Exec(query, aid); err != nil {
   204  		if e, ok := err.(*pq.Error); ok {
   205  			// 42P01 is "UndefinedTableError" in CockroachDB
   206  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   207  			if e.Code == "42P01" {
   208  				// On drops, the lock table is fully removed;  This is fine, and is a valid "unlocked" state for the schema
   209  				c.isLocked = false
   210  				return nil
   211  			}
   212  		}
   213  		return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   214  	}
   215  
   216  	c.isLocked = false
   217  	return nil
   218  }
   219  
   220  func (c *CockroachDb) Run(migration io.Reader) error {
   221  	migr, err := ioutil.ReadAll(migration)
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	// run migration
   227  	query := string(migr[:])
   228  	if _, err := c.db.Exec(query); err != nil {
   229  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  func (c *CockroachDb) SetVersion(version int, dirty bool) error {
   236  	return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
   237  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   238  			return err
   239  		}
   240  
   241  		if version >= 0 {
   242  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   243  				return err
   244  			}
   245  		}
   246  
   247  		return nil
   248  	})
   249  }
   250  
   251  func (c *CockroachDb) Version() (version int, dirty bool, err error) {
   252  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   253  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   254  
   255  	switch {
   256  	case err == sql.ErrNoRows:
   257  		return database.NilVersion, false, nil
   258  
   259  	case err != nil:
   260  		if e, ok := err.(*pq.Error); ok {
   261  			// 42P01 is "UndefinedTableError" in CockroachDB
   262  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   263  			if e.Code == "42P01" {
   264  				return database.NilVersion, false, nil
   265  			}
   266  		}
   267  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   268  
   269  	default:
   270  		return version, dirty, nil
   271  	}
   272  }
   273  
   274  func (c *CockroachDb) Drop() (err error) {
   275  	// select all tables in current schema
   276  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   277  	tables, err := c.db.Query(query)
   278  	if err != nil {
   279  		return &database.Error{OrigErr: err, Query: []byte(query)}
   280  	}
   281  	defer func() {
   282  		if errClose := tables.Close(); errClose != nil {
   283  			err = multierror.Append(err, errClose)
   284  		}
   285  	}()
   286  
   287  	// delete one table after another
   288  	tableNames := make([]string, 0)
   289  	for tables.Next() {
   290  		var tableName string
   291  		if err := tables.Scan(&tableName); err != nil {
   292  			return err
   293  		}
   294  		if len(tableName) > 0 {
   295  			tableNames = append(tableNames, tableName)
   296  		}
   297  	}
   298  
   299  	if len(tableNames) > 0 {
   300  		// delete one by one ...
   301  		for _, t := range tableNames {
   302  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   303  			if _, err := c.db.Exec(query); err != nil {
   304  				return &database.Error{OrigErr: err, Query: []byte(query)}
   305  			}
   306  		}
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  // ensureVersionTable checks if versions table exists and, if not, creates it.
   313  // Note that this function locks the database, which deviates from the usual
   314  // convention of "caller locks" in the CockroachDb type.
   315  func (c *CockroachDb) ensureVersionTable() (err error) {
   316  	if err = c.Lock(); err != nil {
   317  		return err
   318  	}
   319  
   320  	defer func() {
   321  		if e := c.Unlock(); e != nil {
   322  			if err == nil {
   323  				err = e
   324  			} else {
   325  				err = multierror.Append(err, e)
   326  			}
   327  		}
   328  	}()
   329  
   330  	// check if migration table exists
   331  	var count int
   332  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   333  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   334  		return &database.Error{OrigErr: err, Query: []byte(query)}
   335  	}
   336  	if count == 1 {
   337  		return nil
   338  	}
   339  
   340  	// if not, create the empty migration table
   341  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   342  	if _, err := c.db.Exec(query); err != nil {
   343  		return &database.Error{OrigErr: err, Query: []byte(query)}
   344  	}
   345  	return nil
   346  }
   347  
   348  func (c *CockroachDb) ensureLockTable() error {
   349  	// check if lock table exists
   350  	var count int
   351  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   352  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   353  		return &database.Error{OrigErr: err, Query: []byte(query)}
   354  	}
   355  	if count == 1 {
   356  		return nil
   357  	}
   358  
   359  	// if not, create the empty lock table
   360  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
   361  	if _, err := c.db.Exec(query); err != nil {
   362  		return &database.Error{OrigErr: err, Query: []byte(query)}
   363  	}
   364  
   365  	return nil
   366  }