github.com/status-im/status-go@v1.1.0/sqlite/migrate.go (about) 1 package sqlite 2 3 import ( 4 "database/sql" 5 "fmt" 6 "sort" 7 8 "github.com/status-im/migrate/v4" 9 "github.com/status-im/migrate/v4/database/sqlcipher" 10 bindata "github.com/status-im/migrate/v4/source/go_bindata" 11 ) 12 13 type CustomMigrationFunc func(tx *sql.Tx) error 14 15 type PostStep struct { 16 Version uint 17 CustomMigration CustomMigrationFunc 18 RollBackVersion uint 19 } 20 21 var migrationTable = "status_go_" + sqlcipher.DefaultMigrationsTable 22 23 // Migrate database with option to augment the migration steps with additional processing using the customSteps 24 // parameter. For each PostStep entry in customSteps the CustomMigration will be called after the migration step 25 // with the matching Version number has been executed. If the CustomMigration returns an error, the migration process 26 // is aborted. In case the custom step failures the migrations are run down to RollBackVersion if > 0. 27 // 28 // The recommended way to create a custom migration is by providing empty and versioned run/down sql files as markers. 29 // Then running all the SQL code inside the same transaction to transform and commit provides the possibility 30 // to completely rollback the migration in case of failure, avoiding to leave the DB in an inconsistent state. 31 // 32 // Marker migrations can be created by using PostStep structs with specific Version numbers and a callback function, 33 // even when no accompanying SQL migration is needed. This can be used to trigger Go code at specific points 34 // during the migration process. 35 // 36 // Caution: This mechanism should be used as a last resort. Prefer data migration using SQL migration files 37 // whenever possible to ensure consistency and compatibility with standard migration tools. 38 // 39 // untilVersion, for testing purposes optional parameter, can be used to limit the migration to a specific version. 40 // Pass nil to migrate to the latest available version. 41 func Migrate(db *sql.DB, resources *bindata.AssetSource, customSteps []*PostStep, untilVersion *uint) error { 42 source, err := bindata.WithInstance(resources) 43 if err != nil { 44 return fmt.Errorf("failed to create bindata migration source: %w", err) 45 } 46 47 driver, err := sqlcipher.WithInstance(db, &sqlcipher.Config{ 48 MigrationsTable: migrationTable, 49 }) 50 if err != nil { 51 return fmt.Errorf("failed to create sqlcipher driver: %w", err) 52 } 53 54 m, err := migrate.NewWithInstance("go-bindata", source, "sqlcipher", driver) 55 if err != nil { 56 return fmt.Errorf("failed to create migration instance: %w", err) 57 } 58 59 if len(customSteps) == 0 { 60 return runRemainingMigrations(m, untilVersion) 61 } 62 63 sort.Slice(customSteps, func(i, j int) bool { 64 return customSteps[i].Version < customSteps[j].Version 65 }) 66 67 lastVersion, err := getCurrentVersion(m, db) 68 if err != nil { 69 return err 70 } 71 72 customIndex := 0 73 // ignore processed versions 74 for customIndex < len(customSteps) && customSteps[customIndex].Version <= lastVersion { 75 customIndex++ 76 } 77 78 if err := runCustomMigrations(m, db, customSteps, customIndex, untilVersion); err != nil { 79 return err 80 } 81 82 return runRemainingMigrations(m, untilVersion) 83 } 84 85 // runCustomMigrations performs source migrations from current to each custom steps, then runs custom migration callback 86 // until it executes all custom migrations or an error occurs and it tries to rollback to RollBackVersion if > 0. 87 func runCustomMigrations(m *migrate.Migrate, db *sql.DB, customSteps []*PostStep, customIndex int, untilVersion *uint) error { 88 for customIndex < len(customSteps) && (untilVersion == nil || customSteps[customIndex].Version <= *untilVersion) { 89 customStep := customSteps[customIndex] 90 91 if err := m.Migrate(customStep.Version); err != nil && err != migrate.ErrNoChange { 92 return fmt.Errorf("failed to migrate to version %d: %w", customStep.Version, err) 93 } 94 95 if err := runCustomMigrationStep(db, customStep, m); err != nil { 96 return err 97 } 98 99 customIndex++ 100 } 101 return nil 102 } 103 104 func runCustomMigrationStep(db *sql.DB, customStep *PostStep, m *migrate.Migrate) error { 105 106 sqlTx, err := db.Begin() 107 if err != nil { 108 return fmt.Errorf("failed to begin transaction: %w", err) 109 } 110 111 if err := customStep.CustomMigration(sqlTx); err != nil { 112 _ = sqlTx.Rollback() 113 return rollbackCustomMigration(m, customStep, err) 114 } 115 116 if err := sqlTx.Commit(); err != nil { 117 return fmt.Errorf("failed to commit transaction: %w", err) 118 } 119 return nil 120 } 121 122 func rollbackCustomMigration(m *migrate.Migrate, customStep *PostStep, customErr error) error { 123 if customStep.RollBackVersion > 0 { 124 err := m.Migrate(customStep.RollBackVersion) 125 newV, _, _ := m.Version() 126 if err != nil { 127 return fmt.Errorf("failed to rollback migration to version %d: %w", customStep.RollBackVersion, err) 128 } 129 return fmt.Errorf("custom migration step failed for version %d. Successfully rolled back migration to version %d: %w", customStep.Version, newV, customErr) 130 } 131 return fmt.Errorf("custom migration step failed for version %d: %w", customStep.Version, customErr) 132 } 133 134 func runRemainingMigrations(m *migrate.Migrate, untilVersion *uint) error { 135 if untilVersion != nil { 136 if err := m.Migrate(*untilVersion); err != nil && err != migrate.ErrNoChange { 137 return fmt.Errorf("failed to migrate to version %d: %w", *untilVersion, err) 138 } 139 } else { 140 if err := m.Up(); err != nil && err != migrate.ErrNoChange { 141 ver, _, _ := m.Version() 142 return fmt.Errorf("failed to migrate up: %w, current version: %d", err, ver) 143 } 144 } 145 return nil 146 } 147 148 func getCurrentVersion(m *migrate.Migrate, db *sql.DB) (uint, error) { 149 lastVersion, dirty, err := m.Version() 150 if err != nil && err != migrate.ErrNilVersion { 151 return 0, fmt.Errorf("failed to get migration version: %w", err) 152 } 153 if dirty { 154 return 0, fmt.Errorf("DB is dirty after migration version %d", lastVersion) 155 } 156 if err == migrate.ErrNilVersion { 157 lastVersion, _, err = GetLastMigrationVersion(db) 158 return lastVersion, err 159 } 160 return lastVersion, nil 161 } 162 163 // GetLastMigrationVersion returns the last migration version stored in the migration table. 164 // Returns 0 for version in case migrationTableExists is true 165 func GetLastMigrationVersion(db *sql.DB) (version uint, migrationTableExists bool, err error) { 166 // Check if the migration table exists 167 row := db.QueryRow("SELECT exists(SELECT name FROM sqlite_master WHERE type='table' AND name=?)", migrationTable) 168 migrationTableExists = false 169 err = row.Scan(&migrationTableExists) 170 if err != nil && err != sql.ErrNoRows { 171 return 0, false, err 172 } 173 174 var lastMigration uint64 = 0 175 if migrationTableExists { 176 row = db.QueryRow("SELECT version FROM status_go_schema_migrations") 177 err = row.Scan(&lastMigration) 178 if err != nil && err != sql.ErrNoRows { 179 return 0, true, err 180 } 181 } 182 return uint(lastMigration), migrationTableExists, nil 183 }