github.com/nokia/migrate/v4@v4.16.0/database/cockroachdb/cockroachdb.go (about)

     1  package cockroachdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  	"regexp"
    11  	"strconv"
    12  
    13  	"github.com/cockroachdb/cockroach-go/v2/crdb"
    14  	"github.com/hashicorp/go-multierror"
    15  	"github.com/lib/pq"
    16  	"github.com/nokia/migrate/v4"
    17  	"github.com/nokia/migrate/v4/database"
    18  	"github.com/nokia/migrate/v4/source"
    19  	"go.uber.org/atomic"
    20  )
    21  
    22  func init() {
    23  	db := CockroachDb{}
    24  	database.Register("cockroach", &db)
    25  	database.Register("cockroachdb", &db)
    26  	database.Register("crdb-postgres", &db)
    27  }
    28  
    29  var (
    30  	DefaultMigrationsTable = "schema_migrations"
    31  	DefaultLockTable       = "schema_lock"
    32  )
    33  
    34  var (
    35  	ErrNilConfig      = fmt.Errorf("no config")
    36  	ErrNoDatabaseName = fmt.Errorf("no database name")
    37  )
    38  
    39  type Config struct {
    40  	MigrationsTable string
    41  	LockTable       string
    42  	ForceLock       bool
    43  	DatabaseName    string
    44  }
    45  
    46  type CockroachDb struct {
    47  	db       *sql.DB
    48  	isLocked atomic.Bool
    49  
    50  	// Open and WithInstance need to guarantee that config is never nil
    51  	config *Config
    52  }
    53  
    54  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    55  	if config == nil {
    56  		return nil, ErrNilConfig
    57  	}
    58  
    59  	if err := instance.Ping(); err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	if config.DatabaseName == "" {
    64  		query := `SELECT current_database()`
    65  		var databaseName string
    66  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    67  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    68  		}
    69  
    70  		if len(databaseName) == 0 {
    71  			return nil, ErrNoDatabaseName
    72  		}
    73  
    74  		config.DatabaseName = databaseName
    75  	}
    76  
    77  	if len(config.MigrationsTable) == 0 {
    78  		config.MigrationsTable = DefaultMigrationsTable
    79  	}
    80  
    81  	if len(config.LockTable) == 0 {
    82  		config.LockTable = DefaultLockTable
    83  	}
    84  
    85  	px := &CockroachDb{
    86  		db:     instance,
    87  		config: config,
    88  	}
    89  
    90  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
    91  	if err := px.ensureLockTable(); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	if err := px.ensureVersionTable(); err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	return px, nil
   100  }
   101  
   102  func (c *CockroachDb) Open(url string) (database.Driver, error) {
   103  	purl, err := nurl.Parse(url)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   109  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   110  	re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
   111  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   112  
   113  	db, err := sql.Open("postgres", connectString)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	migrationsTable := purl.Query().Get("x-migrations-table")
   119  	if len(migrationsTable) == 0 {
   120  		migrationsTable = DefaultMigrationsTable
   121  	}
   122  
   123  	lockTable := purl.Query().Get("x-lock-table")
   124  	if len(lockTable) == 0 {
   125  		lockTable = DefaultLockTable
   126  	}
   127  
   128  	forceLockQuery := purl.Query().Get("x-force-lock")
   129  	forceLock, err := strconv.ParseBool(forceLockQuery)
   130  	if err != nil {
   131  		forceLock = false
   132  	}
   133  
   134  	px, err := WithInstance(db, &Config{
   135  		DatabaseName:    purl.Path,
   136  		MigrationsTable: migrationsTable,
   137  		LockTable:       lockTable,
   138  		ForceLock:       forceLock,
   139  	})
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	return px, nil
   145  }
   146  
   147  func (c *CockroachDb) Close() error {
   148  	return c.db.Close()
   149  }
   150  
   151  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   152  // See: https://github.com/cockroachdb/cockroach/issues/13546
   153  func (c *CockroachDb) Lock() error {
   154  	return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
   155  		return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
   156  			aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   157  			if err != nil {
   158  				return err
   159  			}
   160  
   161  			query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   162  			rows, err := tx.Query(query, aid)
   163  			if err != nil {
   164  				return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   165  			}
   166  			defer func() {
   167  				if errClose := rows.Close(); errClose != nil {
   168  					err = multierror.Append(err, errClose)
   169  				}
   170  			}()
   171  
   172  			// If row exists at all, lock is present
   173  			locked := rows.Next()
   174  			if locked && !c.config.ForceLock {
   175  				return database.ErrLocked
   176  			}
   177  
   178  			query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   179  			if _, err := tx.Exec(query, aid); err != nil {
   180  				return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   181  			}
   182  
   183  			return nil
   184  		})
   185  	})
   186  }
   187  
   188  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   189  // See: https://github.com/cockroachdb/cockroach/issues/13546
   190  func (c *CockroachDb) Unlock() error {
   191  	return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
   192  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   193  		if err != nil {
   194  			return err
   195  		}
   196  
   197  		// In the event of an implementation (non-migration) error, it is possible for the lock to not be released.  Until
   198  		// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   199  		query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   200  		if _, err := c.db.Exec(query, aid); err != nil {
   201  			if e, ok := err.(*pq.Error); ok {
   202  				// 42P01 is "UndefinedTableError" in CockroachDB
   203  				// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   204  				if e.Code == "42P01" {
   205  					// On drops, the lock table is fully removed;  This is fine, and is a valid "unlocked" state for the schema
   206  					return nil
   207  				}
   208  			}
   209  
   210  			return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   211  		}
   212  
   213  		return nil
   214  	})
   215  }
   216  
   217  func (c *CockroachDb) Run(migration io.Reader) error {
   218  	migr, err := ioutil.ReadAll(migration)
   219  	if err != nil {
   220  		return err
   221  	}
   222  
   223  	// run migration
   224  	query := string(migr[:])
   225  	if _, err := c.db.Exec(query); err != nil {
   226  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  func (c *CockroachDb) RunFunctionMigration(fn source.MigrationFunc) error {
   233  	return database.ErrNotImpl
   234  }
   235  
   236  func (c *CockroachDb) SetVersion(version int, dirty bool) error {
   237  	return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
   238  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   239  			return err
   240  		}
   241  
   242  		// Also re-write the schema version for nil dirty versions to prevent
   243  		// empty schema version for failed down migration on the first migration
   244  		// See: https://github.com/nokia/migrate/issues/330
   245  		if version >= 0 || (version == database.NilVersion && dirty) {
   246  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   247  				return err
   248  			}
   249  		}
   250  
   251  		return nil
   252  	})
   253  }
   254  
   255  func (c *CockroachDb) Version() (version int, dirty bool, err error) {
   256  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   257  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   258  
   259  	switch {
   260  	case err == sql.ErrNoRows:
   261  		return database.NilVersion, false, nil
   262  
   263  	case err != nil:
   264  		if e, ok := err.(*pq.Error); ok {
   265  			// 42P01 is "UndefinedTableError" in CockroachDB
   266  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   267  			if e.Code == "42P01" {
   268  				return database.NilVersion, false, nil
   269  			}
   270  		}
   271  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   272  
   273  	default:
   274  		return version, dirty, nil
   275  	}
   276  }
   277  
   278  func (c *CockroachDb) Drop() (err error) {
   279  	// select all tables in current schema
   280  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   281  	tables, err := c.db.Query(query)
   282  	if err != nil {
   283  		return &database.Error{OrigErr: err, Query: []byte(query)}
   284  	}
   285  	defer func() {
   286  		if errClose := tables.Close(); errClose != nil {
   287  			err = multierror.Append(err, errClose)
   288  		}
   289  	}()
   290  
   291  	// delete one table after another
   292  	tableNames := make([]string, 0)
   293  	for tables.Next() {
   294  		var tableName string
   295  		if err := tables.Scan(&tableName); err != nil {
   296  			return err
   297  		}
   298  		if len(tableName) > 0 {
   299  			tableNames = append(tableNames, tableName)
   300  		}
   301  	}
   302  	if err := tables.Err(); err != nil {
   303  		return &database.Error{OrigErr: err, Query: []byte(query)}
   304  	}
   305  
   306  	if len(tableNames) > 0 {
   307  		// delete one by one ...
   308  		for _, t := range tableNames {
   309  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   310  			if _, err := c.db.Exec(query); err != nil {
   311  				return &database.Error{OrigErr: err, Query: []byte(query)}
   312  			}
   313  		}
   314  	}
   315  
   316  	return nil
   317  }
   318  
   319  // ensureVersionTable checks if versions table exists and, if not, creates it.
   320  // Note that this function locks the database, which deviates from the usual
   321  // convention of "caller locks" in the CockroachDb type.
   322  func (c *CockroachDb) ensureVersionTable() (err error) {
   323  	if err = c.Lock(); err != nil {
   324  		return err
   325  	}
   326  
   327  	defer func() {
   328  		if e := c.Unlock(); e != nil {
   329  			if err == nil {
   330  				err = e
   331  			} else {
   332  				err = multierror.Append(err, e)
   333  			}
   334  		}
   335  	}()
   336  
   337  	// check if migration table exists
   338  	var count int
   339  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   340  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   341  		return &database.Error{OrigErr: err, Query: []byte(query)}
   342  	}
   343  	if count == 1 {
   344  		return nil
   345  	}
   346  
   347  	// if not, create the empty migration table
   348  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   349  	if _, err := c.db.Exec(query); err != nil {
   350  		return &database.Error{OrigErr: err, Query: []byte(query)}
   351  	}
   352  	return nil
   353  }
   354  
   355  func (c *CockroachDb) ensureLockTable() error {
   356  	// check if lock table exists
   357  	var count int
   358  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   359  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   360  		return &database.Error{OrigErr: err, Query: []byte(query)}
   361  	}
   362  	if count == 1 {
   363  		return nil
   364  	}
   365  
   366  	// if not, create the empty lock table
   367  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
   368  	if _, err := c.db.Exec(query); err != nil {
   369  		return &database.Error{OrigErr: err, Query: []byte(query)}
   370  	}
   371  
   372  	return nil
   373  }