github.com/status-im/status-go@v1.1.0/sqlite/migrate.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"sort"
     7  
     8  	"github.com/status-im/migrate/v4"
     9  	"github.com/status-im/migrate/v4/database/sqlcipher"
    10  	bindata "github.com/status-im/migrate/v4/source/go_bindata"
    11  )
    12  
    13  type CustomMigrationFunc func(tx *sql.Tx) error
    14  
    15  type PostStep struct {
    16  	Version         uint
    17  	CustomMigration CustomMigrationFunc
    18  	RollBackVersion uint
    19  }
    20  
    21  var migrationTable = "status_go_" + sqlcipher.DefaultMigrationsTable
    22  
    23  // Migrate database with option to augment the migration steps with additional processing using the customSteps
    24  // parameter. For each PostStep entry in customSteps the CustomMigration will be called after the migration step
    25  // with the matching Version number has been executed. If the CustomMigration returns an error, the migration process
    26  // is aborted. In case the custom step failures the migrations are run down to RollBackVersion if > 0.
    27  //
    28  // The recommended way to create a custom migration is by providing empty and versioned run/down sql files as markers.
    29  // Then running all the SQL code inside the same transaction to transform and commit provides the possibility
    30  // to completely rollback the migration in case of failure, avoiding to leave the DB in an inconsistent state.
    31  //
    32  // Marker migrations can be created by using PostStep structs with specific Version numbers and a callback function,
    33  // even when no accompanying SQL migration is needed. This can be used to trigger Go code at specific points
    34  // during the migration process.
    35  //
    36  // Caution: This mechanism should be used as a last resort. Prefer data migration using SQL migration files
    37  // whenever possible to ensure consistency and compatibility with standard migration tools.
    38  //
    39  // untilVersion, for testing purposes optional parameter, can be used to limit the migration to a specific version.
    40  // Pass nil to migrate to the latest available version.
    41  func Migrate(db *sql.DB, resources *bindata.AssetSource, customSteps []*PostStep, untilVersion *uint) error {
    42  	source, err := bindata.WithInstance(resources)
    43  	if err != nil {
    44  		return fmt.Errorf("failed to create bindata migration source: %w", err)
    45  	}
    46  
    47  	driver, err := sqlcipher.WithInstance(db, &sqlcipher.Config{
    48  		MigrationsTable: migrationTable,
    49  	})
    50  	if err != nil {
    51  		return fmt.Errorf("failed to create sqlcipher driver: %w", err)
    52  	}
    53  
    54  	m, err := migrate.NewWithInstance("go-bindata", source, "sqlcipher", driver)
    55  	if err != nil {
    56  		return fmt.Errorf("failed to create migration instance: %w", err)
    57  	}
    58  
    59  	if len(customSteps) == 0 {
    60  		return runRemainingMigrations(m, untilVersion)
    61  	}
    62  
    63  	sort.Slice(customSteps, func(i, j int) bool {
    64  		return customSteps[i].Version < customSteps[j].Version
    65  	})
    66  
    67  	lastVersion, err := getCurrentVersion(m, db)
    68  	if err != nil {
    69  		return err
    70  	}
    71  
    72  	customIndex := 0
    73  	// ignore processed versions
    74  	for customIndex < len(customSteps) && customSteps[customIndex].Version <= lastVersion {
    75  		customIndex++
    76  	}
    77  
    78  	if err := runCustomMigrations(m, db, customSteps, customIndex, untilVersion); err != nil {
    79  		return err
    80  	}
    81  
    82  	return runRemainingMigrations(m, untilVersion)
    83  }
    84  
    85  // runCustomMigrations performs source migrations from current to each custom steps, then runs custom migration callback
    86  // until it executes all custom migrations or an error occurs and it tries to rollback to RollBackVersion if > 0.
    87  func runCustomMigrations(m *migrate.Migrate, db *sql.DB, customSteps []*PostStep, customIndex int, untilVersion *uint) error {
    88  	for customIndex < len(customSteps) && (untilVersion == nil || customSteps[customIndex].Version <= *untilVersion) {
    89  		customStep := customSteps[customIndex]
    90  
    91  		if err := m.Migrate(customStep.Version); err != nil && err != migrate.ErrNoChange {
    92  			return fmt.Errorf("failed to migrate to version %d: %w", customStep.Version, err)
    93  		}
    94  
    95  		if err := runCustomMigrationStep(db, customStep, m); err != nil {
    96  			return err
    97  		}
    98  
    99  		customIndex++
   100  	}
   101  	return nil
   102  }
   103  
   104  func runCustomMigrationStep(db *sql.DB, customStep *PostStep, m *migrate.Migrate) error {
   105  
   106  	sqlTx, err := db.Begin()
   107  	if err != nil {
   108  		return fmt.Errorf("failed to begin transaction: %w", err)
   109  	}
   110  
   111  	if err := customStep.CustomMigration(sqlTx); err != nil {
   112  		_ = sqlTx.Rollback()
   113  		return rollbackCustomMigration(m, customStep, err)
   114  	}
   115  
   116  	if err := sqlTx.Commit(); err != nil {
   117  		return fmt.Errorf("failed to commit transaction: %w", err)
   118  	}
   119  	return nil
   120  }
   121  
   122  func rollbackCustomMigration(m *migrate.Migrate, customStep *PostStep, customErr error) error {
   123  	if customStep.RollBackVersion > 0 {
   124  		err := m.Migrate(customStep.RollBackVersion)
   125  		newV, _, _ := m.Version()
   126  		if err != nil {
   127  			return fmt.Errorf("failed to rollback migration to version %d: %w", customStep.RollBackVersion, err)
   128  		}
   129  		return fmt.Errorf("custom migration step failed for version %d. Successfully rolled back migration to version %d: %w", customStep.Version, newV, customErr)
   130  	}
   131  	return fmt.Errorf("custom migration step failed for version %d: %w", customStep.Version, customErr)
   132  }
   133  
   134  func runRemainingMigrations(m *migrate.Migrate, untilVersion *uint) error {
   135  	if untilVersion != nil {
   136  		if err := m.Migrate(*untilVersion); err != nil && err != migrate.ErrNoChange {
   137  			return fmt.Errorf("failed to migrate to version %d: %w", *untilVersion, err)
   138  		}
   139  	} else {
   140  		if err := m.Up(); err != nil && err != migrate.ErrNoChange {
   141  			ver, _, _ := m.Version()
   142  			return fmt.Errorf("failed to migrate up: %w, current version: %d", err, ver)
   143  		}
   144  	}
   145  	return nil
   146  }
   147  
   148  func getCurrentVersion(m *migrate.Migrate, db *sql.DB) (uint, error) {
   149  	lastVersion, dirty, err := m.Version()
   150  	if err != nil && err != migrate.ErrNilVersion {
   151  		return 0, fmt.Errorf("failed to get migration version: %w", err)
   152  	}
   153  	if dirty {
   154  		return 0, fmt.Errorf("DB is dirty after migration version %d", lastVersion)
   155  	}
   156  	if err == migrate.ErrNilVersion {
   157  		lastVersion, _, err = GetLastMigrationVersion(db)
   158  		return lastVersion, err
   159  	}
   160  	return lastVersion, nil
   161  }
   162  
   163  // GetLastMigrationVersion returns the last migration version stored in the migration table.
   164  // Returns 0 for version in case migrationTableExists is true
   165  func GetLastMigrationVersion(db *sql.DB) (version uint, migrationTableExists bool, err error) {
   166  	// Check if the migration table exists
   167  	row := db.QueryRow("SELECT exists(SELECT name FROM sqlite_master WHERE type='table' AND name=?)", migrationTable)
   168  	migrationTableExists = false
   169  	err = row.Scan(&migrationTableExists)
   170  	if err != nil && err != sql.ErrNoRows {
   171  		return 0, false, err
   172  	}
   173  
   174  	var lastMigration uint64 = 0
   175  	if migrationTableExists {
   176  		row = db.QueryRow("SELECT version FROM status_go_schema_migrations")
   177  		err = row.Scan(&lastMigration)
   178  		if err != nil && err != sql.ErrNoRows {
   179  			return 0, true, err
   180  		}
   181  	}
   182  	return uint(lastMigration), migrationTableExists, nil
   183  }