github.com/nokia/migrate/v4@v4.16.0/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"crypto/ed25519"
     6  	"crypto/x509"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"log"
    14  	"math/big"
    15  	"math/rand"
    16  	"net/url"
    17  	"os"
    18  	"strconv"
    19  	"testing"
    20  
    21  	"github.com/dhui/dktest"
    22  	"github.com/go-sql-driver/mysql"
    23  	"github.com/nokia/migrate/v4"
    24  	"github.com/stretchr/testify/assert"
    25  
    26  	dt "github.com/nokia/migrate/v4/database/testing"
    27  	"github.com/nokia/migrate/v4/dktesting"
    28  
    29  	_ "github.com/nokia/migrate/v4/source/file"
    30  )
    31  
    32  const defaultPort = 3306
    33  
    34  var (
    35  	opts = dktest.Options{
    36  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    37  		PortRequired: true, ReadyFunc: isReady,
    38  	}
    39  	optsAnsiQuotes = dktest.Options{
    40  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    41  		PortRequired: true, ReadyFunc: isReady,
    42  		Cmd: []string{"--sql-mode=ANSI_QUOTES"},
    43  	}
    44  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    45  	specs = []dktesting.ContainerSpec{
    46  		{ImageName: "mysql:5.5", Options: opts},
    47  		{ImageName: "mysql:5.6", Options: opts},
    48  		{ImageName: "mysql:5.7", Options: opts},
    49  		{ImageName: "mysql:8", Options: opts},
    50  	}
    51  	specsAnsiQuotes = []dktesting.ContainerSpec{
    52  		{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
    53  		{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
    54  		{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
    55  		{ImageName: "mysql:8", Options: optsAnsiQuotes},
    56  	}
    57  )
    58  
    59  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    60  	ip, port, err := c.Port(defaultPort)
    61  	if err != nil {
    62  		return false
    63  	}
    64  
    65  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    66  	if err != nil {
    67  		return false
    68  	}
    69  	defer func() {
    70  		if err := db.Close(); err != nil {
    71  			log.Println("close error:", err)
    72  		}
    73  	}()
    74  	if err = db.PingContext(ctx); err != nil {
    75  		switch err {
    76  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    77  			return false
    78  		default:
    79  			fmt.Println(err)
    80  		}
    81  		return false
    82  	}
    83  
    84  	return true
    85  }
    86  
    87  func Test(t *testing.T) {
    88  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
    89  
    90  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    91  		ip, port, err := c.Port(defaultPort)
    92  		if err != nil {
    93  			t.Fatal(err)
    94  		}
    95  
    96  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
    97  		p := &Mysql{}
    98  		d, err := p.Open(addr)
    99  		if err != nil {
   100  			t.Fatal(err)
   101  		}
   102  		defer func() {
   103  			if err := d.Close(); err != nil {
   104  				t.Error(err)
   105  			}
   106  		}()
   107  		dt.Test(t, d, []byte("SELECT 1"))
   108  
   109  		// check ensureVersionTable
   110  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   111  			t.Fatal(err)
   112  		}
   113  		// check again
   114  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   115  			t.Fatal(err)
   116  		}
   117  	})
   118  }
   119  
   120  func TestMigrate(t *testing.T) {
   121  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   122  
   123  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   124  		ip, port, err := c.Port(defaultPort)
   125  		if err != nil {
   126  			t.Fatal(err)
   127  		}
   128  
   129  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   130  		p := &Mysql{}
   131  		d, err := p.Open(addr)
   132  		if err != nil {
   133  			t.Fatal(err)
   134  		}
   135  		defer func() {
   136  			if err := d.Close(); err != nil {
   137  				t.Error(err)
   138  			}
   139  		}()
   140  
   141  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   142  		if err != nil {
   143  			t.Fatal(err)
   144  		}
   145  		dt.TestMigrate(t, m)
   146  
   147  		// check ensureVersionTable
   148  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   149  			t.Fatal(err)
   150  		}
   151  		// check again
   152  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   153  			t.Fatal(err)
   154  		}
   155  	})
   156  }
   157  
   158  func TestMigrateAnsiQuotes(t *testing.T) {
   159  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   160  
   161  	dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
   162  		ip, port, err := c.Port(defaultPort)
   163  		if err != nil {
   164  			t.Fatal(err)
   165  		}
   166  
   167  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   168  		p := &Mysql{}
   169  		d, err := p.Open(addr)
   170  		if err != nil {
   171  			t.Fatal(err)
   172  		}
   173  		defer func() {
   174  			if err := d.Close(); err != nil {
   175  				t.Error(err)
   176  			}
   177  		}()
   178  
   179  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   180  		if err != nil {
   181  			t.Fatal(err)
   182  		}
   183  		dt.TestMigrate(t, m)
   184  
   185  		// check ensureVersionTable
   186  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   187  			t.Fatal(err)
   188  		}
   189  		// check again
   190  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   191  			t.Fatal(err)
   192  		}
   193  	})
   194  }
   195  
   196  func TestLockWorks(t *testing.T) {
   197  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   198  		ip, port, err := c.Port(defaultPort)
   199  		if err != nil {
   200  			t.Fatal(err)
   201  		}
   202  
   203  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   204  		p := &Mysql{}
   205  		d, err := p.Open(addr)
   206  		if err != nil {
   207  			t.Fatal(err)
   208  		}
   209  		dt.Test(t, d, []byte("SELECT 1"))
   210  
   211  		ms := d.(*Mysql)
   212  
   213  		err = ms.Lock()
   214  		if err != nil {
   215  			t.Fatal(err)
   216  		}
   217  		err = ms.Unlock()
   218  		if err != nil {
   219  			t.Fatal(err)
   220  		}
   221  
   222  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   223  		err = ms.Lock()
   224  		if err != nil {
   225  			t.Fatal(err)
   226  		}
   227  		err = ms.Unlock()
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  	})
   232  }
   233  
   234  func TestNoLockParamValidation(t *testing.T) {
   235  	ip := "127.0.0.1"
   236  	port := 3306
   237  	addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   238  	p := &Mysql{}
   239  	_, err := p.Open(addr + "?x-no-lock=not-a-bool")
   240  	if !errors.Is(err, strconv.ErrSyntax) {
   241  		t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
   242  	}
   243  }
   244  
   245  func TestNoLockWorks(t *testing.T) {
   246  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   247  		ip, port, err := c.Port(defaultPort)
   248  		if err != nil {
   249  			t.Fatal(err)
   250  		}
   251  
   252  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   253  		p := &Mysql{}
   254  		d, err := p.Open(addr)
   255  		if err != nil {
   256  			t.Fatal(err)
   257  		}
   258  
   259  		lock := d.(*Mysql)
   260  
   261  		p = &Mysql{}
   262  		d, err = p.Open(addr + "?x-no-lock=true")
   263  		if err != nil {
   264  			t.Fatal(err)
   265  		}
   266  
   267  		noLock := d.(*Mysql)
   268  
   269  		// Should be possible to take real lock and no-lock at the same time
   270  		if err = lock.Lock(); err != nil {
   271  			t.Fatal(err)
   272  		}
   273  		if err = noLock.Lock(); err != nil {
   274  			t.Fatal(err)
   275  		}
   276  		if err = lock.Unlock(); err != nil {
   277  			t.Fatal(err)
   278  		}
   279  		if err = noLock.Unlock(); err != nil {
   280  			t.Fatal(err)
   281  		}
   282  	})
   283  }
   284  
   285  func TestExtractCustomQueryParams(t *testing.T) {
   286  	testcases := []struct {
   287  		name                 string
   288  		config               *mysql.Config
   289  		expectedParams       map[string]string
   290  		expectedCustomParams map[string]string
   291  		expectedErr          error
   292  	}{
   293  		{name: "nil config", expectedErr: ErrNilConfig},
   294  		{
   295  			name:                 "no params",
   296  			config:               mysql.NewConfig(),
   297  			expectedCustomParams: map[string]string{},
   298  		},
   299  		{
   300  			name:                 "no custom params",
   301  			config:               &mysql.Config{Params: map[string]string{"hello": "world"}},
   302  			expectedParams:       map[string]string{"hello": "world"},
   303  			expectedCustomParams: map[string]string{},
   304  		},
   305  		{
   306  			name: "one param, one custom param",
   307  			config: &mysql.Config{
   308  				Params: map[string]string{"hello": "world", "x-foo": "bar"},
   309  			},
   310  			expectedParams:       map[string]string{"hello": "world"},
   311  			expectedCustomParams: map[string]string{"x-foo": "bar"},
   312  		},
   313  		{
   314  			name: "multiple params, multiple custom params",
   315  			config: &mysql.Config{
   316  				Params: map[string]string{
   317  					"hello": "world",
   318  					"x-foo": "bar",
   319  					"dead":  "beef",
   320  					"x-cat": "hat",
   321  				},
   322  			},
   323  			expectedParams:       map[string]string{"hello": "world", "dead": "beef"},
   324  			expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
   325  		},
   326  	}
   327  	for _, tc := range testcases {
   328  		t.Run(tc.name, func(t *testing.T) {
   329  			customParams, err := extractCustomQueryParams(tc.config)
   330  			if tc.config != nil {
   331  				assert.Equal(t, tc.expectedParams, tc.config.Params,
   332  					"Expected config params have custom params properly removed")
   333  			}
   334  			assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
   335  			assert.Equal(t, tc.expectedCustomParams, customParams,
   336  				"Expected custom params to be properly extracted")
   337  		})
   338  	}
   339  }
   340  
   341  func createTmpCert(t *testing.T) string {
   342  	tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert")
   343  	if err != nil {
   344  		t.Fatal("Failed to create temp cert file:", err)
   345  	}
   346  	t.Cleanup(func() {
   347  		if err := os.Remove(tmpCertFile.Name()); err != nil {
   348  			t.Log("Failed to cleanup temp cert file:", err)
   349  		}
   350  	})
   351  
   352  	r := rand.New(rand.NewSource(0))
   353  	pub, priv, err := ed25519.GenerateKey(r)
   354  	if err != nil {
   355  		t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
   356  	}
   357  	tmpl := x509.Certificate{
   358  		SerialNumber: big.NewInt(0),
   359  	}
   360  	derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
   361  	if err != nil {
   362  		t.Fatal("Failed to generate temp cert file:", err)
   363  	}
   364  	if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   365  		t.Fatal("Failed to encode ")
   366  	}
   367  	if err := tmpCertFile.Close(); err != nil {
   368  		t.Fatal("Failed to close temp cert file:", err)
   369  	}
   370  	return tmpCertFile.Name()
   371  }
   372  
   373  func TestURLToMySQLConfig(t *testing.T) {
   374  	tmpCertFilename := createTmpCert(t)
   375  	tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
   376  
   377  	testcases := []struct {
   378  		name        string
   379  		urlStr      string
   380  		expectedDSN string // empty string signifies that an error is expected
   381  	}{
   382  		{
   383  			name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   384  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   385  		},
   386  		{
   387  			name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   388  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   389  		},
   390  		{
   391  			name:        "only user - with encoded :",
   392  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   393  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   394  		},
   395  		{
   396  			name:        "only user - with encoded @",
   397  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   398  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   399  		},
   400  		{
   401  			name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   402  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   403  		},
   404  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   405  		// {name: "user/password - user with encoded :",
   406  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   407  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   408  		{
   409  			name:        "user/password - user with encoded @",
   410  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   411  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   412  		},
   413  		{
   414  			name:        "user/password - password with encoded :",
   415  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   416  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   417  		},
   418  		{
   419  			name:        "user/password - password with encoded @",
   420  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   421  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   422  		},
   423  		{
   424  			name:        "custom tls",
   425  			urlStr:      "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
   426  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
   427  		},
   428  	}
   429  	for _, tc := range testcases {
   430  		t.Run(tc.name, func(t *testing.T) {
   431  			config, err := urlToMySQLConfig(tc.urlStr)
   432  			if err != nil {
   433  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   434  			}
   435  			dsn := config.FormatDSN()
   436  			if dsn != tc.expectedDSN {
   437  				t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   438  			}
   439  		})
   440  	}
   441  }