github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/dbmate/db_test.go (about)

     1  package dbmate_test
     2  
     3  import (
     4  	"net/url"
     5  	"os"
     6  	"path/filepath"
     7  	"testing"
     8  	"testing/fstest"
     9  	"time"
    10  
    11  	"github.com/amacneil/dbmate/pkg/dbmate"
    12  	"github.com/amacneil/dbmate/pkg/dbutil"
    13  	_ "github.com/amacneil/dbmate/pkg/driver/mysql"
    14  	_ "github.com/amacneil/dbmate/pkg/driver/postgres"
    15  	_ "github.com/amacneil/dbmate/pkg/driver/sqlite"
    16  
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/zenizh/go-capturer"
    19  )
    20  
    21  var testdataDir string
    22  
    23  func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
    24  	var err error
    25  
    26  	// only chdir once, because testdata is relative to current directory
    27  	if testdataDir == "" {
    28  		testdataDir, err = filepath.Abs("../../testdata")
    29  		require.NoError(t, err)
    30  
    31  		err = os.Chdir(testdataDir)
    32  		require.NoError(t, err)
    33  	}
    34  
    35  	db := dbmate.New(u)
    36  	db.AutoDumpSchema = false
    37  
    38  	return db
    39  }
    40  
    41  func TestNew(t *testing.T) {
    42  	db := dbmate.New(dbutil.MustParseURL("foo:test"))
    43  	require.True(t, db.AutoDumpSchema)
    44  	require.Equal(t, "foo:test", db.DatabaseURL.String())
    45  	require.Equal(t, "./db/migrations", db.MigrationsDir)
    46  	require.Equal(t, "schema_migrations", db.MigrationsTableName)
    47  	require.Equal(t, "./db/schema.sql", db.SchemaFile)
    48  	require.False(t, db.WaitBefore)
    49  	require.Equal(t, time.Second, db.WaitInterval)
    50  	require.Equal(t, 60*time.Second, db.WaitTimeout)
    51  }
    52  
    53  func TestGetDriver(t *testing.T) {
    54  	t.Run("missing URL", func(t *testing.T) {
    55  		db := dbmate.New(nil)
    56  		drv, err := db.Driver()
    57  		require.Nil(t, drv)
    58  		require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
    59  	})
    60  
    61  	t.Run("missing schema", func(t *testing.T) {
    62  		db := dbmate.New(dbutil.MustParseURL("//hi"))
    63  		drv, err := db.Driver()
    64  		require.Nil(t, drv)
    65  		require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
    66  	})
    67  
    68  	t.Run("invalid driver", func(t *testing.T) {
    69  		db := dbmate.New(dbutil.MustParseURL("foo://bar"))
    70  		drv, err := db.Driver()
    71  		require.EqualError(t, err, "unsupported driver: foo")
    72  		require.Nil(t, drv)
    73  	})
    74  }
    75  
    76  func TestWait(t *testing.T) {
    77  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
    78  	db := newTestDB(t, u)
    79  
    80  	// speed up our retry loop for testing
    81  	db.WaitInterval = time.Millisecond
    82  	db.WaitTimeout = 5 * time.Millisecond
    83  
    84  	// drop database
    85  	err := db.Drop()
    86  	require.NoError(t, err)
    87  
    88  	// test wait
    89  	err = db.Wait()
    90  	require.NoError(t, err)
    91  
    92  	// test invalid connection
    93  	u.Host = "postgres:404"
    94  	err = db.Wait()
    95  	require.Error(t, err)
    96  	require.Contains(t, err.Error(), "unable to connect to database: dial tcp")
    97  	require.Contains(t, err.Error(), "connect: connection refused")
    98  }
    99  
   100  func TestDumpSchema(t *testing.T) {
   101  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   102  	db := newTestDB(t, u)
   103  
   104  	// create custom schema file directory
   105  	dir, err := os.MkdirTemp("", "dbmate")
   106  	require.NoError(t, err)
   107  	defer func() {
   108  		err := os.RemoveAll(dir)
   109  		require.NoError(t, err)
   110  	}()
   111  
   112  	// create schema.sql in subdirectory to test creating directory
   113  	db.SchemaFile = filepath.Join(dir, "/schema/schema.sql")
   114  
   115  	// drop database
   116  	err = db.Drop()
   117  	require.NoError(t, err)
   118  
   119  	// create and migrate
   120  	err = db.CreateAndMigrate()
   121  	require.NoError(t, err)
   122  
   123  	// schema.sql should not exist
   124  	_, err = os.Stat(db.SchemaFile)
   125  	require.True(t, os.IsNotExist(err))
   126  
   127  	// dump schema
   128  	err = db.DumpSchema()
   129  	require.NoError(t, err)
   130  
   131  	// verify schema
   132  	schema, err := os.ReadFile(db.SchemaFile)
   133  	require.NoError(t, err)
   134  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   135  }
   136  
   137  func TestAutoDumpSchema(t *testing.T) {
   138  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   139  	db := newTestDB(t, u)
   140  	db.AutoDumpSchema = true
   141  
   142  	// create custom schema file directory
   143  	dir, err := os.MkdirTemp("", "dbmate")
   144  	require.NoError(t, err)
   145  	defer func() {
   146  		err := os.RemoveAll(dir)
   147  		require.NoError(t, err)
   148  	}()
   149  
   150  	// create schema.sql in subdirectory to test creating directory
   151  	db.SchemaFile = filepath.Join(dir, "/schema/schema.sql")
   152  
   153  	// drop database
   154  	err = db.Drop()
   155  	require.NoError(t, err)
   156  
   157  	// schema.sql should not exist
   158  	_, err = os.Stat(db.SchemaFile)
   159  	require.True(t, os.IsNotExist(err))
   160  
   161  	// create and migrate
   162  	err = db.CreateAndMigrate()
   163  	require.NoError(t, err)
   164  
   165  	// verify schema
   166  	schema, err := os.ReadFile(db.SchemaFile)
   167  	require.NoError(t, err)
   168  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   169  
   170  	// remove schema
   171  	err = os.Remove(db.SchemaFile)
   172  	require.NoError(t, err)
   173  
   174  	// rollback
   175  	err = db.Rollback()
   176  	require.NoError(t, err)
   177  
   178  	// schema should be recreated
   179  	schema, err = os.ReadFile(db.SchemaFile)
   180  	require.NoError(t, err)
   181  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   182  }
   183  
   184  func checkWaitCalled(t *testing.T, u *url.URL, command func() error) {
   185  	oldHost := u.Host
   186  	u.Host = "postgres:404"
   187  	err := command()
   188  	require.Error(t, err)
   189  	require.Contains(t, err.Error(), "unable to connect to database: dial tcp")
   190  	require.Contains(t, err.Error(), "connect: connection refused")
   191  	u.Host = oldHost
   192  }
   193  
   194  func testWaitBefore(t *testing.T, verbose bool) {
   195  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   196  	db := newTestDB(t, u)
   197  	db.Verbose = verbose
   198  	db.WaitBefore = true
   199  	// so that checkWaitCalled returns quickly
   200  	db.WaitInterval = time.Millisecond
   201  	db.WaitTimeout = 5 * time.Millisecond
   202  
   203  	// drop database
   204  	err := db.Drop()
   205  	require.NoError(t, err)
   206  	checkWaitCalled(t, u, db.Drop)
   207  
   208  	// create
   209  	err = db.Create()
   210  	require.NoError(t, err)
   211  	checkWaitCalled(t, u, db.Create)
   212  
   213  	// create and migrate
   214  	err = db.CreateAndMigrate()
   215  	require.NoError(t, err)
   216  	checkWaitCalled(t, u, db.CreateAndMigrate)
   217  
   218  	// migrate
   219  	err = db.Migrate()
   220  	require.NoError(t, err)
   221  	checkWaitCalled(t, u, db.Migrate)
   222  
   223  	// rollback
   224  	err = db.Rollback()
   225  	require.NoError(t, err)
   226  	checkWaitCalled(t, u, db.Rollback)
   227  
   228  	// dump
   229  	err = db.DumpSchema()
   230  	require.NoError(t, err)
   231  	checkWaitCalled(t, u, db.DumpSchema)
   232  }
   233  
   234  func TestWaitBefore(t *testing.T) {
   235  	testWaitBefore(t, false)
   236  }
   237  
   238  func TestWaitBeforeVerbose(t *testing.T) {
   239  	output := capturer.CaptureOutput(func() {
   240  		testWaitBefore(t, true)
   241  	})
   242  	require.Contains(t, output,
   243  		`Applying: 20151129054053_test_migration.sql
   244  Rows affected: 1
   245  Applying: 20200227231541_test_posts.sql
   246  Rows affected: 0`)
   247  	require.Contains(t, output,
   248  		`Rolling back: 20200227231541_test_posts.sql
   249  Rows affected: 0`)
   250  }
   251  
   252  func testURLs() []*url.URL {
   253  	return []*url.URL{
   254  		dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL")),
   255  		dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL")),
   256  		dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL")),
   257  	}
   258  }
   259  
   260  func TestMigrate(t *testing.T) {
   261  	for _, u := range testURLs() {
   262  		t.Run(u.Scheme, func(t *testing.T) {
   263  			db := newTestDB(t, u)
   264  			drv, err := db.Driver()
   265  			require.NoError(t, err)
   266  
   267  			// drop and recreate database
   268  			err = db.Drop()
   269  			require.NoError(t, err)
   270  			err = db.Create()
   271  			require.NoError(t, err)
   272  
   273  			// migrate
   274  			err = db.Migrate()
   275  			require.NoError(t, err)
   276  
   277  			// verify results
   278  			sqlDB, err := drv.Open()
   279  			require.NoError(t, err)
   280  			defer dbutil.MustClose(sqlDB)
   281  
   282  			count := 0
   283  			err = sqlDB.QueryRow(`select count(*) from schema_migrations
   284  				where version = '20151129054053'`).Scan(&count)
   285  			require.NoError(t, err)
   286  			require.Equal(t, 1, count)
   287  
   288  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   289  			require.NoError(t, err)
   290  			require.Equal(t, 1, count)
   291  		})
   292  	}
   293  }
   294  
   295  func TestUp(t *testing.T) {
   296  	for _, u := range testURLs() {
   297  		t.Run(u.Scheme, func(t *testing.T) {
   298  			db := newTestDB(t, u)
   299  			drv, err := db.Driver()
   300  			require.NoError(t, err)
   301  
   302  			// drop database
   303  			err = db.Drop()
   304  			require.NoError(t, err)
   305  
   306  			// create and migrate
   307  			err = db.CreateAndMigrate()
   308  			require.NoError(t, err)
   309  
   310  			// verify results
   311  			sqlDB, err := drv.Open()
   312  			require.NoError(t, err)
   313  			defer dbutil.MustClose(sqlDB)
   314  
   315  			count := 0
   316  			err = sqlDB.QueryRow(`select count(*) from schema_migrations
   317  				where version = '20151129054053'`).Scan(&count)
   318  			require.NoError(t, err)
   319  			require.Equal(t, 1, count)
   320  
   321  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   322  			require.NoError(t, err)
   323  			require.Equal(t, 1, count)
   324  		})
   325  	}
   326  }
   327  
   328  func TestRollback(t *testing.T) {
   329  	for _, u := range testURLs() {
   330  		t.Run(u.Scheme, func(t *testing.T) {
   331  			db := newTestDB(t, u)
   332  			drv, err := db.Driver()
   333  			require.NoError(t, err)
   334  
   335  			// drop and create database
   336  			err = db.Drop()
   337  			require.NoError(t, err)
   338  			err = db.Create()
   339  			require.NoError(t, err)
   340  
   341  			// rollback should return error
   342  			err = db.Rollback()
   343  			require.Error(t, err)
   344  			require.ErrorContains(t, err, "can't rollback: no migrations have been applied")
   345  
   346  			// migrate database
   347  			err = db.Migrate()
   348  			require.NoError(t, err)
   349  
   350  			// verify migration
   351  			sqlDB, err := drv.Open()
   352  			require.NoError(t, err)
   353  			defer dbutil.MustClose(sqlDB)
   354  
   355  			count := 0
   356  			err = sqlDB.QueryRow(`select count(*) from schema_migrations
   357  				where version = '20151129054053'`).Scan(&count)
   358  			require.NoError(t, err)
   359  			require.Equal(t, 1, count)
   360  
   361  			err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
   362  			require.Nil(t, err)
   363  
   364  			// rollback
   365  			err = db.Rollback()
   366  			require.NoError(t, err)
   367  
   368  			// verify rollback
   369  			err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
   370  			require.NoError(t, err)
   371  			require.Equal(t, 1, count)
   372  
   373  			err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
   374  			require.NotNil(t, err)
   375  			require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
   376  		})
   377  	}
   378  }
   379  
   380  func TestFindMigrations(t *testing.T) {
   381  	for _, u := range testURLs() {
   382  		t.Run(u.Scheme, func(t *testing.T) {
   383  			db := newTestDB(t, u)
   384  			drv, err := db.Driver()
   385  			require.NoError(t, err)
   386  
   387  			// drop, recreate, and migrate database
   388  			err = db.Drop()
   389  			require.NoError(t, err)
   390  			err = db.Create()
   391  			require.NoError(t, err)
   392  
   393  			// verify migration
   394  			sqlDB, err := drv.Open()
   395  			require.NoError(t, err)
   396  			defer dbutil.MustClose(sqlDB)
   397  
   398  			// two pending
   399  			results, err := db.FindMigrations()
   400  			require.NoError(t, err)
   401  			require.Len(t, results, 2)
   402  			require.False(t, results[0].Applied)
   403  			require.False(t, results[1].Applied)
   404  			migrationsTableExists, err := drv.MigrationsTableExists(sqlDB)
   405  			require.NoError(t, err)
   406  			require.False(t, migrationsTableExists)
   407  
   408  			// run migrations
   409  			err = db.Migrate()
   410  			require.NoError(t, err)
   411  
   412  			// two applied
   413  			results, err = db.FindMigrations()
   414  			require.NoError(t, err)
   415  			require.Len(t, results, 2)
   416  			require.True(t, results[0].Applied)
   417  			require.True(t, results[1].Applied)
   418  
   419  			// rollback last migration
   420  			err = db.Rollback()
   421  			require.NoError(t, err)
   422  
   423  			// one applied, one pending
   424  			results, err = db.FindMigrations()
   425  			require.NoError(t, err)
   426  			require.Len(t, results, 2)
   427  			require.True(t, results[0].Applied)
   428  			require.False(t, results[1].Applied)
   429  		})
   430  	}
   431  }
   432  
   433  func TestFindMigrationsFS(t *testing.T) {
   434  	mapFS := fstest.MapFS{
   435  		"db/migrations/20151129054053_test_migration.sql": {},
   436  		"db/migrations/001_test_migration.sql": {
   437  			Data: []byte(`-- migrate:up
   438  create table users (id serial, name text);
   439  -- migrate:down
   440  drop table users;
   441  `),
   442  		},
   443  		"db/migrations/002_test_migration.sql":                {},
   444  		"db/migrations/003_not_sql.txt":                       {},
   445  		"db/migrations/missing_version.sql":                   {},
   446  		"db/not_migrations/20151129054053_test_migration.sql": {},
   447  	}
   448  
   449  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   450  	db := newTestDB(t, u)
   451  	db.FS = mapFS
   452  
   453  	// drop and recreate database
   454  	err := db.Drop()
   455  	require.NoError(t, err)
   456  	err = db.Create()
   457  	require.NoError(t, err)
   458  
   459  	actual, err := db.FindMigrations()
   460  	require.NoError(t, err)
   461  
   462  	// test migrations are correct and in order
   463  	require.Equal(t, "001_test_migration.sql", actual[0].FileName)
   464  	require.Equal(t, "db/migrations/001_test_migration.sql", actual[0].FilePath)
   465  	require.Equal(t, "001", actual[0].Version)
   466  	require.Equal(t, false, actual[0].Applied)
   467  
   468  	require.Equal(t, "002_test_migration.sql", actual[1].FileName)
   469  	require.Equal(t, "db/migrations/002_test_migration.sql", actual[1].FilePath)
   470  	require.Equal(t, "002", actual[1].Version)
   471  	require.Equal(t, false, actual[1].Applied)
   472  
   473  	require.Equal(t, "20151129054053_test_migration.sql", actual[2].FileName)
   474  	require.Equal(t, "db/migrations/20151129054053_test_migration.sql", actual[2].FilePath)
   475  	require.Equal(t, "20151129054053", actual[2].Version)
   476  	require.Equal(t, false, actual[2].Applied)
   477  
   478  	// test parsing first migration
   479  	parsed, err := actual[0].Parse()
   480  	require.Nil(t, err)
   481  	require.Equal(t, "-- migrate:up\ncreate table users (id serial, name text);\n", parsed.Up)
   482  	require.True(t, parsed.UpOptions.Transaction())
   483  	require.Equal(t, "-- migrate:down\ndrop table users;\n", parsed.Down)
   484  	require.True(t, parsed.DownOptions.Transaction())
   485  }