github.com/movio/dbmate@v1.4.1-0.20180723211402-1f8d6073f20c/pkg/dbmate/db.go (about)

     1  package dbmate
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/url"
     8  	"os"
     9  	"path/filepath"
    10  	"regexp"
    11  	"sort"
    12  	"time"
    13  )
    14  
    15  // DefaultMigrationsDir specifies default directory to find migration files
    16  const DefaultMigrationsDir = "./db/migrations"
    17  
    18  // DefaultSchemaFile specifies default location for schema.sql
    19  const DefaultSchemaFile = "./db/schema.sql"
    20  
    21  // DefaultWaitInterval specifies length of time between connection attempts
    22  const DefaultWaitInterval = time.Second
    23  
    24  // DefaultWaitTimeout specifies maximum time for connection attempts
    25  const DefaultWaitTimeout = 60 * time.Second
    26  
    27  // DB allows dbmate actions to be performed on a specified database
    28  type DB struct {
    29  	AutoDumpSchema bool
    30  	DatabaseURL    *url.URL
    31  	MigrationsDir  string
    32  	SchemaFile     string
    33  	WaitInterval   time.Duration
    34  	WaitTimeout    time.Duration
    35  }
    36  
    37  // New initializes a new dbmate database
    38  func New(databaseURL *url.URL) *DB {
    39  	return &DB{
    40  		AutoDumpSchema: true,
    41  		DatabaseURL:    databaseURL,
    42  		MigrationsDir:  DefaultMigrationsDir,
    43  		SchemaFile:     DefaultSchemaFile,
    44  		WaitInterval:   DefaultWaitInterval,
    45  		WaitTimeout:    DefaultWaitTimeout,
    46  	}
    47  }
    48  
    49  // GetDriver loads the required database driver
    50  func (db *DB) GetDriver() (Driver, error) {
    51  	return GetDriver(db.DatabaseURL.Scheme)
    52  }
    53  
    54  // Wait blocks until the database server is available. It does not verify that
    55  // the specified database exists, only that the host is ready to accept connections.
    56  func (db *DB) Wait() error {
    57  	drv, err := db.GetDriver()
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	// attempt connection to database server
    63  	err = drv.Ping(db.DatabaseURL)
    64  	if err == nil {
    65  		// connection successful
    66  		return nil
    67  	}
    68  
    69  	fmt.Print("Waiting for database")
    70  	for i := 0 * time.Second; i < db.WaitTimeout; i += db.WaitInterval {
    71  		fmt.Print(".")
    72  		time.Sleep(db.WaitInterval)
    73  
    74  		// attempt connection to database server
    75  		err = drv.Ping(db.DatabaseURL)
    76  		if err == nil {
    77  			// connection successful
    78  			fmt.Print("\n")
    79  			return nil
    80  		}
    81  	}
    82  
    83  	// if we find outselves here, we could not connect within the timeout
    84  	fmt.Print("\n")
    85  	return fmt.Errorf("unable to connect to database: %s", err)
    86  }
    87  
    88  // CreateAndMigrate creates the database (if necessary) and runs migrations
    89  func (db *DB) CreateAndMigrate() error {
    90  	drv, err := db.GetDriver()
    91  	if err != nil {
    92  		return err
    93  	}
    94  
    95  	// create database if it does not already exist
    96  	// skip this step if we cannot determine status
    97  	// (e.g. user does not have list database permission)
    98  	exists, err := drv.DatabaseExists(db.DatabaseURL)
    99  	if err == nil && !exists {
   100  		if err := drv.CreateDatabase(db.DatabaseURL); err != nil {
   101  			return err
   102  		}
   103  	}
   104  
   105  	// migrate
   106  	return db.Migrate()
   107  }
   108  
   109  // Create creates the current database
   110  func (db *DB) Create() error {
   111  	drv, err := db.GetDriver()
   112  	if err != nil {
   113  		return err
   114  	}
   115  
   116  	return drv.CreateDatabase(db.DatabaseURL)
   117  }
   118  
   119  // Drop drops the current database (if it exists)
   120  func (db *DB) Drop() error {
   121  	drv, err := db.GetDriver()
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	return drv.DropDatabase(db.DatabaseURL)
   127  }
   128  
   129  // DumpSchema writes the current database schema to a file
   130  func (db *DB) DumpSchema() error {
   131  	drv, sqlDB, err := db.openDatabaseForMigration()
   132  	if err != nil {
   133  		return err
   134  	}
   135  	defer mustClose(sqlDB)
   136  
   137  	schema, err := drv.DumpSchema(db.DatabaseURL, sqlDB)
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	fmt.Printf("Writing: %s\n", db.SchemaFile)
   143  
   144  	// ensure schema directory exists
   145  	if err = ensureDir(filepath.Dir(db.SchemaFile)); err != nil {
   146  		return err
   147  	}
   148  
   149  	// write schema to file
   150  	return ioutil.WriteFile(db.SchemaFile, schema, 0644)
   151  }
   152  
   153  const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
   154  
   155  // NewMigration creates a new migration file
   156  func (db *DB) NewMigration(name string) error {
   157  	// new migration name
   158  	timestamp := time.Now().UTC().Format("20060102150405")
   159  	if name == "" {
   160  		return fmt.Errorf("please specify a name for the new migration")
   161  	}
   162  	name = fmt.Sprintf("%s_%s.sql", timestamp, name)
   163  
   164  	// create migrations dir if missing
   165  	if err := ensureDir(db.MigrationsDir); err != nil {
   166  		return err
   167  	}
   168  
   169  	// check file does not already exist
   170  	path := filepath.Join(db.MigrationsDir, name)
   171  	fmt.Printf("Creating migration: %s\n", path)
   172  
   173  	if _, err := os.Stat(path); !os.IsNotExist(err) {
   174  		return fmt.Errorf("file already exists")
   175  	}
   176  
   177  	// write new migration
   178  	file, err := os.Create(path)
   179  	if err != nil {
   180  		return err
   181  	}
   182  
   183  	defer mustClose(file)
   184  	_, err = file.WriteString(migrationTemplate)
   185  	return err
   186  }
   187  
   188  func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
   189  	tx, err := db.Begin()
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	if err := txFunc(tx); err != nil {
   195  		if err1 := tx.Rollback(); err1 != nil {
   196  			return err1
   197  		}
   198  
   199  		return err
   200  	}
   201  
   202  	return tx.Commit()
   203  }
   204  
   205  func (db *DB) openDatabaseForMigration() (Driver, *sql.DB, error) {
   206  	drv, err := db.GetDriver()
   207  	if err != nil {
   208  		return nil, nil, err
   209  	}
   210  
   211  	sqlDB, err := drv.Open(db.DatabaseURL)
   212  	if err != nil {
   213  		return nil, nil, err
   214  	}
   215  
   216  	if err := drv.CreateMigrationsTable(sqlDB); err != nil {
   217  		mustClose(sqlDB)
   218  		return nil, nil, err
   219  	}
   220  
   221  	return drv, sqlDB, nil
   222  }
   223  
   224  // Migrate migrates database to the latest version
   225  func (db *DB) Migrate() error {
   226  	re := regexp.MustCompile(`^\d.*\.sql$`)
   227  	files, err := findMigrationFiles(db.MigrationsDir, re)
   228  	if err != nil {
   229  		return err
   230  	}
   231  
   232  	if len(files) == 0 {
   233  		return fmt.Errorf("no migration files found")
   234  	}
   235  
   236  	drv, sqlDB, err := db.openDatabaseForMigration()
   237  	if err != nil {
   238  		return err
   239  	}
   240  	defer mustClose(sqlDB)
   241  
   242  	applied, err := drv.SelectMigrations(sqlDB, -1)
   243  	if err != nil {
   244  		return err
   245  	}
   246  
   247  	for _, filename := range files {
   248  		ver := migrationVersion(filename)
   249  		if ok := applied[ver]; ok {
   250  			// migration already applied
   251  			continue
   252  		}
   253  
   254  		fmt.Printf("Applying: %s\n", filename)
   255  
   256  		migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
   257  		if err != nil {
   258  			return err
   259  		}
   260  
   261  		// begin transaction
   262  		err = doTransaction(sqlDB, func(tx Transaction) error {
   263  			// run actual migration
   264  			if _, err := tx.Exec(migration["up"]); err != nil {
   265  				return err
   266  			}
   267  
   268  			// record migration
   269  			return drv.InsertMigration(tx, ver)
   270  		})
   271  		if err != nil {
   272  			return err
   273  		}
   274  
   275  	}
   276  
   277  	// automatically update schema file, silence errors
   278  	if db.AutoDumpSchema {
   279  		_ = db.DumpSchema()
   280  	}
   281  
   282  	return nil
   283  }
   284  
   285  func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
   286  	files, err := ioutil.ReadDir(dir)
   287  	if err != nil {
   288  		return nil, fmt.Errorf("could not find migrations directory `%s`", dir)
   289  	}
   290  
   291  	matches := []string{}
   292  	for _, file := range files {
   293  		if file.IsDir() {
   294  			continue
   295  		}
   296  
   297  		name := file.Name()
   298  		if !re.MatchString(name) {
   299  			continue
   300  		}
   301  
   302  		matches = append(matches, name)
   303  	}
   304  
   305  	sort.Strings(matches)
   306  
   307  	return matches, nil
   308  }
   309  
   310  func findMigrationFile(dir string, ver string) (string, error) {
   311  	if ver == "" {
   312  		panic("migration version is required")
   313  	}
   314  
   315  	ver = regexp.QuoteMeta(ver)
   316  	re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
   317  
   318  	files, err := findMigrationFiles(dir, re)
   319  	if err != nil {
   320  		return "", err
   321  	}
   322  
   323  	if len(files) == 0 {
   324  		return "", fmt.Errorf("can't find migration file: %s*.sql", ver)
   325  	}
   326  
   327  	return files[0], nil
   328  }
   329  
   330  func migrationVersion(filename string) string {
   331  	return regexp.MustCompile(`^\d+`).FindString(filename)
   332  }
   333  
   334  // parseMigration reads a migration file into a map with up/down keys
   335  // implementation is similar to regexp.Split()
   336  func parseMigration(path string) (map[string]string, error) {
   337  	// read migration file into string
   338  	data, err := ioutil.ReadFile(path)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  	contents := string(data)
   343  
   344  	// split string on our trigger comment
   345  	separatorRegexp := regexp.MustCompile(`(?m)^-- migrate:(.*)$`)
   346  	matches := separatorRegexp.FindAllStringSubmatchIndex(contents, -1)
   347  
   348  	migrations := map[string]string{}
   349  	direction := ""
   350  	beg := 0
   351  	end := 0
   352  
   353  	for _, match := range matches {
   354  		end = match[0]
   355  		if direction != "" {
   356  			// write previous direction to output map
   357  			migrations[direction] = contents[beg:end]
   358  		}
   359  
   360  		// each match records the start of a new direction
   361  		direction = contents[match[2]:match[3]]
   362  		beg = match[1]
   363  	}
   364  
   365  	// write final direction to output map
   366  	migrations[direction] = contents[beg:]
   367  
   368  	return migrations, nil
   369  }
   370  
   371  // Rollback rolls back the most recent migration
   372  func (db *DB) Rollback() error {
   373  	drv, sqlDB, err := db.openDatabaseForMigration()
   374  	if err != nil {
   375  		return err
   376  	}
   377  	defer mustClose(sqlDB)
   378  
   379  	applied, err := drv.SelectMigrations(sqlDB, 1)
   380  	if err != nil {
   381  		return err
   382  	}
   383  
   384  	// grab most recent applied migration (applied has len=1)
   385  	latest := ""
   386  	for ver := range applied {
   387  		latest = ver
   388  	}
   389  	if latest == "" {
   390  		return fmt.Errorf("can't rollback: no migrations have been applied")
   391  	}
   392  
   393  	filename, err := findMigrationFile(db.MigrationsDir, latest)
   394  	if err != nil {
   395  		return err
   396  	}
   397  
   398  	fmt.Printf("Rolling back: %s\n", filename)
   399  
   400  	migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
   401  	if err != nil {
   402  		return err
   403  	}
   404  
   405  	// begin transaction
   406  	err = doTransaction(sqlDB, func(tx Transaction) error {
   407  		// rollback migration
   408  		if _, err := tx.Exec(migration["down"]); err != nil {
   409  			return err
   410  		}
   411  
   412  		// remove migration record
   413  		return drv.DeleteMigration(tx, latest)
   414  	})
   415  	if err != nil {
   416  		return err
   417  	}
   418  
   419  	// automatically update schema file, silence errors
   420  	if db.AutoDumpSchema {
   421  		_ = db.DumpSchema()
   422  	}
   423  
   424  	return nil
   425  }