github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"crypto/ed25519"
     6  	"crypto/x509"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"log"
    14  	"math/big"
    15  	"math/rand"
    16  	"net/url"
    17  	"os"
    18  	"strconv"
    19  	"testing"
    20  
    21  	"github.com/dhui/dktest"
    22  	"github.com/go-sql-driver/mysql"
    23  	"github.com/stretchr/testify/assert"
    24  
    25  	migrate "github.com/seashell-org/golang-migrate/v4"
    26  
    27  	dt "github.com/seashell-org/golang-migrate/v4/database/testing"
    28  	"github.com/seashell-org/golang-migrate/v4/dktesting"
    29  
    30  	_ "github.com/seashell-org/golang-migrate/v4/source/file"
    31  )
    32  
    33  const defaultPort = 3306
    34  
    35  var (
    36  	opts = dktest.Options{
    37  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    38  		PortRequired: true, ReadyFunc: isReady,
    39  	}
    40  	optsAnsiQuotes = dktest.Options{
    41  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    42  		PortRequired: true, ReadyFunc: isReady,
    43  		Cmd: []string{"--sql-mode=ANSI_QUOTES"},
    44  	}
    45  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    46  	specs = []dktesting.ContainerSpec{
    47  		{ImageName: "mysql:5.5", Options: opts},
    48  		{ImageName: "mysql:5.6", Options: opts},
    49  		{ImageName: "mysql:5.7", Options: opts},
    50  		{ImageName: "mysql:8", Options: opts},
    51  	}
    52  	specsAnsiQuotes = []dktesting.ContainerSpec{
    53  		{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
    54  		{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
    55  		{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
    56  		{ImageName: "mysql:8", Options: optsAnsiQuotes},
    57  	}
    58  )
    59  
    60  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    61  	ip, port, err := c.Port(defaultPort)
    62  	if err != nil {
    63  		return false
    64  	}
    65  
    66  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    67  	if err != nil {
    68  		return false
    69  	}
    70  	defer func() {
    71  		if err := db.Close(); err != nil {
    72  			log.Println("close error:", err)
    73  		}
    74  	}()
    75  	if err = db.PingContext(ctx); err != nil {
    76  		switch err {
    77  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    78  			return false
    79  		default:
    80  			fmt.Println(err)
    81  		}
    82  		return false
    83  	}
    84  
    85  	return true
    86  }
    87  
    88  func Test(t *testing.T) {
    89  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
    90  
    91  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    92  		ip, port, err := c.Port(defaultPort)
    93  		if err != nil {
    94  			t.Fatal(err)
    95  		}
    96  
    97  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
    98  		p := &Mysql{}
    99  		d, err := p.Open(addr)
   100  		if err != nil {
   101  			t.Fatal(err)
   102  		}
   103  		defer func() {
   104  			if err := d.Close(); err != nil {
   105  				t.Error(err)
   106  			}
   107  		}()
   108  		dt.Test(t, d, []byte("SELECT 1"))
   109  
   110  		// check ensureVersionTable
   111  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   112  			t.Fatal(err)
   113  		}
   114  		// check again
   115  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   116  			t.Fatal(err)
   117  		}
   118  	})
   119  }
   120  
   121  func TestMigrate(t *testing.T) {
   122  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   123  
   124  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   125  		ip, port, err := c.Port(defaultPort)
   126  		if err != nil {
   127  			t.Fatal(err)
   128  		}
   129  
   130  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   131  		p := &Mysql{}
   132  		d, err := p.Open(addr)
   133  		if err != nil {
   134  			t.Fatal(err)
   135  		}
   136  		defer func() {
   137  			if err := d.Close(); err != nil {
   138  				t.Error(err)
   139  			}
   140  		}()
   141  
   142  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   143  		if err != nil {
   144  			t.Fatal(err)
   145  		}
   146  		dt.TestMigrate(t, m)
   147  
   148  		// check ensureVersionTable
   149  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   150  			t.Fatal(err)
   151  		}
   152  		// check again
   153  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   154  			t.Fatal(err)
   155  		}
   156  	})
   157  }
   158  
   159  func TestMigrateAnsiQuotes(t *testing.T) {
   160  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   161  
   162  	dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
   163  		ip, port, err := c.Port(defaultPort)
   164  		if err != nil {
   165  			t.Fatal(err)
   166  		}
   167  
   168  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   169  		p := &Mysql{}
   170  		d, err := p.Open(addr)
   171  		if err != nil {
   172  			t.Fatal(err)
   173  		}
   174  		defer func() {
   175  			if err := d.Close(); err != nil {
   176  				t.Error(err)
   177  			}
   178  		}()
   179  
   180  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   181  		if err != nil {
   182  			t.Fatal(err)
   183  		}
   184  		dt.TestMigrate(t, m)
   185  
   186  		// check ensureVersionTable
   187  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   188  			t.Fatal(err)
   189  		}
   190  		// check again
   191  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   192  			t.Fatal(err)
   193  		}
   194  	})
   195  }
   196  
   197  func TestLockWorks(t *testing.T) {
   198  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   199  		ip, port, err := c.Port(defaultPort)
   200  		if err != nil {
   201  			t.Fatal(err)
   202  		}
   203  
   204  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   205  		p := &Mysql{}
   206  		d, err := p.Open(addr)
   207  		if err != nil {
   208  			t.Fatal(err)
   209  		}
   210  		dt.Test(t, d, []byte("SELECT 1"))
   211  
   212  		ms := d.(*Mysql)
   213  
   214  		err = ms.Lock()
   215  		if err != nil {
   216  			t.Fatal(err)
   217  		}
   218  		err = ms.Unlock()
   219  		if err != nil {
   220  			t.Fatal(err)
   221  		}
   222  
   223  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   224  		err = ms.Lock()
   225  		if err != nil {
   226  			t.Fatal(err)
   227  		}
   228  		err = ms.Unlock()
   229  		if err != nil {
   230  			t.Fatal(err)
   231  		}
   232  	})
   233  }
   234  
   235  func TestNoLockParamValidation(t *testing.T) {
   236  	ip := "127.0.0.1"
   237  	port := 3306
   238  	addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   239  	p := &Mysql{}
   240  	_, err := p.Open(addr + "?x-no-lock=not-a-bool")
   241  	if !errors.Is(err, strconv.ErrSyntax) {
   242  		t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
   243  	}
   244  }
   245  
   246  func TestNoLockWorks(t *testing.T) {
   247  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   248  		ip, port, err := c.Port(defaultPort)
   249  		if err != nil {
   250  			t.Fatal(err)
   251  		}
   252  
   253  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   254  		p := &Mysql{}
   255  		d, err := p.Open(addr)
   256  		if err != nil {
   257  			t.Fatal(err)
   258  		}
   259  
   260  		lock := d.(*Mysql)
   261  
   262  		p = &Mysql{}
   263  		d, err = p.Open(addr + "?x-no-lock=true")
   264  		if err != nil {
   265  			t.Fatal(err)
   266  		}
   267  
   268  		noLock := d.(*Mysql)
   269  
   270  		// Should be possible to take real lock and no-lock at the same time
   271  		if err = lock.Lock(); err != nil {
   272  			t.Fatal(err)
   273  		}
   274  		if err = noLock.Lock(); err != nil {
   275  			t.Fatal(err)
   276  		}
   277  		if err = lock.Unlock(); err != nil {
   278  			t.Fatal(err)
   279  		}
   280  		if err = noLock.Unlock(); err != nil {
   281  			t.Fatal(err)
   282  		}
   283  	})
   284  }
   285  
   286  func TestExtractCustomQueryParams(t *testing.T) {
   287  	testcases := []struct {
   288  		name                 string
   289  		config               *mysql.Config
   290  		expectedParams       map[string]string
   291  		expectedCustomParams map[string]string
   292  		expectedErr          error
   293  	}{
   294  		{name: "nil config", expectedErr: ErrNilConfig},
   295  		{
   296  			name:                 "no params",
   297  			config:               mysql.NewConfig(),
   298  			expectedCustomParams: map[string]string{},
   299  		},
   300  		{
   301  			name:                 "no custom params",
   302  			config:               &mysql.Config{Params: map[string]string{"hello": "world"}},
   303  			expectedParams:       map[string]string{"hello": "world"},
   304  			expectedCustomParams: map[string]string{},
   305  		},
   306  		{
   307  			name: "one param, one custom param",
   308  			config: &mysql.Config{
   309  				Params: map[string]string{"hello": "world", "x-foo": "bar"},
   310  			},
   311  			expectedParams:       map[string]string{"hello": "world"},
   312  			expectedCustomParams: map[string]string{"x-foo": "bar"},
   313  		},
   314  		{
   315  			name: "multiple params, multiple custom params",
   316  			config: &mysql.Config{
   317  				Params: map[string]string{
   318  					"hello": "world",
   319  					"x-foo": "bar",
   320  					"dead":  "beef",
   321  					"x-cat": "hat",
   322  				},
   323  			},
   324  			expectedParams:       map[string]string{"hello": "world", "dead": "beef"},
   325  			expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
   326  		},
   327  	}
   328  	for _, tc := range testcases {
   329  		t.Run(tc.name, func(t *testing.T) {
   330  			customParams, err := extractCustomQueryParams(tc.config)
   331  			if tc.config != nil {
   332  				assert.Equal(t, tc.expectedParams, tc.config.Params,
   333  					"Expected config params have custom params properly removed")
   334  			}
   335  			assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
   336  			assert.Equal(t, tc.expectedCustomParams, customParams,
   337  				"Expected custom params to be properly extracted")
   338  		})
   339  	}
   340  }
   341  
   342  func createTmpCert(t *testing.T) string {
   343  	tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert")
   344  	if err != nil {
   345  		t.Fatal("Failed to create temp cert file:", err)
   346  	}
   347  	t.Cleanup(func() {
   348  		if err := os.Remove(tmpCertFile.Name()); err != nil {
   349  			t.Log("Failed to cleanup temp cert file:", err)
   350  		}
   351  	})
   352  
   353  	r := rand.New(rand.NewSource(0))
   354  	pub, priv, err := ed25519.GenerateKey(r)
   355  	if err != nil {
   356  		t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
   357  	}
   358  	tmpl := x509.Certificate{
   359  		SerialNumber: big.NewInt(0),
   360  	}
   361  	derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
   362  	if err != nil {
   363  		t.Fatal("Failed to generate temp cert file:", err)
   364  	}
   365  	if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   366  		t.Fatal("Failed to encode ")
   367  	}
   368  	if err := tmpCertFile.Close(); err != nil {
   369  		t.Fatal("Failed to close temp cert file:", err)
   370  	}
   371  	return tmpCertFile.Name()
   372  }
   373  
   374  func TestURLToMySQLConfig(t *testing.T) {
   375  	tmpCertFilename := createTmpCert(t)
   376  	tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
   377  
   378  	testcases := []struct {
   379  		name        string
   380  		urlStr      string
   381  		expectedDSN string // empty string signifies that an error is expected
   382  	}{
   383  		{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   384  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   385  		{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   386  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   387  		{name: "only user - with encoded :",
   388  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   389  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   390  		{name: "only user - with encoded @",
   391  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   392  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   393  		{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   394  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   395  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   396  		// {name: "user/password - user with encoded :",
   397  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   398  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   399  		{name: "user/password - user with encoded @",
   400  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   401  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   402  		{name: "user/password - password with encoded :",
   403  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   404  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   405  		{name: "user/password - password with encoded @",
   406  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   407  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   408  		{name: "custom tls",
   409  			urlStr:      "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
   410  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
   411  	}
   412  	for _, tc := range testcases {
   413  		t.Run(tc.name, func(t *testing.T) {
   414  			config, err := urlToMySQLConfig(tc.urlStr)
   415  			if err != nil {
   416  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   417  			}
   418  			dsn := config.FormatDSN()
   419  			if dsn != tc.expectedDSN {
   420  				t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   421  			}
   422  		})
   423  	}
   424  }