github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/cockroachdb/cockroachdb.go (about)

     1  package cockroachdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	nurl "net/url"
     9  	"regexp"
    10  	"strconv"
    11  
    12  	"github.com/cockroachdb/cockroach-go/v2/crdb"
    13  	"github.com/golang-migrate/migrate/v4"
    14  	"github.com/golang-migrate/migrate/v4/database"
    15  	"github.com/hashicorp/go-multierror"
    16  	"github.com/lib/pq"
    17  	"go.uber.org/atomic"
    18  )
    19  
    20  func init() {
    21  	db := CockroachDb{}
    22  	database.Register("cockroach", &db)
    23  	database.Register("cockroachdb", &db)
    24  	database.Register("crdb-postgres", &db)
    25  }
    26  
    27  var DefaultMigrationsTable = "schema_migrations"
    28  var DefaultLockTable = "schema_lock"
    29  
    30  var (
    31  	ErrNilConfig      = fmt.Errorf("no config")
    32  	ErrNoDatabaseName = fmt.Errorf("no database name")
    33  )
    34  
    35  type Config struct {
    36  	MigrationsTable string
    37  	LockTable       string
    38  	ForceLock       bool
    39  	DatabaseName    string
    40  }
    41  
    42  type CockroachDb struct {
    43  	db       *sql.DB
    44  	isLocked atomic.Bool
    45  
    46  	// Open and WithInstance need to guarantee that config is never nil
    47  	config *Config
    48  }
    49  
    50  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    51  	if config == nil {
    52  		return nil, ErrNilConfig
    53  	}
    54  
    55  	if err := instance.Ping(); err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	if config.DatabaseName == "" {
    60  		query := `SELECT current_database()`
    61  		var databaseName string
    62  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    63  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    64  		}
    65  
    66  		if len(databaseName) == 0 {
    67  			return nil, ErrNoDatabaseName
    68  		}
    69  
    70  		config.DatabaseName = databaseName
    71  	}
    72  
    73  	if len(config.MigrationsTable) == 0 {
    74  		config.MigrationsTable = DefaultMigrationsTable
    75  	}
    76  
    77  	if len(config.LockTable) == 0 {
    78  		config.LockTable = DefaultLockTable
    79  	}
    80  
    81  	px := &CockroachDb{
    82  		db:     instance,
    83  		config: config,
    84  	}
    85  
    86  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
    87  	if err := px.ensureLockTable(); err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	if err := px.ensureVersionTable(); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return px, nil
    96  }
    97  
    98  func (c *CockroachDb) Open(url string) (database.Driver, error) {
    99  	purl, err := nurl.Parse(url)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   105  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   106  	re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
   107  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   108  
   109  	db, err := sql.Open("postgres", connectString)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	migrationsTable := purl.Query().Get("x-migrations-table")
   115  	if len(migrationsTable) == 0 {
   116  		migrationsTable = DefaultMigrationsTable
   117  	}
   118  
   119  	lockTable := purl.Query().Get("x-lock-table")
   120  	if len(lockTable) == 0 {
   121  		lockTable = DefaultLockTable
   122  	}
   123  
   124  	forceLockQuery := purl.Query().Get("x-force-lock")
   125  	forceLock, err := strconv.ParseBool(forceLockQuery)
   126  	if err != nil {
   127  		forceLock = false
   128  	}
   129  
   130  	px, err := WithInstance(db, &Config{
   131  		DatabaseName:    purl.Path,
   132  		MigrationsTable: migrationsTable,
   133  		LockTable:       lockTable,
   134  		ForceLock:       forceLock,
   135  	})
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return px, nil
   141  }
   142  
   143  func (c *CockroachDb) Close() error {
   144  	return c.db.Close()
   145  }
   146  
   147  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   148  // See: https://github.com/cockroachdb/cockroach/issues/13546
   149  func (c *CockroachDb) Lock() error {
   150  	return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
   151  		return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
   152  			aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   153  			if err != nil {
   154  				return err
   155  			}
   156  
   157  			query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   158  			rows, err := tx.Query(query, aid)
   159  			if err != nil {
   160  				return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   161  			}
   162  			defer func() {
   163  				if errClose := rows.Close(); errClose != nil {
   164  					err = multierror.Append(err, errClose)
   165  				}
   166  			}()
   167  
   168  			// If row exists at all, lock is present
   169  			locked := rows.Next()
   170  			if locked && !c.config.ForceLock {
   171  				return database.ErrLocked
   172  			}
   173  
   174  			query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   175  			if _, err := tx.Exec(query, aid); err != nil {
   176  				return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   177  			}
   178  
   179  			return nil
   180  		})
   181  	})
   182  }
   183  
   184  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   185  // See: https://github.com/cockroachdb/cockroach/issues/13546
   186  func (c *CockroachDb) Unlock() error {
   187  	return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
   188  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   189  		if err != nil {
   190  			return err
   191  		}
   192  
   193  		// In the event of an implementation (non-migration) error, it is possible for the lock to not be released.  Until
   194  		// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   195  		query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   196  		if _, err := c.db.Exec(query, aid); err != nil {
   197  			if e, ok := err.(*pq.Error); ok {
   198  				// 42P01 is "UndefinedTableError" in CockroachDB
   199  				// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   200  				if e.Code == "42P01" {
   201  					// On drops, the lock table is fully removed;  This is fine, and is a valid "unlocked" state for the schema
   202  					return nil
   203  				}
   204  			}
   205  
   206  			return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   207  		}
   208  
   209  		return nil
   210  	})
   211  }
   212  
   213  func (c *CockroachDb) Run(migration io.Reader) error {
   214  	migr, err := io.ReadAll(migration)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	// run migration
   220  	query := string(migr[:])
   221  	if _, err := c.db.Exec(query); err != nil {
   222  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   223  	}
   224  
   225  	return nil
   226  }
   227  
   228  func (c *CockroachDb) SetVersion(version int, dirty bool) error {
   229  	return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
   230  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   231  			return err
   232  		}
   233  
   234  		// Also re-write the schema version for nil dirty versions to prevent
   235  		// empty schema version for failed down migration on the first migration
   236  		// See: https://github.com/golang-migrate/migrate/issues/330
   237  		if version >= 0 || (version == database.NilVersion && dirty) {
   238  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   239  				return err
   240  			}
   241  		}
   242  
   243  		return nil
   244  	})
   245  }
   246  
   247  func (c *CockroachDb) Version() (version int, dirty bool, err error) {
   248  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   249  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   250  
   251  	switch {
   252  	case err == sql.ErrNoRows:
   253  		return database.NilVersion, false, nil
   254  
   255  	case err != nil:
   256  		if e, ok := err.(*pq.Error); ok {
   257  			// 42P01 is "UndefinedTableError" in CockroachDB
   258  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   259  			if e.Code == "42P01" {
   260  				return database.NilVersion, false, nil
   261  			}
   262  		}
   263  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   264  
   265  	default:
   266  		return version, dirty, nil
   267  	}
   268  }
   269  
   270  func (c *CockroachDb) Drop() (err error) {
   271  	// select all tables in current schema
   272  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   273  	tables, err := c.db.Query(query)
   274  	if err != nil {
   275  		return &database.Error{OrigErr: err, Query: []byte(query)}
   276  	}
   277  	defer func() {
   278  		if errClose := tables.Close(); errClose != nil {
   279  			err = multierror.Append(err, errClose)
   280  		}
   281  	}()
   282  
   283  	// delete one table after another
   284  	tableNames := make([]string, 0)
   285  	for tables.Next() {
   286  		var tableName string
   287  		if err := tables.Scan(&tableName); err != nil {
   288  			return err
   289  		}
   290  		if len(tableName) > 0 {
   291  			tableNames = append(tableNames, tableName)
   292  		}
   293  	}
   294  	if err := tables.Err(); err != nil {
   295  		return &database.Error{OrigErr: err, Query: []byte(query)}
   296  	}
   297  
   298  	if len(tableNames) > 0 {
   299  		// delete one by one ...
   300  		for _, t := range tableNames {
   301  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   302  			if _, err := c.db.Exec(query); err != nil {
   303  				return &database.Error{OrigErr: err, Query: []byte(query)}
   304  			}
   305  		}
   306  	}
   307  
   308  	return nil
   309  }
   310  
   311  // ensureVersionTable checks if versions table exists and, if not, creates it.
   312  // Note that this function locks the database, which deviates from the usual
   313  // convention of "caller locks" in the CockroachDb type.
   314  func (c *CockroachDb) ensureVersionTable() (err error) {
   315  	if err = c.Lock(); err != nil {
   316  		return err
   317  	}
   318  
   319  	defer func() {
   320  		if e := c.Unlock(); e != nil {
   321  			if err == nil {
   322  				err = e
   323  			} else {
   324  				err = multierror.Append(err, e)
   325  			}
   326  		}
   327  	}()
   328  
   329  	// check if migration table exists
   330  	var count int
   331  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   332  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   333  		return &database.Error{OrigErr: err, Query: []byte(query)}
   334  	}
   335  	if count == 1 {
   336  		return nil
   337  	}
   338  
   339  	// if not, create the empty migration table
   340  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   341  	if _, err := c.db.Exec(query); err != nil {
   342  		return &database.Error{OrigErr: err, Query: []byte(query)}
   343  	}
   344  	return nil
   345  }
   346  
   347  func (c *CockroachDb) ensureLockTable() error {
   348  	// check if lock table exists
   349  	var count int
   350  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   351  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   352  		return &database.Error{OrigErr: err, Query: []byte(query)}
   353  	}
   354  	if count == 1 {
   355  		return nil
   356  	}
   357  
   358  	// if not, create the empty lock table
   359  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
   360  	if _, err := c.db.Exec(query); err != nil {
   361  		return &database.Error{OrigErr: err, Query: []byte(query)}
   362  	}
   363  
   364  	return nil
   365  }