github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"database/sql"
     5  	"net/url"
     6  	"os"
     7  	"testing"
     8  
     9  	"github.com/amacneil/dbmate/pkg/dbmate"
    10  	"github.com/amacneil/dbmate/pkg/dbutil"
    11  
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func testMySQLDriver(t *testing.T) *Driver {
    16  	u := dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL"))
    17  	drv, err := dbmate.New(u).Driver()
    18  	require.NoError(t, err)
    19  
    20  	return drv.(*Driver)
    21  }
    22  
    23  func prepTestMySQLDB(t *testing.T) *sql.DB {
    24  	drv := testMySQLDriver(t)
    25  
    26  	// drop any existing database
    27  	err := drv.DropDatabase()
    28  	require.NoError(t, err)
    29  
    30  	// create database
    31  	err = drv.CreateDatabase()
    32  	require.NoError(t, err)
    33  
    34  	// connect database
    35  	db, err := drv.Open()
    36  	require.NoError(t, err)
    37  
    38  	return db
    39  }
    40  
    41  func TestGetDriver(t *testing.T) {
    42  	db := dbmate.New(dbutil.MustParseURL("mysql://"))
    43  	drvInterface, err := db.Driver()
    44  	require.NoError(t, err)
    45  
    46  	// driver should have URL and default migrations table set
    47  	drv, ok := drvInterface.(*Driver)
    48  	require.True(t, ok)
    49  	require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
    50  	require.Equal(t, "schema_migrations", drv.migrationsTableName)
    51  }
    52  
    53  func TestConnectionString(t *testing.T) {
    54  	t.Run("defaults", func(t *testing.T) {
    55  		u, err := url.Parse("mysql://host/foo")
    56  		require.NoError(t, err)
    57  		require.Equal(t, "", u.Port())
    58  
    59  		s := connectionString(u)
    60  		require.Equal(t, "tcp(host:3306)/foo?multiStatements=true", s)
    61  	})
    62  
    63  	t.Run("custom", func(t *testing.T) {
    64  		u, err := url.Parse("mysql://bob:secret@host:123/foo?flag=on")
    65  		require.NoError(t, err)
    66  		require.Equal(t, "123", u.Port())
    67  
    68  		s := connectionString(u)
    69  		require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
    70  	})
    71  
    72  	t.Run("special chars", func(t *testing.T) {
    73  		u, err := url.Parse("mysql://duhfsd7s:123!@123!@@host:123/foo?flag=on")
    74  		require.NoError(t, err)
    75  		require.Equal(t, "123", u.Port())
    76  
    77  		s := connectionString(u)
    78  		require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s)
    79  	})
    80  
    81  	t.Run("url encoding", func(t *testing.T) {
    82  		u, err := url.Parse("mysql://bob%2Balice:secret%5E%5B%2A%28%29@host:123/foo")
    83  		require.NoError(t, err)
    84  		require.Equal(t, "bob+alice:secret%5E%5B%2A%28%29", u.User.String())
    85  		require.Equal(t, "123", u.Port())
    86  
    87  		s := connectionString(u)
    88  		// ensure that '+' is correctly encoded by url.PathUnescape as '+'
    89  		// (not whitespace as url.QueryUnescape generates)
    90  		require.Equal(t, "bob+alice:secret^[*()@tcp(host:123)/foo?multiStatements=true", s)
    91  	})
    92  
    93  	t.Run("socket", func(t *testing.T) {
    94  		// test with no user/pass
    95  		u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
    96  		require.NoError(t, err)
    97  		require.Equal(t, "", u.Host)
    98  
    99  		s := connectionString(u)
   100  		require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
   101  
   102  		// test with user/pass
   103  		u, err = url.Parse("mysql://bob:secret@fakehost/foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
   104  		require.NoError(t, err)
   105  
   106  		s = connectionString(u)
   107  		require.Equal(t, "bob:secret@unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
   108  	})
   109  }
   110  
   111  func TestMySQLCreateDropDatabase(t *testing.T) {
   112  	drv := testMySQLDriver(t)
   113  
   114  	// drop any existing database
   115  	err := drv.DropDatabase()
   116  	require.NoError(t, err)
   117  
   118  	// create database
   119  	err = drv.CreateDatabase()
   120  	require.NoError(t, err)
   121  
   122  	// check that database exists and we can connect to it
   123  	func() {
   124  		db, err := drv.Open()
   125  		require.NoError(t, err)
   126  		defer dbutil.MustClose(db)
   127  
   128  		err = db.Ping()
   129  		require.NoError(t, err)
   130  	}()
   131  
   132  	// drop the database
   133  	err = drv.DropDatabase()
   134  	require.NoError(t, err)
   135  
   136  	// check that database no longer exists
   137  	func() {
   138  		db, err := drv.Open()
   139  		require.NoError(t, err)
   140  		defer dbutil.MustClose(db)
   141  
   142  		err = db.Ping()
   143  		require.Error(t, err)
   144  		require.Regexp(t, "Unknown database 'dbmate_test'", err.Error())
   145  	}()
   146  }
   147  
   148  func TestMySQLDumpArgs(t *testing.T) {
   149  	drv := testMySQLDriver(t)
   150  	drv.databaseURL = dbutil.MustParseURL("mysql://bob/mydb")
   151  
   152  	require.Equal(t, []string{"--opt",
   153  		"--routines",
   154  		"--no-data",
   155  		"--skip-dump-date",
   156  		"--skip-add-drop-table",
   157  		"--host=bob",
   158  		"mydb"}, drv.mysqldumpArgs())
   159  
   160  	drv.databaseURL = dbutil.MustParseURL("mysql://alice:pw@bob:5678/mydb")
   161  	require.Equal(t, []string{"--opt",
   162  		"--routines",
   163  		"--no-data",
   164  		"--skip-dump-date",
   165  		"--skip-add-drop-table",
   166  		"--host=bob",
   167  		"--port=5678",
   168  		"--user=alice",
   169  		"--password=pw",
   170  		"mydb"}, drv.mysqldumpArgs())
   171  
   172  	drv.databaseURL = dbutil.MustParseURL("mysql://alice:pw@bob:5678/mydb?socket=/var/run/mysqld/mysqld.sock")
   173  	require.Equal(t, []string{"--opt",
   174  		"--routines",
   175  		"--no-data",
   176  		"--skip-dump-date",
   177  		"--skip-add-drop-table",
   178  		"--socket=/var/run/mysqld/mysqld.sock",
   179  		"--user=alice",
   180  		"--password=pw",
   181  		"mydb"}, drv.mysqldumpArgs())
   182  }
   183  
   184  func TestMySQLDumpSchema(t *testing.T) {
   185  	drv := testMySQLDriver(t)
   186  	drv.migrationsTableName = "test_migrations"
   187  
   188  	// prepare database
   189  	db := prepTestMySQLDB(t)
   190  	defer dbutil.MustClose(db)
   191  	err := drv.CreateMigrationsTable(db)
   192  	require.NoError(t, err)
   193  
   194  	// insert migration
   195  	err = drv.InsertMigration(db, "abc1")
   196  	require.NoError(t, err)
   197  	err = drv.InsertMigration(db, "abc2")
   198  	require.NoError(t, err)
   199  
   200  	// DumpSchema should return schema
   201  	schema, err := drv.DumpSchema(db)
   202  	require.NoError(t, err)
   203  	require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
   204  	require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
   205  		"--\n"+
   206  		"-- Dbmate schema migrations\n"+
   207  		"--\n\n"+
   208  		"LOCK TABLES `test_migrations` WRITE;\n"+
   209  		"INSERT INTO `test_migrations` (version) VALUES\n"+
   210  		"  ('abc1'),\n"+
   211  		"  ('abc2');\n"+
   212  		"UNLOCK TABLES;\n")
   213  
   214  	// DumpSchema should return error if command fails
   215  	drv.databaseURL.Path = "/fakedb"
   216  	schema, err = drv.DumpSchema(db)
   217  	require.Nil(t, schema)
   218  	require.Error(t, err)
   219  	require.Contains(t, err.Error(), "Unknown database 'fakedb'")
   220  }
   221  
   222  func TestMySQLDumpSchemaContainsNoAutoIncrement(t *testing.T) {
   223  	drv := testMySQLDriver(t)
   224  
   225  	db := prepTestMySQLDB(t)
   226  	defer dbutil.MustClose(db)
   227  	err := drv.CreateMigrationsTable(db)
   228  	require.NoError(t, err)
   229  
   230  	// create table with AUTO_INCREMENT column
   231  	_, err = db.Exec(`create table foo_table (id int not null primary key auto_increment)`)
   232  	require.NoError(t, err)
   233  
   234  	// create a record
   235  	_, err = db.Exec(`insert into foo_table values ()`)
   236  	require.NoError(t, err)
   237  
   238  	// AUTO_INCREMENT should appear on the table definition
   239  	var tblName, tblCreate string
   240  	err = db.QueryRow(`show create table foo_table`).Scan(&tblName, &tblCreate)
   241  	require.NoError(t, err)
   242  	require.Contains(t, tblCreate, "AUTO_INCREMENT=")
   243  
   244  	// AUTO_INCREMENT should not appear in the dump
   245  	schema, err := drv.DumpSchema(db)
   246  	require.NoError(t, err)
   247  	require.NotContains(t, string(schema), "AUTO_INCREMENT=")
   248  }
   249  
   250  func TestMySQLDatabaseExists(t *testing.T) {
   251  	drv := testMySQLDriver(t)
   252  
   253  	// drop any existing database
   254  	err := drv.DropDatabase()
   255  	require.NoError(t, err)
   256  
   257  	// DatabaseExists should return false
   258  	exists, err := drv.DatabaseExists()
   259  	require.NoError(t, err)
   260  	require.Equal(t, false, exists)
   261  
   262  	// create database
   263  	err = drv.CreateDatabase()
   264  	require.NoError(t, err)
   265  
   266  	// DatabaseExists should return true
   267  	exists, err = drv.DatabaseExists()
   268  	require.NoError(t, err)
   269  	require.Equal(t, true, exists)
   270  }
   271  
   272  func TestMySQLDatabaseExists_Error(t *testing.T) {
   273  	drv := testMySQLDriver(t)
   274  	drv.databaseURL.User = url.User("invalid")
   275  
   276  	exists, err := drv.DatabaseExists()
   277  	require.Error(t, err)
   278  	require.Regexp(t, "Access denied for user 'invalid'@", err.Error())
   279  	require.Equal(t, false, exists)
   280  }
   281  
   282  func TestMySQLCreateMigrationsTable(t *testing.T) {
   283  	drv := testMySQLDriver(t)
   284  	drv.migrationsTableName = "test_migrations"
   285  
   286  	db := prepTestMySQLDB(t)
   287  	defer dbutil.MustClose(db)
   288  
   289  	// migrations table should not exist
   290  	count := 0
   291  	err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
   292  	require.Error(t, err)
   293  	require.Regexp(t, "Table 'dbmate_test.test_migrations' doesn't exist", err.Error())
   294  
   295  	// create table
   296  	err = drv.CreateMigrationsTable(db)
   297  	require.NoError(t, err)
   298  
   299  	// migrations table should exist
   300  	err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
   301  	require.NoError(t, err)
   302  
   303  	// create table should be idempotent
   304  	err = drv.CreateMigrationsTable(db)
   305  	require.NoError(t, err)
   306  }
   307  
   308  func TestMySQLSelectMigrations(t *testing.T) {
   309  	drv := testMySQLDriver(t)
   310  	drv.migrationsTableName = "test_migrations"
   311  
   312  	db := prepTestMySQLDB(t)
   313  	defer dbutil.MustClose(db)
   314  
   315  	err := drv.CreateMigrationsTable(db)
   316  	require.NoError(t, err)
   317  
   318  	_, err = db.Exec(`insert into test_migrations (version)
   319  		values ('abc2'), ('abc1'), ('abc3')`)
   320  	require.NoError(t, err)
   321  
   322  	migrations, err := drv.SelectMigrations(db, -1)
   323  	require.NoError(t, err)
   324  	require.Equal(t, true, migrations["abc1"])
   325  	require.Equal(t, true, migrations["abc2"])
   326  	require.Equal(t, true, migrations["abc2"])
   327  
   328  	// test limit param
   329  	migrations, err = drv.SelectMigrations(db, 1)
   330  	require.NoError(t, err)
   331  	require.Equal(t, true, migrations["abc3"])
   332  	require.Equal(t, false, migrations["abc1"])
   333  	require.Equal(t, false, migrations["abc2"])
   334  }
   335  
   336  func TestMySQLInsertMigration(t *testing.T) {
   337  	drv := testMySQLDriver(t)
   338  	drv.migrationsTableName = "test_migrations"
   339  
   340  	db := prepTestMySQLDB(t)
   341  	defer dbutil.MustClose(db)
   342  
   343  	err := drv.CreateMigrationsTable(db)
   344  	require.NoError(t, err)
   345  
   346  	count := 0
   347  	err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
   348  	require.NoError(t, err)
   349  	require.Equal(t, 0, count)
   350  
   351  	// insert migration
   352  	err = drv.InsertMigration(db, "abc1")
   353  	require.NoError(t, err)
   354  
   355  	err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
   356  		Scan(&count)
   357  	require.NoError(t, err)
   358  	require.Equal(t, 1, count)
   359  }
   360  
   361  func TestMySQLDeleteMigration(t *testing.T) {
   362  	drv := testMySQLDriver(t)
   363  	drv.migrationsTableName = "test_migrations"
   364  
   365  	db := prepTestMySQLDB(t)
   366  	defer dbutil.MustClose(db)
   367  
   368  	err := drv.CreateMigrationsTable(db)
   369  	require.NoError(t, err)
   370  
   371  	_, err = db.Exec(`insert into test_migrations (version)
   372  		values ('abc1'), ('abc2')`)
   373  	require.NoError(t, err)
   374  
   375  	err = drv.DeleteMigration(db, "abc2")
   376  	require.NoError(t, err)
   377  
   378  	count := 0
   379  	err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
   380  	require.NoError(t, err)
   381  	require.Equal(t, 1, count)
   382  }
   383  
   384  func TestMySQLPing(t *testing.T) {
   385  	drv := testMySQLDriver(t)
   386  
   387  	// drop any existing database
   388  	err := drv.DropDatabase()
   389  	require.NoError(t, err)
   390  
   391  	// ping database
   392  	err = drv.Ping()
   393  	require.NoError(t, err)
   394  
   395  	// ping invalid host should return error
   396  	drv.databaseURL.Host = "mysql:404"
   397  	err = drv.Ping()
   398  	require.Error(t, err)
   399  	require.Contains(t, err.Error(), "connect: connection refused")
   400  }
   401  
   402  func TestMySQLQuotedMigrationsTableName(t *testing.T) {
   403  	t.Run("default name", func(t *testing.T) {
   404  		drv := testMySQLDriver(t)
   405  		name := drv.quotedMigrationsTableName()
   406  		require.Equal(t, "`schema_migrations`", name)
   407  	})
   408  
   409  	t.Run("custom name", func(t *testing.T) {
   410  		drv := testMySQLDriver(t)
   411  		drv.migrationsTableName = "fooMigrations"
   412  
   413  		name := drv.quotedMigrationsTableName()
   414  		require.Equal(t, "`fooMigrations`", name)
   415  	})
   416  }