github.com/Elate-DevOps/migrate/v4@v4.0.12/database/cockroachdb/cockroachdb.go (about)

     1  package cockroachdb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	nurl "net/url"
     9  	"regexp"
    10  	"strconv"
    11  
    12  	"github.com/Elate-DevOps/migrate/v4"
    13  	"github.com/Elate-DevOps/migrate/v4/database"
    14  	"github.com/cockroachdb/cockroach-go/v2/crdb"
    15  	"github.com/hashicorp/go-multierror"
    16  	"github.com/lib/pq"
    17  	"go.uber.org/atomic"
    18  )
    19  
    20  func init() {
    21  	db := CockroachDb{}
    22  	database.Register("cockroach", &db)
    23  	database.Register("cockroachdb", &db)
    24  	database.Register("crdb-postgres", &db)
    25  }
    26  
    27  var (
    28  	DefaultMigrationsTable = "schema_migrations"
    29  	DefaultLockTable       = "schema_lock"
    30  )
    31  
    32  var (
    33  	ErrNilConfig      = fmt.Errorf("no config")
    34  	ErrNoDatabaseName = fmt.Errorf("no database name")
    35  )
    36  
    37  type Config struct {
    38  	MigrationsTable string
    39  	LockTable       string
    40  	ForceLock       bool
    41  	DatabaseName    string
    42  }
    43  
    44  type CockroachDb struct {
    45  	db       *sql.DB
    46  	isLocked atomic.Bool
    47  
    48  	// Open and WithInstance need to guarantee that config is never nil
    49  	config *Config
    50  }
    51  
    52  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    53  	if config == nil {
    54  		return nil, ErrNilConfig
    55  	}
    56  
    57  	if err := instance.Ping(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	if config.DatabaseName == "" {
    62  		query := `SELECT current_database()`
    63  		var databaseName string
    64  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    65  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    66  		}
    67  
    68  		if len(databaseName) == 0 {
    69  			return nil, ErrNoDatabaseName
    70  		}
    71  
    72  		config.DatabaseName = databaseName
    73  	}
    74  
    75  	if len(config.MigrationsTable) == 0 {
    76  		config.MigrationsTable = DefaultMigrationsTable
    77  	}
    78  
    79  	if len(config.LockTable) == 0 {
    80  		config.LockTable = DefaultLockTable
    81  	}
    82  
    83  	px := &CockroachDb{
    84  		db:     instance,
    85  		config: config,
    86  	}
    87  
    88  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
    89  	if err := px.ensureLockTable(); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	if err := px.ensureVersionTable(); err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return px, nil
    98  }
    99  
   100  func (c *CockroachDb) Open(url string) (database.Driver, error) {
   101  	purl, err := nurl.Parse(url)
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   107  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   108  	re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
   109  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   110  
   111  	db, err := sql.Open("postgres", connectString)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	migrationsTable := purl.Query().Get("x-migrations-table")
   117  	if len(migrationsTable) == 0 {
   118  		migrationsTable = DefaultMigrationsTable
   119  	}
   120  
   121  	lockTable := purl.Query().Get("x-lock-table")
   122  	if len(lockTable) == 0 {
   123  		lockTable = DefaultLockTable
   124  	}
   125  
   126  	forceLockQuery := purl.Query().Get("x-force-lock")
   127  	forceLock, err := strconv.ParseBool(forceLockQuery)
   128  	if err != nil {
   129  		forceLock = false
   130  	}
   131  
   132  	px, err := WithInstance(db, &Config{
   133  		DatabaseName:    purl.Path,
   134  		MigrationsTable: migrationsTable,
   135  		LockTable:       lockTable,
   136  		ForceLock:       forceLock,
   137  	})
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	return px, nil
   143  }
   144  
   145  func (c *CockroachDb) Close() error {
   146  	return c.db.Close()
   147  }
   148  
   149  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   150  // See: https://github.com/cockroachdb/cockroach/issues/13546
   151  func (c *CockroachDb) Lock() error {
   152  	return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
   153  		return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
   154  			aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   155  			if err != nil {
   156  				return err
   157  			}
   158  
   159  			query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   160  			rows, err := tx.Query(query, aid)
   161  			if err != nil {
   162  				return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   163  			}
   164  			defer func() {
   165  				if errClose := rows.Close(); errClose != nil {
   166  					err = multierror.Append(err, errClose)
   167  				}
   168  			}()
   169  
   170  			// If row exists at all, lock is present
   171  			locked := rows.Next()
   172  			if locked && !c.config.ForceLock {
   173  				return database.ErrLocked
   174  			}
   175  
   176  			query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   177  			if _, err := tx.Exec(query, aid); err != nil {
   178  				return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   179  			}
   180  
   181  			return nil
   182  		})
   183  	})
   184  }
   185  
   186  // Locking is done manually with a separate lock table.  Implementing advisory locks in CRDB is being discussed
   187  // See: https://github.com/cockroachdb/cockroach/issues/13546
   188  func (c *CockroachDb) Unlock() error {
   189  	return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
   190  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   191  		if err != nil {
   192  			return err
   193  		}
   194  
   195  		// In the event of an implementation (non-migration) error, it is possible for the lock to not be released.  Until
   196  		// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   197  		query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   198  		if _, err := c.db.Exec(query, aid); err != nil {
   199  			if e, ok := err.(*pq.Error); ok {
   200  				// 42P01 is "UndefinedTableError" in CockroachDB
   201  				// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   202  				if e.Code == "42P01" {
   203  					// On drops, the lock table is fully removed;  This is fine, and is a valid "unlocked" state for the schema
   204  					return nil
   205  				}
   206  			}
   207  
   208  			return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   209  		}
   210  
   211  		return nil
   212  	})
   213  }
   214  
   215  func (c *CockroachDb) Run(migration io.Reader) error {
   216  	migr, err := io.ReadAll(migration)
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	// run migration
   222  	query := string(migr[:])
   223  	if _, err := c.db.Exec(query); err != nil {
   224  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   225  	}
   226  
   227  	return nil
   228  }
   229  
   230  func (c *CockroachDb) SetVersion(version int, dirty bool) error {
   231  	return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
   232  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   233  			return err
   234  		}
   235  
   236  		// Also re-write the schema version for nil dirty versions to prevent
   237  		// empty schema version for failed down migration on the first migration
   238  		// See: https://github.com/golang-migrate/migrate/issues/330
   239  		if version >= 0 || (version == database.NilVersion && dirty) {
   240  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   241  				return err
   242  			}
   243  		}
   244  
   245  		return nil
   246  	})
   247  }
   248  
   249  func (c *CockroachDb) Version() (version int, dirty bool, err error) {
   250  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   251  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   252  
   253  	switch {
   254  	case err == sql.ErrNoRows:
   255  		return database.NilVersion, false, nil
   256  
   257  	case err != nil:
   258  		if e, ok := err.(*pq.Error); ok {
   259  			// 42P01 is "UndefinedTableError" in CockroachDB
   260  			// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
   261  			if e.Code == "42P01" {
   262  				return database.NilVersion, false, nil
   263  			}
   264  		}
   265  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   266  
   267  	default:
   268  		return version, dirty, nil
   269  	}
   270  }
   271  
   272  func (c *CockroachDb) Drop() (err error) {
   273  	// select all tables in current schema
   274  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
   275  	tables, err := c.db.Query(query)
   276  	if err != nil {
   277  		return &database.Error{OrigErr: err, Query: []byte(query)}
   278  	}
   279  	defer func() {
   280  		if errClose := tables.Close(); errClose != nil {
   281  			err = multierror.Append(err, errClose)
   282  		}
   283  	}()
   284  
   285  	// delete one table after another
   286  	tableNames := make([]string, 0)
   287  	for tables.Next() {
   288  		var tableName string
   289  		if err := tables.Scan(&tableName); err != nil {
   290  			return err
   291  		}
   292  		if len(tableName) > 0 {
   293  			tableNames = append(tableNames, tableName)
   294  		}
   295  	}
   296  	if err := tables.Err(); err != nil {
   297  		return &database.Error{OrigErr: err, Query: []byte(query)}
   298  	}
   299  
   300  	if len(tableNames) > 0 {
   301  		// delete one by one ...
   302  		for _, t := range tableNames {
   303  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   304  			if _, err := c.db.Exec(query); err != nil {
   305  				return &database.Error{OrigErr: err, Query: []byte(query)}
   306  			}
   307  		}
   308  	}
   309  
   310  	return nil
   311  }
   312  
   313  // ensureVersionTable checks if versions table exists and, if not, creates it.
   314  // Note that this function locks the database, which deviates from the usual
   315  // convention of "caller locks" in the CockroachDb type.
   316  func (c *CockroachDb) ensureVersionTable() (err error) {
   317  	if err = c.Lock(); err != nil {
   318  		return err
   319  	}
   320  
   321  	defer func() {
   322  		if e := c.Unlock(); e != nil {
   323  			if err == nil {
   324  				err = e
   325  			} else {
   326  				err = multierror.Append(err, e)
   327  			}
   328  		}
   329  	}()
   330  
   331  	// check if migration table exists
   332  	var count int
   333  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   334  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   335  		return &database.Error{OrigErr: err, Query: []byte(query)}
   336  	}
   337  	if count == 1 {
   338  		return nil
   339  	}
   340  
   341  	// if not, create the empty migration table
   342  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   343  	if _, err := c.db.Exec(query); err != nil {
   344  		return &database.Error{OrigErr: err, Query: []byte(query)}
   345  	}
   346  	return nil
   347  }
   348  
   349  func (c *CockroachDb) ensureLockTable() error {
   350  	// check if lock table exists
   351  	var count int
   352  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   353  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   354  		return &database.Error{OrigErr: err, Query: []byte(query)}
   355  	}
   356  	if count == 1 {
   357  		return nil
   358  	}
   359  
   360  	// if not, create the empty lock table
   361  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
   362  	if _, err := c.db.Exec(query); err != nil {
   363  		return &database.Error{OrigErr: err, Query: []byte(query)}
   364  	}
   365  
   366  	return nil
   367  }