github.com/woremacx/kocha@v0.7.1-0.20150731103243-a5889322afc9/db.go (about) 1 package kocha 2 3 import ( 4 "database/sql" 5 "fmt" 6 "math" 7 "os" 8 "reflect" 9 "regexp" 10 "sort" 11 12 "github.com/naoina/genmai" 13 ) 14 15 type DatabaseMap map[string]DatabaseConfig 16 17 // DatabaseConfig represents a configuration of the database. 18 type DatabaseConfig struct { 19 // name of database driver such as "mysql". 20 Driver string 21 22 // Data Source Name. 23 // e.g. such as "travis@/db_name". 24 DSN string 25 } 26 27 var TxTypeMap = make(map[string]Transactioner) 28 29 // Transactioner is an interface for a transaction type. 30 type Transactioner interface { 31 // ImportPath returns the import path to import to use a transaction. 32 // Usually, import path of ORM like "github.com/naoina/genmai". 33 ImportPath() string 34 35 // TransactionType returns an object of a transaction type. 36 // Return value will only used to determine an argument type for the 37 // methods of migration when generate the template. 38 TransactionType() interface{} 39 40 // Begin starts a transaction from driver name and data source name. 41 Begin(driverName, dsn string) (tx interface{}, err error) 42 43 // Commit commits the transaction. 44 Commit() error 45 46 // Rollback rollbacks the transaction. 47 Rollback() error 48 } 49 50 // RegisterTransactionType registers a transaction type. 51 // If already registered, it overwrites. 52 func RegisterTransactionType(name string, tx Transactioner) { 53 TxTypeMap[name] = tx 54 } 55 56 // GenmaiTransaction implements Transactioner interface. 57 type GenmaiTransaction struct { 58 tx *genmai.DB 59 } 60 61 // ImportPath returns the import path of Genmai. 62 func (t *GenmaiTransaction) ImportPath() string { 63 return "github.com/naoina/genmai" 64 } 65 66 // TransactionType returns the transaction type of Genmai. 67 func (t *GenmaiTransaction) TransactionType() interface{} { 68 return &genmai.DB{} 69 } 70 71 // Begin starts a transaction of Genmai. 72 // If unsupported driver name given or any error occurred, it returns nil and error. 73 func (t *GenmaiTransaction) Begin(driverName, dsn string) (tx interface{}, err error) { 74 var d genmai.Dialect 75 switch driverName { 76 case "mysql": 77 d = &genmai.MySQLDialect{} 78 case "postgres": 79 d = &genmai.PostgresDialect{} 80 case "sqlite3": 81 d = &genmai.SQLite3Dialect{} 82 default: 83 return nil, fmt.Errorf("kocha: migration: genmai: unsupported driver type `%v'", driverName) 84 } 85 t.tx, err = genmai.New(d, dsn) 86 if err != nil { 87 return nil, err 88 } 89 if err := t.tx.Begin(); err != nil { 90 return nil, err 91 } 92 return t.tx, nil 93 } 94 95 // Commit commits the transaction of Genmai. 96 func (t *GenmaiTransaction) Commit() error { 97 return t.tx.Commit() 98 } 99 100 // Rollback rollbacks the transaction of Genmai. 101 func (t *GenmaiTransaction) Rollback() error { 102 return t.tx.Rollback() 103 } 104 105 var ( 106 upMethodRegexp = regexp.MustCompile(`\AUp_(\d{14})_\w+\z`) 107 downMethodRegexp = regexp.MustCompile(`\ADown_(\d{14})_\w+\z`) 108 ) 109 110 const MigrationTableName = "schema_migration" 111 112 type Migration struct { 113 config DatabaseConfig 114 m interface{} 115 } 116 117 func Migrate(config DatabaseConfig, m interface{}) *Migration { 118 return &Migration{ 119 config: config, 120 m: m, 121 } 122 } 123 124 func (mig *Migration) Up(limit int) error { 125 if err := mig.transaction(func(tx *sql.Tx) error { 126 stmt, err := tx.Prepare(fmt.Sprintf( 127 `CREATE TABLE IF NOT EXISTS %s (version varchar(255) PRIMARY KEY)`, 128 MigrationTableName)) 129 if err != nil { 130 return err 131 } 132 defer stmt.Close() 133 if _, err := stmt.Exec(); err != nil { 134 return err 135 } 136 return nil 137 }); err != nil { 138 return err 139 } 140 var version string 141 if err := mig.transaction(func(tx *sql.Tx) error { 142 stmt, err := tx.Prepare(fmt.Sprintf(`SELECT version FROM %s ORDER BY version DESC LIMIT 1`, MigrationTableName)) 143 if err != nil { 144 return err 145 } 146 defer stmt.Close() 147 if err := stmt.QueryRow().Scan(&version); err != nil && err != sql.ErrNoRows { 148 return err 149 } 150 return nil 151 }); err != nil { 152 return err 153 } 154 minfos, err := mig.collectInfos(upMethodRegexp, func(p string) bool { 155 return p > version 156 }) 157 if err != nil { 158 return err 159 } 160 limit = int(math.Min(float64(limit), float64(len(minfos)))) 161 if limit < 0 { 162 limit = len(minfos) 163 } 164 if len(minfos[:limit]) < 1 { 165 fmt.Fprintf(os.Stderr, "kocha: migrate: there is no need to migrate.\n") 166 return nil 167 } 168 sort.Sort(migrationInfoSlice(minfos)) 169 return mig.run("migrating", minfos[:limit], func(version string) { 170 if err := mig.transaction(func(tx *sql.Tx) error { 171 stmt, err := tx.Prepare(fmt.Sprintf(`INSERT INTO %s (version) VALUES (?)`, MigrationTableName)) 172 if err != nil { 173 return err 174 } 175 defer stmt.Close() 176 if _, err := stmt.Exec(version); err != nil { 177 return err 178 } 179 return nil 180 }); err != nil { 181 panic(err) 182 } 183 }) 184 } 185 186 func (mig *Migration) Down(limit int) error { 187 var positions []string 188 if err := mig.transaction(func(tx *sql.Tx) error { 189 stmt, err := tx.Prepare(fmt.Sprintf(`SELECT version FROM %s ORDER BY version DESC LIMIT ?`, MigrationTableName)) 190 if err != nil { 191 return err 192 } 193 defer stmt.Close() 194 if limit < 1 { 195 limit = 1 196 } 197 rows, err := stmt.Query(limit) 198 if err != nil { 199 return err 200 } 201 defer rows.Close() 202 for rows.Next() { 203 var version string 204 if err := rows.Scan(&version); err != nil { 205 return err 206 } 207 positions = append(positions, version) 208 } 209 return rows.Err() 210 }); err != nil { 211 return err 212 } 213 minfos, err := mig.collectInfos(downMethodRegexp, func(p string) bool { 214 for _, version := range positions { 215 if p == version { 216 return true 217 } 218 } 219 return false 220 }) 221 if err != nil { 222 return err 223 } 224 if len(minfos) < 1 { 225 fmt.Fprintf(os.Stderr, "kocha: migrate: there is no need to migrate.\n") 226 return nil 227 } 228 sort.Sort(sort.Reverse(migrationInfoSlice(minfos))) 229 return mig.run("rollback", minfos, func(version string) { 230 if err := mig.transaction(func(tx *sql.Tx) error { 231 stmt, err := tx.Prepare(fmt.Sprintf(`DELETE FROM %s WHERE version = ?`, MigrationTableName)) 232 if err != nil { 233 return err 234 } 235 defer stmt.Close() 236 if _, err := stmt.Exec(version); err != nil { 237 return err 238 } 239 return nil 240 }); err != nil { 241 panic(err) 242 } 243 }) 244 } 245 246 func (mig *Migration) transaction(f func(tx *sql.Tx) error) error { 247 db, err := sql.Open(mig.config.Driver, mig.config.DSN) 248 if err != nil { 249 return err 250 } 251 defer db.Close() 252 tx, err := db.Begin() 253 if err != nil { 254 return err 255 } 256 defer func() { 257 if err := recover(); err != nil { 258 tx.Rollback() 259 panic(err) 260 } 261 tx.Commit() 262 }() 263 if err := f(tx); err != nil { 264 return tx.Rollback() 265 } 266 return tx.Commit() 267 } 268 269 func (mig *Migration) run(msg string, minfos []migrationInfo, afterFunc func(version string)) error { 270 v := reflect.ValueOf(mig.m) 271 for _, mi := range minfos { 272 func(mi migrationInfo) { 273 tx, err := mi.tx.Begin(mig.config.Driver, mig.config.DSN) 274 if err != nil { 275 panic(err) 276 } 277 defer func() { 278 if err := recover(); err != nil { 279 mi.tx.Rollback() 280 panic(err) 281 } 282 mi.tx.Commit() 283 }() 284 fmt.Printf("%v by %v...\n", msg, mi.methodName) 285 meth := v.MethodByName(mi.methodName) 286 meth.Call([]reflect.Value{reflect.ValueOf(tx)}) 287 }(mi) 288 afterFunc(mi.version) 289 } 290 return nil 291 } 292 293 func (mig *Migration) collectInfos(r *regexp.Regexp, isTarget func(string) bool) ([]migrationInfo, error) { 294 v := reflect.ValueOf(mig.m) 295 t := v.Type() 296 var minfos []migrationInfo 297 for i := 0; i < t.NumMethod(); i++ { 298 meth := t.Method(i) 299 name := meth.Name 300 matches := r.FindStringSubmatch(name) 301 if matches == nil || !isTarget(matches[1]) { 302 continue 303 } 304 if meth.Type.NumIn() != 2 { 305 return nil, fmt.Errorf("kocha: migrate: %v: arguments number must be 1", meth.Name) 306 } 307 argType := meth.Type.In(1) 308 tx := mig.findTransactioner(argType) 309 if tx == nil { 310 return nil, fmt.Errorf("kocha: migrate: argument type `%v' is undefined", argType) 311 } 312 minfos = append(minfos, migrationInfo{ 313 methodName: name, 314 version: matches[1], 315 tx: tx, 316 }) 317 } 318 return minfos, nil 319 } 320 321 func (mig *Migration) findTransactioner(t reflect.Type) Transactioner { 322 for _, tx := range TxTypeMap { 323 if t == reflect.TypeOf(tx.TransactionType()) { 324 return tx 325 } 326 } 327 return nil 328 } 329 330 // migrationInfo is an intermediate information of a migration. 331 type migrationInfo struct { 332 methodName string 333 version string 334 tx Transactioner 335 } 336 337 // migrationInfoSlice implements sort.Interface interface. 338 type migrationInfoSlice []migrationInfo 339 340 // Len implements sort.Interface.Len. 341 func (ms migrationInfoSlice) Len() int { 342 return len(ms) 343 } 344 345 // Less implements sort.Interface.Less. 346 func (ms migrationInfoSlice) Less(i, j int) bool { 347 return ms[i].version < ms[j].version 348 } 349 350 // Swap implements sort.Interface.Swap. 351 func (ms migrationInfoSlice) Swap(i, j int) { 352 ms[i], ms[j] = ms[j], ms[i] 353 }