github.com/CloudCom/goose@v0.0.0-20151110184009-e03c3249c21b/lib/goose/migrate.go (about)

     1  package goose
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"log"
     8  	"os"
     9  	"path/filepath"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"text/template"
    14  	"time"
    15  )
    16  
    17  var (
    18  	ErrTableDoesNotExist = errors.New("table does not exist")
    19  	ErrNoPreviousVersion = errors.New("no previous version found")
    20  )
    21  
    22  type Direction bool
    23  
    24  func (d Direction) String() string {
    25  	if d == DirectionUp {
    26  		return "up"
    27  	} else {
    28  		return "down"
    29  	}
    30  }
    31  
    32  const (
    33  	DirectionDown = Direction(false)
    34  	DirectionUp   = Direction(true)
    35  )
    36  
    37  //go:generate sh -c "go get github.com/jteeuwen/go-bindata/go-bindata && go-bindata -pkg goose -o templates.go -nometadata -nocompress ./templates && gofmt -w templates.go"
    38  var goMigrationDriverTemplate = template.Must(template.New("").Parse(string(_templatesMigrationMainGoTmpl)))
    39  var goMigrationTemplate = template.Must(template.New("").Parse(string(_templatesMigrationGoTmpl)))
    40  var sqlMigrationTemplate = template.Must(template.New("").Parse(string(_templatesMigrationSqlTmpl)))
    41  
    42  type Migration struct {
    43  	Version   int64
    44  	IsApplied bool
    45  	TStamp    time.Time
    46  	Source    string // path to .go or .sql script
    47  }
    48  
    49  type migrationSorter []*Migration
    50  
    51  // helpers so we can use pkg sort
    52  func (ms migrationSorter) Len() int           { return len(ms) }
    53  func (ms migrationSorter) Swap(i, j int)      { ms[i], ms[j] = ms[j], ms[i] }
    54  func (ms migrationSorter) Less(i, j int) bool { return ms[i].Version < ms[j].Version }
    55  
    56  func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) {
    57  	db, err := OpenDBFromDBConf(conf)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	defer db.Close()
    62  
    63  	return RunMigrationsOnDb(conf, migrationsDir, target, db)
    64  }
    65  
    66  // Runs migration on a specific database instance.
    67  func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) {
    68  	//TODO get rid of migrationsDir, it's already in conf.MigrationsDir
    69  	current, err := EnsureDBVersion(conf, db)
    70  	if err != nil {
    71  		return err
    72  	}
    73  
    74  	migrations, err := CollectMigrations(migrationsDir)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	if err := getMigrationsStatus(conf, db, migrations); err != nil {
    80  		return err
    81  	}
    82  
    83  	direction := DirectionUp
    84  	if target < current {
    85  		direction = DirectionDown
    86  	}
    87  
    88  	var neededMigrations []*Migration
    89  	for _, m := range migrations {
    90  		if direction == DirectionUp {
    91  			if m.Version > target {
    92  				continue
    93  			}
    94  			if m.IsApplied {
    95  				continue
    96  			}
    97  		} else {
    98  			if m.Version <= target {
    99  				continue
   100  			}
   101  			if !m.IsApplied {
   102  				continue
   103  			}
   104  		}
   105  		neededMigrations = append(neededMigrations, m)
   106  	}
   107  
   108  	if len(neededMigrations) == 0 {
   109  		fmt.Printf("goose: no migrations to run. current version: %d, target: %d\n", current, target)
   110  		return nil
   111  	}
   112  
   113  	fmt.Printf("goose: migrating db, current version: %d, target: %d\n", current, target)
   114  
   115  	ms := migrationSorter(neededMigrations)
   116  	if direction == DirectionUp {
   117  		sort.Sort(ms)
   118  	} else {
   119  		sort.Sort(sort.Reverse(ms))
   120  	}
   121  
   122  	for _, m := range ms {
   123  		switch filepath.Ext(m.Source) {
   124  		case ".go":
   125  			err = runGoMigration(conf, m.Source, m.Version, direction)
   126  		case ".sql":
   127  			err = runSQLMigration(conf, db, m.Source, m.Version, direction)
   128  		}
   129  
   130  		if err != nil {
   131  			return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err))
   132  		}
   133  
   134  		fmt.Println("OK   ", filepath.Base(m.Source))
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  // collect all the valid looking migration scripts in the
   141  // migrations folder, and key them by version
   142  func CollectMigrations(dirpath string) (m []*Migration, err error) {
   143  	// extract the numeric component of each migration,
   144  	// filter out any uninteresting files,
   145  	// and ensure we only have one file per migration version.
   146  	filepath.Walk(dirpath, func(name string, info os.FileInfo, err error) error {
   147  
   148  		if v, e := NumericComponent(name); e == nil {
   149  
   150  			for _, g := range m {
   151  				if v == g.Version {
   152  					log.Fatalf("more than one file specifies the migration for version %d (%s and %s)",
   153  						v, g.Source, filepath.Join(dirpath, name))
   154  				}
   155  			}
   156  
   157  			m = append(m, &Migration{Version: v, Source: name})
   158  		}
   159  
   160  		return nil
   161  	})
   162  
   163  	return m, nil
   164  }
   165  
   166  // look for migration scripts with names in the form:
   167  //  XXX_descriptivename.ext
   168  // where XXX specifies the version number
   169  // and ext specifies the type of migration
   170  func NumericComponent(name string) (int64, error) {
   171  	base := filepath.Base(name)
   172  
   173  	if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
   174  		return 0, errors.New("not a recognized migration file type")
   175  	}
   176  
   177  	idx := strings.Index(base, "_")
   178  	if idx < 0 {
   179  		return 0, errors.New("no separator found")
   180  	}
   181  
   182  	n, e := strconv.ParseInt(base[:idx], 10, 64)
   183  	if e == nil && n <= 0 {
   184  		return 0, errors.New("migration IDs must be greater than zero")
   185  	}
   186  
   187  	return n, e
   188  }
   189  
   190  func getMigrationsStatus(conf *DBConf, db *sql.DB, migrations []*Migration) error {
   191  	rows, err := conf.Driver.Dialect.dbVersionQuery(db)
   192  	if err != nil {
   193  		if err == ErrTableDoesNotExist {
   194  			for _, m := range migrations {
   195  				m.IsApplied = false
   196  			}
   197  			return nil
   198  		}
   199  		return fmt.Errorf("getting db version: %s", err)
   200  	}
   201  	defer rows.Close()
   202  
   203  	mm := map[int64]*Migration{}
   204  	for _, m := range migrations {
   205  		mm[m.Version] = m
   206  		// default to false so if the DB doesn't know about the migration...
   207  		m.IsApplied = false
   208  	}
   209  
   210  	for rows.Next() {
   211  		var row Migration
   212  		if err = rows.Scan(&row.Version, &row.IsApplied, &row.TStamp); err != nil {
   213  			log.Fatal("error scanning rows:", err)
   214  		}
   215  
   216  		m, ok := mm[row.Version]
   217  		if !ok {
   218  			continue
   219  		}
   220  		if !row.TStamp.After(m.TStamp) {
   221  			// If the migration went up, then down, it'll have multiple rows.
   222  			// But we only want the newest, so skip this row if it's older.
   223  			continue
   224  		}
   225  		m.IsApplied = row.IsApplied
   226  		m.TStamp = row.TStamp
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  // retrieve the current version for this DB.
   233  // Create and initialize the DB version table if it doesn't exist.
   234  func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
   235  	rows, err := conf.Driver.Dialect.dbVersionQuery(db)
   236  	if err != nil {
   237  		if err == ErrTableDoesNotExist {
   238  			return 0, createVersionTable(conf, db)
   239  		}
   240  		return 0, fmt.Errorf("getting db version: %#v", err)
   241  	}
   242  	defer rows.Close()
   243  
   244  	// The most recent record for each migration specifies
   245  	// whether it has been applied or rolled back.
   246  	// The first version we find that has been applied is the current version.
   247  
   248  	toSkip := make([]int64, 0)
   249  
   250  	for rows.Next() {
   251  		var row Migration
   252  		if err = rows.Scan(&row.Version, &row.IsApplied, &row.TStamp); err != nil {
   253  			log.Fatal("error scanning rows:", err)
   254  		}
   255  
   256  		// have we already marked this version to be skipped?
   257  		skip := false
   258  		for _, v := range toSkip {
   259  			if v == row.Version {
   260  				skip = true
   261  				break
   262  			}
   263  		}
   264  
   265  		if skip {
   266  			continue
   267  		}
   268  
   269  		// if version has been applied we're done
   270  		if row.IsApplied {
   271  			return row.Version, nil
   272  		}
   273  
   274  		// latest version of migration has not been applied.
   275  		toSkip = append(toSkip, row.Version)
   276  	}
   277  
   278  	panic("failure in EnsureDBVersion()")
   279  }
   280  
   281  // Create the goose_db_version table
   282  // and insert the initial 0 value into it
   283  func createVersionTable(conf *DBConf, db *sql.DB) error {
   284  	txn, err := db.Begin()
   285  	if err != nil {
   286  		return err
   287  	}
   288  
   289  	d := conf.Driver.Dialect
   290  
   291  	if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
   292  		txn.Rollback()
   293  		return fmt.Errorf("creating migration table: %s", err)
   294  	}
   295  
   296  	version := 0
   297  	applied := true
   298  	if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
   299  		txn.Rollback()
   300  		return fmt.Errorf("inserting first migration: %s", err)
   301  	}
   302  
   303  	return txn.Commit()
   304  }
   305  
   306  // wrapper for EnsureDBVersion for callers that don't already have
   307  // their own DB instance
   308  func GetDBVersion(conf *DBConf) (version int64, err error) {
   309  	db, err := OpenDBFromDBConf(conf)
   310  	if err != nil {
   311  		return -1, err
   312  	}
   313  	defer db.Close()
   314  
   315  	version, err = EnsureDBVersion(conf, db)
   316  	if err != nil {
   317  		return -1, err
   318  	}
   319  
   320  	return version, nil
   321  }
   322  
   323  func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) {
   324  	previous = -1
   325  	sawGivenVersion := false
   326  
   327  	filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
   328  
   329  		if !info.IsDir() {
   330  			if v, e := NumericComponent(name); e == nil {
   331  				if v > previous && v < version {
   332  					previous = v
   333  				}
   334  				if v == version {
   335  					sawGivenVersion = true
   336  				}
   337  			}
   338  		}
   339  
   340  		return nil
   341  	})
   342  
   343  	if previous == -1 {
   344  		if sawGivenVersion {
   345  			// the given version is (likely) valid but we didn't find
   346  			// anything before it.
   347  			// 'previous' must reflect that no migrations have been applied.
   348  			previous = 0
   349  		} else {
   350  			err = ErrNoPreviousVersion
   351  		}
   352  	}
   353  
   354  	return
   355  }
   356  
   357  // helper to identify the most recent possible version
   358  // within a folder of migration scripts
   359  func GetMostRecentDBVersion(dirpath string) (version int64, err error) {
   360  	version = -1
   361  
   362  	filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
   363  		if walkerr != nil {
   364  			return walkerr
   365  		}
   366  
   367  		if !info.IsDir() {
   368  			if v, e := NumericComponent(name); e == nil {
   369  				if v > version {
   370  					version = v
   371  				}
   372  			}
   373  		}
   374  
   375  		return nil
   376  	})
   377  
   378  	if version == -1 {
   379  		err = errors.New("no valid version found")
   380  	}
   381  
   382  	return
   383  }
   384  
   385  func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
   386  	if migrationType != "go" && migrationType != "sql" {
   387  		return "", errors.New("migration type must be 'go' or 'sql'")
   388  	}
   389  
   390  	timestamp := t.Format("20060102150405")
   391  	filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType)
   392  
   393  	fpath := filepath.Join(dir, filename)
   394  
   395  	var tmpl *template.Template
   396  	if migrationType == "sql" {
   397  		tmpl = sqlMigrationTemplate
   398  	} else {
   399  		tmpl = goMigrationTemplate
   400  	}
   401  
   402  	path, err = writeTemplateToFile(fpath, tmpl, timestamp)
   403  
   404  	return
   405  }
   406  
   407  // Update the version table for the given migration,
   408  // and finalize the transaction.
   409  func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction Direction, v int64) error {
   410  	// XXX: drop goose_db_version table on some minimum version number?
   411  	stmt := conf.Driver.Dialect.insertVersionSql()
   412  	if _, err := txn.Exec(stmt, v, bool(direction)); err != nil {
   413  		txn.Rollback()
   414  		return err
   415  	}
   416  
   417  	return txn.Commit()
   418  }