github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/migration/migration.go (about)

     1  package migration
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"sort"
     8  	"time"
     9  
    10  	"code.cloudfoundry.org/lager"
    11  	"github.com/pf-qiu/concourse/v6/atc/db/encryption"
    12  	"github.com/pf-qiu/concourse/v6/atc/db/lock"
    13  	"github.com/pf-qiu/concourse/v6/atc/db/migration/migrations"
    14  	"github.com/gobuffalo/packr"
    15  	multierror "github.com/hashicorp/go-multierror"
    16  	_ "github.com/lib/pq"
    17  )
    18  
    19  func NewOpenHelper(driver, name string, lockFactory lock.LockFactory, newKey *encryption.Key, oldKey *encryption.Key) *OpenHelper {
    20  	return &OpenHelper{
    21  		driver,
    22  		name,
    23  		lockFactory,
    24  		newKey,
    25  		oldKey,
    26  	}
    27  }
    28  
    29  type OpenHelper struct {
    30  	driver         string
    31  	dataSourceName string
    32  	lockFactory    lock.LockFactory
    33  	newKey         *encryption.Key
    34  	oldKey         *encryption.Key
    35  }
    36  
    37  func (helper *OpenHelper) CurrentVersion() (int, error) {
    38  	db, err := sql.Open(helper.driver, helper.dataSourceName)
    39  	if err != nil {
    40  		return -1, err
    41  	}
    42  
    43  	defer db.Close()
    44  
    45  	return NewMigrator(db, helper.lockFactory).CurrentVersion()
    46  }
    47  
    48  func (helper *OpenHelper) SupportedVersion() (int, error) {
    49  	db, err := sql.Open(helper.driver, helper.dataSourceName)
    50  	if err != nil {
    51  		return -1, err
    52  	}
    53  
    54  	defer db.Close()
    55  
    56  	return NewMigrator(db, helper.lockFactory).SupportedVersion()
    57  }
    58  
    59  func (helper *OpenHelper) Open() (*sql.DB, error) {
    60  	db, err := sql.Open(helper.driver, helper.dataSourceName)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	if err := NewMigrator(db, helper.lockFactory).Up(helper.newKey, helper.oldKey); err != nil {
    66  		_ = db.Close()
    67  		return nil, err
    68  	}
    69  
    70  	return db, nil
    71  }
    72  
    73  func (helper *OpenHelper) OpenAtVersion(version int) (*sql.DB, error) {
    74  	db, err := sql.Open(helper.driver, helper.dataSourceName)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	if err := NewMigrator(db, helper.lockFactory).Migrate(helper.newKey, helper.oldKey, version); err != nil {
    80  		_ = db.Close()
    81  		return nil, err
    82  	}
    83  
    84  	return db, nil
    85  }
    86  
    87  func (helper *OpenHelper) MigrateToVersion(version int) error {
    88  	db, err := sql.Open(helper.driver, helper.dataSourceName)
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	defer db.Close()
    94  	m := NewMigrator(db, helper.lockFactory)
    95  
    96  	err = helper.migrateFromMigrationVersion(db)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	return m.Migrate(helper.newKey, helper.oldKey, version)
   102  }
   103  
   104  func (helper *OpenHelper) migrateFromMigrationVersion(db *sql.DB) error {
   105  
   106  	legacySchemaExists, err := checkTableExist(db, "migration_version")
   107  	if err != nil {
   108  		return err
   109  	}
   110  
   111  	if !legacySchemaExists {
   112  		return nil
   113  	}
   114  
   115  	oldMigrationLastVersion := 189
   116  	newMigrationStartVersion := 1510262030
   117  
   118  	var dbVersion int
   119  
   120  	if err = db.QueryRow("SELECT version FROM migration_version").Scan(&dbVersion); err != nil {
   121  		return err
   122  	}
   123  
   124  	if dbVersion != oldMigrationLastVersion {
   125  		return fmt.Errorf("Must upgrade from db version %d (concourse 3.6.0), current db version: %d", oldMigrationLastVersion, dbVersion)
   126  	}
   127  
   128  	if _, err = db.Exec("DROP TABLE IF EXISTS migration_version"); err != nil {
   129  		return err
   130  	}
   131  
   132  	_, err = db.Exec("CREATE TABLE IF NOT EXISTS schema_migrations (version bigint, dirty boolean)")
   133  	if err != nil {
   134  		return err
   135  	}
   136  
   137  	_, err = db.Exec("INSERT INTO schema_migrations (version, dirty) VALUES ($1, false)", newMigrationStartVersion)
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  type Migrator interface {
   146  	CurrentVersion() (int, error)
   147  	SupportedVersion() (int, error)
   148  	Migrate(newKey, oldKey *encryption.Key, version int) error
   149  	Up(newKey, oldKey *encryption.Key) error
   150  	Migrations() ([]migration, error)
   151  }
   152  
   153  func NewMigrator(db *sql.DB, lockFactory lock.LockFactory) Migrator {
   154  	return NewMigratorForMigrations(db, lockFactory, &packrSource{packr.NewBox("./migrations")})
   155  }
   156  
   157  func NewMigratorForMigrations(db *sql.DB, lockFactory lock.LockFactory, bindata Bindata) Migrator {
   158  	return &migrator{
   159  		db,
   160  		lockFactory,
   161  		lager.NewLogger("migrations"),
   162  		bindata,
   163  	}
   164  }
   165  
   166  type migrator struct {
   167  	db          *sql.DB
   168  	lockFactory lock.LockFactory
   169  	logger      lager.Logger
   170  	bindata     Bindata
   171  }
   172  
   173  func (m *migrator) SupportedVersion() (int, error) {
   174  	matches := []migration{}
   175  
   176  	assets := m.bindata.AssetNames()
   177  
   178  	var parser = NewParser(m.bindata)
   179  	for _, match := range assets {
   180  		if migration, err := parser.ParseMigrationFilename(match); err == nil {
   181  			matches = append(matches, migration)
   182  		}
   183  	}
   184  	sortMigrations(matches)
   185  	return matches[len(matches)-1].Version, nil
   186  }
   187  
   188  func (helper *migrator) CurrentVersion() (int, error) {
   189  	var currentVersion int
   190  	var direction string
   191  	err := helper.db.QueryRow("SELECT version, direction FROM migrations_history WHERE status!='failed' ORDER BY tstamp DESC LIMIT 1").Scan(&currentVersion, &direction)
   192  	if err != nil {
   193  		if err == sql.ErrNoRows {
   194  			return 0, nil
   195  		}
   196  		return -1, err
   197  	}
   198  	migrations, err := helper.Migrations()
   199  	if err != nil {
   200  		return -1, err
   201  	}
   202  	versions := []int{migrations[0].Version}
   203  	for _, m := range migrations {
   204  		if m.Version > versions[len(versions)-1] {
   205  			versions = append(versions, m.Version)
   206  		}
   207  	}
   208  	for i, version := range versions {
   209  		if currentVersion == version && direction == "down" {
   210  			currentVersion = versions[i-1]
   211  			break
   212  		}
   213  	}
   214  	return currentVersion, nil
   215  }
   216  
   217  func (helper *migrator) Migrate(newKey, oldKey *encryption.Key, toVersion int) error {
   218  	var strategy encryption.Strategy
   219  	if oldKey != nil {
   220  		strategy = oldKey
   221  	} else if newKey != nil {
   222  		// special case - if the old key is not provided but the new key is,
   223  		// this might mean the data was not encrypted, or that it was encrypted with newKey
   224  		strategy = encryption.NewFallbackStrategy(newKey, encryption.NewNoEncryption())
   225  	} else if newKey == nil {
   226  		strategy = encryption.NewNoEncryption()
   227  	}
   228  
   229  	lock, err := helper.acquireLock()
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	if lock != nil {
   235  		defer lock.Release()
   236  	}
   237  
   238  	existingDBVersion, err := helper.migrateFromSchemaMigrations()
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	_, err = helper.db.Exec("CREATE TABLE IF NOT EXISTS migrations_history (version bigint, tstamp timestamp with time zone, direction varchar, status varchar, dirty boolean)")
   244  	if err != nil {
   245  		return err
   246  	}
   247  
   248  	if existingDBVersion > 0 {
   249  		var containsOldMigrationInfo bool
   250  		err = helper.db.QueryRow("SELECT EXISTS (SELECT 1 FROM migrations_history where version=$1)", existingDBVersion).Scan(&containsOldMigrationInfo)
   251  		if err != nil {
   252  			return err
   253  		}
   254  
   255  		if !containsOldMigrationInfo {
   256  			_, err = helper.db.Exec("INSERT INTO migrations_history (version, tstamp, direction, status, dirty) VALUES ($1, current_timestamp, 'up', 'passed', false)", existingDBVersion)
   257  			if err != nil {
   258  				return err
   259  			}
   260  		}
   261  	}
   262  
   263  	currentVersion, err := helper.CurrentVersion()
   264  	if err != nil {
   265  		return err
   266  	}
   267  
   268  	migrations, err := helper.Migrations()
   269  	if err != nil {
   270  		return err
   271  	}
   272  
   273  	if currentVersion <= toVersion {
   274  		for _, m := range migrations {
   275  			if currentVersion < m.Version && m.Version <= toVersion && m.Direction == "up" {
   276  				err = helper.runMigration(m, strategy)
   277  				if err != nil {
   278  					return err
   279  				}
   280  			}
   281  		}
   282  	} else {
   283  		for i := len(migrations) - 1; i >= 0; i-- {
   284  			if currentVersion >= migrations[i].Version && migrations[i].Version > toVersion && migrations[i].Direction == "down" {
   285  				err = helper.runMigration(migrations[i], strategy)
   286  				if err != nil {
   287  					return err
   288  				}
   289  
   290  			}
   291  		}
   292  
   293  		err = helper.migrateToSchemaMigrations(toVersion)
   294  		if err != nil {
   295  			return err
   296  		}
   297  	}
   298  
   299  	switch {
   300  	case oldKey != nil && newKey == nil:
   301  		err = helper.decryptToPlaintext(oldKey)
   302  	case oldKey != nil && newKey != nil:
   303  		err = helper.encryptWithNewKey(newKey, oldKey)
   304  	}
   305  	if err != nil {
   306  		return err
   307  	}
   308  
   309  	if newKey != nil {
   310  		err = helper.encryptPlaintext(newKey)
   311  		if err != nil {
   312  			return err
   313  		}
   314  	}
   315  
   316  	return nil
   317  }
   318  
   319  type Strategy int
   320  
   321  const (
   322  	GoMigration Strategy = iota
   323  	SQLMigration
   324  )
   325  
   326  type migration struct {
   327  	Name       string
   328  	Version    int
   329  	Direction  string
   330  	Statements string
   331  	Strategy   Strategy
   332  }
   333  
   334  func (m *migrator) recordMigrationFailure(migration migration, migrationErr error, dirty bool) error {
   335  	_, recordErr := m.db.Exec("INSERT INTO migrations_history (version, tstamp, direction, status, dirty) VALUES ($1, current_timestamp, $2, 'failed', $3)", migration.Version, migration.Direction, dirty)
   336  	if recordErr != nil {
   337  		return multierror.Append(
   338  			migrationErr,
   339  			fmt.Errorf("record failure to migration history: %w", recordErr),
   340  		)
   341  	}
   342  
   343  	return migrationErr
   344  }
   345  
   346  func (m *migrator) runMigration(migration migration, strategy encryption.Strategy) error {
   347  	var err error
   348  
   349  	switch migration.Strategy {
   350  	case GoMigration:
   351  		err = migrations.NewMigrations(m.db, strategy).Run(migration.Name)
   352  		if err != nil {
   353  			return m.recordMigrationFailure(
   354  				migration,
   355  				fmt.Errorf("migration '%s' failed: %w", migration.Name, err),
   356  				false,
   357  			)
   358  		}
   359  	case SQLMigration:
   360  		_, err = m.db.Exec(migration.Statements)
   361  		if err != nil {
   362  			// rollback in case the migration was BEGIN ... COMMIT and failed
   363  			//
   364  			// note that this succeeds and does a no-op (with a warning) if no
   365  			// transaction was opened; we're just OK with that
   366  			_, rbErr := m.db.Exec(`ROLLBACK`)
   367  			if rbErr != nil {
   368  				return multierror.Append(err, fmt.Errorf("rollback failed: %w", rbErr))
   369  			}
   370  
   371  			return m.recordMigrationFailure(
   372  				migration,
   373  				fmt.Errorf("migration '%s' failed and was rolled back: %w", migration.Name, err),
   374  				false,
   375  			)
   376  		}
   377  	}
   378  
   379  	_, err = m.db.Exec("INSERT INTO migrations_history (version, tstamp, direction, status, dirty) VALUES ($1, current_timestamp, $2, 'passed', false)", migration.Version, migration.Direction)
   380  	return err
   381  }
   382  
   383  func (helper *migrator) Migrations() ([]migration, error) {
   384  	migrationList := []migration{}
   385  	assets := helper.bindata.AssetNames()
   386  	var parser = NewParser(helper.bindata)
   387  	for _, assetName := range assets {
   388  		parsedMigration, err := parser.ParseFileToMigration(assetName)
   389  		if err != nil {
   390  			return nil, err
   391  		}
   392  		migrationList = append(migrationList, parsedMigration)
   393  	}
   394  
   395  	sortMigrations(migrationList)
   396  
   397  	return migrationList, nil
   398  }
   399  
   400  func (helper *migrator) Up(newKey, oldKey *encryption.Key) error {
   401  	migrations, err := helper.Migrations()
   402  	if err != nil {
   403  		return err
   404  	}
   405  	return helper.Migrate(newKey, oldKey, migrations[len(migrations)-1].Version)
   406  }
   407  
   408  func (helper *migrator) acquireLock() (lock.Lock, error) {
   409  
   410  	var err error
   411  	var acquired bool
   412  	var newLock lock.Lock
   413  
   414  	if helper.lockFactory != nil {
   415  		for {
   416  			newLock, acquired, err = helper.lockFactory.Acquire(helper.logger, lock.NewDatabaseMigrationLockID())
   417  
   418  			if err != nil {
   419  				return nil, err
   420  			}
   421  
   422  			if acquired {
   423  				break
   424  			}
   425  
   426  			time.Sleep(1 * time.Second)
   427  		}
   428  	}
   429  
   430  	return newLock, err
   431  }
   432  
   433  func checkTableExist(db *sql.DB, tableName string) (bool, error) {
   434  	var existingTable sql.NullString
   435  	err := db.QueryRow("SELECT to_regclass($1)", tableName).Scan(&existingTable)
   436  	if err != nil {
   437  		return false, err
   438  	}
   439  
   440  	return existingTable.Valid, nil
   441  }
   442  
   443  func (helper *migrator) migrateFromSchemaMigrations() (int, error) {
   444  	oldSchemaExists, err := checkTableExist(helper.db, "schema_migrations")
   445  	if err != nil {
   446  		return 0, err
   447  	}
   448  
   449  	newSchemaExists, err := checkTableExist(helper.db, "migrations_history")
   450  	if err != nil {
   451  		return 0, err
   452  	}
   453  
   454  	if !oldSchemaExists || newSchemaExists {
   455  		return 0, nil
   456  	}
   457  
   458  	var isDirty = false
   459  	var existingVersion int
   460  	err = helper.db.QueryRow("SELECT dirty, version FROM schema_migrations LIMIT 1").Scan(&isDirty, &existingVersion)
   461  	if err != nil {
   462  		return 0, err
   463  	}
   464  
   465  	if isDirty {
   466  		return 0, errors.New("cannot begin migration: database is in a dirty state")
   467  	}
   468  
   469  	return existingVersion, nil
   470  }
   471  
   472  func sortMigrations(migrationList []migration) {
   473  	sort.Slice(migrationList, func(i, j int) bool {
   474  		return migrationList[i].Version < migrationList[j].Version
   475  	})
   476  }
   477  
   478  func (helper *migrator) migrateToSchemaMigrations(toVersion int) error {
   479  	newMigrationsHistoryFirstVersion := 1532706545
   480  
   481  	if toVersion >= newMigrationsHistoryFirstVersion {
   482  		return nil
   483  	}
   484  
   485  	oldSchemaExists, err := checkTableExist(helper.db, "schema_migrations")
   486  	if err != nil {
   487  		return err
   488  	}
   489  
   490  	if !oldSchemaExists {
   491  		_, err := helper.db.Exec("CREATE TABLE schema_migrations (version bigint, dirty boolean)")
   492  		if err != nil {
   493  			return err
   494  		}
   495  
   496  		_, err = helper.db.Exec("INSERT INTO schema_migrations (version, dirty) VALUES ($1, false)", toVersion)
   497  		if err != nil {
   498  			return err
   499  		}
   500  	} else {
   501  		_, err := helper.db.Exec("UPDATE schema_migrations SET version=$1, dirty=false", toVersion)
   502  		if err != nil {
   503  			return err
   504  		}
   505  	}
   506  
   507  	return nil
   508  }