git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/migrate/migrate.go (about) 1 package migrate 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "sort" 9 10 "log/slog" 11 12 "git.sr.ht/~pingoo/stdx/db" 13 "git.sr.ht/~pingoo/stdx/log/slogx" 14 ) 15 16 type Migration struct { 17 ID int64 18 Up func(ctx context.Context, tx db.Queryer) (err error) 19 Down func(ctx context.Context, tx db.Queryer) (err error) 20 } 21 22 func Migrate(ctx context.Context, db db.DB, migrations []Migration) (err error) { 23 logger := slogx.FromCtx(ctx) 24 if logger == nil { 25 err = errors.New("migrate.Migrate: logger is missing from context") 26 return 27 } 28 29 logger.Debug("migrate: Creating/checking migrations table...") 30 err = createMigrationTable(ctx, db) 31 if err != nil { 32 return err 33 } 34 35 tx, err := db.Begin(ctx) 36 if err != nil { 37 err = fmt.Errorf("migrate.Migrate: Starting DB transaction: %w", err) 38 return 39 } 40 defer tx.Rollback() 41 42 _, err = tx.Exec(ctx, "LOCK TABLE migrations IN ACCESS EXCLUSIVE MODE") 43 if err != nil { 44 err = fmt.Errorf("migrate.Migrate: Locking table: %w", err) 45 return 46 } 47 48 for _, migration := range migrations { 49 var found string 50 51 err = tx.Get(ctx, &found, "SELECT id FROM migrations WHERE id = $1 FOR UPDATE", migration.ID) 52 switch err { 53 case sql.ErrNoRows: 54 logger.Info("migrate: Running migration", slog.Int64("migrations.id", migration.ID)) 55 // we need to run the migration so we continue to code below 56 case nil: 57 logger.Debug("migrate: Skipping migration", slog.Int64("migrations.id", migration.ID)) 58 continue 59 default: 60 err = fmt.Errorf("migrate.Migrate: looking up migration by id: %w", err) 61 return 62 } 63 64 _, err = tx.Exec(ctx, "INSERT INTO migrations (id) VALUES ($1)", migration.ID) 65 if err != nil { 66 err = fmt.Errorf("migrate.Migrate: inserting migration: %w", err) 67 return 68 } 69 70 err = migration.Up(ctx, tx) 71 if err != nil { 72 err = fmt.Errorf("migrate.Migrate: executing migration (migration id = %d): %w", migration.ID, err) 73 return 74 } 75 } 76 77 err = tx.Commit() 78 if err != nil { 79 err = fmt.Errorf("migrate: Committing transaction: %w", err) 80 return 81 } 82 83 return 84 } 85 86 // Rollback undo the latest migration 87 func Rollback(ctx context.Context, db db.DB, migrations []Migration, numberToRollback int64) (err error) { 88 logger := slogx.FromCtx(ctx) 89 if logger == nil { 90 err = errors.New("migrate.Rollback: logger is missing from context") 91 return 92 } 93 94 logger.Debug("migrate: Creating/checking migrations table...") 95 err = createMigrationTable(ctx, db) 96 if err != nil { 97 return err 98 } 99 100 // reverse migration 101 sort.SliceStable(migrations, func(i, j int) bool { 102 return migrations[i].ID > migrations[j].ID 103 }) 104 105 tx, err := db.Begin(ctx) 106 if err != nil { 107 err = fmt.Errorf("migrate: Starting DB transaction: %w", err) 108 return 109 } 110 defer tx.Rollback() 111 112 _, err = tx.Exec(ctx, "LOCK TABLE migrations IN ACCESS EXCLUSIVE MODE") 113 if err != nil { 114 err = fmt.Errorf("migrate.Rollback: Locking table: %w", err) 115 return 116 } 117 118 for i := int64(0); i < numberToRollback; i += 1 { 119 migration := migrations[i] 120 121 var found string 122 err = tx.Get(ctx, &found, "SELECT id FROM migrations WHERE id = $1", migration.ID) 123 switch err { 124 case sql.ErrNoRows: 125 logger.Debug("migrate: Skipping rollback", slog.Int64("migration.id", migration.ID)) 126 continue 127 case nil: 128 logger.Info("migrate: Running rollback", slog.Int64("migration.id", migration.ID)) 129 // we need to run the rollback so we continue to code below 130 default: 131 err = fmt.Errorf("migrate.Rollback: looking up rollback by id: %w", err) 132 return 133 } 134 135 _, err = tx.Exec(ctx, "DELETE FROM migrations WHERE id=$1", migration.ID) 136 if err != nil { 137 err = fmt.Errorf("migrate.Rollback: deleting migration: %w", err) 138 return 139 } 140 141 err = migration.Down(ctx, tx) 142 if err != nil { 143 err = fmt.Errorf("migrate.Rollback: executing rollback (migration id = %d): %w", migration.ID, err) 144 return 145 } 146 } 147 148 err = tx.Commit() 149 if err != nil { 150 err = fmt.Errorf("migrate.Rollback: Committing transaction: %w", err) 151 return 152 } 153 154 return 155 } 156 157 func createMigrationTable(ctx context.Context, db db.DB) error { 158 _, err := db.Exec(ctx, "CREATE TABLE IF NOT EXISTS migrations (id BIGINT PRIMARY KEY )") 159 if err != nil { 160 return fmt.Errorf("migrate: Creating migrations table: %w", err) 161 } 162 return nil 163 }