github.com/kaichaosun/dbmate@v0.0.3/pkg/dbmate/db_test.go (about)

     1  package dbmate_test
     2  
     3  import (
     4  	"net/url"
     5  	"os"
     6  	"path/filepath"
     7  	"testing"
     8  	"testing/fstest"
     9  	"time"
    10  
    11  	"github.com/kaichaosun/dbmate/pkg/dbmate"
    12  	"github.com/kaichaosun/dbmate/pkg/dbutil"
    13  	_ "github.com/kaichaosun/dbmate/pkg/driver/mysql"
    14  	_ "github.com/kaichaosun/dbmate/pkg/driver/postgres"
    15  	_ "github.com/kaichaosun/dbmate/pkg/driver/sqlite"
    16  
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/zenizh/go-capturer"
    19  )
    20  
    21  var rootDir string
    22  
    23  func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
    24  	var err error
    25  
    26  	// find root directory relative to current directory
    27  	if rootDir == "" {
    28  		rootDir, err = filepath.Abs("../..")
    29  		require.NoError(t, err)
    30  	}
    31  
    32  	err = os.Chdir(rootDir + "/testdata")
    33  	require.NoError(t, err)
    34  
    35  	db := dbmate.New(u)
    36  	db.AutoDumpSchema = false
    37  
    38  	return db
    39  }
    40  
    41  func TestNew(t *testing.T) {
    42  	db := dbmate.New(dbutil.MustParseURL("foo:test"))
    43  	require.True(t, db.AutoDumpSchema)
    44  	require.Equal(t, "foo:test", db.DatabaseURL.String())
    45  	require.Equal(t, []string{"./db/migrations"}, db.MigrationsDir)
    46  	require.Equal(t, "schema_migrations", db.MigrationsTableName)
    47  	require.Equal(t, "./db/schema.sql", db.SchemaFile)
    48  	require.False(t, db.WaitBefore)
    49  	require.Equal(t, time.Second, db.WaitInterval)
    50  	require.Equal(t, 60*time.Second, db.WaitTimeout)
    51  }
    52  
    53  func TestGetDriver(t *testing.T) {
    54  	t.Run("missing URL", func(t *testing.T) {
    55  		db := dbmate.New(nil)
    56  		drv, err := db.Driver()
    57  		require.Nil(t, drv)
    58  		require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
    59  	})
    60  
    61  	t.Run("missing schema", func(t *testing.T) {
    62  		db := dbmate.New(dbutil.MustParseURL("//hi"))
    63  		drv, err := db.Driver()
    64  		require.Nil(t, drv)
    65  		require.EqualError(t, err, "invalid url, have you set your --url flag or DATABASE_URL environment variable?")
    66  	})
    67  
    68  	t.Run("invalid driver", func(t *testing.T) {
    69  		db := dbmate.New(dbutil.MustParseURL("foo://bar"))
    70  		drv, err := db.Driver()
    71  		require.EqualError(t, err, "unsupported driver: foo")
    72  		require.Nil(t, drv)
    73  	})
    74  }
    75  
    76  func TestWait(t *testing.T) {
    77  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
    78  	db := newTestDB(t, u)
    79  
    80  	// speed up our retry loop for testing
    81  	db.WaitInterval = time.Millisecond
    82  	db.WaitTimeout = 5 * time.Millisecond
    83  
    84  	// drop database
    85  	err := db.Drop()
    86  	require.NoError(t, err)
    87  
    88  	// test wait
    89  	err = db.Wait()
    90  	require.NoError(t, err)
    91  
    92  	// test invalid connection
    93  	u.Host = "postgres:404"
    94  	err = db.Wait()
    95  	require.Error(t, err)
    96  	require.Contains(t, err.Error(), "unable to connect to database: dial tcp")
    97  	require.Contains(t, err.Error(), "connect: connection refused")
    98  }
    99  
   100  func TestDumpSchema(t *testing.T) {
   101  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   102  	db := newTestDB(t, u)
   103  
   104  	// create custom schema file directory
   105  	dir, err := os.MkdirTemp("", "dbmate")
   106  	require.NoError(t, err)
   107  	defer os.RemoveAll(dir)
   108  
   109  	// create schema.sql in subdirectory to test creating directory
   110  	db.SchemaFile = filepath.Join(dir, "/schema/schema.sql")
   111  
   112  	// drop database
   113  	err = db.Drop()
   114  	require.NoError(t, err)
   115  
   116  	// create and migrate
   117  	err = db.CreateAndMigrate()
   118  	require.NoError(t, err)
   119  
   120  	// schema.sql should not exist
   121  	_, err = os.Stat(db.SchemaFile)
   122  	require.True(t, os.IsNotExist(err))
   123  
   124  	// dump schema
   125  	err = db.DumpSchema()
   126  	require.NoError(t, err)
   127  
   128  	// verify schema
   129  	schema, err := os.ReadFile(db.SchemaFile)
   130  	require.NoError(t, err)
   131  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   132  }
   133  
   134  func TestAutoDumpSchema(t *testing.T) {
   135  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   136  	db := newTestDB(t, u)
   137  	db.AutoDumpSchema = true
   138  
   139  	// create custom schema file directory
   140  	dir, err := os.MkdirTemp("", "dbmate")
   141  	require.NoError(t, err)
   142  	defer os.RemoveAll(dir)
   143  
   144  	// create schema.sql in subdirectory to test creating directory
   145  	db.SchemaFile = filepath.Join(dir, "/schema/schema.sql")
   146  
   147  	// drop database
   148  	err = db.Drop()
   149  	require.NoError(t, err)
   150  
   151  	// schema.sql should not exist
   152  	_, err = os.Stat(db.SchemaFile)
   153  	require.True(t, os.IsNotExist(err))
   154  
   155  	// create and migrate
   156  	err = db.CreateAndMigrate()
   157  	require.NoError(t, err)
   158  
   159  	// verify schema
   160  	schema, err := os.ReadFile(db.SchemaFile)
   161  	require.NoError(t, err)
   162  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   163  
   164  	// remove schema
   165  	err = os.Remove(db.SchemaFile)
   166  	require.NoError(t, err)
   167  
   168  	// rollback
   169  	err = db.Rollback()
   170  	require.NoError(t, err)
   171  
   172  	// schema should be recreated
   173  	schema, err = os.ReadFile(db.SchemaFile)
   174  	require.NoError(t, err)
   175  	require.Contains(t, string(schema), "-- PostgreSQL database dump")
   176  }
   177  
   178  func checkWaitCalled(t *testing.T, u *url.URL, command func() error) {
   179  	oldHost := u.Host
   180  	u.Host = "postgres:404"
   181  	err := command()
   182  	require.Error(t, err)
   183  	require.Contains(t, err.Error(), "unable to connect to database: dial tcp")
   184  	require.Contains(t, err.Error(), "connect: connection refused")
   185  	u.Host = oldHost
   186  }
   187  
   188  func testWaitBefore(t *testing.T, verbose bool) {
   189  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   190  	db := newTestDB(t, u)
   191  	db.Verbose = verbose
   192  	db.WaitBefore = true
   193  	// so that checkWaitCalled returns quickly
   194  	db.WaitInterval = time.Millisecond
   195  	db.WaitTimeout = 5 * time.Millisecond
   196  
   197  	// drop database
   198  	err := db.Drop()
   199  	require.NoError(t, err)
   200  	checkWaitCalled(t, u, db.Drop)
   201  
   202  	// create
   203  	err = db.Create()
   204  	require.NoError(t, err)
   205  	checkWaitCalled(t, u, db.Create)
   206  
   207  	// create and migrate
   208  	err = db.CreateAndMigrate()
   209  	require.NoError(t, err)
   210  	checkWaitCalled(t, u, db.CreateAndMigrate)
   211  
   212  	// migrate
   213  	err = db.Migrate()
   214  	require.NoError(t, err)
   215  	checkWaitCalled(t, u, db.Migrate)
   216  
   217  	// rollback
   218  	err = db.Rollback()
   219  	require.NoError(t, err)
   220  	checkWaitCalled(t, u, db.Rollback)
   221  
   222  	// dump
   223  	err = db.DumpSchema()
   224  	require.NoError(t, err)
   225  	checkWaitCalled(t, u, db.DumpSchema)
   226  }
   227  
   228  func TestWaitBefore(t *testing.T) {
   229  	testWaitBefore(t, false)
   230  }
   231  
   232  func TestWaitBeforeVerbose(t *testing.T) {
   233  	output := capturer.CaptureOutput(func() {
   234  		testWaitBefore(t, true)
   235  	})
   236  	require.Contains(t, output,
   237  		`Applying: 20151129054053_test_migration.sql
   238  Rows affected: 1
   239  Applying: 20200227231541_test_posts.sql
   240  Rows affected: 0`)
   241  	require.Contains(t, output,
   242  		`Rolling back: 20200227231541_test_posts.sql
   243  Rows affected: 0`)
   244  }
   245  
   246  func testURLs() []*url.URL {
   247  	return []*url.URL{
   248  		dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL")),
   249  		dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL")),
   250  		dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL")),
   251  	}
   252  }
   253  
   254  func TestMigrate(t *testing.T) {
   255  	for _, u := range testURLs() {
   256  		t.Run(u.Scheme, func(t *testing.T) {
   257  			db := newTestDB(t, u)
   258  			drv, err := db.Driver()
   259  			require.NoError(t, err)
   260  
   261  			// drop and recreate database
   262  			err = db.Drop()
   263  			require.NoError(t, err)
   264  			err = db.Create()
   265  			require.NoError(t, err)
   266  
   267  			// migrate
   268  			err = db.Migrate()
   269  			require.NoError(t, err)
   270  
   271  			// verify results
   272  			sqlDB, err := drv.Open()
   273  			require.NoError(t, err)
   274  			defer dbutil.MustClose(sqlDB)
   275  
   276  			count := 0
   277  			err = sqlDB.QueryRow(`select count(*) from schema_migrations
   278  				where version = '20151129054053'`).Scan(&count)
   279  			require.NoError(t, err)
   280  			require.Equal(t, 1, count)
   281  
   282  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   283  			require.NoError(t, err)
   284  			require.Equal(t, 1, count)
   285  		})
   286  	}
   287  }
   288  
   289  func TestUp(t *testing.T) {
   290  	for _, u := range testURLs() {
   291  		t.Run(u.Scheme, func(t *testing.T) {
   292  			db := newTestDB(t, u)
   293  			drv, err := db.Driver()
   294  			require.NoError(t, err)
   295  
   296  			// drop database
   297  			err = db.Drop()
   298  			require.NoError(t, err)
   299  
   300  			// create and migrate
   301  			err = db.CreateAndMigrate()
   302  			require.NoError(t, err)
   303  
   304  			// verify results
   305  			sqlDB, err := drv.Open()
   306  			require.NoError(t, err)
   307  			defer dbutil.MustClose(sqlDB)
   308  
   309  			count := 0
   310  			err = sqlDB.QueryRow(`select count(*) from schema_migrations
   311  				where version = '20151129054053'`).Scan(&count)
   312  			require.NoError(t, err)
   313  			require.Equal(t, 1, count)
   314  
   315  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   316  			require.NoError(t, err)
   317  			require.Equal(t, 1, count)
   318  		})
   319  	}
   320  }
   321  
   322  func TestRollback(t *testing.T) {
   323  	for _, u := range testURLs() {
   324  		t.Run(u.Scheme, func(t *testing.T) {
   325  			db := newTestDB(t, u)
   326  			drv, err := db.Driver()
   327  			require.NoError(t, err)
   328  
   329  			// drop and create database
   330  			err = db.Drop()
   331  			require.NoError(t, err)
   332  			err = db.Create()
   333  			require.NoError(t, err)
   334  
   335  			// rollback should return error
   336  			err = db.Rollback()
   337  			require.Error(t, err)
   338  			require.ErrorContains(t, err, "can't rollback: no migrations have been applied")
   339  
   340  			// migrate database
   341  			err = db.Migrate()
   342  			require.NoError(t, err)
   343  
   344  			// verify migration
   345  			sqlDB, err := drv.Open()
   346  			require.NoError(t, err)
   347  			defer dbutil.MustClose(sqlDB)
   348  
   349  			var applied []string
   350  			rows, err := sqlDB.Query("select version from schema_migrations order by version asc")
   351  			require.NoError(t, err)
   352  			defer rows.Close()
   353  			for rows.Next() {
   354  				var version string
   355  				require.NoError(t, rows.Scan(&version))
   356  				applied = append(applied, version)
   357  			}
   358  			require.NoError(t, rows.Err())
   359  			require.Equal(t, []string{"20151129054053", "20200227231541"}, applied)
   360  
   361  			// users and posts tables have been created
   362  			var count int
   363  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   364  			require.Nil(t, err)
   365  			err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
   366  			require.Nil(t, err)
   367  
   368  			// rollback second migration
   369  			err = db.Rollback()
   370  			require.NoError(t, err)
   371  
   372  			// one migration remaining
   373  			err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
   374  			require.NoError(t, err)
   375  			require.Equal(t, 1, count)
   376  
   377  			// posts table was deleted
   378  			err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
   379  			require.NotNil(t, err)
   380  			require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
   381  
   382  			// users table still exists
   383  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   384  			require.Nil(t, err)
   385  
   386  			// rollback first migration
   387  			err = db.Rollback()
   388  			require.NoError(t, err)
   389  
   390  			// no migrations remaining
   391  			err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
   392  			require.NoError(t, err)
   393  			require.Equal(t, 0, count)
   394  
   395  			// posts table was deleted
   396  			err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
   397  			require.NotNil(t, err)
   398  			require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
   399  
   400  			// users table was deleted
   401  			err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
   402  			require.NotNil(t, err)
   403  			require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
   404  		})
   405  	}
   406  }
   407  
   408  func TestFindMigrations(t *testing.T) {
   409  	for _, u := range testURLs() {
   410  		t.Run(u.Scheme, func(t *testing.T) {
   411  			db := newTestDB(t, u)
   412  			drv, err := db.Driver()
   413  			require.NoError(t, err)
   414  
   415  			// drop, recreate, and migrate database
   416  			err = db.Drop()
   417  			require.NoError(t, err)
   418  			err = db.Create()
   419  			require.NoError(t, err)
   420  
   421  			// verify migration
   422  			sqlDB, err := drv.Open()
   423  			require.NoError(t, err)
   424  			defer dbutil.MustClose(sqlDB)
   425  
   426  			// two pending
   427  			results, err := db.FindMigrations()
   428  			require.NoError(t, err)
   429  			require.Len(t, results, 2)
   430  			require.False(t, results[0].Applied)
   431  			require.False(t, results[1].Applied)
   432  			migrationsTableExists, err := drv.MigrationsTableExists(sqlDB)
   433  			require.NoError(t, err)
   434  			require.False(t, migrationsTableExists)
   435  
   436  			// run migrations
   437  			err = db.Migrate()
   438  			require.NoError(t, err)
   439  
   440  			// two applied
   441  			results, err = db.FindMigrations()
   442  			require.NoError(t, err)
   443  			require.Len(t, results, 2)
   444  			require.True(t, results[0].Applied)
   445  			require.True(t, results[1].Applied)
   446  
   447  			// rollback last migration
   448  			err = db.Rollback()
   449  			require.NoError(t, err)
   450  
   451  			// one applied, one pending
   452  			results, err = db.FindMigrations()
   453  			require.NoError(t, err)
   454  			require.Len(t, results, 2)
   455  			require.True(t, results[0].Applied)
   456  			require.False(t, results[1].Applied)
   457  		})
   458  	}
   459  }
   460  
   461  func TestFindMigrationsAbsolute(t *testing.T) {
   462  	t.Run("relative path", func(t *testing.T) {
   463  		u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   464  		db := newTestDB(t, u)
   465  		db.MigrationsDir = []string{"db/migrations"}
   466  
   467  		migrations, err := db.FindMigrations()
   468  		require.NoError(t, err)
   469  
   470  		require.Equal(t, "db/migrations/20151129054053_test_migration.sql", migrations[0].FilePath)
   471  	})
   472  
   473  	t.Run("absolute path", func(t *testing.T) {
   474  		dir, err := os.MkdirTemp("", "dbmate")
   475  		require.NoError(t, err)
   476  		defer os.RemoveAll(dir)
   477  		require.True(t, filepath.IsAbs(dir))
   478  
   479  		file, err := os.Create(filepath.Join(dir, "1234_example.sql"))
   480  		require.NoError(t, err)
   481  		defer file.Close()
   482  
   483  		u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   484  		db := newTestDB(t, u)
   485  		db.MigrationsDir = []string{dir}
   486  		require.Nil(t, db.FS)
   487  
   488  		migrations, err := db.FindMigrations()
   489  		require.NoError(t, err)
   490  		require.Len(t, migrations, 1)
   491  		require.Equal(t, dir+"/1234_example.sql", migrations[0].FilePath)
   492  		require.True(t, filepath.IsAbs(migrations[0].FilePath))
   493  		require.Nil(t, migrations[0].FS)
   494  		require.Equal(t, "1234_example.sql", migrations[0].FileName)
   495  		require.Equal(t, "1234", migrations[0].Version)
   496  		require.False(t, migrations[0].Applied)
   497  	})
   498  }
   499  
   500  func TestFindMigrationsFS(t *testing.T) {
   501  	mapFS := fstest.MapFS{
   502  		"db/migrations/20151129054053_test_migration.sql": {},
   503  		"db/migrations/001_test_migration.sql": {
   504  			Data: []byte(`-- migrate:up
   505  create table users (id serial, name text);
   506  -- migrate:down
   507  drop table users;
   508  `),
   509  		},
   510  		"db/migrations/002_test_migration.sql":                {},
   511  		"db/migrations/003_not_sql.txt":                       {},
   512  		"db/migrations/missing_version.sql":                   {},
   513  		"db/not_migrations/20151129054053_test_migration.sql": {},
   514  	}
   515  
   516  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   517  	db := newTestDB(t, u)
   518  	db.FS = mapFS
   519  
   520  	// drop and recreate database
   521  	err := db.Drop()
   522  	require.NoError(t, err)
   523  	err = db.Create()
   524  	require.NoError(t, err)
   525  
   526  	actual, err := db.FindMigrations()
   527  	require.NoError(t, err)
   528  
   529  	// test migrations are correct and in order
   530  	require.Equal(t, "001_test_migration.sql", actual[0].FileName)
   531  	require.Equal(t, "db/migrations/001_test_migration.sql", actual[0].FilePath)
   532  	require.Equal(t, "001", actual[0].Version)
   533  	require.Equal(t, false, actual[0].Applied)
   534  
   535  	require.Equal(t, "002_test_migration.sql", actual[1].FileName)
   536  	require.Equal(t, "db/migrations/002_test_migration.sql", actual[1].FilePath)
   537  	require.Equal(t, "002", actual[1].Version)
   538  	require.Equal(t, false, actual[1].Applied)
   539  
   540  	require.Equal(t, "20151129054053_test_migration.sql", actual[2].FileName)
   541  	require.Equal(t, "db/migrations/20151129054053_test_migration.sql", actual[2].FilePath)
   542  	require.Equal(t, "20151129054053", actual[2].Version)
   543  	require.Equal(t, false, actual[2].Applied)
   544  
   545  	// test parsing first migration
   546  	parsed, err := actual[0].Parse()
   547  	require.Nil(t, err)
   548  	require.Equal(t, "-- migrate:up\ncreate table users (id serial, name text);\n", parsed.Up)
   549  	require.True(t, parsed.UpOptions.Transaction())
   550  	require.Equal(t, "-- migrate:down\ndrop table users;\n", parsed.Down)
   551  	require.True(t, parsed.DownOptions.Transaction())
   552  }
   553  
   554  func TestFindMigrationsFSMultipleDirs(t *testing.T) {
   555  	mapFS := fstest.MapFS{
   556  		"db/migrations_a/001_test_migration_a.sql": {},
   557  		"db/migrations_a/005_test_migration_a.sql": {},
   558  		"db/migrations_b/003_test_migration_b.sql": {},
   559  		"db/migrations_b/004_test_migration_b.sql": {},
   560  		"db/migrations_c/002_test_migration_c.sql": {},
   561  		"db/migrations_c/006_test_migration_c.sql": {},
   562  	}
   563  
   564  	u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
   565  	db := newTestDB(t, u)
   566  	db.FS = mapFS
   567  	db.MigrationsDir = []string{"./db/migrations_a", "./db/migrations_b", "./db/migrations_c"}
   568  
   569  	// drop and recreate database
   570  	err := db.Drop()
   571  	require.NoError(t, err)
   572  	err = db.Create()
   573  	require.NoError(t, err)
   574  
   575  	actual, err := db.FindMigrations()
   576  	require.NoError(t, err)
   577  
   578  	// test migrations are correct and in order
   579  	require.Equal(t, "db/migrations_a/001_test_migration_a.sql", actual[0].FilePath)
   580  	require.Equal(t, "db/migrations_c/002_test_migration_c.sql", actual[1].FilePath)
   581  	require.Equal(t, "db/migrations_b/003_test_migration_b.sql", actual[2].FilePath)
   582  	require.Equal(t, "db/migrations_b/004_test_migration_b.sql", actual[3].FilePath)
   583  	require.Equal(t, "db/migrations_a/005_test_migration_a.sql", actual[4].FilePath)
   584  	require.Equal(t, "db/migrations_c/006_test_migration_c.sql", actual[5].FilePath)
   585  }