github.com/dynastymasra/migrate/v4@v4.11.0/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	sqldriver "database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  	"testing"
    10  )
    11  
    12  import (
    13  	"github.com/dhui/dktest"
    14  	"github.com/go-sql-driver/mysql"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  import (
    19  	"github.com/golang-migrate/migrate/v4"
    20  	dt "github.com/golang-migrate/migrate/v4/database/testing"
    21  	"github.com/golang-migrate/migrate/v4/dktesting"
    22  	_ "github.com/golang-migrate/migrate/v4/source/file"
    23  )
    24  
    25  const defaultPort = 3306
    26  
    27  var (
    28  	opts = dktest.Options{
    29  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    30  		PortRequired: true, ReadyFunc: isReady,
    31  	}
    32  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    33  	specs = []dktesting.ContainerSpec{
    34  		{ImageName: "mysql:5.5", Options: opts},
    35  		{ImageName: "mysql:5.6", Options: opts},
    36  		{ImageName: "mysql:5.7", Options: opts},
    37  		{ImageName: "mysql:8", Options: opts},
    38  	}
    39  )
    40  
    41  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    42  	ip, port, err := c.Port(defaultPort)
    43  	if err != nil {
    44  		return false
    45  	}
    46  
    47  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    48  	if err != nil {
    49  		return false
    50  	}
    51  	defer func() {
    52  		if err := db.Close(); err != nil {
    53  			log.Println("close error:", err)
    54  		}
    55  	}()
    56  	if err = db.PingContext(ctx); err != nil {
    57  		switch err {
    58  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    59  			return false
    60  		default:
    61  			fmt.Println(err)
    62  		}
    63  		return false
    64  	}
    65  
    66  	return true
    67  }
    68  
    69  func Test(t *testing.T) {
    70  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
    71  
    72  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    73  		ip, port, err := c.Port(defaultPort)
    74  		if err != nil {
    75  			t.Fatal(err)
    76  		}
    77  
    78  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
    79  		p := &Mysql{}
    80  		d, err := p.Open(addr)
    81  		if err != nil {
    82  			t.Fatal(err)
    83  		}
    84  		defer func() {
    85  			if err := d.Close(); err != nil {
    86  				t.Error(err)
    87  			}
    88  		}()
    89  		dt.Test(t, d, []byte("SELECT 1"))
    90  
    91  		// check ensureVersionTable
    92  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
    93  			t.Fatal(err)
    94  		}
    95  		// check again
    96  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
    97  			t.Fatal(err)
    98  		}
    99  	})
   100  }
   101  
   102  func TestMigrate(t *testing.T) {
   103  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   104  
   105  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   106  		ip, port, err := c.Port(defaultPort)
   107  		if err != nil {
   108  			t.Fatal(err)
   109  		}
   110  
   111  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   112  		p := &Mysql{}
   113  		d, err := p.Open(addr)
   114  		if err != nil {
   115  			t.Fatal(err)
   116  		}
   117  		defer func() {
   118  			if err := d.Close(); err != nil {
   119  				t.Error(err)
   120  			}
   121  		}()
   122  
   123  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   124  		if err != nil {
   125  			t.Fatal(err)
   126  		}
   127  		dt.TestMigrate(t, m)
   128  
   129  		// check ensureVersionTable
   130  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   131  			t.Fatal(err)
   132  		}
   133  		// check again
   134  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   135  			t.Fatal(err)
   136  		}
   137  	})
   138  }
   139  
   140  func TestLockWorks(t *testing.T) {
   141  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   142  		ip, port, err := c.Port(defaultPort)
   143  		if err != nil {
   144  			t.Fatal(err)
   145  		}
   146  
   147  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   148  		p := &Mysql{}
   149  		d, err := p.Open(addr)
   150  		if err != nil {
   151  			t.Fatal(err)
   152  		}
   153  		dt.Test(t, d, []byte("SELECT 1"))
   154  
   155  		ms := d.(*Mysql)
   156  
   157  		err = ms.Lock()
   158  		if err != nil {
   159  			t.Fatal(err)
   160  		}
   161  		err = ms.Unlock()
   162  		if err != nil {
   163  			t.Fatal(err)
   164  		}
   165  
   166  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   167  		err = ms.Lock()
   168  		if err != nil {
   169  			t.Fatal(err)
   170  		}
   171  		err = ms.Unlock()
   172  		if err != nil {
   173  			t.Fatal(err)
   174  		}
   175  	})
   176  }
   177  
   178  func TestExtractCustomQueryParams(t *testing.T) {
   179  	testcases := []struct {
   180  		name                 string
   181  		config               *mysql.Config
   182  		expectedParams       map[string]string
   183  		expectedCustomParams map[string]string
   184  		expectedErr          error
   185  	}{
   186  		{name: "nil config", expectedErr: ErrNilConfig},
   187  		{
   188  			name:                 "no params",
   189  			config:               mysql.NewConfig(),
   190  			expectedCustomParams: map[string]string{},
   191  		},
   192  		{
   193  			name:                 "no custom params",
   194  			config:               &mysql.Config{Params: map[string]string{"hello": "world"}},
   195  			expectedParams:       map[string]string{"hello": "world"},
   196  			expectedCustomParams: map[string]string{},
   197  		},
   198  		{
   199  			name: "one param, one custom param",
   200  			config: &mysql.Config{
   201  				Params: map[string]string{"hello": "world", "x-foo": "bar"},
   202  			},
   203  			expectedParams:       map[string]string{"hello": "world"},
   204  			expectedCustomParams: map[string]string{"x-foo": "bar"},
   205  		},
   206  		{
   207  			name: "multiple params, multiple custom params",
   208  			config: &mysql.Config{
   209  				Params: map[string]string{
   210  					"hello": "world",
   211  					"x-foo": "bar",
   212  					"dead":  "beef",
   213  					"x-cat": "hat",
   214  				},
   215  			},
   216  			expectedParams:       map[string]string{"hello": "world", "dead": "beef"},
   217  			expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
   218  		},
   219  	}
   220  	for _, tc := range testcases {
   221  		t.Run(tc.name, func(t *testing.T) {
   222  			customParams, err := extractCustomQueryParams(tc.config)
   223  			if tc.config != nil {
   224  				assert.Equal(t, tc.expectedParams, tc.config.Params,
   225  					"Expected config params have custom params properly removed")
   226  			}
   227  			assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
   228  			assert.Equal(t, tc.expectedCustomParams, customParams,
   229  				"Expected custom params to be properly extracted")
   230  		})
   231  	}
   232  }
   233  
   234  func TestURLToMySQLConfig(t *testing.T) {
   235  	testcases := []struct {
   236  		name        string
   237  		urlStr      string
   238  		expectedDSN string // empty string signifies that an error is expected
   239  	}{
   240  		{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   241  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   242  		{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   243  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   244  		{name: "only user - with encoded :",
   245  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   246  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   247  		{name: "only user - with encoded @",
   248  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   249  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   250  		{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   251  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   252  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   253  		// {name: "user/password - user with encoded :",
   254  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   255  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   256  		{name: "user/password - user with encoded @",
   257  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   258  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   259  		{name: "user/password - password with encoded :",
   260  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   261  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   262  		{name: "user/password - password with encoded @",
   263  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   264  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   265  	}
   266  	for _, tc := range testcases {
   267  		t.Run(tc.name, func(t *testing.T) {
   268  			config, err := urlToMySQLConfig(tc.urlStr)
   269  			if err != nil {
   270  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   271  			}
   272  			dsn := config.FormatDSN()
   273  			if dsn != tc.expectedDSN {
   274  				t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   275  			}
   276  		})
   277  	}
   278  }