github.com/woremacx/kocha@v0.7.1-0.20150731103243-a5889322afc9/db.go (about)

     1  package kocha
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"reflect"
     9  	"regexp"
    10  	"sort"
    11  
    12  	"github.com/naoina/genmai"
    13  )
    14  
    15  type DatabaseMap map[string]DatabaseConfig
    16  
    17  // DatabaseConfig represents a configuration of the database.
    18  type DatabaseConfig struct {
    19  	// name of database driver such as "mysql".
    20  	Driver string
    21  
    22  	// Data Source Name.
    23  	// e.g. such as "travis@/db_name".
    24  	DSN string
    25  }
    26  
    27  var TxTypeMap = make(map[string]Transactioner)
    28  
    29  // Transactioner is an interface for a transaction type.
    30  type Transactioner interface {
    31  	// ImportPath returns the import path to import to use a transaction.
    32  	// Usually, import path of ORM like "github.com/naoina/genmai".
    33  	ImportPath() string
    34  
    35  	// TransactionType returns an object of a transaction type.
    36  	// Return value will only used to determine an argument type for the
    37  	// methods of migration when generate the template.
    38  	TransactionType() interface{}
    39  
    40  	// Begin starts a transaction from driver name and data source name.
    41  	Begin(driverName, dsn string) (tx interface{}, err error)
    42  
    43  	// Commit commits the transaction.
    44  	Commit() error
    45  
    46  	// Rollback rollbacks the transaction.
    47  	Rollback() error
    48  }
    49  
    50  // RegisterTransactionType registers a transaction type.
    51  // If already registered, it overwrites.
    52  func RegisterTransactionType(name string, tx Transactioner) {
    53  	TxTypeMap[name] = tx
    54  }
    55  
    56  // GenmaiTransaction implements Transactioner interface.
    57  type GenmaiTransaction struct {
    58  	tx *genmai.DB
    59  }
    60  
    61  // ImportPath returns the import path of Genmai.
    62  func (t *GenmaiTransaction) ImportPath() string {
    63  	return "github.com/naoina/genmai"
    64  }
    65  
    66  // TransactionType returns the transaction type of Genmai.
    67  func (t *GenmaiTransaction) TransactionType() interface{} {
    68  	return &genmai.DB{}
    69  }
    70  
    71  // Begin starts a transaction of Genmai.
    72  // If unsupported driver name given or any error occurred, it returns nil and error.
    73  func (t *GenmaiTransaction) Begin(driverName, dsn string) (tx interface{}, err error) {
    74  	var d genmai.Dialect
    75  	switch driverName {
    76  	case "mysql":
    77  		d = &genmai.MySQLDialect{}
    78  	case "postgres":
    79  		d = &genmai.PostgresDialect{}
    80  	case "sqlite3":
    81  		d = &genmai.SQLite3Dialect{}
    82  	default:
    83  		return nil, fmt.Errorf("kocha: migration: genmai: unsupported driver type `%v'", driverName)
    84  	}
    85  	t.tx, err = genmai.New(d, dsn)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	if err := t.tx.Begin(); err != nil {
    90  		return nil, err
    91  	}
    92  	return t.tx, nil
    93  }
    94  
    95  // Commit commits the transaction of Genmai.
    96  func (t *GenmaiTransaction) Commit() error {
    97  	return t.tx.Commit()
    98  }
    99  
   100  // Rollback rollbacks the transaction of Genmai.
   101  func (t *GenmaiTransaction) Rollback() error {
   102  	return t.tx.Rollback()
   103  }
   104  
   105  var (
   106  	upMethodRegexp   = regexp.MustCompile(`\AUp_(\d{14})_\w+\z`)
   107  	downMethodRegexp = regexp.MustCompile(`\ADown_(\d{14})_\w+\z`)
   108  )
   109  
   110  const MigrationTableName = "schema_migration"
   111  
   112  type Migration struct {
   113  	config DatabaseConfig
   114  	m      interface{}
   115  }
   116  
   117  func Migrate(config DatabaseConfig, m interface{}) *Migration {
   118  	return &Migration{
   119  		config: config,
   120  		m:      m,
   121  	}
   122  }
   123  
   124  func (mig *Migration) Up(limit int) error {
   125  	if err := mig.transaction(func(tx *sql.Tx) error {
   126  		stmt, err := tx.Prepare(fmt.Sprintf(
   127  			`CREATE TABLE IF NOT EXISTS %s (version varchar(255) PRIMARY KEY)`,
   128  			MigrationTableName))
   129  		if err != nil {
   130  			return err
   131  		}
   132  		defer stmt.Close()
   133  		if _, err := stmt.Exec(); err != nil {
   134  			return err
   135  		}
   136  		return nil
   137  	}); err != nil {
   138  		return err
   139  	}
   140  	var version string
   141  	if err := mig.transaction(func(tx *sql.Tx) error {
   142  		stmt, err := tx.Prepare(fmt.Sprintf(`SELECT version FROM %s ORDER BY version DESC LIMIT 1`, MigrationTableName))
   143  		if err != nil {
   144  			return err
   145  		}
   146  		defer stmt.Close()
   147  		if err := stmt.QueryRow().Scan(&version); err != nil && err != sql.ErrNoRows {
   148  			return err
   149  		}
   150  		return nil
   151  	}); err != nil {
   152  		return err
   153  	}
   154  	minfos, err := mig.collectInfos(upMethodRegexp, func(p string) bool {
   155  		return p > version
   156  	})
   157  	if err != nil {
   158  		return err
   159  	}
   160  	limit = int(math.Min(float64(limit), float64(len(minfos))))
   161  	if limit < 0 {
   162  		limit = len(minfos)
   163  	}
   164  	if len(minfos[:limit]) < 1 {
   165  		fmt.Fprintf(os.Stderr, "kocha: migrate: there is no need to migrate.\n")
   166  		return nil
   167  	}
   168  	sort.Sort(migrationInfoSlice(minfos))
   169  	return mig.run("migrating", minfos[:limit], func(version string) {
   170  		if err := mig.transaction(func(tx *sql.Tx) error {
   171  			stmt, err := tx.Prepare(fmt.Sprintf(`INSERT INTO %s (version) VALUES (?)`, MigrationTableName))
   172  			if err != nil {
   173  				return err
   174  			}
   175  			defer stmt.Close()
   176  			if _, err := stmt.Exec(version); err != nil {
   177  				return err
   178  			}
   179  			return nil
   180  		}); err != nil {
   181  			panic(err)
   182  		}
   183  	})
   184  }
   185  
   186  func (mig *Migration) Down(limit int) error {
   187  	var positions []string
   188  	if err := mig.transaction(func(tx *sql.Tx) error {
   189  		stmt, err := tx.Prepare(fmt.Sprintf(`SELECT version FROM %s ORDER BY version DESC LIMIT ?`, MigrationTableName))
   190  		if err != nil {
   191  			return err
   192  		}
   193  		defer stmt.Close()
   194  		if limit < 1 {
   195  			limit = 1
   196  		}
   197  		rows, err := stmt.Query(limit)
   198  		if err != nil {
   199  			return err
   200  		}
   201  		defer rows.Close()
   202  		for rows.Next() {
   203  			var version string
   204  			if err := rows.Scan(&version); err != nil {
   205  				return err
   206  			}
   207  			positions = append(positions, version)
   208  		}
   209  		return rows.Err()
   210  	}); err != nil {
   211  		return err
   212  	}
   213  	minfos, err := mig.collectInfos(downMethodRegexp, func(p string) bool {
   214  		for _, version := range positions {
   215  			if p == version {
   216  				return true
   217  			}
   218  		}
   219  		return false
   220  	})
   221  	if err != nil {
   222  		return err
   223  	}
   224  	if len(minfos) < 1 {
   225  		fmt.Fprintf(os.Stderr, "kocha: migrate: there is no need to migrate.\n")
   226  		return nil
   227  	}
   228  	sort.Sort(sort.Reverse(migrationInfoSlice(minfos)))
   229  	return mig.run("rollback", minfos, func(version string) {
   230  		if err := mig.transaction(func(tx *sql.Tx) error {
   231  			stmt, err := tx.Prepare(fmt.Sprintf(`DELETE FROM %s WHERE version = ?`, MigrationTableName))
   232  			if err != nil {
   233  				return err
   234  			}
   235  			defer stmt.Close()
   236  			if _, err := stmt.Exec(version); err != nil {
   237  				return err
   238  			}
   239  			return nil
   240  		}); err != nil {
   241  			panic(err)
   242  		}
   243  	})
   244  }
   245  
   246  func (mig *Migration) transaction(f func(tx *sql.Tx) error) error {
   247  	db, err := sql.Open(mig.config.Driver, mig.config.DSN)
   248  	if err != nil {
   249  		return err
   250  	}
   251  	defer db.Close()
   252  	tx, err := db.Begin()
   253  	if err != nil {
   254  		return err
   255  	}
   256  	defer func() {
   257  		if err := recover(); err != nil {
   258  			tx.Rollback()
   259  			panic(err)
   260  		}
   261  		tx.Commit()
   262  	}()
   263  	if err := f(tx); err != nil {
   264  		return tx.Rollback()
   265  	}
   266  	return tx.Commit()
   267  }
   268  
   269  func (mig *Migration) run(msg string, minfos []migrationInfo, afterFunc func(version string)) error {
   270  	v := reflect.ValueOf(mig.m)
   271  	for _, mi := range minfos {
   272  		func(mi migrationInfo) {
   273  			tx, err := mi.tx.Begin(mig.config.Driver, mig.config.DSN)
   274  			if err != nil {
   275  				panic(err)
   276  			}
   277  			defer func() {
   278  				if err := recover(); err != nil {
   279  					mi.tx.Rollback()
   280  					panic(err)
   281  				}
   282  				mi.tx.Commit()
   283  			}()
   284  			fmt.Printf("%v by %v...\n", msg, mi.methodName)
   285  			meth := v.MethodByName(mi.methodName)
   286  			meth.Call([]reflect.Value{reflect.ValueOf(tx)})
   287  		}(mi)
   288  		afterFunc(mi.version)
   289  	}
   290  	return nil
   291  }
   292  
   293  func (mig *Migration) collectInfos(r *regexp.Regexp, isTarget func(string) bool) ([]migrationInfo, error) {
   294  	v := reflect.ValueOf(mig.m)
   295  	t := v.Type()
   296  	var minfos []migrationInfo
   297  	for i := 0; i < t.NumMethod(); i++ {
   298  		meth := t.Method(i)
   299  		name := meth.Name
   300  		matches := r.FindStringSubmatch(name)
   301  		if matches == nil || !isTarget(matches[1]) {
   302  			continue
   303  		}
   304  		if meth.Type.NumIn() != 2 {
   305  			return nil, fmt.Errorf("kocha: migrate: %v: arguments number must be 1", meth.Name)
   306  		}
   307  		argType := meth.Type.In(1)
   308  		tx := mig.findTransactioner(argType)
   309  		if tx == nil {
   310  			return nil, fmt.Errorf("kocha: migrate: argument type `%v' is undefined", argType)
   311  		}
   312  		minfos = append(minfos, migrationInfo{
   313  			methodName: name,
   314  			version:    matches[1],
   315  			tx:         tx,
   316  		})
   317  	}
   318  	return minfos, nil
   319  }
   320  
   321  func (mig *Migration) findTransactioner(t reflect.Type) Transactioner {
   322  	for _, tx := range TxTypeMap {
   323  		if t == reflect.TypeOf(tx.TransactionType()) {
   324  			return tx
   325  		}
   326  	}
   327  	return nil
   328  }
   329  
   330  // migrationInfo is an intermediate information of a migration.
   331  type migrationInfo struct {
   332  	methodName string
   333  	version    string
   334  	tx         Transactioner
   335  }
   336  
   337  // migrationInfoSlice implements sort.Interface interface.
   338  type migrationInfoSlice []migrationInfo
   339  
   340  // Len implements sort.Interface.Len.
   341  func (ms migrationInfoSlice) Len() int {
   342  	return len(ms)
   343  }
   344  
   345  // Less implements sort.Interface.Less.
   346  func (ms migrationInfoSlice) Less(i, j int) bool {
   347  	return ms[i].version < ms[j].version
   348  }
   349  
   350  // Swap implements sort.Interface.Swap.
   351  func (ms migrationInfoSlice) Swap(i, j int) {
   352  	ms[i], ms[j] = ms[j], ms[i]
   353  }