github.com/nshntarora/pop@v0.1.2/migrator.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "io" 6 "os" 7 "path/filepath" 8 "regexp" 9 "sort" 10 "text/tabwriter" 11 "time" 12 13 "github.com/nshntarora/pop/logging" 14 ) 15 16 var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz)$`) 17 18 // NewMigrator returns a new "blank" migrator. It is recommended 19 // to use something like MigrationBox or FileMigrator. A "blank" 20 // Migrator should only be used as the basis for a new type of 21 // migration system. 22 func NewMigrator(c *Connection) Migrator { 23 return Migrator{ 24 Connection: c, 25 } 26 } 27 28 // Migrator forms the basis of all migrations systems. 29 // It does the actual heavy lifting of running migrations. 30 // When building a new migration system, you should embed this 31 // type into your migrator. 32 type Migrator struct { 33 Connection *Connection 34 SchemaPath string 35 UpMigrations UpMigrations 36 DownMigrations DownMigrations 37 } 38 39 func (m Migrator) migrationIsCompatible(d dialect, mi Migration) bool { 40 if mi.DBType == "all" || mi.DBType == d.Name() { 41 return true 42 } 43 return false 44 } 45 46 // UpLogOnly insert pending "up" migrations logs only, without applying the patch. 47 // It's used when loading the schema dump, instead of the migrations. 48 func (m Migrator) UpLogOnly() error { 49 c := m.Connection 50 return m.exec(func() error { 51 mtn := c.MigrationTableName() 52 mfs := m.UpMigrations 53 sort.Sort(mfs) 54 return c.Transaction(func(tx *Connection) error { 55 for _, mi := range mfs.Migrations { 56 if !m.migrationIsCompatible(c.Dialect, mi) { 57 continue 58 } 59 exists, err := c.Where("version = ?", mi.Version).Exists(mtn) 60 if err != nil { 61 return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err) 62 } 63 if exists { 64 continue 65 } 66 _, err = tx.Store.Exec(fmt.Sprintf("insert into %s (version) values ('%s')", mtn, mi.Version)) 67 if err != nil { 68 return fmt.Errorf("problem inserting migration version %s: %w", mi.Version, err) 69 } 70 } 71 return nil 72 }) 73 }) 74 } 75 76 // Up runs pending "up" migrations and applies them to the database. 77 func (m Migrator) Up() error { 78 _, err := m.UpTo(0) 79 return err 80 } 81 82 // UpTo runs up to step "up" migrations and applies them to the database. 83 // If step <= 0 all pending migrations are run. 84 func (m Migrator) UpTo(step int) (applied int, err error) { 85 c := m.Connection 86 err = m.exec(func() error { 87 mtn := c.MigrationTableName() 88 mfs := m.UpMigrations 89 mfs.Filter(func(mf Migration) bool { 90 return m.migrationIsCompatible(c.Dialect, mf) 91 }) 92 sort.Sort(mfs) 93 for _, mi := range mfs.Migrations { 94 exists, err := c.Where("version = ?", mi.Version).Exists(mtn) 95 if err != nil { 96 return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err) 97 } 98 if exists { 99 continue 100 } 101 err = c.Transaction(func(tx *Connection) error { 102 err := mi.Run(tx) 103 if err != nil { 104 return err 105 } 106 _, err = tx.Store.Exec(fmt.Sprintf("insert into %s (version) values ('%s')", mtn, mi.Version)) 107 if err != nil { 108 return fmt.Errorf("problem inserting migration version %s: %w", mi.Version, err) 109 } 110 return nil 111 }) 112 if err != nil { 113 return err 114 } 115 log(logging.Info, "> %s", mi.Name) 116 applied++ 117 if step > 0 && applied >= step { 118 break 119 } 120 } 121 if applied == 0 { 122 log(logging.Info, "Migrations already up to date, nothing to apply") 123 } else { 124 log(logging.Info, "Successfully applied %d migrations.", applied) 125 } 126 return nil 127 }) 128 return 129 } 130 131 // Down runs pending "down" migrations and rolls back the 132 // database by the specified number of steps. 133 func (m Migrator) Down(step int) error { 134 c := m.Connection 135 return m.exec(func() error { 136 mtn := c.MigrationTableName() 137 count, err := c.Count(mtn) 138 if err != nil { 139 return fmt.Errorf("migration down: unable count existing migration: %w", err) 140 } 141 mfs := m.DownMigrations 142 mfs.Filter(func(mf Migration) bool { 143 return m.migrationIsCompatible(c.Dialect, mf) 144 }) 145 sort.Sort(mfs) 146 // skip all ran migration 147 if len(mfs.Migrations) > count { 148 mfs.Migrations = mfs.Migrations[len(mfs.Migrations)-count:] 149 } 150 // run only required steps 151 if step > 0 && len(mfs.Migrations) >= step { 152 mfs.Migrations = mfs.Migrations[:step] 153 } 154 for _, mi := range mfs.Migrations { 155 exists, err := c.Where("version = ?", mi.Version).Exists(mtn) 156 if err != nil { 157 return fmt.Errorf("problem checking for migration version %s: %w", mi.Version, err) 158 } 159 if !exists { 160 return fmt.Errorf("migration version %s does not exist", mi.Version) 161 } 162 err = c.Transaction(func(tx *Connection) error { 163 err := mi.Run(tx) 164 if err != nil { 165 return err 166 } 167 err = tx.RawQuery(fmt.Sprintf("delete from %s where version = ?", mtn), mi.Version).Exec() 168 if err != nil { 169 return fmt.Errorf("problem deleting migration version %s: %w", mi.Version, err) 170 } 171 return nil 172 }) 173 if err != nil { 174 return err 175 } 176 177 log(logging.Info, "< %s", mi.Name) 178 } 179 return nil 180 }) 181 } 182 183 // Reset the database by running the down migrations followed by the up migrations. 184 func (m Migrator) Reset() error { 185 err := m.Down(-1) 186 if err != nil { 187 return err 188 } 189 return m.Up() 190 } 191 192 // CreateSchemaMigrations sets up a table to track migrations. This is an idempotent 193 // operation. 194 func CreateSchemaMigrations(c *Connection) error { 195 mtn := c.MigrationTableName() 196 err := c.Open() 197 if err != nil { 198 return fmt.Errorf("could not open connection: %w", err) 199 } 200 _, err = c.Store.Exec(fmt.Sprintf("select * from %s", mtn)) 201 if err == nil { 202 return nil 203 } 204 205 return c.Transaction(func(tx *Connection) error { 206 schemaMigrations := newSchemaMigrations(mtn) 207 smSQL, err := c.Dialect.FizzTranslator().CreateTable(schemaMigrations) 208 if err != nil { 209 return fmt.Errorf("could not build SQL for schema migration table: %w", err) 210 } 211 err = tx.RawQuery(smSQL).Exec() 212 if err != nil { 213 return fmt.Errorf("could not execute %s: %w", smSQL, err) 214 } 215 return nil 216 }) 217 } 218 219 // CreateSchemaMigrations sets up a table to track migrations. This is an idempotent 220 // operation. 221 func (m Migrator) CreateSchemaMigrations() error { 222 return CreateSchemaMigrations(m.Connection) 223 } 224 225 // Status prints out the status of applied/pending migrations. 226 func (m Migrator) Status(out io.Writer) error { 227 err := m.CreateSchemaMigrations() 228 if err != nil { 229 return err 230 } 231 w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent) 232 _, _ = fmt.Fprintln(w, "Version\tName\tStatus\t") 233 for _, mf := range m.UpMigrations.Migrations { 234 exists, err := m.Connection.Where("version = ?", mf.Version).Exists(m.Connection.MigrationTableName()) 235 if err != nil { 236 return fmt.Errorf("problem with migration: %w", err) 237 } 238 state := "Pending" 239 if exists { 240 state = "Applied" 241 } 242 _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t\n", mf.Version, mf.Name, state) 243 } 244 return w.Flush() 245 } 246 247 // DumpMigrationSchema will generate a file of the current database schema 248 // based on the value of Migrator.SchemaPath 249 func (m Migrator) DumpMigrationSchema() error { 250 if m.SchemaPath == "" { 251 return nil 252 } 253 c := m.Connection 254 schema := filepath.Join(m.SchemaPath, "schema.sql") 255 f, err := os.Create(schema) 256 if err != nil { 257 return err 258 } 259 err = c.Dialect.DumpSchema(f) 260 if err != nil { 261 os.RemoveAll(schema) 262 return err 263 } 264 return nil 265 } 266 267 func (m Migrator) exec(fn func() error) error { 268 now := time.Now() 269 defer func() { 270 err := m.DumpMigrationSchema() 271 if err != nil { 272 log(logging.Warn, "Migrator: unable to dump schema: %v", err) 273 } 274 }() 275 defer printTimer(now) 276 277 err := m.CreateSchemaMigrations() 278 if err != nil { 279 return fmt.Errorf("Migrator: problem creating schema migrations: %w", err) 280 } 281 return fn() 282 } 283 284 func printTimer(timerStart time.Time) { 285 diff := time.Since(timerStart).Seconds() 286 if diff > 60 { 287 log(logging.Info, "%.4f minutes", diff/60) 288 } else { 289 log(logging.Info, "%.4f seconds", diff) 290 } 291 }