git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/migrate/migrate.go (about)

     1  package migrate
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"sort"
     9  
    10  	"log/slog"
    11  
    12  	"git.sr.ht/~pingoo/stdx/db"
    13  	"git.sr.ht/~pingoo/stdx/log/slogx"
    14  )
    15  
    16  type Migration struct {
    17  	ID   int64
    18  	Up   func(ctx context.Context, tx db.Queryer) (err error)
    19  	Down func(ctx context.Context, tx db.Queryer) (err error)
    20  }
    21  
    22  func Migrate(ctx context.Context, db db.DB, migrations []Migration) (err error) {
    23  	logger := slogx.FromCtx(ctx)
    24  	if logger == nil {
    25  		err = errors.New("migrate.Migrate: logger is missing from context")
    26  		return
    27  	}
    28  
    29  	logger.Debug("migrate: Creating/checking migrations table...")
    30  	err = createMigrationTable(ctx, db)
    31  	if err != nil {
    32  		return err
    33  	}
    34  
    35  	tx, err := db.Begin(ctx)
    36  	if err != nil {
    37  		err = fmt.Errorf("migrate.Migrate: Starting DB transaction: %w", err)
    38  		return
    39  	}
    40  	defer tx.Rollback()
    41  
    42  	_, err = tx.Exec(ctx, "LOCK TABLE migrations IN ACCESS EXCLUSIVE MODE")
    43  	if err != nil {
    44  		err = fmt.Errorf("migrate.Migrate: Locking table: %w", err)
    45  		return
    46  	}
    47  
    48  	for _, migration := range migrations {
    49  		var found string
    50  
    51  		err = tx.Get(ctx, &found, "SELECT id FROM migrations WHERE id = $1 FOR UPDATE", migration.ID)
    52  		switch err {
    53  		case sql.ErrNoRows:
    54  			logger.Info("migrate: Running migration", slog.Int64("migrations.id", migration.ID))
    55  			// we need to run the migration so we continue to code below
    56  		case nil:
    57  			logger.Debug("migrate: Skipping migration", slog.Int64("migrations.id", migration.ID))
    58  			continue
    59  		default:
    60  			err = fmt.Errorf("migrate.Migrate: looking up migration by id: %w", err)
    61  			return
    62  		}
    63  
    64  		_, err = tx.Exec(ctx, "INSERT INTO migrations (id) VALUES ($1)", migration.ID)
    65  		if err != nil {
    66  			err = fmt.Errorf("migrate.Migrate: inserting migration: %w", err)
    67  			return
    68  		}
    69  
    70  		err = migration.Up(ctx, tx)
    71  		if err != nil {
    72  			err = fmt.Errorf("migrate.Migrate: executing migration (migration id = %d): %w", migration.ID, err)
    73  			return
    74  		}
    75  	}
    76  
    77  	err = tx.Commit()
    78  	if err != nil {
    79  		err = fmt.Errorf("migrate: Committing transaction: %w", err)
    80  		return
    81  	}
    82  
    83  	return
    84  }
    85  
    86  // Rollback undo the latest migration
    87  func Rollback(ctx context.Context, db db.DB, migrations []Migration, numberToRollback int64) (err error) {
    88  	logger := slogx.FromCtx(ctx)
    89  	if logger == nil {
    90  		err = errors.New("migrate.Rollback: logger is missing from context")
    91  		return
    92  	}
    93  
    94  	logger.Debug("migrate: Creating/checking migrations table...")
    95  	err = createMigrationTable(ctx, db)
    96  	if err != nil {
    97  		return err
    98  	}
    99  
   100  	// reverse migration
   101  	sort.SliceStable(migrations, func(i, j int) bool {
   102  		return migrations[i].ID > migrations[j].ID
   103  	})
   104  
   105  	tx, err := db.Begin(ctx)
   106  	if err != nil {
   107  		err = fmt.Errorf("migrate: Starting DB transaction: %w", err)
   108  		return
   109  	}
   110  	defer tx.Rollback()
   111  
   112  	_, err = tx.Exec(ctx, "LOCK TABLE migrations IN ACCESS EXCLUSIVE MODE")
   113  	if err != nil {
   114  		err = fmt.Errorf("migrate.Rollback: Locking table: %w", err)
   115  		return
   116  	}
   117  
   118  	for i := int64(0); i < numberToRollback; i += 1 {
   119  		migration := migrations[i]
   120  
   121  		var found string
   122  		err = tx.Get(ctx, &found, "SELECT id FROM migrations WHERE id = $1", migration.ID)
   123  		switch err {
   124  		case sql.ErrNoRows:
   125  			logger.Debug("migrate: Skipping rollback", slog.Int64("migration.id", migration.ID))
   126  			continue
   127  		case nil:
   128  			logger.Info("migrate: Running rollback", slog.Int64("migration.id", migration.ID))
   129  			// we need to run the rollback so we continue to code below
   130  		default:
   131  			err = fmt.Errorf("migrate.Rollback: looking up rollback by id: %w", err)
   132  			return
   133  		}
   134  
   135  		_, err = tx.Exec(ctx, "DELETE FROM migrations WHERE id=$1", migration.ID)
   136  		if err != nil {
   137  			err = fmt.Errorf("migrate.Rollback: deleting migration: %w", err)
   138  			return
   139  		}
   140  
   141  		err = migration.Down(ctx, tx)
   142  		if err != nil {
   143  			err = fmt.Errorf("migrate.Rollback: executing rollback (migration id = %d): %w", migration.ID, err)
   144  			return
   145  		}
   146  	}
   147  
   148  	err = tx.Commit()
   149  	if err != nil {
   150  		err = fmt.Errorf("migrate.Rollback: Committing transaction: %w", err)
   151  		return
   152  	}
   153  
   154  	return
   155  }
   156  
   157  func createMigrationTable(ctx context.Context, db db.DB) error {
   158  	_, err := db.Exec(ctx, "CREATE TABLE IF NOT EXISTS migrations (id BIGINT PRIMARY KEY )")
   159  	if err != nil {
   160  		return fmt.Errorf("migrate: Creating migrations table: %w", err)
   161  	}
   162  	return nil
   163  }