github.com/kubeshop/testkube@v1.17.23/pkg/dbmigrator/dbmigrator.go (about) 1 package dbmigrator 2 3 import ( 4 "context" 5 "io/fs" 6 "path/filepath" 7 "reflect" 8 "regexp" 9 "sort" 10 11 "github.com/pkg/errors" 12 "go.mongodb.org/mongo-driver/bson" 13 "golang.org/x/exp/slices" 14 ) 15 16 type DbPlan struct { 17 Ups []DbMigration 18 Downs []DbMigration 19 Total int 20 } 21 22 type DbMigration struct { 23 Name string `bson:"name"` 24 UpScript []bson.D `bson:"up"` 25 DownScript []bson.D `bson:"down"` 26 } 27 28 type DbMigrator struct { 29 db Database 30 list []DbMigration 31 } 32 33 func GetDbMigrationsFromFs(fsys fs.FS) ([]DbMigration, error) { 34 filePaths, err := fs.Glob(fsys, "*.json") 35 if err != nil { 36 return nil, err 37 } 38 sort.Slice(filePaths, func(i, j int) bool { 39 return filePaths[i] < filePaths[j] 40 }) 41 var list []DbMigration 42 upRe := regexp.MustCompile(`\.up\.json$`) 43 for _, filePath := range filePaths { 44 if !upRe.MatchString(filePath) { 45 continue 46 } 47 name := upRe.ReplaceAllString(filepath.Base(filePath), "") 48 downFilePath := upRe.ReplaceAllString(filePath, ".down.json") 49 var upBytes []byte 50 downBytes := []byte("[]") 51 if slices.Contains(filePaths, downFilePath) { 52 downBytes, err = fs.ReadFile(fsys, downFilePath) 53 if err != nil { 54 return nil, err 55 } 56 } 57 upBytes, err = fs.ReadFile(fsys, filePath) 58 if err != nil { 59 return nil, err 60 } 61 var downScript, upScript []bson.D 62 err = bson.UnmarshalExtJSON(downBytes, true, &downScript) 63 if err != nil { 64 return nil, errors.Wrapf(err, "migration '%s' has invalid rollback commands", name) 65 } 66 err = bson.UnmarshalExtJSON(upBytes, true, &upScript) 67 if err != nil { 68 return nil, errors.Wrapf(err, "migration '%s' has invalid commands", name) 69 } 70 list = append(list, DbMigration{ 71 Name: name, 72 UpScript: upScript, 73 DownScript: downScript, 74 }) 75 } 76 return list, nil 77 } 78 79 func NewDbMigrator(db Database, list []DbMigration) *DbMigrator { 80 return &DbMigrator{db: db, list: list} 81 } 82 83 func (d *DbMigrator) up(ctx context.Context, migration *DbMigration) (err error) { 84 err = d.db.RunCommands(ctx, migration.UpScript) 85 if err != nil { 86 downErr := d.down(ctx, migration) 87 if downErr == nil { 88 return errors.Wrapf(err, "migration '%s' failed, rolled back.", migration.Name) 89 } else { 90 return errors.Wrapf(err, "migration '%s' failed, rolled failed to: %v", migration.Name, downErr.Error()) 91 } 92 } 93 err = d.db.InsertMigrationState(ctx, migration) 94 if err != nil { 95 return errors.Wrapf(err, "failed to save '%s' migration state to database", migration.Name) 96 } 97 return nil 98 } 99 100 // TODO: Consider transactions, but it requires MongoDB with replicaset 101 func (d *DbMigrator) down(ctx context.Context, migration *DbMigration) (err error) { 102 err = d.db.RunCommands(ctx, migration.DownScript) 103 if err != nil { 104 return errors.Wrapf(err, "rolling back '%s' failed.", migration.Name) 105 } 106 err = d.db.DeleteMigrationState(ctx, migration) 107 if err != nil { 108 return errors.Wrapf(err, "failed to save '%s' rollback state to database", migration.Name) 109 } 110 return err 111 } 112 113 func (d *DbMigrator) GetApplied(ctx context.Context) (results []DbMigration, err error) { 114 return d.db.GetAppliedMigrations(ctx) 115 } 116 117 func (d *DbMigrator) Plan(ctx context.Context) (plan DbPlan, err error) { 118 applied, err := d.GetApplied(ctx) 119 if err != nil { 120 return plan, err 121 } 122 matchCount := 0 123 for i, migration := range d.list { 124 if i >= len(applied) || applied[i].Name != migration.Name || !reflect.DeepEqual(applied[i].UpScript, migration.UpScript) { 125 break 126 } 127 matchCount++ 128 } 129 130 if matchCount < len(applied) { 131 plan.Downs = applied[matchCount:] 132 slices.Reverse(plan.Downs) 133 } 134 if len(d.list) > matchCount { 135 plan.Ups = d.list[matchCount:] 136 } 137 plan.Total = len(plan.Ups) + len(plan.Downs) 138 return plan, err 139 } 140 141 func (d *DbMigrator) Apply(ctx context.Context) error { 142 plan, err := d.Plan(ctx) 143 if err != nil { 144 return err 145 } 146 for _, migration := range plan.Downs { 147 err = d.down(ctx, &migration) 148 if err != nil { 149 return err 150 } 151 } 152 for _, migration := range plan.Ups { 153 err = d.up(ctx, &migration) 154 if err != nil { 155 return err 156 } 157 } 158 return nil 159 }