github.com/CloudCom/goose@v0.0.0-20151110184009-e03c3249c21b/lib/goose/migrate.go (about) 1 package goose 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "log" 8 "os" 9 "path/filepath" 10 "sort" 11 "strconv" 12 "strings" 13 "text/template" 14 "time" 15 ) 16 17 var ( 18 ErrTableDoesNotExist = errors.New("table does not exist") 19 ErrNoPreviousVersion = errors.New("no previous version found") 20 ) 21 22 type Direction bool 23 24 func (d Direction) String() string { 25 if d == DirectionUp { 26 return "up" 27 } else { 28 return "down" 29 } 30 } 31 32 const ( 33 DirectionDown = Direction(false) 34 DirectionUp = Direction(true) 35 ) 36 37 //go:generate sh -c "go get github.com/jteeuwen/go-bindata/go-bindata && go-bindata -pkg goose -o templates.go -nometadata -nocompress ./templates && gofmt -w templates.go" 38 var goMigrationDriverTemplate = template.Must(template.New("").Parse(string(_templatesMigrationMainGoTmpl))) 39 var goMigrationTemplate = template.Must(template.New("").Parse(string(_templatesMigrationGoTmpl))) 40 var sqlMigrationTemplate = template.Must(template.New("").Parse(string(_templatesMigrationSqlTmpl))) 41 42 type Migration struct { 43 Version int64 44 IsApplied bool 45 TStamp time.Time 46 Source string // path to .go or .sql script 47 } 48 49 type migrationSorter []*Migration 50 51 // helpers so we can use pkg sort 52 func (ms migrationSorter) Len() int { return len(ms) } 53 func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] } 54 func (ms migrationSorter) Less(i, j int) bool { return ms[i].Version < ms[j].Version } 55 56 func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) { 57 db, err := OpenDBFromDBConf(conf) 58 if err != nil { 59 return err 60 } 61 defer db.Close() 62 63 return RunMigrationsOnDb(conf, migrationsDir, target, db) 64 } 65 66 // Runs migration on a specific database instance. 67 func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) { 68 //TODO get rid of migrationsDir, it's already in conf.MigrationsDir 69 current, err := EnsureDBVersion(conf, db) 70 if err != nil { 71 return err 72 } 73 74 migrations, err := CollectMigrations(migrationsDir) 75 if err != nil { 76 return err 77 } 78 79 if err := getMigrationsStatus(conf, db, migrations); err != nil { 80 return err 81 } 82 83 direction := DirectionUp 84 if target < current { 85 direction = DirectionDown 86 } 87 88 var neededMigrations []*Migration 89 for _, m := range migrations { 90 if direction == DirectionUp { 91 if m.Version > target { 92 continue 93 } 94 if m.IsApplied { 95 continue 96 } 97 } else { 98 if m.Version <= target { 99 continue 100 } 101 if !m.IsApplied { 102 continue 103 } 104 } 105 neededMigrations = append(neededMigrations, m) 106 } 107 108 if len(neededMigrations) == 0 { 109 fmt.Printf("goose: no migrations to run. current version: %d, target: %d\n", current, target) 110 return nil 111 } 112 113 fmt.Printf("goose: migrating db, current version: %d, target: %d\n", current, target) 114 115 ms := migrationSorter(neededMigrations) 116 if direction == DirectionUp { 117 sort.Sort(ms) 118 } else { 119 sort.Sort(sort.Reverse(ms)) 120 } 121 122 for _, m := range ms { 123 switch filepath.Ext(m.Source) { 124 case ".go": 125 err = runGoMigration(conf, m.Source, m.Version, direction) 126 case ".sql": 127 err = runSQLMigration(conf, db, m.Source, m.Version, direction) 128 } 129 130 if err != nil { 131 return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) 132 } 133 134 fmt.Println("OK ", filepath.Base(m.Source)) 135 } 136 137 return nil 138 } 139 140 // collect all the valid looking migration scripts in the 141 // migrations folder, and key them by version 142 func CollectMigrations(dirpath string) (m []*Migration, err error) { 143 // extract the numeric component of each migration, 144 // filter out any uninteresting files, 145 // and ensure we only have one file per migration version. 146 filepath.Walk(dirpath, func(name string, info os.FileInfo, err error) error { 147 148 if v, e := NumericComponent(name); e == nil { 149 150 for _, g := range m { 151 if v == g.Version { 152 log.Fatalf("more than one file specifies the migration for version %d (%s and %s)", 153 v, g.Source, filepath.Join(dirpath, name)) 154 } 155 } 156 157 m = append(m, &Migration{Version: v, Source: name}) 158 } 159 160 return nil 161 }) 162 163 return m, nil 164 } 165 166 // look for migration scripts with names in the form: 167 // XXX_descriptivename.ext 168 // where XXX specifies the version number 169 // and ext specifies the type of migration 170 func NumericComponent(name string) (int64, error) { 171 base := filepath.Base(name) 172 173 if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { 174 return 0, errors.New("not a recognized migration file type") 175 } 176 177 idx := strings.Index(base, "_") 178 if idx < 0 { 179 return 0, errors.New("no separator found") 180 } 181 182 n, e := strconv.ParseInt(base[:idx], 10, 64) 183 if e == nil && n <= 0 { 184 return 0, errors.New("migration IDs must be greater than zero") 185 } 186 187 return n, e 188 } 189 190 func getMigrationsStatus(conf *DBConf, db *sql.DB, migrations []*Migration) error { 191 rows, err := conf.Driver.Dialect.dbVersionQuery(db) 192 if err != nil { 193 if err == ErrTableDoesNotExist { 194 for _, m := range migrations { 195 m.IsApplied = false 196 } 197 return nil 198 } 199 return fmt.Errorf("getting db version: %s", err) 200 } 201 defer rows.Close() 202 203 mm := map[int64]*Migration{} 204 for _, m := range migrations { 205 mm[m.Version] = m 206 // default to false so if the DB doesn't know about the migration... 207 m.IsApplied = false 208 } 209 210 for rows.Next() { 211 var row Migration 212 if err = rows.Scan(&row.Version, &row.IsApplied, &row.TStamp); err != nil { 213 log.Fatal("error scanning rows:", err) 214 } 215 216 m, ok := mm[row.Version] 217 if !ok { 218 continue 219 } 220 if !row.TStamp.After(m.TStamp) { 221 // If the migration went up, then down, it'll have multiple rows. 222 // But we only want the newest, so skip this row if it's older. 223 continue 224 } 225 m.IsApplied = row.IsApplied 226 m.TStamp = row.TStamp 227 } 228 229 return nil 230 } 231 232 // retrieve the current version for this DB. 233 // Create and initialize the DB version table if it doesn't exist. 234 func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) { 235 rows, err := conf.Driver.Dialect.dbVersionQuery(db) 236 if err != nil { 237 if err == ErrTableDoesNotExist { 238 return 0, createVersionTable(conf, db) 239 } 240 return 0, fmt.Errorf("getting db version: %#v", err) 241 } 242 defer rows.Close() 243 244 // The most recent record for each migration specifies 245 // whether it has been applied or rolled back. 246 // The first version we find that has been applied is the current version. 247 248 toSkip := make([]int64, 0) 249 250 for rows.Next() { 251 var row Migration 252 if err = rows.Scan(&row.Version, &row.IsApplied, &row.TStamp); err != nil { 253 log.Fatal("error scanning rows:", err) 254 } 255 256 // have we already marked this version to be skipped? 257 skip := false 258 for _, v := range toSkip { 259 if v == row.Version { 260 skip = true 261 break 262 } 263 } 264 265 if skip { 266 continue 267 } 268 269 // if version has been applied we're done 270 if row.IsApplied { 271 return row.Version, nil 272 } 273 274 // latest version of migration has not been applied. 275 toSkip = append(toSkip, row.Version) 276 } 277 278 panic("failure in EnsureDBVersion()") 279 } 280 281 // Create the goose_db_version table 282 // and insert the initial 0 value into it 283 func createVersionTable(conf *DBConf, db *sql.DB) error { 284 txn, err := db.Begin() 285 if err != nil { 286 return err 287 } 288 289 d := conf.Driver.Dialect 290 291 if _, err := txn.Exec(d.createVersionTableSql()); err != nil { 292 txn.Rollback() 293 return fmt.Errorf("creating migration table: %s", err) 294 } 295 296 version := 0 297 applied := true 298 if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil { 299 txn.Rollback() 300 return fmt.Errorf("inserting first migration: %s", err) 301 } 302 303 return txn.Commit() 304 } 305 306 // wrapper for EnsureDBVersion for callers that don't already have 307 // their own DB instance 308 func GetDBVersion(conf *DBConf) (version int64, err error) { 309 db, err := OpenDBFromDBConf(conf) 310 if err != nil { 311 return -1, err 312 } 313 defer db.Close() 314 315 version, err = EnsureDBVersion(conf, db) 316 if err != nil { 317 return -1, err 318 } 319 320 return version, nil 321 } 322 323 func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) { 324 previous = -1 325 sawGivenVersion := false 326 327 filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { 328 329 if !info.IsDir() { 330 if v, e := NumericComponent(name); e == nil { 331 if v > previous && v < version { 332 previous = v 333 } 334 if v == version { 335 sawGivenVersion = true 336 } 337 } 338 } 339 340 return nil 341 }) 342 343 if previous == -1 { 344 if sawGivenVersion { 345 // the given version is (likely) valid but we didn't find 346 // anything before it. 347 // 'previous' must reflect that no migrations have been applied. 348 previous = 0 349 } else { 350 err = ErrNoPreviousVersion 351 } 352 } 353 354 return 355 } 356 357 // helper to identify the most recent possible version 358 // within a folder of migration scripts 359 func GetMostRecentDBVersion(dirpath string) (version int64, err error) { 360 version = -1 361 362 filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { 363 if walkerr != nil { 364 return walkerr 365 } 366 367 if !info.IsDir() { 368 if v, e := NumericComponent(name); e == nil { 369 if v > version { 370 version = v 371 } 372 } 373 } 374 375 return nil 376 }) 377 378 if version == -1 { 379 err = errors.New("no valid version found") 380 } 381 382 return 383 } 384 385 func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { 386 if migrationType != "go" && migrationType != "sql" { 387 return "", errors.New("migration type must be 'go' or 'sql'") 388 } 389 390 timestamp := t.Format("20060102150405") 391 filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType) 392 393 fpath := filepath.Join(dir, filename) 394 395 var tmpl *template.Template 396 if migrationType == "sql" { 397 tmpl = sqlMigrationTemplate 398 } else { 399 tmpl = goMigrationTemplate 400 } 401 402 path, err = writeTemplateToFile(fpath, tmpl, timestamp) 403 404 return 405 } 406 407 // Update the version table for the given migration, 408 // and finalize the transaction. 409 func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction Direction, v int64) error { 410 // XXX: drop goose_db_version table on some minimum version number? 411 stmt := conf.Driver.Dialect.insertVersionSql() 412 if _, err := txn.Exec(stmt, v, bool(direction)); err != nil { 413 txn.Rollback() 414 return err 415 } 416 417 return txn.Commit() 418 }