github.com/marinho/drone@v0.2.1-0.20140504195434-d3ba962e89a7/pkg/database/migrate/migrate.go (about)

     1  package migrate
     2  
     3  import (
     4  	"database/sql"
     5  	"log"
     6  )
     7  
     8  const migrationTableStmt = `
     9  CREATE TABLE IF NOT EXISTS migration (
    10  	revision BIGINT PRIMARY KEY
    11  )
    12  `
    13  
    14  const migrationSelectStmt = `
    15  SELECT revision FROM migration
    16  WHERE revision = ?
    17  `
    18  
    19  const migrationSelectMaxStmt = `
    20  SELECT max(revision) FROM migration
    21  `
    22  
    23  const insertRevisionStmt = `
    24  INSERT INTO migration (revision) VALUES (?)
    25  `
    26  
    27  const deleteRevisionStmt = `
    28  DELETE FROM migration where revision = ?
    29  `
    30  
    31  type Revision interface {
    32  	Up(mg *MigrationDriver) error
    33  	Down(mg *MigrationDriver) error
    34  	Revision() int64
    35  }
    36  
    37  type Migration struct {
    38  	db   *sql.DB
    39  	revs []Revision
    40  }
    41  
    42  var Driver DriverBuilder
    43  
    44  func New(db *sql.DB) *Migration {
    45  	return &Migration{db: db}
    46  }
    47  
    48  // Add the Revision to the list of migrations.
    49  func (m *Migration) Add(rev ...Revision) *Migration {
    50  	m.revs = append(m.revs, rev...)
    51  	return m
    52  }
    53  
    54  // Migrate executes the full list of migrations.
    55  func (m *Migration) Migrate() error {
    56  	var target int64
    57  	if len(m.revs) > 0 {
    58  		// get the last revision number in
    59  		// the list. This is what we'll
    60  		// migrate toward.
    61  		target = m.revs[len(m.revs)-1].Revision()
    62  	}
    63  	return m.MigrateTo(target)
    64  }
    65  
    66  // MigrateTo executes all database migration until
    67  // you are at the specified revision number.
    68  // If the revision number is less than the
    69  // current revision, then we will downgrade.
    70  func (m *Migration) MigrateTo(target int64) error {
    71  
    72  	// make sure the migration table is created.
    73  	if _, err := m.db.Exec(migrationTableStmt); err != nil {
    74  		return err
    75  	}
    76  
    77  	// get the current revision
    78  	var current int64
    79  	m.db.QueryRow(migrationSelectMaxStmt).Scan(&current)
    80  
    81  	// already up to date
    82  	if current == target {
    83  		log.Println("Database already up-to-date.")
    84  		return nil
    85  	}
    86  
    87  	// should we downgrade?
    88  	if target < current {
    89  		return m.down(target, current)
    90  	}
    91  
    92  	// else upgrade
    93  	return m.up(target, current)
    94  }
    95  
    96  func (m *Migration) up(target, current int64) error {
    97  	// create the database transaction
    98  	tx, err := m.db.Begin()
    99  	if err != nil {
   100  		return err
   101  	}
   102  
   103  	mg := Driver(tx)
   104  
   105  	// loop through and execute revisions
   106  	for _, rev := range m.revs {
   107  		if rev.Revision() > current && rev.Revision() <= target {
   108  			current = rev.Revision()
   109  			// execute the revision Upgrade.
   110  			if err := rev.Up(mg); err != nil {
   111  				log.Printf("Failed to upgrade to Revision Number %v\n", current)
   112  				log.Println(err)
   113  				return tx.Rollback()
   114  			}
   115  			// update the revision number in the database
   116  			if _, err := tx.Exec(insertRevisionStmt, current); err != nil {
   117  				log.Printf("Failed to register Revision Number %v\n", current)
   118  				log.Println(err)
   119  				return tx.Rollback()
   120  			}
   121  
   122  			log.Printf("Successfully upgraded to Revision %v\n", current)
   123  		}
   124  	}
   125  
   126  	return tx.Commit()
   127  }
   128  
   129  func (m *Migration) down(target, current int64) error {
   130  	// create the database transaction
   131  	tx, err := m.db.Begin()
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	mg := Driver(tx)
   137  
   138  	// reverse the list of revisions
   139  	revs := []Revision{}
   140  	for _, rev := range m.revs {
   141  		revs = append([]Revision{rev}, revs...)
   142  	}
   143  
   144  	// loop through the (reversed) list of
   145  	// revisions and execute.
   146  	for _, rev := range revs {
   147  		if rev.Revision() > target {
   148  			current = rev.Revision()
   149  			// execute the revision Upgrade.
   150  			if err := rev.Down(mg); err != nil {
   151  				log.Printf("Failed to downgrade from Revision Number %v\n", current)
   152  				log.Println(err)
   153  				return tx.Rollback()
   154  			}
   155  			// update the revision number in the database
   156  			if _, err := tx.Exec(deleteRevisionStmt, current); err != nil {
   157  				log.Printf("Failed to unregistser Revision Number %v\n", current)
   158  				log.Println(err)
   159  				return tx.Rollback()
   160  			}
   161  
   162  			log.Printf("Successfully downgraded from Revision %v\n", current)
   163  		}
   164  	}
   165  
   166  	return tx.Commit()
   167  }