github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/dbmate/db.go (about)

     1  package dbmate
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"net/url"
    10  	"os"
    11  	"path/filepath"
    12  	"regexp"
    13  	"sort"
    14  	"time"
    15  
    16  	"github.com/amacneil/dbmate/pkg/dbutil"
    17  )
    18  
    19  // Error codes
    20  var (
    21  	ErrNoMigrationFiles      = errors.New("no migration files found")
    22  	ErrInvalidURL            = errors.New("invalid url, have you set your --url flag or DATABASE_URL environment variable?")
    23  	ErrNoRollback            = errors.New("can't rollback: no migrations have been applied")
    24  	ErrCantConnect           = errors.New("unable to connect to database")
    25  	ErrUnsupportedDriver     = errors.New("unsupported driver")
    26  	ErrNoMigrationName       = errors.New("please specify a name for the new migration")
    27  	ErrMigrationAlreadyExist = errors.New("file already exists")
    28  	ErrMigrationDirNotFound  = errors.New("could not find migrations directory")
    29  	ErrMigrationNotFound     = errors.New("can't find migration file")
    30  	ErrCreateDirectory       = errors.New("unable to create directory")
    31  )
    32  
    33  // migrationFileRegexp pattern for valid migration files
    34  var migrationFileRegexp = regexp.MustCompile(`^(\d+).*\.sql$`)
    35  
    36  // DB allows dbmate actions to be performed on a specified database
    37  type DB struct {
    38  	// AutoDumpSchema generates schema.sql after each action
    39  	AutoDumpSchema bool
    40  	// DatabaseURL is the database connection string
    41  	DatabaseURL *url.URL
    42  	// FS allows overriding the filesystem
    43  	FS fs.FS
    44  	// Log is the interface to write stdout
    45  	Log io.Writer
    46  	// MigrationsDir specifies the directory to find migration files
    47  	MigrationsDir string
    48  	// MigrationsTableName specifies the database table to record migrations in
    49  	MigrationsTableName string
    50  	// SchemaFile specifies the location for schema.sql file
    51  	SchemaFile string
    52  	// Verbose prints the result of each statement execution
    53  	Verbose bool
    54  	// WaitBefore will wait for database to become available before running any actions
    55  	WaitBefore bool
    56  	// WaitInterval specifies length of time between connection attempts
    57  	WaitInterval time.Duration
    58  	// WaitTimeout specifies maximum time for connection attempts
    59  	WaitTimeout time.Duration
    60  }
    61  
    62  // StatusResult represents an available migration status
    63  type StatusResult struct {
    64  	Filename string
    65  	Applied  bool
    66  }
    67  
    68  // New initializes a new dbmate database
    69  func New(databaseURL *url.URL) *DB {
    70  	return &DB{
    71  		AutoDumpSchema:      true,
    72  		DatabaseURL:         databaseURL,
    73  		FS:                  os.DirFS("."),
    74  		Log:                 os.Stdout,
    75  		MigrationsDir:       "./db/migrations",
    76  		MigrationsTableName: "schema_migrations",
    77  		SchemaFile:          "./db/schema.sql",
    78  		Verbose:             false,
    79  		WaitBefore:          false,
    80  		WaitInterval:        time.Second,
    81  		WaitTimeout:         60 * time.Second,
    82  	}
    83  }
    84  
    85  // Driver initializes the appropriate database driver
    86  func (db *DB) Driver() (Driver, error) {
    87  	if db.DatabaseURL == nil || db.DatabaseURL.Scheme == "" {
    88  		return nil, ErrInvalidURL
    89  	}
    90  
    91  	driverFunc := drivers[db.DatabaseURL.Scheme]
    92  	if driverFunc == nil {
    93  		return nil, fmt.Errorf("%w: %s", ErrUnsupportedDriver, db.DatabaseURL.Scheme)
    94  	}
    95  
    96  	config := DriverConfig{
    97  		DatabaseURL:         db.DatabaseURL,
    98  		Log:                 db.Log,
    99  		MigrationsTableName: db.MigrationsTableName,
   100  	}
   101  	drv := driverFunc(config)
   102  
   103  	if db.WaitBefore {
   104  		if err := db.wait(drv); err != nil {
   105  			return nil, err
   106  		}
   107  	}
   108  
   109  	return drv, nil
   110  }
   111  
   112  func (db *DB) wait(drv Driver) error {
   113  	// attempt connection to database server
   114  	err := drv.Ping()
   115  	if err == nil {
   116  		// connection successful
   117  		return nil
   118  	}
   119  
   120  	fmt.Fprint(db.Log, "Waiting for database")
   121  	for i := 0 * time.Second; i < db.WaitTimeout; i += db.WaitInterval {
   122  		fmt.Fprint(db.Log, ".")
   123  		time.Sleep(db.WaitInterval)
   124  
   125  		// attempt connection to database server
   126  		err = drv.Ping()
   127  		if err == nil {
   128  			// connection successful
   129  			fmt.Fprint(db.Log, "\n")
   130  			return nil
   131  		}
   132  	}
   133  
   134  	// if we find outselves here, we could not connect within the timeout
   135  	fmt.Fprint(db.Log, "\n")
   136  	return fmt.Errorf("%w: %s", ErrCantConnect, err)
   137  }
   138  
   139  // Wait blocks until the database server is available. It does not verify that
   140  // the specified database exists, only that the host is ready to accept connections.
   141  func (db *DB) Wait() error {
   142  	drv, err := db.Driver()
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	// if db.WaitBefore is true, wait() will get called twice, no harm
   148  	return db.wait(drv)
   149  }
   150  
   151  // CreateAndMigrate creates the database (if necessary) and runs migrations
   152  func (db *DB) CreateAndMigrate() error {
   153  	drv, err := db.Driver()
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	// create database if it does not already exist
   159  	// skip this step if we cannot determine status
   160  	// (e.g. user does not have list database permission)
   161  	exists, err := drv.DatabaseExists()
   162  	if err == nil && !exists {
   163  		if err := drv.CreateDatabase(); err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	// migrate
   169  	return db.Migrate()
   170  }
   171  
   172  // Create creates the current database
   173  func (db *DB) Create() error {
   174  	drv, err := db.Driver()
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	return drv.CreateDatabase()
   180  }
   181  
   182  // Drop drops the current database (if it exists)
   183  func (db *DB) Drop() error {
   184  	drv, err := db.Driver()
   185  	if err != nil {
   186  		return err
   187  	}
   188  
   189  	return drv.DropDatabase()
   190  }
   191  
   192  // DumpSchema writes the current database schema to a file
   193  func (db *DB) DumpSchema() error {
   194  	drv, err := db.Driver()
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	sqlDB, err := db.openDatabaseForMigration(drv)
   200  	if err != nil {
   201  		return err
   202  	}
   203  	defer dbutil.MustClose(sqlDB)
   204  
   205  	schema, err := drv.DumpSchema(sqlDB)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	fmt.Fprintf(db.Log, "Writing: %s\n", db.SchemaFile)
   211  
   212  	// ensure schema directory exists
   213  	if err = ensureDir(filepath.Dir(db.SchemaFile)); err != nil {
   214  		return err
   215  	}
   216  
   217  	// write schema to file
   218  	return os.WriteFile(db.SchemaFile, schema, 0o644)
   219  }
   220  
   221  // ensureDir creates a directory if it does not already exist
   222  func ensureDir(dir string) error {
   223  	if err := os.MkdirAll(dir, 0o755); err != nil {
   224  		return fmt.Errorf("%w `%s`", ErrCreateDirectory, dir)
   225  	}
   226  
   227  	return nil
   228  }
   229  
   230  const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
   231  
   232  // NewMigration creates a new migration file
   233  func (db *DB) NewMigration(name string) error {
   234  	// new migration name
   235  	timestamp := time.Now().UTC().Format("20060102150405")
   236  	if name == "" {
   237  		return ErrNoMigrationName
   238  	}
   239  	name = fmt.Sprintf("%s_%s.sql", timestamp, name)
   240  
   241  	// create migrations dir if missing
   242  	if err := ensureDir(db.MigrationsDir); err != nil {
   243  		return err
   244  	}
   245  
   246  	// check file does not already exist
   247  	path := filepath.Join(db.MigrationsDir, name)
   248  	fmt.Fprintf(db.Log, "Creating migration: %s\n", path)
   249  
   250  	if _, err := os.Stat(path); !os.IsNotExist(err) {
   251  		return ErrMigrationAlreadyExist
   252  	}
   253  
   254  	// write new migration
   255  	file, err := os.Create(path)
   256  	if err != nil {
   257  		return err
   258  	}
   259  
   260  	defer dbutil.MustClose(file)
   261  	_, err = file.WriteString(migrationTemplate)
   262  	return err
   263  }
   264  
   265  func doTransaction(sqlDB *sql.DB, txFunc func(dbutil.Transaction) error) error {
   266  	tx, err := sqlDB.Begin()
   267  	if err != nil {
   268  		return err
   269  	}
   270  
   271  	if err := txFunc(tx); err != nil {
   272  		if err1 := tx.Rollback(); err1 != nil {
   273  			return err1
   274  		}
   275  
   276  		return err
   277  	}
   278  
   279  	return tx.Commit()
   280  }
   281  
   282  func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {
   283  	sqlDB, err := drv.Open()
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  
   288  	if err := drv.CreateMigrationsTable(sqlDB); err != nil {
   289  		dbutil.MustClose(sqlDB)
   290  		return nil, err
   291  	}
   292  
   293  	return sqlDB, nil
   294  }
   295  
   296  // Migrate migrates database to the latest version
   297  func (db *DB) Migrate() error {
   298  	drv, err := db.Driver()
   299  	if err != nil {
   300  		return err
   301  	}
   302  
   303  	migrations, err := db.FindMigrations()
   304  	if err != nil {
   305  		return err
   306  	}
   307  
   308  	if len(migrations) == 0 {
   309  		return ErrNoMigrationFiles
   310  	}
   311  
   312  	sqlDB, err := db.openDatabaseForMigration(drv)
   313  	if err != nil {
   314  		return err
   315  	}
   316  	defer dbutil.MustClose(sqlDB)
   317  
   318  	for _, migration := range migrations {
   319  		if migration.Applied {
   320  			continue
   321  		}
   322  
   323  		fmt.Fprintf(db.Log, "Applying: %s\n", migration.FileName)
   324  
   325  		parsed, err := migration.Parse()
   326  		if err != nil {
   327  			return err
   328  		}
   329  
   330  		execMigration := func(tx dbutil.Transaction) error {
   331  			// run actual migration
   332  			result, err := tx.Exec(parsed.Up)
   333  			if err != nil {
   334  				return err
   335  			} else if db.Verbose {
   336  				db.printVerbose(result)
   337  			}
   338  
   339  			// record migration
   340  			return drv.InsertMigration(tx, migration.Version)
   341  		}
   342  
   343  		if parsed.UpOptions.Transaction() {
   344  			// begin transaction
   345  			err = doTransaction(sqlDB, execMigration)
   346  		} else {
   347  			// run outside of transaction
   348  			err = execMigration(sqlDB)
   349  		}
   350  
   351  		if err != nil {
   352  			return err
   353  		}
   354  	}
   355  
   356  	// automatically update schema file, silence errors
   357  	if db.AutoDumpSchema {
   358  		_ = db.DumpSchema()
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  func (db *DB) printVerbose(result sql.Result) {
   365  	lastInsertID, err := result.LastInsertId()
   366  	if err == nil {
   367  		fmt.Fprintf(db.Log, "Last insert ID: %d\n", lastInsertID)
   368  	}
   369  	rowsAffected, err := result.RowsAffected()
   370  	if err == nil {
   371  		fmt.Fprintf(db.Log, "Rows affected: %d\n", rowsAffected)
   372  	}
   373  }
   374  
   375  // FindMigrations lists all available migrations
   376  func (db *DB) FindMigrations() ([]Migration, error) {
   377  	drv, err := db.Driver()
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	sqlDB, err := drv.Open()
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	defer dbutil.MustClose(sqlDB)
   387  
   388  	// find applied migrations
   389  	appliedMigrations := map[string]bool{}
   390  	migrationsTableExists, err := drv.MigrationsTableExists(sqlDB)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  
   395  	if migrationsTableExists {
   396  		appliedMigrations, err = drv.SelectMigrations(sqlDB, -1)
   397  		if err != nil {
   398  			return nil, err
   399  		}
   400  	}
   401  
   402  	// find filesystem migrations
   403  	files, err := fs.ReadDir(db.FS, filepath.Clean(db.MigrationsDir))
   404  	if err != nil {
   405  		return nil, fmt.Errorf("%w `%s`", ErrMigrationDirNotFound, db.MigrationsDir)
   406  	}
   407  
   408  	migrations := []Migration{}
   409  	for _, file := range files {
   410  		if file.IsDir() {
   411  			continue
   412  		}
   413  
   414  		matches := migrationFileRegexp.FindStringSubmatch(file.Name())
   415  		if len(matches) < 2 {
   416  			continue
   417  		}
   418  
   419  		migration := Migration{
   420  			Applied:  false,
   421  			FileName: matches[0],
   422  			FilePath: filepath.Join(db.MigrationsDir, matches[0]),
   423  			FS:       db.FS,
   424  			Version:  matches[1],
   425  		}
   426  		if ok := appliedMigrations[migration.Version]; ok {
   427  			migration.Applied = true
   428  		}
   429  
   430  		migrations = append(migrations, migration)
   431  	}
   432  
   433  	sort.Slice(migrations, func(i, j int) bool {
   434  		return migrations[i].FileName < migrations[j].FileName
   435  	})
   436  
   437  	return migrations, nil
   438  }
   439  
   440  // Rollback rolls back the most recent migration
   441  func (db *DB) Rollback() error {
   442  	drv, err := db.Driver()
   443  	if err != nil {
   444  		return err
   445  	}
   446  
   447  	sqlDB, err := db.openDatabaseForMigration(drv)
   448  	if err != nil {
   449  		return err
   450  	}
   451  	defer dbutil.MustClose(sqlDB)
   452  
   453  	// find last applied migration
   454  	var latest *Migration
   455  	migrations, err := db.FindMigrations()
   456  	if err != nil {
   457  		return err
   458  	}
   459  
   460  	for _, migration := range migrations {
   461  		if migration.Applied {
   462  			latest = &migration
   463  		}
   464  	}
   465  
   466  	if latest == nil {
   467  		return ErrNoRollback
   468  	}
   469  
   470  	fmt.Fprintf(db.Log, "Rolling back: %s\n", latest.FileName)
   471  
   472  	parsed, err := latest.Parse()
   473  	if err != nil {
   474  		return err
   475  	}
   476  
   477  	execMigration := func(tx dbutil.Transaction) error {
   478  		// rollback migration
   479  		result, err := tx.Exec(parsed.Down)
   480  		if err != nil {
   481  			return err
   482  		} else if db.Verbose {
   483  			db.printVerbose(result)
   484  		}
   485  
   486  		// remove migration record
   487  		return drv.DeleteMigration(tx, latest.Version)
   488  	}
   489  
   490  	if parsed.DownOptions.Transaction() {
   491  		// begin transaction
   492  		err = doTransaction(sqlDB, execMigration)
   493  	} else {
   494  		// run outside of transaction
   495  		err = execMigration(sqlDB)
   496  	}
   497  
   498  	if err != nil {
   499  		return err
   500  	}
   501  
   502  	// automatically update schema file, silence errors
   503  	if db.AutoDumpSchema {
   504  		_ = db.DumpSchema()
   505  	}
   506  
   507  	return nil
   508  }
   509  
   510  // Status shows the status of all migrations
   511  func (db *DB) Status(quiet bool) (int, error) {
   512  	results, err := db.FindMigrations()
   513  	if err != nil {
   514  		return -1, err
   515  	}
   516  
   517  	var totalApplied int
   518  	var line string
   519  
   520  	for _, res := range results {
   521  		if res.Applied {
   522  			line = fmt.Sprintf("[X] %s", res.FileName)
   523  			totalApplied++
   524  		} else {
   525  			line = fmt.Sprintf("[ ] %s", res.FileName)
   526  		}
   527  		if !quiet {
   528  			fmt.Fprintln(db.Log, line)
   529  		}
   530  	}
   531  
   532  	totalPending := len(results) - totalApplied
   533  	if !quiet {
   534  		fmt.Fprintln(db.Log)
   535  		fmt.Fprintf(db.Log, "Applied: %d\n", totalApplied)
   536  		fmt.Fprintf(db.Log, "Pending: %d\n", totalPending)
   537  	}
   538  
   539  	return totalPending, nil
   540  }