github.com/nagyist/migrate/v4@v4.14.6/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"crypto/ed25519"
     6  	"crypto/x509"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"log"
    14  	"math/big"
    15  	"math/rand"
    16  	"net/url"
    17  	"os"
    18  	"strconv"
    19  	"testing"
    20  )
    21  
    22  import (
    23  	"github.com/dhui/dktest"
    24  	"github.com/go-sql-driver/mysql"
    25  	"github.com/stretchr/testify/assert"
    26  )
    27  
    28  import (
    29  	"github.com/golang-migrate/migrate/v4"
    30  	dt "github.com/golang-migrate/migrate/v4/database/testing"
    31  	"github.com/golang-migrate/migrate/v4/dktesting"
    32  	_ "github.com/golang-migrate/migrate/v4/source/file"
    33  )
    34  
    35  const defaultPort = 3306
    36  
    37  var (
    38  	opts = dktest.Options{
    39  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    40  		PortRequired: true, ReadyFunc: isReady,
    41  	}
    42  	optsAnsiQuotes = dktest.Options{
    43  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    44  		PortRequired: true, ReadyFunc: isReady,
    45  		Cmd: []string{"--sql-mode=ANSI_QUOTES"},
    46  	}
    47  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    48  	specs = []dktesting.ContainerSpec{
    49  		{ImageName: "mysql:5.5", Options: opts},
    50  		{ImageName: "mysql:5.6", Options: opts},
    51  		{ImageName: "mysql:5.7", Options: opts},
    52  		{ImageName: "mysql:8", Options: opts},
    53  	}
    54  	specsAnsiQuotes = []dktesting.ContainerSpec{
    55  		{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
    56  		{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
    57  		{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
    58  		{ImageName: "mysql:8", Options: optsAnsiQuotes},
    59  	}
    60  )
    61  
    62  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    63  	ip, port, err := c.Port(defaultPort)
    64  	if err != nil {
    65  		return false
    66  	}
    67  
    68  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    69  	if err != nil {
    70  		return false
    71  	}
    72  	defer func() {
    73  		if err := db.Close(); err != nil {
    74  			log.Println("close error:", err)
    75  		}
    76  	}()
    77  	if err = db.PingContext(ctx); err != nil {
    78  		switch err {
    79  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    80  			return false
    81  		default:
    82  			fmt.Println(err)
    83  		}
    84  		return false
    85  	}
    86  
    87  	return true
    88  }
    89  
    90  func Test(t *testing.T) {
    91  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
    92  
    93  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    94  		ip, port, err := c.Port(defaultPort)
    95  		if err != nil {
    96  			t.Fatal(err)
    97  		}
    98  
    99  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   100  		p := &Mysql{}
   101  		d, err := p.Open(addr)
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  		defer func() {
   106  			if err := d.Close(); err != nil {
   107  				t.Error(err)
   108  			}
   109  		}()
   110  		dt.Test(t, d, []byte("SELECT 1"))
   111  
   112  		// check ensureVersionTable
   113  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   114  			t.Fatal(err)
   115  		}
   116  		// check again
   117  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   118  			t.Fatal(err)
   119  		}
   120  	})
   121  }
   122  
   123  func TestMigrate(t *testing.T) {
   124  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   125  
   126  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   127  		ip, port, err := c.Port(defaultPort)
   128  		if err != nil {
   129  			t.Fatal(err)
   130  		}
   131  
   132  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   133  		p := &Mysql{}
   134  		d, err := p.Open(addr)
   135  		if err != nil {
   136  			t.Fatal(err)
   137  		}
   138  		defer func() {
   139  			if err := d.Close(); err != nil {
   140  				t.Error(err)
   141  			}
   142  		}()
   143  
   144  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		dt.TestMigrate(t, m)
   149  
   150  		// check ensureVersionTable
   151  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   152  			t.Fatal(err)
   153  		}
   154  		// check again
   155  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   156  			t.Fatal(err)
   157  		}
   158  	})
   159  }
   160  
   161  func TestMigrateAnsiQuotes(t *testing.T) {
   162  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   163  
   164  	dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
   165  		ip, port, err := c.Port(defaultPort)
   166  		if err != nil {
   167  			t.Fatal(err)
   168  		}
   169  
   170  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   171  		p := &Mysql{}
   172  		d, err := p.Open(addr)
   173  		if err != nil {
   174  			t.Fatal(err)
   175  		}
   176  		defer func() {
   177  			if err := d.Close(); err != nil {
   178  				t.Error(err)
   179  			}
   180  		}()
   181  
   182  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   183  		if err != nil {
   184  			t.Fatal(err)
   185  		}
   186  		dt.TestMigrate(t, m)
   187  
   188  		// check ensureVersionTable
   189  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   190  			t.Fatal(err)
   191  		}
   192  		// check again
   193  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   194  			t.Fatal(err)
   195  		}
   196  	})
   197  }
   198  
   199  func TestLockWorks(t *testing.T) {
   200  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   201  		ip, port, err := c.Port(defaultPort)
   202  		if err != nil {
   203  			t.Fatal(err)
   204  		}
   205  
   206  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   207  		p := &Mysql{}
   208  		d, err := p.Open(addr)
   209  		if err != nil {
   210  			t.Fatal(err)
   211  		}
   212  		dt.Test(t, d, []byte("SELECT 1"))
   213  
   214  		ms := d.(*Mysql)
   215  
   216  		err = ms.Lock()
   217  		if err != nil {
   218  			t.Fatal(err)
   219  		}
   220  		err = ms.Unlock()
   221  		if err != nil {
   222  			t.Fatal(err)
   223  		}
   224  
   225  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   226  		err = ms.Lock()
   227  		if err != nil {
   228  			t.Fatal(err)
   229  		}
   230  		err = ms.Unlock()
   231  		if err != nil {
   232  			t.Fatal(err)
   233  		}
   234  	})
   235  }
   236  
   237  func TestNoLockParamValidation(t *testing.T) {
   238  	ip := "127.0.0.1"
   239  	port := 3306
   240  	addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   241  	p := &Mysql{}
   242  	_, err := p.Open(addr + "?x-no-lock=not-a-bool")
   243  	if !errors.Is(err, strconv.ErrSyntax) {
   244  		t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
   245  	}
   246  }
   247  
   248  func TestNoLockWorks(t *testing.T) {
   249  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   250  		ip, port, err := c.Port(defaultPort)
   251  		if err != nil {
   252  			t.Fatal(err)
   253  		}
   254  
   255  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   256  		p := &Mysql{}
   257  		d, err := p.Open(addr)
   258  		if err != nil {
   259  			t.Fatal(err)
   260  		}
   261  
   262  		lock := d.(*Mysql)
   263  
   264  		p = &Mysql{}
   265  		d, err = p.Open(addr + "?x-no-lock=true")
   266  		if err != nil {
   267  			t.Fatal(err)
   268  		}
   269  
   270  		noLock := d.(*Mysql)
   271  
   272  		// Should be possible to take real lock and no-lock at the same time
   273  		if err = lock.Lock(); err != nil {
   274  			t.Fatal(err)
   275  		}
   276  		if err = noLock.Lock(); err != nil {
   277  			t.Fatal(err)
   278  		}
   279  		if err = lock.Unlock(); err != nil {
   280  			t.Fatal(err)
   281  		}
   282  		if err = noLock.Unlock(); err != nil {
   283  			t.Fatal(err)
   284  		}
   285  	})
   286  }
   287  
   288  func TestExtractCustomQueryParams(t *testing.T) {
   289  	testcases := []struct {
   290  		name                 string
   291  		config               *mysql.Config
   292  		expectedParams       map[string]string
   293  		expectedCustomParams map[string]string
   294  		expectedErr          error
   295  	}{
   296  		{name: "nil config", expectedErr: ErrNilConfig},
   297  		{
   298  			name:                 "no params",
   299  			config:               mysql.NewConfig(),
   300  			expectedCustomParams: map[string]string{},
   301  		},
   302  		{
   303  			name:                 "no custom params",
   304  			config:               &mysql.Config{Params: map[string]string{"hello": "world"}},
   305  			expectedParams:       map[string]string{"hello": "world"},
   306  			expectedCustomParams: map[string]string{},
   307  		},
   308  		{
   309  			name: "one param, one custom param",
   310  			config: &mysql.Config{
   311  				Params: map[string]string{"hello": "world", "x-foo": "bar"},
   312  			},
   313  			expectedParams:       map[string]string{"hello": "world"},
   314  			expectedCustomParams: map[string]string{"x-foo": "bar"},
   315  		},
   316  		{
   317  			name: "multiple params, multiple custom params",
   318  			config: &mysql.Config{
   319  				Params: map[string]string{
   320  					"hello": "world",
   321  					"x-foo": "bar",
   322  					"dead":  "beef",
   323  					"x-cat": "hat",
   324  				},
   325  			},
   326  			expectedParams:       map[string]string{"hello": "world", "dead": "beef"},
   327  			expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
   328  		},
   329  	}
   330  	for _, tc := range testcases {
   331  		t.Run(tc.name, func(t *testing.T) {
   332  			customParams, err := extractCustomQueryParams(tc.config)
   333  			if tc.config != nil {
   334  				assert.Equal(t, tc.expectedParams, tc.config.Params,
   335  					"Expected config params have custom params properly removed")
   336  			}
   337  			assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
   338  			assert.Equal(t, tc.expectedCustomParams, customParams,
   339  				"Expected custom params to be properly extracted")
   340  		})
   341  	}
   342  }
   343  
   344  func createTmpCert(t *testing.T) string {
   345  	tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert")
   346  	if err != nil {
   347  		t.Fatal("Failed to create temp cert file:", err)
   348  	}
   349  	t.Cleanup(func() {
   350  		if err := os.Remove(tmpCertFile.Name()); err != nil {
   351  			t.Log("Failed to cleanup temp cert file:", err)
   352  		}
   353  	})
   354  
   355  	r := rand.New(rand.NewSource(0))
   356  	pub, priv, err := ed25519.GenerateKey(r)
   357  	if err != nil {
   358  		t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
   359  	}
   360  	tmpl := x509.Certificate{
   361  		SerialNumber: big.NewInt(0),
   362  	}
   363  	derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
   364  	if err != nil {
   365  		t.Fatal("Failed to generate temp cert file:", err)
   366  	}
   367  	if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   368  		t.Fatal("Failed to encode ")
   369  	}
   370  	if err := tmpCertFile.Close(); err != nil {
   371  		t.Fatal("Failed to close temp cert file:", err)
   372  	}
   373  	return tmpCertFile.Name()
   374  }
   375  
   376  func TestURLToMySQLConfig(t *testing.T) {
   377  	tmpCertFilename := createTmpCert(t)
   378  	tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
   379  
   380  	testcases := []struct {
   381  		name        string
   382  		urlStr      string
   383  		expectedDSN string // empty string signifies that an error is expected
   384  	}{
   385  		{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   386  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   387  		{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   388  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   389  		{name: "only user - with encoded :",
   390  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   391  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   392  		{name: "only user - with encoded @",
   393  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   394  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   395  		{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   396  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   397  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   398  		// {name: "user/password - user with encoded :",
   399  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   400  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   401  		{name: "user/password - user with encoded @",
   402  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   403  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   404  		{name: "user/password - password with encoded :",
   405  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   406  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   407  		{name: "user/password - password with encoded @",
   408  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   409  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   410  		{name: "custom tls",
   411  			urlStr:      "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
   412  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
   413  	}
   414  	for _, tc := range testcases {
   415  		t.Run(tc.name, func(t *testing.T) {
   416  			config, err := urlToMySQLConfig(tc.urlStr)
   417  			if err != nil {
   418  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   419  			}
   420  			dsn := config.FormatDSN()
   421  			if dsn != tc.expectedDSN {
   422  				t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   423  			}
   424  		})
   425  	}
   426  }