github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"crypto/ed25519"
     6  	"crypto/x509"
     7  	"database/sql"
     8  	sqldriver "database/sql/driver"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"log"
    13  	"math/big"
    14  	"math/rand"
    15  	"net/url"
    16  	"os"
    17  	"strconv"
    18  	"testing"
    19  
    20  	"github.com/dhui/dktest"
    21  	"github.com/go-sql-driver/mysql"
    22  	"github.com/golang-migrate/migrate/v4"
    23  	dt "github.com/golang-migrate/migrate/v4/database/testing"
    24  	"github.com/golang-migrate/migrate/v4/dktesting"
    25  	_ "github.com/golang-migrate/migrate/v4/source/file"
    26  	"github.com/stretchr/testify/assert"
    27  )
    28  
    29  const defaultPort = 3306
    30  
    31  var (
    32  	opts = dktest.Options{
    33  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    34  		PortRequired: true, ReadyFunc: isReady,
    35  	}
    36  	optsAnsiQuotes = dktest.Options{
    37  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    38  		PortRequired: true, ReadyFunc: isReady,
    39  		Cmd: []string{"--sql-mode=ANSI_QUOTES"},
    40  	}
    41  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    42  	specs = []dktesting.ContainerSpec{
    43  		{ImageName: "mysql:5.5", Options: opts},
    44  		{ImageName: "mysql:5.6", Options: opts},
    45  		{ImageName: "mysql:5.7", Options: opts},
    46  		{ImageName: "mysql:8", Options: opts},
    47  	}
    48  	specsAnsiQuotes = []dktesting.ContainerSpec{
    49  		{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
    50  		{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
    51  		{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
    52  		{ImageName: "mysql:8", Options: optsAnsiQuotes},
    53  	}
    54  )
    55  
    56  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    57  	ip, port, err := c.Port(defaultPort)
    58  	if err != nil {
    59  		return false
    60  	}
    61  
    62  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    63  	if err != nil {
    64  		return false
    65  	}
    66  	defer func() {
    67  		if err := db.Close(); err != nil {
    68  			log.Println("close error:", err)
    69  		}
    70  	}()
    71  	if err = db.PingContext(ctx); err != nil {
    72  		switch err {
    73  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    74  			return false
    75  		default:
    76  			fmt.Println(err)
    77  		}
    78  		return false
    79  	}
    80  
    81  	return true
    82  }
    83  
    84  func Test(t *testing.T) {
    85  	// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
    86  
    87  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    88  		ip, port, err := c.Port(defaultPort)
    89  		if err != nil {
    90  			t.Fatal(err)
    91  		}
    92  
    93  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
    94  		p := &Mysql{}
    95  		d, err := p.Open(addr)
    96  		if err != nil {
    97  			t.Fatal(err)
    98  		}
    99  		defer func() {
   100  			if err := d.Close(); err != nil {
   101  				t.Error(err)
   102  			}
   103  		}()
   104  		dt.Test(t, d, []byte("SELECT 1"))
   105  
   106  		// check ensureVersionTable
   107  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   108  			t.Fatal(err)
   109  		}
   110  		// check again
   111  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   112  			t.Fatal(err)
   113  		}
   114  	})
   115  }
   116  
   117  func TestMigrate(t *testing.T) {
   118  	// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
   119  
   120  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   121  		ip, port, err := c.Port(defaultPort)
   122  		if err != nil {
   123  			t.Fatal(err)
   124  		}
   125  
   126  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   127  		p := &Mysql{}
   128  		d, err := p.Open(addr)
   129  		if err != nil {
   130  			t.Fatal(err)
   131  		}
   132  		defer func() {
   133  			if err := d.Close(); err != nil {
   134  				t.Error(err)
   135  			}
   136  		}()
   137  
   138  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   139  		if err != nil {
   140  			t.Fatal(err)
   141  		}
   142  		dt.TestMigrate(t, m)
   143  
   144  		// check ensureVersionTable
   145  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   146  			t.Fatal(err)
   147  		}
   148  		// check again
   149  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   150  			t.Fatal(err)
   151  		}
   152  	})
   153  }
   154  
   155  func TestMigrateAnsiQuotes(t *testing.T) {
   156  	// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
   157  
   158  	dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
   159  		ip, port, err := c.Port(defaultPort)
   160  		if err != nil {
   161  			t.Fatal(err)
   162  		}
   163  
   164  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   165  		p := &Mysql{}
   166  		d, err := p.Open(addr)
   167  		if err != nil {
   168  			t.Fatal(err)
   169  		}
   170  		defer func() {
   171  			if err := d.Close(); err != nil {
   172  				t.Error(err)
   173  			}
   174  		}()
   175  
   176  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   177  		if err != nil {
   178  			t.Fatal(err)
   179  		}
   180  		dt.TestMigrate(t, m)
   181  
   182  		// check ensureVersionTable
   183  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   184  			t.Fatal(err)
   185  		}
   186  		// check again
   187  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   188  			t.Fatal(err)
   189  		}
   190  	})
   191  }
   192  
   193  func TestLockWorks(t *testing.T) {
   194  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   195  		ip, port, err := c.Port(defaultPort)
   196  		if err != nil {
   197  			t.Fatal(err)
   198  		}
   199  
   200  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   201  		p := &Mysql{}
   202  		d, err := p.Open(addr)
   203  		if err != nil {
   204  			t.Fatal(err)
   205  		}
   206  		dt.Test(t, d, []byte("SELECT 1"))
   207  
   208  		ms := d.(*Mysql)
   209  
   210  		err = ms.Lock()
   211  		if err != nil {
   212  			t.Fatal(err)
   213  		}
   214  		err = ms.Unlock()
   215  		if err != nil {
   216  			t.Fatal(err)
   217  		}
   218  
   219  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   220  		err = ms.Lock()
   221  		if err != nil {
   222  			t.Fatal(err)
   223  		}
   224  		err = ms.Unlock()
   225  		if err != nil {
   226  			t.Fatal(err)
   227  		}
   228  	})
   229  }
   230  
   231  func TestNoLockParamValidation(t *testing.T) {
   232  	ip := "127.0.0.1"
   233  	port := 3306
   234  	addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   235  	p := &Mysql{}
   236  	_, err := p.Open(addr + "?x-no-lock=not-a-bool")
   237  	if !errors.Is(err, strconv.ErrSyntax) {
   238  		t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
   239  	}
   240  }
   241  
   242  func TestNoLockWorks(t *testing.T) {
   243  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   244  		ip, port, err := c.Port(defaultPort)
   245  		if err != nil {
   246  			t.Fatal(err)
   247  		}
   248  
   249  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   250  		p := &Mysql{}
   251  		d, err := p.Open(addr)
   252  		if err != nil {
   253  			t.Fatal(err)
   254  		}
   255  
   256  		lock := d.(*Mysql)
   257  
   258  		p = &Mysql{}
   259  		d, err = p.Open(addr + "?x-no-lock=true")
   260  		if err != nil {
   261  			t.Fatal(err)
   262  		}
   263  
   264  		noLock := d.(*Mysql)
   265  
   266  		// Should be possible to take real lock and no-lock at the same time
   267  		if err = lock.Lock(); err != nil {
   268  			t.Fatal(err)
   269  		}
   270  		if err = noLock.Lock(); err != nil {
   271  			t.Fatal(err)
   272  		}
   273  		if err = lock.Unlock(); err != nil {
   274  			t.Fatal(err)
   275  		}
   276  		if err = noLock.Unlock(); err != nil {
   277  			t.Fatal(err)
   278  		}
   279  	})
   280  }
   281  
   282  func TestExtractCustomQueryParams(t *testing.T) {
   283  	testcases := []struct {
   284  		name                 string
   285  		config               *mysql.Config
   286  		expectedParams       map[string]string
   287  		expectedCustomParams map[string]string
   288  		expectedErr          error
   289  	}{
   290  		{name: "nil config", expectedErr: ErrNilConfig},
   291  		{
   292  			name:                 "no params",
   293  			config:               mysql.NewConfig(),
   294  			expectedCustomParams: map[string]string{},
   295  		},
   296  		{
   297  			name:                 "no custom params",
   298  			config:               &mysql.Config{Params: map[string]string{"hello": "world"}},
   299  			expectedParams:       map[string]string{"hello": "world"},
   300  			expectedCustomParams: map[string]string{},
   301  		},
   302  		{
   303  			name: "one param, one custom param",
   304  			config: &mysql.Config{
   305  				Params: map[string]string{"hello": "world", "x-foo": "bar"},
   306  			},
   307  			expectedParams:       map[string]string{"hello": "world"},
   308  			expectedCustomParams: map[string]string{"x-foo": "bar"},
   309  		},
   310  		{
   311  			name: "multiple params, multiple custom params",
   312  			config: &mysql.Config{
   313  				Params: map[string]string{
   314  					"hello": "world",
   315  					"x-foo": "bar",
   316  					"dead":  "beef",
   317  					"x-cat": "hat",
   318  				},
   319  			},
   320  			expectedParams:       map[string]string{"hello": "world", "dead": "beef"},
   321  			expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
   322  		},
   323  	}
   324  	for _, tc := range testcases {
   325  		t.Run(tc.name, func(t *testing.T) {
   326  			customParams, err := extractCustomQueryParams(tc.config)
   327  			if tc.config != nil {
   328  				assert.Equal(t, tc.expectedParams, tc.config.Params,
   329  					"Expected config params have custom params properly removed")
   330  			}
   331  			assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
   332  			assert.Equal(t, tc.expectedCustomParams, customParams,
   333  				"Expected custom params to be properly extracted")
   334  		})
   335  	}
   336  }
   337  
   338  func createTmpCert(t *testing.T) string {
   339  	tmpCertFile, err := os.CreateTemp("", "migrate_test_cert")
   340  	if err != nil {
   341  		t.Fatal("Failed to create temp cert file:", err)
   342  	}
   343  	t.Cleanup(func() {
   344  		if err := os.Remove(tmpCertFile.Name()); err != nil {
   345  			t.Log("Failed to cleanup temp cert file:", err)
   346  		}
   347  	})
   348  
   349  	r := rand.New(rand.NewSource(0))
   350  	pub, priv, err := ed25519.GenerateKey(r)
   351  	if err != nil {
   352  		t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
   353  	}
   354  	tmpl := x509.Certificate{
   355  		SerialNumber: big.NewInt(0),
   356  	}
   357  	derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
   358  	if err != nil {
   359  		t.Fatal("Failed to generate temp cert file:", err)
   360  	}
   361  	if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   362  		t.Fatal("Failed to encode ")
   363  	}
   364  	if err := tmpCertFile.Close(); err != nil {
   365  		t.Fatal("Failed to close temp cert file:", err)
   366  	}
   367  	return tmpCertFile.Name()
   368  }
   369  
   370  func TestURLToMySQLConfig(t *testing.T) {
   371  	tmpCertFilename := createTmpCert(t)
   372  	tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
   373  
   374  	testcases := []struct {
   375  		name        string
   376  		urlStr      string
   377  		expectedDSN string // empty string signifies that an error is expected
   378  	}{
   379  		{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   380  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   381  		{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   382  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   383  		{name: "only user - with encoded :",
   384  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   385  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   386  		{name: "only user - with encoded @",
   387  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   388  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   389  		{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   390  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   391  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   392  		// {name: "user/password - user with encoded :",
   393  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   394  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   395  		{name: "user/password - user with encoded @",
   396  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   397  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   398  		{name: "user/password - password with encoded :",
   399  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   400  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   401  		{name: "user/password - password with encoded @",
   402  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   403  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   404  		{name: "custom tls",
   405  			urlStr:      "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
   406  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
   407  	}
   408  	for _, tc := range testcases {
   409  		t.Run(tc.name, func(t *testing.T) {
   410  			config, err := urlToMySQLConfig(tc.urlStr)
   411  			if err != nil {
   412  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   413  			}
   414  			dsn := config.FormatDSN()
   415  			if dsn != tc.expectedDSN {
   416  				t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   417  			}
   418  		})
   419  	}
   420  }