github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/cockroachdb/cockroachdb.go (about)

     1  package cockroachdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	nurl "net/url"
    10  	"regexp"
    11  	"strconv"
    12  
    13  	"github.com/cockroachdb/cockroach-go/v2/crdb"
    14  	"github.com/getsynq/migrate/v4"
    15  	"github.com/getsynq/migrate/v4/database"
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/lib/pq"
    18  	"go.uber.org/atomic"
    19  )
    20  
    21  func init() {
    22  	db := CockroachDb{}
    23  	database.Register("cockroach", &db)
    24  	database.Register("cockroachdb", &db)
    25  	database.Register("crdb-postgres", &db)
    26  }
    27  
    28  var DefaultMigrationsTable = "schema_migrations"
    29  var DefaultLockTable = "schema_lock"
    30  
    31  var (
    32  	ErrNilConfig      = fmt.Errorf("no config")
    33  	ErrNoDatabaseName = fmt.Errorf("no database name")
    34  )
    35  
    36  type Config struct {
    37  	MigrationsTable string
    38  	LockTable       string
    39  	ForceLock       bool
    40  	DatabaseName    string
    41  }
    42  
    43  type CockroachDb struct {
    44  	db       *sql.DB
    45  	isLocked atomic.Bool
    46  
    47  	// Open and WithInstance need to guarantee that config is never nil
    48  	config *Config
    49  }
    50  
    51  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    52  	if config == nil {
    53  		return nil, ErrNilConfig
    54  	}
    55  
    56  	if err := instance.Ping(); err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	if config.DatabaseName == "" {
    61  		query := `SELECT current_database()`
    62  		var databaseName string
    63  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    64  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    65  		}
    66  
    67  		if len(databaseName) == 0 {
    68  			return nil, ErrNoDatabaseName
    69  		}
    70  
    71  		config.DatabaseName = databaseName
    72  	}
    73  
    74  	if len(config.MigrationsTable) == 0 {
    75  		config.MigrationsTable = DefaultMigrationsTable
    76  	}
    77  
    78  	if len(config.LockTable) == 0 {
    79  		config.LockTable = DefaultLockTable
    80  	}
    81  
    82  	px := &CockroachDb{
    83  		db:     instance,
    84  		config: config,
    85  	}
    86  
    87  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
    88  	if err := px.ensureLockTable(); err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	if err := px.ensureVersionTable(); err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	return px, nil
    97  }
    98  
    99  func (c *CockroachDb) Open(url string) (database.Driver, error) {
   100  	purl, err := nurl.Parse(url)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   106  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   107  	re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
   108  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   109  
   110  	db, err := sql.Open("postgres", connectString)
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	migrationsTable := purl.Query().Get("x-migrations-table")
   116  	if len(migrationsTable) == 0 {
   117  		migrationsTable = DefaultMigrationsTable
   118  	}
   119  
   120  	lockTable := purl.Query().Get("x-lock-table")
   121  	if len(lockTable) == 0 {
   122  		lockTable = DefaultLockTable
   123  	}
   124  
   125  	forceLockQuery := purl.Query().Get("x-force-lock")
   126  	forceLock, err := strconv.ParseBool(forceLockQuery)
   127  	if err != nil {
   128  		forceLock = false
   129  	}
   130  
   131  	px, err := WithInstance(db, &Config{
   132  		DatabaseName:    purl.Path,
   133  		MigrationsTable: migrationsTable,
   134  		LockTable:       lockTable,
   135  		ForceLock:       forceLock,
   136  	})
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	return px, nil
   142  }
   143  
   144  func (c *CockroachDb) Close() error {
   145  	return c.db.Close()
   146  }
   147  
   148  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   149  // See: https://github.com/cockroachdb/cockroach/issues/13546
   150  func (c *CockroachDb) Lock() error {
   151  	return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
   152  		return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
   153  			aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   154  			if err != nil {
   155  				return err
   156  			}
   157  
   158  			query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   159  			rows, err := tx.Query(query, aid)
   160  			if err != nil {
   161  				return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   162  			}
   163  			defer func() {
   164  				if errClose := rows.Close(); errClose != nil {
   165  					err = multierror.Append(err, errClose)
   166  				}
   167  			}()
   168  
   169  			// If row exists at all, lock is present
   170  			locked := rows.Next()
   171  			if locked && !c.config.ForceLock {
   172  				return database.ErrLocked
   173  			}
   174  
   175  			query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   176  			if _, err := tx.Exec(query, aid); err != nil {
   177  				return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   178  			}
   179  
   180  			return nil
   181  		})
   182  	})
   183  }
   184  
   185  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   186  // See: https://github.com/cockroachdb/cockroach/issues/13546
   187  func (c *CockroachDb) Unlock() error {
   188  	return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
   189  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   190  		if err != nil {
   191  			return err
   192  		}
   193  
   194  		// In the event of an implementation (non-migration) error, it is possible for the lock to not be released.  Until
   195  		// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   196  		query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   197  		if _, err := c.db.Exec(query, aid); err != nil {
   198  			if e, ok := err.(*pq.Error); ok {
   199  				// 42P01 is "UndefinedTableError" in CockroachDB
   200  				// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   201  				if e.Code == "42P01" {
   202  					// On drops, the lock table is fully removed;  This is fine, and is a valid "unlocked" state for the schema
   203  					return nil
   204  				}
   205  			}
   206  
   207  			return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   208  		}
   209  
   210  		return nil
   211  	})
   212  }
   213  
   214  func (c *CockroachDb) Run(migration io.Reader) error {
   215  	migr, err := ioutil.ReadAll(migration)
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	// run migration
   221  	query := string(migr[:])
   222  	if _, err := c.db.Exec(query); err != nil {
   223  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   224  	}
   225  
   226  	return nil
   227  }
   228  
   229  func (c *CockroachDb) SetVersion(version int, dirty bool) error {
   230  	return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
   231  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   232  			return err
   233  		}
   234  
   235  		// Also re-write the schema version for nil dirty versions to prevent
   236  		// empty schema version for failed down migration on the first migration
   237  		// See: https://github.com/getsynq/migrate/issues/330
   238  		if version >= 0 || (version == database.NilVersion && dirty) {
   239  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   240  				return err
   241  			}
   242  		}
   243  
   244  		return nil
   245  	})
   246  }
   247  
   248  func (c *CockroachDb) Version() (version int, dirty bool, err error) {
   249  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   250  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   251  
   252  	switch {
   253  	case err == sql.ErrNoRows:
   254  		return database.NilVersion, false, nil
   255  
   256  	case err != nil:
   257  		if e, ok := err.(*pq.Error); ok {
   258  			// 42P01 is "UndefinedTableError" in CockroachDB
   259  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   260  			if e.Code == "42P01" {
   261  				return database.NilVersion, false, nil
   262  			}
   263  		}
   264  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   265  
   266  	default:
   267  		return version, dirty, nil
   268  	}
   269  }
   270  
   271  func (c *CockroachDb) Drop() (err error) {
   272  	// select all tables in current schema
   273  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   274  	tables, err := c.db.Query(query)
   275  	if err != nil {
   276  		return &database.Error{OrigErr: err, Query: []byte(query)}
   277  	}
   278  	defer func() {
   279  		if errClose := tables.Close(); errClose != nil {
   280  			err = multierror.Append(err, errClose)
   281  		}
   282  	}()
   283  
   284  	// delete one table after another
   285  	tableNames := make([]string, 0)
   286  	for tables.Next() {
   287  		var tableName string
   288  		if err := tables.Scan(&tableName); err != nil {
   289  			return err
   290  		}
   291  		if len(tableName) > 0 {
   292  			tableNames = append(tableNames, tableName)
   293  		}
   294  	}
   295  	if err := tables.Err(); err != nil {
   296  		return &database.Error{OrigErr: err, Query: []byte(query)}
   297  	}
   298  
   299  	if len(tableNames) > 0 {
   300  		// delete one by one ...
   301  		for _, t := range tableNames {
   302  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   303  			if _, err := c.db.Exec(query); err != nil {
   304  				return &database.Error{OrigErr: err, Query: []byte(query)}
   305  			}
   306  		}
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  // ensureVersionTable checks if versions table exists and, if not, creates it.
   313  // Note that this function locks the database, which deviates from the usual
   314  // convention of "caller locks" in the CockroachDb type.
   315  func (c *CockroachDb) ensureVersionTable() (err error) {
   316  	if err = c.Lock(); err != nil {
   317  		return err
   318  	}
   319  
   320  	defer func() {
   321  		if e := c.Unlock(); e != nil {
   322  			if err == nil {
   323  				err = e
   324  			} else {
   325  				err = multierror.Append(err, e)
   326  			}
   327  		}
   328  	}()
   329  
   330  	// check if migration table exists
   331  	var count int
   332  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   333  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   334  		return &database.Error{OrigErr: err, Query: []byte(query)}
   335  	}
   336  	if count == 1 {
   337  		return nil
   338  	}
   339  
   340  	// if not, create the empty migration table
   341  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   342  	if _, err := c.db.Exec(query); err != nil {
   343  		return &database.Error{OrigErr: err, Query: []byte(query)}
   344  	}
   345  	return nil
   346  }
   347  
   348  func (c *CockroachDb) ensureLockTable() error {
   349  	// check if lock table exists
   350  	var count int
   351  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   352  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   353  		return &database.Error{OrigErr: err, Query: []byte(query)}
   354  	}
   355  	if count == 1 {
   356  		return nil
   357  	}
   358  
   359  	// if not, create the empty lock table
   360  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
   361  	if _, err := c.db.Exec(query); err != nil {
   362  		return &database.Error{OrigErr: err, Query: []byte(query)}
   363  	}
   364  
   365  	return nil
   366  }