github.com/nagyist/migrate/v4@v4.14.6/database/mysql/mysql_test.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "crypto/ed25519" 6 "crypto/x509" 7 "database/sql" 8 sqldriver "database/sql/driver" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "io/ioutil" 13 "log" 14 "math/big" 15 "math/rand" 16 "net/url" 17 "os" 18 "strconv" 19 "testing" 20 ) 21 22 import ( 23 "github.com/dhui/dktest" 24 "github.com/go-sql-driver/mysql" 25 "github.com/stretchr/testify/assert" 26 ) 27 28 import ( 29 "github.com/golang-migrate/migrate/v4" 30 dt "github.com/golang-migrate/migrate/v4/database/testing" 31 "github.com/golang-migrate/migrate/v4/dktesting" 32 _ "github.com/golang-migrate/migrate/v4/source/file" 33 ) 34 35 const defaultPort = 3306 36 37 var ( 38 opts = dktest.Options{ 39 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 40 PortRequired: true, ReadyFunc: isReady, 41 } 42 optsAnsiQuotes = dktest.Options{ 43 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 44 PortRequired: true, ReadyFunc: isReady, 45 Cmd: []string{"--sql-mode=ANSI_QUOTES"}, 46 } 47 // Supported versions: https://www.mysql.com/support/supportedplatforms/database.html 48 specs = []dktesting.ContainerSpec{ 49 {ImageName: "mysql:5.5", Options: opts}, 50 {ImageName: "mysql:5.6", Options: opts}, 51 {ImageName: "mysql:5.7", Options: opts}, 52 {ImageName: "mysql:8", Options: opts}, 53 } 54 specsAnsiQuotes = []dktesting.ContainerSpec{ 55 {ImageName: "mysql:5.5", Options: optsAnsiQuotes}, 56 {ImageName: "mysql:5.6", Options: optsAnsiQuotes}, 57 {ImageName: "mysql:5.7", Options: optsAnsiQuotes}, 58 {ImageName: "mysql:8", Options: optsAnsiQuotes}, 59 } 60 ) 61 62 func isReady(ctx context.Context, c dktest.ContainerInfo) bool { 63 ip, port, err := c.Port(defaultPort) 64 if err != nil { 65 return false 66 } 67 68 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port)) 69 if err != nil { 70 return false 71 } 72 defer func() { 73 if err := db.Close(); err != nil { 74 log.Println("close error:", err) 75 } 76 }() 77 if err = db.PingContext(ctx); err != nil { 78 switch err { 79 case sqldriver.ErrBadConn, mysql.ErrInvalidConn: 80 return false 81 default: 82 fmt.Println(err) 83 } 84 return false 85 } 86 87 return true 88 } 89 90 func Test(t *testing.T) { 91 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 92 93 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 94 ip, port, err := c.Port(defaultPort) 95 if err != nil { 96 t.Fatal(err) 97 } 98 99 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 100 p := &Mysql{} 101 d, err := p.Open(addr) 102 if err != nil { 103 t.Fatal(err) 104 } 105 defer func() { 106 if err := d.Close(); err != nil { 107 t.Error(err) 108 } 109 }() 110 dt.Test(t, d, []byte("SELECT 1")) 111 112 // check ensureVersionTable 113 if err := d.(*Mysql).ensureVersionTable(); err != nil { 114 t.Fatal(err) 115 } 116 // check again 117 if err := d.(*Mysql).ensureVersionTable(); err != nil { 118 t.Fatal(err) 119 } 120 }) 121 } 122 123 func TestMigrate(t *testing.T) { 124 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 125 126 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 127 ip, port, err := c.Port(defaultPort) 128 if err != nil { 129 t.Fatal(err) 130 } 131 132 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 133 p := &Mysql{} 134 d, err := p.Open(addr) 135 if err != nil { 136 t.Fatal(err) 137 } 138 defer func() { 139 if err := d.Close(); err != nil { 140 t.Error(err) 141 } 142 }() 143 144 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 145 if err != nil { 146 t.Fatal(err) 147 } 148 dt.TestMigrate(t, m) 149 150 // check ensureVersionTable 151 if err := d.(*Mysql).ensureVersionTable(); err != nil { 152 t.Fatal(err) 153 } 154 // check again 155 if err := d.(*Mysql).ensureVersionTable(); err != nil { 156 t.Fatal(err) 157 } 158 }) 159 } 160 161 func TestMigrateAnsiQuotes(t *testing.T) { 162 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 163 164 dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) { 165 ip, port, err := c.Port(defaultPort) 166 if err != nil { 167 t.Fatal(err) 168 } 169 170 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 171 p := &Mysql{} 172 d, err := p.Open(addr) 173 if err != nil { 174 t.Fatal(err) 175 } 176 defer func() { 177 if err := d.Close(); err != nil { 178 t.Error(err) 179 } 180 }() 181 182 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 183 if err != nil { 184 t.Fatal(err) 185 } 186 dt.TestMigrate(t, m) 187 188 // check ensureVersionTable 189 if err := d.(*Mysql).ensureVersionTable(); err != nil { 190 t.Fatal(err) 191 } 192 // check again 193 if err := d.(*Mysql).ensureVersionTable(); err != nil { 194 t.Fatal(err) 195 } 196 }) 197 } 198 199 func TestLockWorks(t *testing.T) { 200 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 201 ip, port, err := c.Port(defaultPort) 202 if err != nil { 203 t.Fatal(err) 204 } 205 206 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 207 p := &Mysql{} 208 d, err := p.Open(addr) 209 if err != nil { 210 t.Fatal(err) 211 } 212 dt.Test(t, d, []byte("SELECT 1")) 213 214 ms := d.(*Mysql) 215 216 err = ms.Lock() 217 if err != nil { 218 t.Fatal(err) 219 } 220 err = ms.Unlock() 221 if err != nil { 222 t.Fatal(err) 223 } 224 225 // make sure the 2nd lock works (RELEASE_LOCK is very finicky) 226 err = ms.Lock() 227 if err != nil { 228 t.Fatal(err) 229 } 230 err = ms.Unlock() 231 if err != nil { 232 t.Fatal(err) 233 } 234 }) 235 } 236 237 func TestNoLockParamValidation(t *testing.T) { 238 ip := "127.0.0.1" 239 port := 3306 240 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 241 p := &Mysql{} 242 _, err := p.Open(addr + "?x-no-lock=not-a-bool") 243 if !errors.Is(err, strconv.ErrSyntax) { 244 t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter") 245 } 246 } 247 248 func TestNoLockWorks(t *testing.T) { 249 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 250 ip, port, err := c.Port(defaultPort) 251 if err != nil { 252 t.Fatal(err) 253 } 254 255 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 256 p := &Mysql{} 257 d, err := p.Open(addr) 258 if err != nil { 259 t.Fatal(err) 260 } 261 262 lock := d.(*Mysql) 263 264 p = &Mysql{} 265 d, err = p.Open(addr + "?x-no-lock=true") 266 if err != nil { 267 t.Fatal(err) 268 } 269 270 noLock := d.(*Mysql) 271 272 // Should be possible to take real lock and no-lock at the same time 273 if err = lock.Lock(); err != nil { 274 t.Fatal(err) 275 } 276 if err = noLock.Lock(); err != nil { 277 t.Fatal(err) 278 } 279 if err = lock.Unlock(); err != nil { 280 t.Fatal(err) 281 } 282 if err = noLock.Unlock(); err != nil { 283 t.Fatal(err) 284 } 285 }) 286 } 287 288 func TestExtractCustomQueryParams(t *testing.T) { 289 testcases := []struct { 290 name string 291 config *mysql.Config 292 expectedParams map[string]string 293 expectedCustomParams map[string]string 294 expectedErr error 295 }{ 296 {name: "nil config", expectedErr: ErrNilConfig}, 297 { 298 name: "no params", 299 config: mysql.NewConfig(), 300 expectedCustomParams: map[string]string{}, 301 }, 302 { 303 name: "no custom params", 304 config: &mysql.Config{Params: map[string]string{"hello": "world"}}, 305 expectedParams: map[string]string{"hello": "world"}, 306 expectedCustomParams: map[string]string{}, 307 }, 308 { 309 name: "one param, one custom param", 310 config: &mysql.Config{ 311 Params: map[string]string{"hello": "world", "x-foo": "bar"}, 312 }, 313 expectedParams: map[string]string{"hello": "world"}, 314 expectedCustomParams: map[string]string{"x-foo": "bar"}, 315 }, 316 { 317 name: "multiple params, multiple custom params", 318 config: &mysql.Config{ 319 Params: map[string]string{ 320 "hello": "world", 321 "x-foo": "bar", 322 "dead": "beef", 323 "x-cat": "hat", 324 }, 325 }, 326 expectedParams: map[string]string{"hello": "world", "dead": "beef"}, 327 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, 328 }, 329 } 330 for _, tc := range testcases { 331 t.Run(tc.name, func(t *testing.T) { 332 customParams, err := extractCustomQueryParams(tc.config) 333 if tc.config != nil { 334 assert.Equal(t, tc.expectedParams, tc.config.Params, 335 "Expected config params have custom params properly removed") 336 } 337 assert.Equal(t, tc.expectedErr, err, "Expected errors to match") 338 assert.Equal(t, tc.expectedCustomParams, customParams, 339 "Expected custom params to be properly extracted") 340 }) 341 } 342 } 343 344 func createTmpCert(t *testing.T) string { 345 tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert") 346 if err != nil { 347 t.Fatal("Failed to create temp cert file:", err) 348 } 349 t.Cleanup(func() { 350 if err := os.Remove(tmpCertFile.Name()); err != nil { 351 t.Log("Failed to cleanup temp cert file:", err) 352 } 353 }) 354 355 r := rand.New(rand.NewSource(0)) 356 pub, priv, err := ed25519.GenerateKey(r) 357 if err != nil { 358 t.Fatal("Failed to generate ed25519 key for temp cert file:", err) 359 } 360 tmpl := x509.Certificate{ 361 SerialNumber: big.NewInt(0), 362 } 363 derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv) 364 if err != nil { 365 t.Fatal("Failed to generate temp cert file:", err) 366 } 367 if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 368 t.Fatal("Failed to encode ") 369 } 370 if err := tmpCertFile.Close(); err != nil { 371 t.Fatal("Failed to close temp cert file:", err) 372 } 373 return tmpCertFile.Name() 374 } 375 376 func TestURLToMySQLConfig(t *testing.T) { 377 tmpCertFilename := createTmpCert(t) 378 tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename) 379 380 testcases := []struct { 381 name string 382 urlStr string 383 expectedDSN string // empty string signifies that an error is expected 384 }{ 385 {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", 386 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 387 {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 388 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 389 {name: "only user - with encoded :", 390 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 391 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 392 {name: "only user - with encoded @", 393 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 394 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 395 {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 396 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 397 // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 398 // {name: "user/password - user with encoded :", 399 // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 400 // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 401 {name: "user/password - user with encoded @", 402 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 403 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 404 {name: "user/password - password with encoded :", 405 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 406 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 407 {name: "user/password - password with encoded @", 408 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 409 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 410 {name: "custom tls", 411 urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped, 412 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped}, 413 } 414 for _, tc := range testcases { 415 t.Run(tc.name, func(t *testing.T) { 416 config, err := urlToMySQLConfig(tc.urlStr) 417 if err != nil { 418 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) 419 } 420 dsn := config.FormatDSN() 421 if dsn != tc.expectedDSN { 422 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) 423 } 424 }) 425 } 426 }