github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/yugabytedb/yugabytedb.go (about)

     1  package yugabytedb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"io"
     8  	"net/url"
     9  	"regexp"
    10  	"strconv"
    11  	"time"
    12  
    13  	"github.com/cenkalti/backoff/v4"
    14  	"github.com/golang-migrate/migrate/v4"
    15  	"github.com/golang-migrate/migrate/v4/database"
    16  	"github.com/hashicorp/go-multierror"
    17  	"github.com/jackc/pgconn"
    18  	"github.com/jackc/pgerrcode"
    19  	"github.com/lib/pq"
    20  	"go.uber.org/atomic"
    21  )
    22  
    23  const (
    24  	DefaultMaxRetryInterval    = time.Second * 15
    25  	DefaultMaxRetryElapsedTime = time.Second * 30
    26  	DefaultMaxRetries          = 10
    27  	DefaultMigrationsTable     = "migrations"
    28  	DefaultLockTable           = "migrations_locks"
    29  )
    30  
    31  var (
    32  	ErrNilConfig          = errors.New("no config")
    33  	ErrNoDatabaseName     = errors.New("no database name")
    34  	ErrMaxRetriesExceeded = errors.New("max retries exceeded")
    35  )
    36  
    37  func init() {
    38  	db := YugabyteDB{}
    39  	database.Register("yugabyte", &db)
    40  	database.Register("yugabytedb", &db)
    41  	database.Register("ysql", &db)
    42  }
    43  
    44  type Config struct {
    45  	MigrationsTable     string
    46  	LockTable           string
    47  	ForceLock           bool
    48  	DatabaseName        string
    49  	MaxRetryInterval    time.Duration
    50  	MaxRetryElapsedTime time.Duration
    51  	MaxRetries          int
    52  }
    53  
    54  type YugabyteDB struct {
    55  	db       *sql.DB
    56  	isLocked atomic.Bool
    57  
    58  	// Open and WithInstance need to guarantee that config is never nil
    59  	config *Config
    60  }
    61  
    62  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    63  	if config == nil {
    64  		return nil, ErrNilConfig
    65  	}
    66  
    67  	if err := instance.Ping(); err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	if config.DatabaseName == "" {
    72  		query := `SELECT current_database()`
    73  		var databaseName string
    74  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    75  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    76  		}
    77  
    78  		if len(databaseName) == 0 {
    79  			return nil, ErrNoDatabaseName
    80  		}
    81  
    82  		config.DatabaseName = databaseName
    83  	}
    84  
    85  	if len(config.MigrationsTable) == 0 {
    86  		config.MigrationsTable = DefaultMigrationsTable
    87  	}
    88  
    89  	if len(config.LockTable) == 0 {
    90  		config.LockTable = DefaultLockTable
    91  	}
    92  
    93  	if config.MaxRetryInterval == 0 {
    94  		config.MaxRetryInterval = DefaultMaxRetryInterval
    95  	}
    96  
    97  	if config.MaxRetryElapsedTime == 0 {
    98  		config.MaxRetryElapsedTime = DefaultMaxRetryElapsedTime
    99  	}
   100  
   101  	if config.MaxRetries == 0 {
   102  		config.MaxRetries = DefaultMaxRetries
   103  	}
   104  
   105  	px := &YugabyteDB{
   106  		db:     instance,
   107  		config: config,
   108  	}
   109  
   110  	// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
   111  	if err := px.ensureLockTable(); err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	if err := px.ensureVersionTable(); err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return px, nil
   120  }
   121  
   122  func (c *YugabyteDB) Open(dbURL string) (database.Driver, error) {
   123  	purl, err := url.Parse(dbURL)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	// As YugabyteDB uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
   129  	// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
   130  	re := regexp.MustCompile("^(yugabyte(db)?|ysql)")
   131  	connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
   132  
   133  	db, err := sql.Open("postgres", connectString)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	migrationsTable := purl.Query().Get("x-migrations-table")
   139  	if len(migrationsTable) == 0 {
   140  		migrationsTable = DefaultMigrationsTable
   141  	}
   142  
   143  	lockTable := purl.Query().Get("x-lock-table")
   144  	if len(lockTable) == 0 {
   145  		lockTable = DefaultLockTable
   146  	}
   147  
   148  	forceLockQuery := purl.Query().Get("x-force-lock")
   149  	forceLock, err := strconv.ParseBool(forceLockQuery)
   150  	if err != nil {
   151  		forceLock = false
   152  	}
   153  
   154  	maxIntervalStr := purl.Query().Get("x-max-retry-interval")
   155  	maxInterval, err := time.ParseDuration(maxIntervalStr)
   156  	if err != nil {
   157  		maxInterval = DefaultMaxRetryInterval
   158  	}
   159  
   160  	maxElapsedTimeStr := purl.Query().Get("x-max-retry-elapsed-time")
   161  	maxElapsedTime, err := time.ParseDuration(maxElapsedTimeStr)
   162  	if err != nil {
   163  		maxElapsedTime = DefaultMaxRetryElapsedTime
   164  	}
   165  
   166  	maxRetriesStr := purl.Query().Get("x-max-retries")
   167  	maxRetries, err := strconv.Atoi(maxRetriesStr)
   168  	if err != nil {
   169  		maxRetries = DefaultMaxRetries
   170  	}
   171  
   172  	px, err := WithInstance(db, &Config{
   173  		DatabaseName:        purl.Path,
   174  		MigrationsTable:     migrationsTable,
   175  		LockTable:           lockTable,
   176  		ForceLock:           forceLock,
   177  		MaxRetryInterval:    maxInterval,
   178  		MaxRetryElapsedTime: maxElapsedTime,
   179  		MaxRetries:          maxRetries,
   180  	})
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	return px, nil
   186  }
   187  
   188  func (c *YugabyteDB) Close() error {
   189  	return c.db.Close()
   190  }
   191  
   192  // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed
   193  // See: https://github.com/yugabyte/yugabyte-db/issues/3642
   194  func (c *YugabyteDB) Lock() error {
   195  	return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
   196  		return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) (err error) {
   197  			aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   198  			if err != nil {
   199  				return err
   200  			}
   201  
   202  			query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
   203  			rows, err := tx.Query(query, aid)
   204  			if err != nil {
   205  				return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
   206  			}
   207  			defer func() {
   208  				if errClose := rows.Close(); errClose != nil {
   209  					err = multierror.Append(err, errClose)
   210  				}
   211  			}()
   212  
   213  			// If row exists at all, lock is present
   214  			locked := rows.Next()
   215  			if locked && !c.config.ForceLock {
   216  				return database.ErrLocked
   217  			}
   218  
   219  			query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
   220  			if _, err := tx.Exec(query, aid); err != nil {
   221  				return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
   222  			}
   223  
   224  			return nil
   225  		})
   226  	})
   227  }
   228  
   229  // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed
   230  // See: https://github.com/yugabyte/yugabyte-db/issues/3642
   231  func (c *YugabyteDB) Unlock() error {
   232  	return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
   233  		aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
   234  		if err != nil {
   235  			return err
   236  		}
   237  
   238  		// In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until
   239  		// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
   240  		query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
   241  		if _, err := c.db.Exec(query, aid); err != nil {
   242  			if e, ok := err.(*pq.Error); ok {
   243  				// 42P01 is "UndefinedTableError" in YugabyteDB
   244  				// https://github.com/yugabyte/yugabyte-db/blob/9c6b8e6beb56eed8eeb357178c0c6b837eb49896/src/postgres/src/backend/utils/errcodes.txt#L366
   245  				if e.Code == "42P01" {
   246  					// On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema
   247  					return nil
   248  				}
   249  			}
   250  
   251  			return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
   252  		}
   253  
   254  		return nil
   255  	})
   256  }
   257  
   258  func (c *YugabyteDB) Run(migration io.Reader) error {
   259  	migr, err := io.ReadAll(migration)
   260  	if err != nil {
   261  		return err
   262  	}
   263  
   264  	// run migration
   265  	query := string(migr[:])
   266  	if _, err := c.db.Exec(query); err != nil {
   267  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   268  	}
   269  
   270  	return nil
   271  }
   272  
   273  func (c *YugabyteDB) SetVersion(version int, dirty bool) error {
   274  	return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error {
   275  		if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
   276  			return err
   277  		}
   278  
   279  		// Also re-write the schema version for nil dirty versions to prevent
   280  		// empty schema version for failed down migration on the first migration
   281  		// See: https://github.com/golang-migrate/migrate/issues/330
   282  		if version >= 0 || (version == database.NilVersion && dirty) {
   283  			if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
   284  				return err
   285  			}
   286  		}
   287  
   288  		return nil
   289  	})
   290  }
   291  
   292  func (c *YugabyteDB) Version() (version int, dirty bool, err error) {
   293  	query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
   294  	err = c.db.QueryRow(query).Scan(&version, &dirty)
   295  
   296  	switch {
   297  	case err == sql.ErrNoRows:
   298  		return database.NilVersion, false, nil
   299  
   300  	case err != nil:
   301  		if e, ok := err.(*pq.Error); ok {
   302  			// 42P01 is "UndefinedTableError" in YugabyteDB
   303  			// https://github.com/yugabyte/yugabyte-db/blob/9c6b8e6beb56eed8eeb357178c0c6b837eb49896/src/postgres/src/backend/utils/errcodes.txt#L366
   304  			if e.Code == "42P01" {
   305  				return database.NilVersion, false, nil
   306  			}
   307  		}
   308  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   309  
   310  	default:
   311  		return version, dirty, nil
   312  	}
   313  }
   314  
   315  func (c *YugabyteDB) Drop() (err error) {
   316  	query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
   317  	tables, err := c.db.Query(query)
   318  	if err != nil {
   319  		return &database.Error{OrigErr: err, Query: []byte(query)}
   320  	}
   321  	defer func() {
   322  		if errClose := tables.Close(); errClose != nil {
   323  			err = multierror.Append(err, errClose)
   324  		}
   325  	}()
   326  
   327  	// delete one table after another
   328  	tableNames := make([]string, 0)
   329  	for tables.Next() {
   330  		var tableName string
   331  		if err := tables.Scan(&tableName); err != nil {
   332  			return err
   333  		}
   334  		if len(tableName) > 0 {
   335  			tableNames = append(tableNames, tableName)
   336  		}
   337  	}
   338  	if err := tables.Err(); err != nil {
   339  		return &database.Error{OrigErr: err, Query: []byte(query)}
   340  	}
   341  
   342  	if len(tableNames) > 0 {
   343  		for _, t := range tableNames {
   344  			query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
   345  			if _, err := c.db.Exec(query); err != nil {
   346  				return &database.Error{OrigErr: err, Query: []byte(query)}
   347  			}
   348  		}
   349  	}
   350  
   351  	return nil
   352  }
   353  
   354  // ensureVersionTable checks if versions table exists and, if not, creates it.
   355  // Note that this function locks the database
   356  func (c *YugabyteDB) ensureVersionTable() (err error) {
   357  	if err = c.Lock(); err != nil {
   358  		return err
   359  	}
   360  
   361  	defer func() {
   362  		if e := c.Unlock(); e != nil {
   363  			if err == nil {
   364  				err = e
   365  			} else {
   366  				err = multierror.Append(err, e)
   367  			}
   368  		}
   369  	}()
   370  
   371  	// check if migration table exists
   372  	var count int
   373  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   374  	if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
   375  		return &database.Error{OrigErr: err, Query: []byte(query)}
   376  	}
   377  	if count == 1 {
   378  		return nil
   379  	}
   380  
   381  	// if not, create the empty migration table
   382  	query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
   383  	if _, err := c.db.Exec(query); err != nil {
   384  		return &database.Error{OrigErr: err, Query: []byte(query)}
   385  	}
   386  	return nil
   387  }
   388  
   389  func (c *YugabyteDB) ensureLockTable() error {
   390  	// check if lock table exists
   391  	var count int
   392  	query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
   393  	if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
   394  		return &database.Error{OrigErr: err, Query: []byte(query)}
   395  	}
   396  	if count == 1 {
   397  		return nil
   398  	}
   399  
   400  	// if not, create the empty lock table
   401  	query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id TEXT NOT NULL PRIMARY KEY)`
   402  	if _, err := c.db.Exec(query); err != nil {
   403  		return &database.Error{OrigErr: err, Query: []byte(query)}
   404  	}
   405  
   406  	return nil
   407  }
   408  
   409  func (c *YugabyteDB) doTxWithRetry(
   410  	ctx context.Context,
   411  	txOpts *sql.TxOptions,
   412  	fn func(tx *sql.Tx) error,
   413  ) error {
   414  	backOff := c.newBackoff(ctx)
   415  
   416  	return backoff.Retry(func() error {
   417  		tx, err := c.db.BeginTx(ctx, txOpts)
   418  		if err != nil {
   419  			return backoff.Permanent(err)
   420  		}
   421  
   422  		// If we've tried to commit the transaction Rollback just returns sql.ErrTxDone.
   423  		//nolint:errcheck
   424  		defer tx.Rollback()
   425  
   426  		if err := fn(tx); err != nil {
   427  			if errIsRetryable(err) {
   428  				return err
   429  			}
   430  
   431  			return backoff.Permanent(err)
   432  		}
   433  
   434  		if err := tx.Commit(); err != nil {
   435  			if errIsRetryable(err) {
   436  				return err
   437  			}
   438  
   439  			return backoff.Permanent(err)
   440  		}
   441  
   442  		return nil
   443  	}, backOff)
   444  }
   445  
   446  func (c *YugabyteDB) newBackoff(ctx context.Context) backoff.BackOff {
   447  	if ctx == nil {
   448  		ctx = context.Background()
   449  	}
   450  
   451  	retrier := backoff.WithMaxRetries(backoff.WithContext(&backoff.ExponentialBackOff{
   452  		InitialInterval:     backoff.DefaultInitialInterval,
   453  		RandomizationFactor: backoff.DefaultRandomizationFactor,
   454  		Multiplier:          backoff.DefaultMultiplier,
   455  		MaxInterval:         c.config.MaxRetryInterval,
   456  		MaxElapsedTime:      c.config.MaxRetryElapsedTime,
   457  		Stop:                backoff.Stop,
   458  		Clock:               backoff.SystemClock,
   459  	}, ctx), uint64(c.config.MaxRetries))
   460  
   461  	retrier.Reset()
   462  
   463  	return retrier
   464  }
   465  
   466  func errIsRetryable(err error) bool {
   467  	var pgErr *pgconn.PgError
   468  	if !errors.As(err, &pgErr) {
   469  		return false
   470  	}
   471  
   472  	// Assume that it's safe to retry 08006 and XX000 because we check for lock existence
   473  	// before creating and lock ID is primary key. Version field in migrations table is primary key too
   474  	// and delete all versions is an idempotent operation.
   475  	return pgErr.Code == pgerrcode.SerializationFailure || // optimistic locking conflict
   476  		pgErr.Code == pgerrcode.DeadlockDetected ||
   477  		pgErr.Code == pgerrcode.ConnectionFailure || // node down, need to reconnect
   478  		pgErr.Code == pgerrcode.InternalError // may happen during HA
   479  }