github.com/kyleu/dbaudit@v0.0.2-0.20240321155047-ff2f2c940496/app/lib/database/migrate/migrate.go (about)

     1  // Package migrate - Content managed by Project Forge, see [projectforge.md] for details.
     2  package migrate
     3  
     4  import (
     5  	"context"
     6  	"slices"
     7  	"strings"
     8  
     9  	"github.com/jmoiron/sqlx"
    10  	"github.com/pkg/errors"
    11  	"github.com/samber/lo"
    12  
    13  	"github.com/kyleu/dbaudit/app/lib/database"
    14  	"github.com/kyleu/dbaudit/app/lib/telemetry"
    15  	"github.com/kyleu/dbaudit/app/util"
    16  )
    17  
    18  func Migrate(ctx context.Context, s *database.Service, logger util.Logger, matchesTags ...string) error {
    19  	logger = logger.With("svc", "migrate")
    20  	ctx, span, logger := telemetry.StartSpan(ctx, "database:migrate", logger)
    21  	defer span.Complete()
    22  
    23  	var positiveTags, negativeTags []string
    24  	for _, t := range matchesTags {
    25  		if strings.HasPrefix(t, "-") {
    26  			negativeTags = append(negativeTags, strings.TrimPrefix(t, "-"))
    27  		} else {
    28  			positiveTags = append(positiveTags, t)
    29  		}
    30  	}
    31  
    32  	migrations := lo.Filter(databaseMigrations, func(m *MigrationFile, _ int) bool {
    33  		if len(matchesTags) == 0 {
    34  			return true
    35  		}
    36  		good := len(positiveTags) == 0 || lo.ContainsBy(positiveTags, func(x string) bool {
    37  			return slices.Contains(m.Tags, x)
    38  		})
    39  		bad := lo.ContainsBy(negativeTags, func(x string) bool {
    40  			return slices.Contains(m.Tags, x)
    41  		})
    42  		return good && !bad
    43  	})
    44  
    45  	err := createMigrationTableIfNeeded(ctx, s, nil, logger)
    46  	if err != nil {
    47  		return errors.Wrapf(err, "unable to create migration table for database [%s]", s.Key)
    48  	}
    49  
    50  	tx, err := s.StartTransaction(logger)
    51  	if err != nil {
    52  		return errors.Wrap(err, "unable to start transaction")
    53  	}
    54  	defer func() {
    55  		_ = tx.Rollback()
    56  	}()
    57  
    58  	maxIdx := maxMigrationIdx(ctx, s, tx, logger)
    59  
    60  	if len(migrations) > maxIdx+1 {
    61  		c := len(migrations) - maxIdx
    62  		logger.Infof("applying [%s] to database [%s]...", util.StringPlural(c, "migration"), s.Key)
    63  	}
    64  
    65  	for i, file := range migrations {
    66  		err = run(ctx, maxIdx, i, file, s, tx, logger)
    67  		if err != nil {
    68  			return errors.Wrapf(err, "error running database migration [%s]", file.Title)
    69  		}
    70  	}
    71  
    72  	if err := tx.Commit(); err != nil {
    73  		return err
    74  	}
    75  
    76  	logger.Infof("verified [%s] in database [%s]", util.StringPlural(maxIdx, "migration"), s.Key)
    77  	return nil
    78  }
    79  
    80  func run(ctx context.Context, maxIdx int, i int, file *MigrationFile, s *database.Service, tx *sqlx.Tx, logger util.Logger) error {
    81  	idx := i + 1
    82  	switch {
    83  	case idx == maxIdx:
    84  		m := getMigrationByIdx(ctx, s, maxIdx, tx, logger)
    85  		if m == nil {
    86  			return nil
    87  		}
    88  		if m.Title != file.Title {
    89  			logger.Infof("migration [%d] name has changed from [%s] to [%s]", idx, m.Title, file.Title)
    90  			err := removeMigrationByIdx(ctx, s, idx, tx, logger)
    91  			if err != nil {
    92  				return err
    93  			}
    94  			err = applyMigration(ctx, s, idx, file, tx, logger)
    95  			if err != nil {
    96  				return err
    97  			}
    98  			return nil
    99  		}
   100  		nc := file.Content
   101  		if nc != m.Src {
   102  			logger.Infof("migration [%d:%s] content has changed from [%dB] to [%dB]", idx, file.Title, len(nc), len(m.Src))
   103  			err := removeMigrationByIdx(ctx, s, idx, tx, logger)
   104  			if err != nil {
   105  				return err
   106  			}
   107  			err = applyMigration(ctx, s, idx, file, tx, logger)
   108  			if err != nil {
   109  				return err
   110  			}
   111  		}
   112  	case idx > maxIdx:
   113  		err := applyMigration(ctx, s, idx, file, tx, logger)
   114  		if err != nil {
   115  			return err
   116  		}
   117  	default:
   118  		// noop
   119  	}
   120  	return nil
   121  }
   122  
   123  func applyMigration(ctx context.Context, s *database.Service, idx int, file *MigrationFile, tx *sqlx.Tx, logger util.Logger) error {
   124  	logger.Infof("applying database migration [%d]: %s", idx, file.Title)
   125  	sql, err := exec(ctx, file, s, tx, logger)
   126  	if err != nil {
   127  		return err
   128  	}
   129  	m := &Migration{Idx: idx, Title: file.Title, Src: sql, Created: util.TimeCurrent()}
   130  	return newMigration(ctx, s, m, tx, logger)
   131  }