github.com/fr-nvriep/migrate/v4@v4.3.2/database/mysql/mysql_test.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	sqldriver "database/sql/driver"
     7  	"fmt"
     8  	"log"
     9  
    10  	"github.com/fr-nvriep/migrate/v4"
    11  	"net/url"
    12  	"testing"
    13  )
    14  
    15  import (
    16  	"github.com/dhui/dktest"
    17  	"github.com/go-sql-driver/mysql"
    18  )
    19  
    20  import (
    21  	dt "github.com/fr-nvriep/migrate/v4/database/testing"
    22  	"github.com/fr-nvriep/migrate/v4/dktesting"
    23  	_ "github.com/fr-nvriep/migrate/v4/source/file"
    24  )
    25  
    26  const defaultPort = 3306
    27  
    28  var (
    29  	opts = dktest.Options{
    30  		Env:          map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
    31  		PortRequired: true, ReadyFunc: isReady,
    32  	}
    33  	// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
    34  	specs = []dktesting.ContainerSpec{
    35  		{ImageName: "mysql:5.5", Options: opts},
    36  		{ImageName: "mysql:5.6", Options: opts},
    37  		{ImageName: "mysql:5.7", Options: opts},
    38  		{ImageName: "mysql:8", Options: opts},
    39  	}
    40  )
    41  
    42  func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
    43  	ip, port, err := c.Port(defaultPort)
    44  	if err != nil {
    45  		return false
    46  	}
    47  
    48  	db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
    49  	if err != nil {
    50  		return false
    51  	}
    52  	defer func() {
    53  		if err := db.Close(); err != nil {
    54  			log.Println("close error:", err)
    55  		}
    56  	}()
    57  	if err = db.PingContext(ctx); err != nil {
    58  		switch err {
    59  		case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
    60  			return false
    61  		default:
    62  			fmt.Println(err)
    63  		}
    64  		return false
    65  	}
    66  
    67  	return true
    68  }
    69  
    70  func Test(t *testing.T) {
    71  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
    72  
    73  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
    74  		ip, port, err := c.Port(defaultPort)
    75  		if err != nil {
    76  			t.Fatal(err)
    77  		}
    78  
    79  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
    80  		p := &Mysql{}
    81  		d, err := p.Open(addr)
    82  		if err != nil {
    83  			t.Fatal(err)
    84  		}
    85  		defer func() {
    86  			if err := d.Close(); err != nil {
    87  				t.Error(err)
    88  			}
    89  		}()
    90  		dt.Test(t, d, []byte("SELECT 1"))
    91  
    92  		// check ensureVersionTable
    93  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
    94  			t.Fatal(err)
    95  		}
    96  		// check again
    97  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
    98  			t.Fatal(err)
    99  		}
   100  	})
   101  }
   102  
   103  func TestMigrate(t *testing.T) {
   104  	// mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime)))
   105  
   106  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   107  		ip, port, err := c.Port(defaultPort)
   108  		if err != nil {
   109  			t.Fatal(err)
   110  		}
   111  
   112  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   113  		p := &Mysql{}
   114  		d, err := p.Open(addr)
   115  		if err != nil {
   116  			t.Fatal(err)
   117  		}
   118  		defer func() {
   119  			if err := d.Close(); err != nil {
   120  				t.Error(err)
   121  			}
   122  		}()
   123  
   124  		m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
   125  		if err != nil {
   126  			t.Fatal(err)
   127  		}
   128  		dt.TestMigrate(t, m, []byte("SELECT 1"))
   129  
   130  		// check ensureVersionTable
   131  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   132  			t.Fatal(err)
   133  		}
   134  		// check again
   135  		if err := d.(*Mysql).ensureVersionTable(); err != nil {
   136  			t.Fatal(err)
   137  		}
   138  	})
   139  }
   140  
   141  func TestLockWorks(t *testing.T) {
   142  	dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
   143  		ip, port, err := c.Port(defaultPort)
   144  		if err != nil {
   145  			t.Fatal(err)
   146  		}
   147  
   148  		addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
   149  		p := &Mysql{}
   150  		d, err := p.Open(addr)
   151  		if err != nil {
   152  			t.Fatal(err)
   153  		}
   154  		dt.Test(t, d, []byte("SELECT 1"))
   155  
   156  		ms := d.(*Mysql)
   157  
   158  		err = ms.Lock()
   159  		if err != nil {
   160  			t.Fatal(err)
   161  		}
   162  		err = ms.Unlock()
   163  		if err != nil {
   164  			t.Fatal(err)
   165  		}
   166  
   167  		// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
   168  		err = ms.Lock()
   169  		if err != nil {
   170  			t.Fatal(err)
   171  		}
   172  		err = ms.Unlock()
   173  		if err != nil {
   174  			t.Fatal(err)
   175  		}
   176  	})
   177  }
   178  
   179  func TestURLToMySQLConfig(t *testing.T) {
   180  	testcases := []struct {
   181  		name        string
   182  		urlStr      string
   183  		expectedDSN string // empty string signifies that an error is expected
   184  	}{
   185  		{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   186  			expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   187  		{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   188  			expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   189  		{name: "only user - with encoded :",
   190  			urlStr:      "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   191  			expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   192  		{name: "only user - with encoded @",
   193  			urlStr:      "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   194  			expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   195  		{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   196  			expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   197  		// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
   198  		// {name: "user/password - user with encoded :",
   199  		// 	urlStr:      "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   200  		// 	expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   201  		{name: "user/password - user with encoded @",
   202  			urlStr:      "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   203  			expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   204  		{name: "user/password - password with encoded :",
   205  			urlStr:      "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   206  			expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   207  		{name: "user/password - password with encoded @",
   208  			urlStr:      "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
   209  			expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
   210  	}
   211  	for _, tc := range testcases {
   212  		t.Run(tc.name, func(t *testing.T) {
   213  			u, err := url.Parse(tc.urlStr)
   214  			if err != nil {
   215  				t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
   216  			}
   217  			if config, err := urlToMySQLConfig(*u); err == nil {
   218  				dsn := config.FormatDSN()
   219  				if dsn != tc.expectedDSN {
   220  					t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
   221  				}
   222  			} else {
   223  				if tc.expectedDSN != "" {
   224  					t.Error("Got unexpected error:", err, "urlStr:", tc.urlStr)
   225  				}
   226  			}
   227  		})
   228  	}
   229  }