github.com/nokia/migrate/v4@v4.16.0/database/mysql/mysql_test.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "crypto/ed25519" 6 "crypto/x509" 7 "database/sql" 8 sqldriver "database/sql/driver" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "io/ioutil" 13 "log" 14 "math/big" 15 "math/rand" 16 "net/url" 17 "os" 18 "strconv" 19 "testing" 20 21 "github.com/dhui/dktest" 22 "github.com/go-sql-driver/mysql" 23 "github.com/nokia/migrate/v4" 24 "github.com/stretchr/testify/assert" 25 26 dt "github.com/nokia/migrate/v4/database/testing" 27 "github.com/nokia/migrate/v4/dktesting" 28 29 _ "github.com/nokia/migrate/v4/source/file" 30 ) 31 32 const defaultPort = 3306 33 34 var ( 35 opts = dktest.Options{ 36 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 37 PortRequired: true, ReadyFunc: isReady, 38 } 39 optsAnsiQuotes = dktest.Options{ 40 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 41 PortRequired: true, ReadyFunc: isReady, 42 Cmd: []string{"--sql-mode=ANSI_QUOTES"}, 43 } 44 // Supported versions: https://www.mysql.com/support/supportedplatforms/database.html 45 specs = []dktesting.ContainerSpec{ 46 {ImageName: "mysql:5.5", Options: opts}, 47 {ImageName: "mysql:5.6", Options: opts}, 48 {ImageName: "mysql:5.7", Options: opts}, 49 {ImageName: "mysql:8", Options: opts}, 50 } 51 specsAnsiQuotes = []dktesting.ContainerSpec{ 52 {ImageName: "mysql:5.5", Options: optsAnsiQuotes}, 53 {ImageName: "mysql:5.6", Options: optsAnsiQuotes}, 54 {ImageName: "mysql:5.7", Options: optsAnsiQuotes}, 55 {ImageName: "mysql:8", Options: optsAnsiQuotes}, 56 } 57 ) 58 59 func isReady(ctx context.Context, c dktest.ContainerInfo) bool { 60 ip, port, err := c.Port(defaultPort) 61 if err != nil { 62 return false 63 } 64 65 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port)) 66 if err != nil { 67 return false 68 } 69 defer func() { 70 if err := db.Close(); err != nil { 71 log.Println("close error:", err) 72 } 73 }() 74 if err = db.PingContext(ctx); err != nil { 75 switch err { 76 case sqldriver.ErrBadConn, mysql.ErrInvalidConn: 77 return false 78 default: 79 fmt.Println(err) 80 } 81 return false 82 } 83 84 return true 85 } 86 87 func Test(t *testing.T) { 88 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 89 90 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 91 ip, port, err := c.Port(defaultPort) 92 if err != nil { 93 t.Fatal(err) 94 } 95 96 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 97 p := &Mysql{} 98 d, err := p.Open(addr) 99 if err != nil { 100 t.Fatal(err) 101 } 102 defer func() { 103 if err := d.Close(); err != nil { 104 t.Error(err) 105 } 106 }() 107 dt.Test(t, d, []byte("SELECT 1")) 108 109 // check ensureVersionTable 110 if err := d.(*Mysql).ensureVersionTable(); err != nil { 111 t.Fatal(err) 112 } 113 // check again 114 if err := d.(*Mysql).ensureVersionTable(); err != nil { 115 t.Fatal(err) 116 } 117 }) 118 } 119 120 func TestMigrate(t *testing.T) { 121 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 122 123 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 124 ip, port, err := c.Port(defaultPort) 125 if err != nil { 126 t.Fatal(err) 127 } 128 129 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 130 p := &Mysql{} 131 d, err := p.Open(addr) 132 if err != nil { 133 t.Fatal(err) 134 } 135 defer func() { 136 if err := d.Close(); err != nil { 137 t.Error(err) 138 } 139 }() 140 141 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 142 if err != nil { 143 t.Fatal(err) 144 } 145 dt.TestMigrate(t, m) 146 147 // check ensureVersionTable 148 if err := d.(*Mysql).ensureVersionTable(); err != nil { 149 t.Fatal(err) 150 } 151 // check again 152 if err := d.(*Mysql).ensureVersionTable(); err != nil { 153 t.Fatal(err) 154 } 155 }) 156 } 157 158 func TestMigrateAnsiQuotes(t *testing.T) { 159 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 160 161 dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) { 162 ip, port, err := c.Port(defaultPort) 163 if err != nil { 164 t.Fatal(err) 165 } 166 167 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 168 p := &Mysql{} 169 d, err := p.Open(addr) 170 if err != nil { 171 t.Fatal(err) 172 } 173 defer func() { 174 if err := d.Close(); err != nil { 175 t.Error(err) 176 } 177 }() 178 179 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 180 if err != nil { 181 t.Fatal(err) 182 } 183 dt.TestMigrate(t, m) 184 185 // check ensureVersionTable 186 if err := d.(*Mysql).ensureVersionTable(); err != nil { 187 t.Fatal(err) 188 } 189 // check again 190 if err := d.(*Mysql).ensureVersionTable(); err != nil { 191 t.Fatal(err) 192 } 193 }) 194 } 195 196 func TestLockWorks(t *testing.T) { 197 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 198 ip, port, err := c.Port(defaultPort) 199 if err != nil { 200 t.Fatal(err) 201 } 202 203 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 204 p := &Mysql{} 205 d, err := p.Open(addr) 206 if err != nil { 207 t.Fatal(err) 208 } 209 dt.Test(t, d, []byte("SELECT 1")) 210 211 ms := d.(*Mysql) 212 213 err = ms.Lock() 214 if err != nil { 215 t.Fatal(err) 216 } 217 err = ms.Unlock() 218 if err != nil { 219 t.Fatal(err) 220 } 221 222 // make sure the 2nd lock works (RELEASE_LOCK is very finicky) 223 err = ms.Lock() 224 if err != nil { 225 t.Fatal(err) 226 } 227 err = ms.Unlock() 228 if err != nil { 229 t.Fatal(err) 230 } 231 }) 232 } 233 234 func TestNoLockParamValidation(t *testing.T) { 235 ip := "127.0.0.1" 236 port := 3306 237 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 238 p := &Mysql{} 239 _, err := p.Open(addr + "?x-no-lock=not-a-bool") 240 if !errors.Is(err, strconv.ErrSyntax) { 241 t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter") 242 } 243 } 244 245 func TestNoLockWorks(t *testing.T) { 246 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 247 ip, port, err := c.Port(defaultPort) 248 if err != nil { 249 t.Fatal(err) 250 } 251 252 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 253 p := &Mysql{} 254 d, err := p.Open(addr) 255 if err != nil { 256 t.Fatal(err) 257 } 258 259 lock := d.(*Mysql) 260 261 p = &Mysql{} 262 d, err = p.Open(addr + "?x-no-lock=true") 263 if err != nil { 264 t.Fatal(err) 265 } 266 267 noLock := d.(*Mysql) 268 269 // Should be possible to take real lock and no-lock at the same time 270 if err = lock.Lock(); err != nil { 271 t.Fatal(err) 272 } 273 if err = noLock.Lock(); err != nil { 274 t.Fatal(err) 275 } 276 if err = lock.Unlock(); err != nil { 277 t.Fatal(err) 278 } 279 if err = noLock.Unlock(); err != nil { 280 t.Fatal(err) 281 } 282 }) 283 } 284 285 func TestExtractCustomQueryParams(t *testing.T) { 286 testcases := []struct { 287 name string 288 config *mysql.Config 289 expectedParams map[string]string 290 expectedCustomParams map[string]string 291 expectedErr error 292 }{ 293 {name: "nil config", expectedErr: ErrNilConfig}, 294 { 295 name: "no params", 296 config: mysql.NewConfig(), 297 expectedCustomParams: map[string]string{}, 298 }, 299 { 300 name: "no custom params", 301 config: &mysql.Config{Params: map[string]string{"hello": "world"}}, 302 expectedParams: map[string]string{"hello": "world"}, 303 expectedCustomParams: map[string]string{}, 304 }, 305 { 306 name: "one param, one custom param", 307 config: &mysql.Config{ 308 Params: map[string]string{"hello": "world", "x-foo": "bar"}, 309 }, 310 expectedParams: map[string]string{"hello": "world"}, 311 expectedCustomParams: map[string]string{"x-foo": "bar"}, 312 }, 313 { 314 name: "multiple params, multiple custom params", 315 config: &mysql.Config{ 316 Params: map[string]string{ 317 "hello": "world", 318 "x-foo": "bar", 319 "dead": "beef", 320 "x-cat": "hat", 321 }, 322 }, 323 expectedParams: map[string]string{"hello": "world", "dead": "beef"}, 324 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, 325 }, 326 } 327 for _, tc := range testcases { 328 t.Run(tc.name, func(t *testing.T) { 329 customParams, err := extractCustomQueryParams(tc.config) 330 if tc.config != nil { 331 assert.Equal(t, tc.expectedParams, tc.config.Params, 332 "Expected config params have custom params properly removed") 333 } 334 assert.Equal(t, tc.expectedErr, err, "Expected errors to match") 335 assert.Equal(t, tc.expectedCustomParams, customParams, 336 "Expected custom params to be properly extracted") 337 }) 338 } 339 } 340 341 func createTmpCert(t *testing.T) string { 342 tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert") 343 if err != nil { 344 t.Fatal("Failed to create temp cert file:", err) 345 } 346 t.Cleanup(func() { 347 if err := os.Remove(tmpCertFile.Name()); err != nil { 348 t.Log("Failed to cleanup temp cert file:", err) 349 } 350 }) 351 352 r := rand.New(rand.NewSource(0)) 353 pub, priv, err := ed25519.GenerateKey(r) 354 if err != nil { 355 t.Fatal("Failed to generate ed25519 key for temp cert file:", err) 356 } 357 tmpl := x509.Certificate{ 358 SerialNumber: big.NewInt(0), 359 } 360 derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv) 361 if err != nil { 362 t.Fatal("Failed to generate temp cert file:", err) 363 } 364 if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 365 t.Fatal("Failed to encode ") 366 } 367 if err := tmpCertFile.Close(); err != nil { 368 t.Fatal("Failed to close temp cert file:", err) 369 } 370 return tmpCertFile.Name() 371 } 372 373 func TestURLToMySQLConfig(t *testing.T) { 374 tmpCertFilename := createTmpCert(t) 375 tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename) 376 377 testcases := []struct { 378 name string 379 urlStr string 380 expectedDSN string // empty string signifies that an error is expected 381 }{ 382 { 383 name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", 384 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true", 385 }, 386 { 387 name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 388 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 389 }, 390 { 391 name: "only user - with encoded :", 392 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 393 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 394 }, 395 { 396 name: "only user - with encoded @", 397 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 398 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 399 }, 400 { 401 name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 402 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 403 }, 404 // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 405 // {name: "user/password - user with encoded :", 406 // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 407 // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 408 { 409 name: "user/password - user with encoded @", 410 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 411 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 412 }, 413 { 414 name: "user/password - password with encoded :", 415 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 416 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 417 }, 418 { 419 name: "user/password - password with encoded @", 420 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 421 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 422 }, 423 { 424 name: "custom tls", 425 urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped, 426 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped, 427 }, 428 } 429 for _, tc := range testcases { 430 t.Run(tc.name, func(t *testing.T) { 431 config, err := urlToMySQLConfig(tc.urlStr) 432 if err != nil { 433 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) 434 } 435 dsn := config.FormatDSN() 436 if dsn != tc.expectedDSN { 437 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) 438 } 439 }) 440 } 441 }