github.com/dynastymasra/migrate/v4@v4.11.0/database/mysql/mysql_test.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "database/sql" 6 sqldriver "database/sql/driver" 7 "fmt" 8 "log" 9 "testing" 10 ) 11 12 import ( 13 "github.com/dhui/dktest" 14 "github.com/go-sql-driver/mysql" 15 "github.com/stretchr/testify/assert" 16 ) 17 18 import ( 19 "github.com/golang-migrate/migrate/v4" 20 dt "github.com/golang-migrate/migrate/v4/database/testing" 21 "github.com/golang-migrate/migrate/v4/dktesting" 22 _ "github.com/golang-migrate/migrate/v4/source/file" 23 ) 24 25 const defaultPort = 3306 26 27 var ( 28 opts = dktest.Options{ 29 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 30 PortRequired: true, ReadyFunc: isReady, 31 } 32 // Supported versions: https://www.mysql.com/support/supportedplatforms/database.html 33 specs = []dktesting.ContainerSpec{ 34 {ImageName: "mysql:5.5", Options: opts}, 35 {ImageName: "mysql:5.6", Options: opts}, 36 {ImageName: "mysql:5.7", Options: opts}, 37 {ImageName: "mysql:8", Options: opts}, 38 } 39 ) 40 41 func isReady(ctx context.Context, c dktest.ContainerInfo) bool { 42 ip, port, err := c.Port(defaultPort) 43 if err != nil { 44 return false 45 } 46 47 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port)) 48 if err != nil { 49 return false 50 } 51 defer func() { 52 if err := db.Close(); err != nil { 53 log.Println("close error:", err) 54 } 55 }() 56 if err = db.PingContext(ctx); err != nil { 57 switch err { 58 case sqldriver.ErrBadConn, mysql.ErrInvalidConn: 59 return false 60 default: 61 fmt.Println(err) 62 } 63 return false 64 } 65 66 return true 67 } 68 69 func Test(t *testing.T) { 70 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 71 72 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 73 ip, port, err := c.Port(defaultPort) 74 if err != nil { 75 t.Fatal(err) 76 } 77 78 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 79 p := &Mysql{} 80 d, err := p.Open(addr) 81 if err != nil { 82 t.Fatal(err) 83 } 84 defer func() { 85 if err := d.Close(); err != nil { 86 t.Error(err) 87 } 88 }() 89 dt.Test(t, d, []byte("SELECT 1")) 90 91 // check ensureVersionTable 92 if err := d.(*Mysql).ensureVersionTable(); err != nil { 93 t.Fatal(err) 94 } 95 // check again 96 if err := d.(*Mysql).ensureVersionTable(); err != nil { 97 t.Fatal(err) 98 } 99 }) 100 } 101 102 func TestMigrate(t *testing.T) { 103 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 104 105 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 106 ip, port, err := c.Port(defaultPort) 107 if err != nil { 108 t.Fatal(err) 109 } 110 111 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 112 p := &Mysql{} 113 d, err := p.Open(addr) 114 if err != nil { 115 t.Fatal(err) 116 } 117 defer func() { 118 if err := d.Close(); err != nil { 119 t.Error(err) 120 } 121 }() 122 123 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 124 if err != nil { 125 t.Fatal(err) 126 } 127 dt.TestMigrate(t, m) 128 129 // check ensureVersionTable 130 if err := d.(*Mysql).ensureVersionTable(); err != nil { 131 t.Fatal(err) 132 } 133 // check again 134 if err := d.(*Mysql).ensureVersionTable(); err != nil { 135 t.Fatal(err) 136 } 137 }) 138 } 139 140 func TestLockWorks(t *testing.T) { 141 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 142 ip, port, err := c.Port(defaultPort) 143 if err != nil { 144 t.Fatal(err) 145 } 146 147 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 148 p := &Mysql{} 149 d, err := p.Open(addr) 150 if err != nil { 151 t.Fatal(err) 152 } 153 dt.Test(t, d, []byte("SELECT 1")) 154 155 ms := d.(*Mysql) 156 157 err = ms.Lock() 158 if err != nil { 159 t.Fatal(err) 160 } 161 err = ms.Unlock() 162 if err != nil { 163 t.Fatal(err) 164 } 165 166 // make sure the 2nd lock works (RELEASE_LOCK is very finicky) 167 err = ms.Lock() 168 if err != nil { 169 t.Fatal(err) 170 } 171 err = ms.Unlock() 172 if err != nil { 173 t.Fatal(err) 174 } 175 }) 176 } 177 178 func TestExtractCustomQueryParams(t *testing.T) { 179 testcases := []struct { 180 name string 181 config *mysql.Config 182 expectedParams map[string]string 183 expectedCustomParams map[string]string 184 expectedErr error 185 }{ 186 {name: "nil config", expectedErr: ErrNilConfig}, 187 { 188 name: "no params", 189 config: mysql.NewConfig(), 190 expectedCustomParams: map[string]string{}, 191 }, 192 { 193 name: "no custom params", 194 config: &mysql.Config{Params: map[string]string{"hello": "world"}}, 195 expectedParams: map[string]string{"hello": "world"}, 196 expectedCustomParams: map[string]string{}, 197 }, 198 { 199 name: "one param, one custom param", 200 config: &mysql.Config{ 201 Params: map[string]string{"hello": "world", "x-foo": "bar"}, 202 }, 203 expectedParams: map[string]string{"hello": "world"}, 204 expectedCustomParams: map[string]string{"x-foo": "bar"}, 205 }, 206 { 207 name: "multiple params, multiple custom params", 208 config: &mysql.Config{ 209 Params: map[string]string{ 210 "hello": "world", 211 "x-foo": "bar", 212 "dead": "beef", 213 "x-cat": "hat", 214 }, 215 }, 216 expectedParams: map[string]string{"hello": "world", "dead": "beef"}, 217 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, 218 }, 219 } 220 for _, tc := range testcases { 221 t.Run(tc.name, func(t *testing.T) { 222 customParams, err := extractCustomQueryParams(tc.config) 223 if tc.config != nil { 224 assert.Equal(t, tc.expectedParams, tc.config.Params, 225 "Expected config params have custom params properly removed") 226 } 227 assert.Equal(t, tc.expectedErr, err, "Expected errors to match") 228 assert.Equal(t, tc.expectedCustomParams, customParams, 229 "Expected custom params to be properly extracted") 230 }) 231 } 232 } 233 234 func TestURLToMySQLConfig(t *testing.T) { 235 testcases := []struct { 236 name string 237 urlStr string 238 expectedDSN string // empty string signifies that an error is expected 239 }{ 240 {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", 241 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 242 {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 243 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 244 {name: "only user - with encoded :", 245 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 246 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 247 {name: "only user - with encoded @", 248 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 249 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 250 {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 251 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 252 // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 253 // {name: "user/password - user with encoded :", 254 // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 255 // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 256 {name: "user/password - user with encoded @", 257 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 258 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 259 {name: "user/password - password with encoded :", 260 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 261 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 262 {name: "user/password - password with encoded @", 263 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 264 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 265 } 266 for _, tc := range testcases { 267 t.Run(tc.name, func(t *testing.T) { 268 config, err := urlToMySQLConfig(tc.urlStr) 269 if err != nil { 270 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) 271 } 272 dsn := config.FormatDSN() 273 if dsn != tc.expectedDSN { 274 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) 275 } 276 }) 277 } 278 }