github.com/gocaveman/caveman@v0.0.0-20191211162744-0ddf99dbdf6e/migrate/migrate.go (about) 1 // Manage and apply/unapply database schema changes, grouped by category and db driver. 2 package migrate 3 4 import ( 5 "bufio" 6 "bytes" 7 "database/sql" 8 "fmt" 9 "io" 10 "log" 11 "net/http" 12 "path" 13 "sort" 14 "strings" 15 "text/template" 16 ) 17 18 // OpenFunc is the function to use to "open" connections. It defaults to 19 // sql.Open but you can set it to something else to intercept the call. 20 var OpenFunc func(driverName, dataSourceName string) (*sql.DB, error) = sql.Open 21 22 // CloseFunc is used to close a database instead of directly calling sql.DB.Close(). 23 // It defaults to a func that just calls sql.DB.Close() but you can set it to 24 // something else as needed to match your custom OpenFunc. 25 var CloseFunc func(*sql.DB) error = func(db *sql.DB) error { 26 return db.Close() 27 } 28 29 // multiple named migrations 30 // must be able to test migrations 31 // in a cluster it must not explode when mulitiple servers try to migrate at the same time 32 33 // registry? how do the various components ensure their migrations get done 34 35 // what about "minimum veresion required for this code to run" 36 37 // Versioner interface is implemented by things that record which version is applied to a database and category. 38 type Versioner interface { 39 Categories() ([]string, error) 40 Version(category string) (string, error) 41 StartVersionChange(category, curVersionName string) error 42 EndVersionChange(category, newVersionName string) error 43 } 44 45 // NewRunner creates a Runner. 46 func NewRunner(driverName, dsn string, versioner Versioner, migrations MigrationList) *Runner { 47 return &Runner{ 48 DriverName: driverName, 49 DSN: dsn, 50 Versioner: versioner, 51 Migrations: migrations, 52 } 53 } 54 55 // Runner is used to apply migration changes. 56 type Runner struct { 57 DriverName string 58 DSN string 59 Versioner 60 Migrations MigrationList 61 } 62 63 // FIXME: Runner should also be able to tell us if there are outstanding migrations 64 // that need to be run. In local environments we'll probably just call RunAllUpToLatest() 65 // at startup, but for staging and production we'll instead just check at startup 66 // and show messages if migrations are needed. 67 68 // CheckResult is a list of CheckResultItem 69 type CheckResult []CheckResultItem 70 71 // CheckResultItem gives the current and latest versions for a specific driver/dsn/category. 72 type CheckResultItem struct { 73 DriverName string 74 DSN string 75 Category string 76 CurrentVersion string 77 LatestVersion string 78 } 79 80 // IsCurrent returns true if latest version is current version. 81 func (i CheckResultItem) IsCurrent() bool { 82 return i.CurrentVersion == i.LatestVersion 83 } 84 85 // CheckAll checks all categories and compares the current version to the latest version 86 // and returns the results. If you pass true then all results will be returned, otherwise 87 // only results where the version is not current will be returned. 88 func (r *Runner) CheckAll(returnAll bool) (CheckResult, error) { 89 90 var res CheckResult 91 92 ms := r.Migrations.WithDriverName(r.DriverName) 93 94 cats := ms.Categories() 95 for _, cat := range cats { 96 msc := ms.WithCategory(cat).Sorted() 97 latestVersion := "" 98 if len(msc) > 0 { 99 latestVersion = msc[len(msc)-1].Version() 100 } 101 currentVersion, err := r.Versioner.Version(cat) 102 if err != nil { 103 return nil, err 104 } 105 res = append(res, CheckResultItem{ 106 DriverName: r.DriverName, 107 DSN: r.DSN, 108 Category: cat, 109 CurrentVersion: currentVersion, 110 LatestVersion: latestVersion, 111 }) 112 } 113 114 if returnAll { 115 return res, nil 116 } 117 118 var res2 CheckResult 119 for _, item := range res { 120 if !item.IsCurrent() { 121 res2 = append(res2, item) 122 } 123 } 124 125 return res2, nil 126 127 } 128 129 // RunAllUpToLatest runs all migrations for all categories up to the latest version. 130 func (r *Runner) RunAllUpToLatest() error { 131 cats := r.Migrations.Categories() 132 for _, cat := range cats { 133 err := r.RunUpToLatest(cat) 134 if err != nil { 135 return err 136 } 137 } 138 return nil 139 } 140 141 // RunUpToLatest runs all migrations for a specific category up to the latest. 142 func (r *Runner) RunUpToLatest(category string) error { 143 144 mc := r.Migrations.WithDriverName(r.DriverName).WithCategory(category).Sorted() 145 146 if len(mc) == 0 { 147 return nil 148 } 149 150 ver := mc[len(mc)-1].Version() 151 152 return r.RunUpTo(category, ver) 153 } 154 155 // RunTo will run the migrations for a speific category up or down to a specific version. 156 func (r *Runner) RunTo(category, targetVersion string) error { 157 158 if targetVersion == "" { 159 return fmt.Errorf("RunTo with empty target version not allowed, call RunDownTo() explicitly") 160 } 161 162 mc := r.Migrations.WithDriverName(r.DriverName).WithCategory(category).Sorted() 163 curVer, err := r.Versioner.Version(category) 164 if err != nil { 165 return err 166 } 167 168 if curVer == targetVersion { 169 return nil 170 } 171 172 // empty cur ver always means it's an up 173 if curVer == "" { 174 return r.RunUpTo(category, targetVersion) 175 } 176 177 curIdx := -1 178 tgtIdx := -1 179 for i := range mc { 180 if mc[i].Version() == curVer { 181 curIdx = i 182 } 183 if mc[i].Version() == targetVersion { 184 tgtIdx = i 185 } 186 } 187 188 if curIdx < 0 { 189 return fmt.Errorf("current version %q not found in category %q", curVer, category) 190 } 191 if tgtIdx < 0 { 192 return fmt.Errorf("target version %q not found in category %q", targetVersion, category) 193 } 194 195 if curIdx > tgtIdx { 196 return r.RunDownTo(category, targetVersion) 197 } 198 199 return r.RunUpTo(category, targetVersion) 200 } 201 202 // RunUpTo runs migrations up to a specific version. Will only run up, will error if 203 // this version is lower than the current one. 204 func (r *Runner) RunUpTo(category, targetVersion string) error { 205 206 curVer, err := r.Versioner.Version(category) 207 if err != nil { 208 return err 209 } 210 ml := r.Migrations.WithDriverName(r.DriverName).WithCategory(category).Sorted() 211 if len(ml) == 0 { 212 return nil 213 } 214 215 if !ml.HasVersion(targetVersion) { 216 return fmt.Errorf("version %q not found", targetVersion) 217 } 218 219 active := curVer == "" // start active if empty current version 220 for _, m := range ml { 221 222 if curVer == m.Version() { 223 active = true 224 continue 225 } 226 227 if !active { 228 continue 229 } 230 231 err := r.Versioner.StartVersionChange(category, curVer) 232 if err != nil { 233 return err 234 } 235 236 err = m.ExecUp(r.DSN) 237 if err != nil { 238 // try to revert the version 239 err2 := r.Versioner.EndVersionChange(category, curVer) 240 if err2 != nil { // just log the error in this case, so the orignal error is preserved 241 log.Printf("EndVersionChange returned error: %v", err2) 242 } 243 return fmt.Errorf("ExecUp(%q) error: %v", r.DSN, err) 244 } 245 246 // Update version to the migration we just ran. 247 // NOTE: This will leave things in an inconsistent state if it errors but nothing we can do... 248 err = r.Versioner.EndVersionChange(category, m.Version()) 249 if err != nil { 250 return err 251 } 252 253 curVer = m.Version() 254 255 if targetVersion == m.Version() { 256 break 257 } 258 259 } 260 261 return nil 262 263 } 264 265 // RunUpTo runs migrations down to a specific version. Will only run down, will error if 266 // this version is higher than the current one. 267 func (r *Runner) RunDownTo(category, targetVersion string) error { 268 269 // log.Printf("RunDownTo %q %q", category, targetVersion) 270 271 curVer, err := r.Versioner.Version(category) 272 if err != nil { 273 return err 274 } 275 ml := r.Migrations.WithDriverName(r.DriverName).WithCategory(category).Sorted() 276 sort.Sort(sort.Reverse(ml)) 277 if len(ml) == 0 { 278 return nil 279 } 280 281 if targetVersion != "" && !ml.HasVersion(targetVersion) { 282 return fmt.Errorf("version %q not found", targetVersion) 283 } 284 285 active := false 286 for mlidx, m := range ml { 287 288 // check for target version, in which case we're done 289 if targetVersion == m.Version() { 290 break 291 } 292 293 // if we're on current version, mark active and continue 294 if curVer == m.Version() { 295 active = true 296 } 297 298 if !active { 299 continue 300 } 301 302 err := r.Versioner.StartVersionChange(category, curVer) 303 if err != nil { 304 return err 305 } 306 307 err = m.ExecDown(r.DSN) 308 if err != nil { 309 // try to revert the version 310 err2 := r.Versioner.EndVersionChange(category, curVer) 311 if err2 != nil { // just log the error in this case, so the orignal error is preserved 312 log.Printf("EndVersionChange returned error: %v", err2) 313 } 314 return fmt.Errorf("ExecDown(%q) error: %v", r.DSN, err) 315 } 316 317 // Update version to the NEXT migration in the sequence or empty string if at the end 318 nextLowerVersion := "" 319 if mlidx+1 < len(ml) { 320 nextLowerVersion = ml[mlidx+1].Version() 321 } 322 323 // NOTE: This will leave things in an inconsistent state if it errors but nothing we can do... 324 err = r.Versioner.EndVersionChange(category, nextLowerVersion) 325 if err != nil { 326 return err 327 } 328 329 curVer = m.Version() 330 331 } 332 333 return nil 334 } 335 336 // Migration represents a driver name, category and version and functionality to perform an 337 // "up" and "down" to and from this version. See SQLMigration and FuncsMigration for implementations. 338 type Migration interface { 339 DriverName() string 340 Category() string 341 Version() string 342 ExecUp(dsn string) error 343 ExecDown(dsn string) error 344 } 345 346 // MigrationList is a slice of Migration 347 type MigrationList []Migration 348 349 func (p MigrationList) String() string { 350 var buf bytes.Buffer 351 buf.WriteString("[") 352 for _, m := range p { 353 fmt.Fprintf(&buf, `{"type":%q,"driverName":%q,"category":%q,"version":%q},`, 354 fmt.Sprintf("%T", m), m.DriverName(), m.Category(), m.Version()) 355 } 356 if len(p) > 0 { 357 buf.Truncate(buf.Len() - 1) // remove trailing comma 358 } 359 buf.WriteString("]") 360 return buf.String() 361 } 362 363 func (p MigrationList) Len() int { return len(p) } 364 func (p MigrationList) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 365 func (p MigrationList) Less(i, j int) bool { 366 mli := p[i] 367 mlj := p[j] 368 return mli.DriverName()+" | "+mli.Category()+"|"+mli.Version() < 369 mlj.DriverName()+" | "+mlj.Category()+"|"+mlj.Version() 370 371 } 372 373 // HasVersion returns true if this has has an item with this version name. 374 func (ml MigrationList) HasVersion(ver string) bool { 375 vers := ml.Versions() 376 for _, v := range vers { 377 if v == ver { 378 return true 379 } 380 } 381 return false 382 } 383 384 // Categories returns a unique list of categories from these Migrations. 385 func (ml MigrationList) Categories() []string { 386 var ret []string 387 catMap := make(map[string]bool) 388 for _, m := range ml { 389 cat := m.Category() 390 if !catMap[cat] { 391 catMap[cat] = true 392 ret = append(ret, cat) 393 } 394 } 395 sort.Strings(ret) 396 return ret 397 } 398 399 // Versions returns a unique list of versions from these Migrations. 400 func (ml MigrationList) Versions() []string { 401 var ret []string 402 verMap := make(map[string]bool) 403 for _, m := range ml { 404 ver := m.Version() 405 if !verMap[ver] { 406 verMap[ver] = true 407 ret = append(ret, ver) 408 } 409 } 410 sort.Strings(ret) 411 return ret 412 } 413 414 // WithDriverName returns a new list filtered to only include migrations with the specified driver name. 415 func (ml MigrationList) WithDriverName(driverName string) MigrationList { 416 417 var ret MigrationList 418 419 for _, m := range ml { 420 if m.DriverName() == driverName { 421 ret = append(ret, m) 422 } 423 } 424 425 return ret 426 } 427 428 // WithCategory returns a new list filtered to only include migrations with the specified category. 429 func (ml MigrationList) WithCategory(category string) MigrationList { 430 431 var ret MigrationList 432 433 for _, m := range ml { 434 if m.Category() == category { 435 ret = append(ret, m) 436 } 437 } 438 439 return ret 440 } 441 442 // ExcludeCategory returns a new MigrationList without records for the specified category. 443 func (ml MigrationList) ExcludeCategory(category string) MigrationList { 444 var ret MigrationList 445 for _, m := range ml { 446 if m.Category() != category { 447 ret = append(ret, m) 448 } 449 } 450 return ret 451 } 452 453 // Sorted returns a sorted copy of the list. Sequence is by driver, category and then version. 454 func (ml MigrationList) Sorted() MigrationList { 455 456 ml2 := make(MigrationList, len(ml)) 457 copy(ml2, ml) 458 459 sort.Sort(ml2) 460 461 return ml2 462 463 } 464 465 // LoadSQLMigrationsHFS is like LoadMigrations but loads from an http.FileSystem, so you can control the file source. 466 func LoadSQLMigrationsHFS(hfs http.FileSystem, dir string) (MigrationList, error) { 467 468 f, err := hfs.Open(dir) 469 if err != nil { 470 return nil, err 471 } 472 defer f.Close() 473 474 fis, err := f.Readdir(-1) 475 if err != nil { 476 return nil, err 477 } 478 479 migMap := make(map[string]*SQLMigration) 480 481 for _, fi := range fis { 482 483 // skip dirs 484 if fi.IsDir() { 485 continue 486 } 487 488 fname := path.Base(fi.Name()) 489 490 // skip anything not a .sql file 491 if path.Ext(fname) != ".sql" { 492 continue 493 } 494 495 parts := strings.Split(strings.TrimSuffix(fname, ".sql"), "-") 496 if !(len(parts) == 4 && (parts[3] == "up" || parts[3] == "down")) { 497 return nil, fmt.Errorf("LoadSQLMigrationsHFS(hfs=%v, dir=%q): filename %q is wrong format, expected exactly 4 parts separated by dashes and the last part must be 'up' or 'down'", hfs, dir, fname) 498 } 499 500 key := strings.Join(parts[:3], "-") 501 502 // check for existing migration so we can fill in either up or down 503 var sqlMigration *SQLMigration 504 if migMap[key] != nil { 505 sqlMigration = migMap[key] 506 } else { 507 sqlMigration = &SQLMigration{ 508 DriverNameValue: parts[0], 509 CategoryValue: parts[1], 510 VersionValue: parts[2], 511 } 512 } 513 migMap[key] = sqlMigration 514 515 // figure out up/down part 516 var stmts *[]string 517 if parts[3] == "up" { 518 519 if sqlMigration.UpSQL != nil { 520 return nil, fmt.Errorf("LoadSQLMigrationsHFS(hfs=%v, dir=%q): filename %q - more than one up migration found", hfs, dir, fname) 521 } 522 523 stmts = &sqlMigration.UpSQL 524 525 } else { // down 526 527 if sqlMigration.DownSQL != nil { 528 return nil, fmt.Errorf("LoadSQLMigrationsHFS(hfs=%v, dir=%q): filename %q - more than one down migration found", hfs, dir, fname) 529 } 530 531 stmts = &sqlMigration.DownSQL 532 533 } 534 535 f, err := hfs.Open(path.Join(dir, fname)) 536 if err != nil { 537 return nil, err 538 } 539 defer f.Close() 540 541 // make sure it's non-nil 542 (*stmts) = make([]string, 0) 543 544 r := bufio.NewReader(f) 545 var thisStmt bytes.Buffer 546 for { 547 line, err := r.ReadBytes('\n') 548 if err == io.EOF { 549 break 550 } 551 if err != nil { 552 return nil, fmt.Errorf("LoadSQLMigrationsHFS(hfs=%v, dir=%q): filename %q error reading file: %v", hfs, dir, fname, err) 553 } 554 thisStmt.Write(line) 555 556 // check for end of statement 557 if bytes.HasSuffix(bytes.TrimSpace(line), []byte(";")) { 558 thisStmtStr := thisStmt.String() 559 if len(strings.TrimSpace(thisStmtStr)) > 0 { 560 (*stmts) = append((*stmts), thisStmtStr) 561 } 562 thisStmt.Truncate(0) 563 } 564 565 } 566 567 thisStmtStr := thisStmt.String() 568 if len(strings.TrimSpace(thisStmtStr)) > 0 { 569 (*stmts) = append((*stmts), thisStmtStr) 570 } 571 572 } 573 574 var ret MigrationList 575 for _, m := range migMap { 576 ret = append(ret, m) 577 } 578 ret = ret.Sorted() 579 580 return ret, nil 581 } 582 583 // LoadSQLMigrations loads migrations from the specified directory. File names are 584 // expected to be in exactly four parts each separated with a dash and have a .sql extension: 585 // `driver-category-version-up.sql` is the format for up migrations, the corresponding down 586 // migration is the same but with 'down' instead of 'up'. Another example: 587 // 588 // mysql-users-2017120301_create-up.sql 589 // 590 // mysql-users-2017120301_create-down.sql 591 // 592 // Both the up and down files must be present for a migration or an error will be returned. 593 // Files are plain text with SQL in them. Each line that ends with a semicolon (ignoring whitespace after) 594 // will be treated as a separate SQL statement. 595 func LoadSQLMigrations(dir string) (MigrationList, error) { 596 return LoadSQLMigrationsHFS(http.Dir(dir), "/") 597 } 598 599 // SQLMigration implements Migration with a simple slice of SQL strings for the up and down migration steps. 600 // Common migrations which are just one or more static SQL statements can be implemented easily using SQLMigration. 601 type SQLMigration struct { 602 DriverNameValue string 603 CategoryValue string 604 VersionValue string 605 UpSQL []string 606 DownSQL []string 607 } 608 609 // NewWithDriverName as a convenience returns a copy with DriverNameValue set to the specified value. 610 func (m *SQLMigration) NewWithDriverName(driverName string) *SQLMigration { 611 ret := *m 612 ret.DriverNameValue = driverName 613 return &ret 614 } 615 616 func (m *SQLMigration) DriverName() string { return m.DriverNameValue } 617 func (m *SQLMigration) Category() string { return m.CategoryValue } 618 func (m *SQLMigration) Version() string { return m.VersionValue } 619 620 func (m *SQLMigration) exec(dsn string, stmts []string) error { 621 622 db, err := OpenFunc(m.DriverNameValue, dsn) 623 if err != nil { 624 return err 625 } 626 defer CloseFunc(db) 627 628 for n, s := range stmts { 629 _, err := db.Exec(s) 630 if err != nil { 631 return fmt.Errorf("SQLMigration (driverName=%q, category=%q, version=%q, stmtidx=%d) Exec on dsn=%q failed with error: %v\nSQL Statement:\n%s", 632 m.DriverNameValue, m.CategoryValue, m.VersionValue, n, dsn, err, s) 633 } 634 } 635 636 return nil 637 } 638 639 func (m *SQLMigration) ExecUp(dsn string) error { 640 return m.exec(dsn, m.UpSQL) 641 } 642 643 func (m *SQLMigration) ExecDown(dsn string) error { 644 return m.exec(dsn, m.DownSQL) 645 } 646 647 // SQLTmplMigration implements Migration with a simple slice of strings which are 648 // interpreted as Go templates with SQL as the up and down migration steps. 649 // This allows you to customize the SQL with things like table prefixes. 650 // Template are executed using text/template and the SQLTmplMigration instance 651 // is passed as the data to the Execute() call. 652 type SQLTmplMigration struct { 653 DriverNameValue string 654 CategoryValue string 655 VersionValue string 656 UpSQL []string 657 DownSQL []string 658 659 // a common reason to use SQLTmplMigration is be able to configure the table prefix 660 TablePrefix string `autowire:"db.TablePrefix,optional"` 661 // other custom data needed by the template(s) can go here 662 Data interface{} 663 } 664 665 // NewWithDriverName as a convenience returns a copy with DriverNameValue set to the specified value. 666 func (m *SQLTmplMigration) NewWithDriverName(driverName string) *SQLTmplMigration { 667 ret := *m 668 ret.DriverNameValue = driverName 669 return &ret 670 } 671 672 func (m *SQLTmplMigration) DriverName() string { return m.DriverNameValue } 673 func (m *SQLTmplMigration) Category() string { return m.CategoryValue } 674 func (m *SQLTmplMigration) Version() string { return m.VersionValue } 675 676 func (m *SQLTmplMigration) tmplExec(dsn string, stmts []string) error { 677 678 db, err := OpenFunc(m.DriverNameValue, dsn) 679 if err != nil { 680 return err 681 } 682 defer CloseFunc(db) 683 684 for n, s := range stmts { 685 686 t := template.New("sql") 687 t, err := t.Parse(s) 688 if err != nil { 689 return fmt.Errorf("SQLTmplMigration (driverName=%q, category=%q, version=%q, stmtidx=%d) template parse on dsn=%q failed with error: %v\nSQL Statement:\n%s", 690 m.DriverNameValue, m.CategoryValue, m.VersionValue, n, dsn, err, s) 691 } 692 693 var buf bytes.Buffer 694 err = t.Execute(&buf, m) 695 if err != nil { 696 return fmt.Errorf("SQLTmplMigration (driverName=%q, category=%q, version=%q, stmtidx=%d) template execute on dsn=%q failed with error: %v\nSQL Statement:\n%s", 697 m.DriverNameValue, m.CategoryValue, m.VersionValue, n, dsn, err, s) 698 } 699 700 newS := buf.String() 701 702 _, err = db.Exec(newS) 703 if err != nil { 704 return fmt.Errorf("SQLTmplMigration (driverName=%q, category=%q, version=%q, stmtidx=%d) Exec on dsn=%q failed with error: %v\nSQL Statement:\n%s", 705 m.DriverNameValue, m.CategoryValue, m.VersionValue, n, dsn, err, newS) 706 } 707 } 708 709 return nil 710 } 711 712 func (m *SQLTmplMigration) ExecUp(dsn string) error { 713 return m.tmplExec(dsn, m.UpSQL) 714 } 715 716 func (m *SQLTmplMigration) ExecDown(dsn string) error { 717 return m.tmplExec(dsn, m.DownSQL) 718 } 719 720 // NewFuncsMigration makes and returns a new FuncsMigration pointer with the data you provide. 721 func NewFuncsMigration(driverName, category, version string, upFunc, downFunc MigrationFunc) *FuncsMigration { 722 return &FuncsMigration{ 723 DriverNameValue: driverName, 724 CategoryValue: category, 725 VersionValue: version, 726 UpFunc: upFunc, 727 DownFunc: downFunc, 728 } 729 } 730 731 type MigrationFunc func(driverName, dsn string) error 732 733 // FuncsMigration is a Migration implementation that simply has up and down migration functions. 734 type FuncsMigration struct { 735 DriverNameValue string 736 CategoryValue string 737 VersionValue string 738 UpFunc MigrationFunc 739 DownFunc MigrationFunc 740 } 741 742 func (m *FuncsMigration) DriverName() string { return m.DriverNameValue } 743 func (m *FuncsMigration) Category() string { return m.CategoryValue } 744 func (m *FuncsMigration) Version() string { return m.VersionValue } 745 func (m *FuncsMigration) ExecUp(dsn string) error { 746 return m.UpFunc(m.DriverNameValue, dsn) 747 } 748 func (m *FuncsMigration) ExecDown(dsn string) error { 749 return m.DownFunc(m.DriverNameValue, dsn) 750 }