github.com/kubeshop/testkube@v1.17.23/pkg/dbmigrator/dbmigrator.go (about)

     1  package dbmigrator
     2  
     3  import (
     4  	"context"
     5  	"io/fs"
     6  	"path/filepath"
     7  	"reflect"
     8  	"regexp"
     9  	"sort"
    10  
    11  	"github.com/pkg/errors"
    12  	"go.mongodb.org/mongo-driver/bson"
    13  	"golang.org/x/exp/slices"
    14  )
    15  
    16  type DbPlan struct {
    17  	Ups   []DbMigration
    18  	Downs []DbMigration
    19  	Total int
    20  }
    21  
    22  type DbMigration struct {
    23  	Name       string   `bson:"name"`
    24  	UpScript   []bson.D `bson:"up"`
    25  	DownScript []bson.D `bson:"down"`
    26  }
    27  
    28  type DbMigrator struct {
    29  	db   Database
    30  	list []DbMigration
    31  }
    32  
    33  func GetDbMigrationsFromFs(fsys fs.FS) ([]DbMigration, error) {
    34  	filePaths, err := fs.Glob(fsys, "*.json")
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	sort.Slice(filePaths, func(i, j int) bool {
    39  		return filePaths[i] < filePaths[j]
    40  	})
    41  	var list []DbMigration
    42  	upRe := regexp.MustCompile(`\.up\.json$`)
    43  	for _, filePath := range filePaths {
    44  		if !upRe.MatchString(filePath) {
    45  			continue
    46  		}
    47  		name := upRe.ReplaceAllString(filepath.Base(filePath), "")
    48  		downFilePath := upRe.ReplaceAllString(filePath, ".down.json")
    49  		var upBytes []byte
    50  		downBytes := []byte("[]")
    51  		if slices.Contains(filePaths, downFilePath) {
    52  			downBytes, err = fs.ReadFile(fsys, downFilePath)
    53  			if err != nil {
    54  				return nil, err
    55  			}
    56  		}
    57  		upBytes, err = fs.ReadFile(fsys, filePath)
    58  		if err != nil {
    59  			return nil, err
    60  		}
    61  		var downScript, upScript []bson.D
    62  		err = bson.UnmarshalExtJSON(downBytes, true, &downScript)
    63  		if err != nil {
    64  			return nil, errors.Wrapf(err, "migration '%s' has invalid rollback commands", name)
    65  		}
    66  		err = bson.UnmarshalExtJSON(upBytes, true, &upScript)
    67  		if err != nil {
    68  			return nil, errors.Wrapf(err, "migration '%s' has invalid commands", name)
    69  		}
    70  		list = append(list, DbMigration{
    71  			Name:       name,
    72  			UpScript:   upScript,
    73  			DownScript: downScript,
    74  		})
    75  	}
    76  	return list, nil
    77  }
    78  
    79  func NewDbMigrator(db Database, list []DbMigration) *DbMigrator {
    80  	return &DbMigrator{db: db, list: list}
    81  }
    82  
    83  func (d *DbMigrator) up(ctx context.Context, migration *DbMigration) (err error) {
    84  	err = d.db.RunCommands(ctx, migration.UpScript)
    85  	if err != nil {
    86  		downErr := d.down(ctx, migration)
    87  		if downErr == nil {
    88  			return errors.Wrapf(err, "migration '%s' failed, rolled back.", migration.Name)
    89  		} else {
    90  			return errors.Wrapf(err, "migration '%s' failed, rolled failed to: %v", migration.Name, downErr.Error())
    91  		}
    92  	}
    93  	err = d.db.InsertMigrationState(ctx, migration)
    94  	if err != nil {
    95  		return errors.Wrapf(err, "failed to save '%s' migration state to database", migration.Name)
    96  	}
    97  	return nil
    98  }
    99  
   100  // TODO: Consider transactions, but it requires MongoDB with replicaset
   101  func (d *DbMigrator) down(ctx context.Context, migration *DbMigration) (err error) {
   102  	err = d.db.RunCommands(ctx, migration.DownScript)
   103  	if err != nil {
   104  		return errors.Wrapf(err, "rolling back '%s' failed.", migration.Name)
   105  	}
   106  	err = d.db.DeleteMigrationState(ctx, migration)
   107  	if err != nil {
   108  		return errors.Wrapf(err, "failed to save '%s' rollback state to database", migration.Name)
   109  	}
   110  	return err
   111  }
   112  
   113  func (d *DbMigrator) GetApplied(ctx context.Context) (results []DbMigration, err error) {
   114  	return d.db.GetAppliedMigrations(ctx)
   115  }
   116  
   117  func (d *DbMigrator) Plan(ctx context.Context) (plan DbPlan, err error) {
   118  	applied, err := d.GetApplied(ctx)
   119  	if err != nil {
   120  		return plan, err
   121  	}
   122  	matchCount := 0
   123  	for i, migration := range d.list {
   124  		if i >= len(applied) || applied[i].Name != migration.Name || !reflect.DeepEqual(applied[i].UpScript, migration.UpScript) {
   125  			break
   126  		}
   127  		matchCount++
   128  	}
   129  
   130  	if matchCount < len(applied) {
   131  		plan.Downs = applied[matchCount:]
   132  		slices.Reverse(plan.Downs)
   133  	}
   134  	if len(d.list) > matchCount {
   135  		plan.Ups = d.list[matchCount:]
   136  	}
   137  	plan.Total = len(plan.Ups) + len(plan.Downs)
   138  	return plan, err
   139  }
   140  
   141  func (d *DbMigrator) Apply(ctx context.Context) error {
   142  	plan, err := d.Plan(ctx)
   143  	if err != nil {
   144  		return err
   145  	}
   146  	for _, migration := range plan.Downs {
   147  		err = d.down(ctx, &migration)
   148  		if err != nil {
   149  			return err
   150  		}
   151  	}
   152  	for _, migration := range plan.Ups {
   153  		err = d.up(ctx, &migration)
   154  		if err != nil {
   155  			return err
   156  		}
   157  	}
   158  	return nil
   159  }