github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/mysql/mysql_test.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "crypto/ed25519" 6 "crypto/x509" 7 "database/sql" 8 sqldriver "database/sql/driver" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "io/ioutil" 13 "log" 14 "math/big" 15 "math/rand" 16 "net/url" 17 "os" 18 "strconv" 19 "testing" 20 21 "github.com/dhui/dktest" 22 "github.com/go-sql-driver/mysql" 23 "github.com/stretchr/testify/assert" 24 25 migrate "github.com/seashell-org/golang-migrate/v4" 26 27 dt "github.com/seashell-org/golang-migrate/v4/database/testing" 28 "github.com/seashell-org/golang-migrate/v4/dktesting" 29 30 _ "github.com/seashell-org/golang-migrate/v4/source/file" 31 ) 32 33 const defaultPort = 3306 34 35 var ( 36 opts = dktest.Options{ 37 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 38 PortRequired: true, ReadyFunc: isReady, 39 } 40 optsAnsiQuotes = dktest.Options{ 41 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 42 PortRequired: true, ReadyFunc: isReady, 43 Cmd: []string{"--sql-mode=ANSI_QUOTES"}, 44 } 45 // Supported versions: https://www.mysql.com/support/supportedplatforms/database.html 46 specs = []dktesting.ContainerSpec{ 47 {ImageName: "mysql:5.5", Options: opts}, 48 {ImageName: "mysql:5.6", Options: opts}, 49 {ImageName: "mysql:5.7", Options: opts}, 50 {ImageName: "mysql:8", Options: opts}, 51 } 52 specsAnsiQuotes = []dktesting.ContainerSpec{ 53 {ImageName: "mysql:5.5", Options: optsAnsiQuotes}, 54 {ImageName: "mysql:5.6", Options: optsAnsiQuotes}, 55 {ImageName: "mysql:5.7", Options: optsAnsiQuotes}, 56 {ImageName: "mysql:8", Options: optsAnsiQuotes}, 57 } 58 ) 59 60 func isReady(ctx context.Context, c dktest.ContainerInfo) bool { 61 ip, port, err := c.Port(defaultPort) 62 if err != nil { 63 return false 64 } 65 66 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port)) 67 if err != nil { 68 return false 69 } 70 defer func() { 71 if err := db.Close(); err != nil { 72 log.Println("close error:", err) 73 } 74 }() 75 if err = db.PingContext(ctx); err != nil { 76 switch err { 77 case sqldriver.ErrBadConn, mysql.ErrInvalidConn: 78 return false 79 default: 80 fmt.Println(err) 81 } 82 return false 83 } 84 85 return true 86 } 87 88 func Test(t *testing.T) { 89 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 90 91 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 92 ip, port, err := c.Port(defaultPort) 93 if err != nil { 94 t.Fatal(err) 95 } 96 97 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 98 p := &Mysql{} 99 d, err := p.Open(addr) 100 if err != nil { 101 t.Fatal(err) 102 } 103 defer func() { 104 if err := d.Close(); err != nil { 105 t.Error(err) 106 } 107 }() 108 dt.Test(t, d, []byte("SELECT 1")) 109 110 // check ensureVersionTable 111 if err := d.(*Mysql).ensureVersionTable(); err != nil { 112 t.Fatal(err) 113 } 114 // check again 115 if err := d.(*Mysql).ensureVersionTable(); err != nil { 116 t.Fatal(err) 117 } 118 }) 119 } 120 121 func TestMigrate(t *testing.T) { 122 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 123 124 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 125 ip, port, err := c.Port(defaultPort) 126 if err != nil { 127 t.Fatal(err) 128 } 129 130 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 131 p := &Mysql{} 132 d, err := p.Open(addr) 133 if err != nil { 134 t.Fatal(err) 135 } 136 defer func() { 137 if err := d.Close(); err != nil { 138 t.Error(err) 139 } 140 }() 141 142 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 143 if err != nil { 144 t.Fatal(err) 145 } 146 dt.TestMigrate(t, m) 147 148 // check ensureVersionTable 149 if err := d.(*Mysql).ensureVersionTable(); err != nil { 150 t.Fatal(err) 151 } 152 // check again 153 if err := d.(*Mysql).ensureVersionTable(); err != nil { 154 t.Fatal(err) 155 } 156 }) 157 } 158 159 func TestMigrateAnsiQuotes(t *testing.T) { 160 // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) 161 162 dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) { 163 ip, port, err := c.Port(defaultPort) 164 if err != nil { 165 t.Fatal(err) 166 } 167 168 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 169 p := &Mysql{} 170 d, err := p.Open(addr) 171 if err != nil { 172 t.Fatal(err) 173 } 174 defer func() { 175 if err := d.Close(); err != nil { 176 t.Error(err) 177 } 178 }() 179 180 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 181 if err != nil { 182 t.Fatal(err) 183 } 184 dt.TestMigrate(t, m) 185 186 // check ensureVersionTable 187 if err := d.(*Mysql).ensureVersionTable(); err != nil { 188 t.Fatal(err) 189 } 190 // check again 191 if err := d.(*Mysql).ensureVersionTable(); err != nil { 192 t.Fatal(err) 193 } 194 }) 195 } 196 197 func TestLockWorks(t *testing.T) { 198 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 199 ip, port, err := c.Port(defaultPort) 200 if err != nil { 201 t.Fatal(err) 202 } 203 204 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 205 p := &Mysql{} 206 d, err := p.Open(addr) 207 if err != nil { 208 t.Fatal(err) 209 } 210 dt.Test(t, d, []byte("SELECT 1")) 211 212 ms := d.(*Mysql) 213 214 err = ms.Lock() 215 if err != nil { 216 t.Fatal(err) 217 } 218 err = ms.Unlock() 219 if err != nil { 220 t.Fatal(err) 221 } 222 223 // make sure the 2nd lock works (RELEASE_LOCK is very finicky) 224 err = ms.Lock() 225 if err != nil { 226 t.Fatal(err) 227 } 228 err = ms.Unlock() 229 if err != nil { 230 t.Fatal(err) 231 } 232 }) 233 } 234 235 func TestNoLockParamValidation(t *testing.T) { 236 ip := "127.0.0.1" 237 port := 3306 238 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 239 p := &Mysql{} 240 _, err := p.Open(addr + "?x-no-lock=not-a-bool") 241 if !errors.Is(err, strconv.ErrSyntax) { 242 t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter") 243 } 244 } 245 246 func TestNoLockWorks(t *testing.T) { 247 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 248 ip, port, err := c.Port(defaultPort) 249 if err != nil { 250 t.Fatal(err) 251 } 252 253 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 254 p := &Mysql{} 255 d, err := p.Open(addr) 256 if err != nil { 257 t.Fatal(err) 258 } 259 260 lock := d.(*Mysql) 261 262 p = &Mysql{} 263 d, err = p.Open(addr + "?x-no-lock=true") 264 if err != nil { 265 t.Fatal(err) 266 } 267 268 noLock := d.(*Mysql) 269 270 // Should be possible to take real lock and no-lock at the same time 271 if err = lock.Lock(); err != nil { 272 t.Fatal(err) 273 } 274 if err = noLock.Lock(); err != nil { 275 t.Fatal(err) 276 } 277 if err = lock.Unlock(); err != nil { 278 t.Fatal(err) 279 } 280 if err = noLock.Unlock(); err != nil { 281 t.Fatal(err) 282 } 283 }) 284 } 285 286 func TestExtractCustomQueryParams(t *testing.T) { 287 testcases := []struct { 288 name string 289 config *mysql.Config 290 expectedParams map[string]string 291 expectedCustomParams map[string]string 292 expectedErr error 293 }{ 294 {name: "nil config", expectedErr: ErrNilConfig}, 295 { 296 name: "no params", 297 config: mysql.NewConfig(), 298 expectedCustomParams: map[string]string{}, 299 }, 300 { 301 name: "no custom params", 302 config: &mysql.Config{Params: map[string]string{"hello": "world"}}, 303 expectedParams: map[string]string{"hello": "world"}, 304 expectedCustomParams: map[string]string{}, 305 }, 306 { 307 name: "one param, one custom param", 308 config: &mysql.Config{ 309 Params: map[string]string{"hello": "world", "x-foo": "bar"}, 310 }, 311 expectedParams: map[string]string{"hello": "world"}, 312 expectedCustomParams: map[string]string{"x-foo": "bar"}, 313 }, 314 { 315 name: "multiple params, multiple custom params", 316 config: &mysql.Config{ 317 Params: map[string]string{ 318 "hello": "world", 319 "x-foo": "bar", 320 "dead": "beef", 321 "x-cat": "hat", 322 }, 323 }, 324 expectedParams: map[string]string{"hello": "world", "dead": "beef"}, 325 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, 326 }, 327 } 328 for _, tc := range testcases { 329 t.Run(tc.name, func(t *testing.T) { 330 customParams, err := extractCustomQueryParams(tc.config) 331 if tc.config != nil { 332 assert.Equal(t, tc.expectedParams, tc.config.Params, 333 "Expected config params have custom params properly removed") 334 } 335 assert.Equal(t, tc.expectedErr, err, "Expected errors to match") 336 assert.Equal(t, tc.expectedCustomParams, customParams, 337 "Expected custom params to be properly extracted") 338 }) 339 } 340 } 341 342 func createTmpCert(t *testing.T) string { 343 tmpCertFile, err := ioutil.TempFile("", "migrate_test_cert") 344 if err != nil { 345 t.Fatal("Failed to create temp cert file:", err) 346 } 347 t.Cleanup(func() { 348 if err := os.Remove(tmpCertFile.Name()); err != nil { 349 t.Log("Failed to cleanup temp cert file:", err) 350 } 351 }) 352 353 r := rand.New(rand.NewSource(0)) 354 pub, priv, err := ed25519.GenerateKey(r) 355 if err != nil { 356 t.Fatal("Failed to generate ed25519 key for temp cert file:", err) 357 } 358 tmpl := x509.Certificate{ 359 SerialNumber: big.NewInt(0), 360 } 361 derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv) 362 if err != nil { 363 t.Fatal("Failed to generate temp cert file:", err) 364 } 365 if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 366 t.Fatal("Failed to encode ") 367 } 368 if err := tmpCertFile.Close(); err != nil { 369 t.Fatal("Failed to close temp cert file:", err) 370 } 371 return tmpCertFile.Name() 372 } 373 374 func TestURLToMySQLConfig(t *testing.T) { 375 tmpCertFilename := createTmpCert(t) 376 tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename) 377 378 testcases := []struct { 379 name string 380 urlStr string 381 expectedDSN string // empty string signifies that an error is expected 382 }{ 383 {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", 384 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 385 {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 386 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 387 {name: "only user - with encoded :", 388 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 389 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 390 {name: "only user - with encoded @", 391 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 392 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 393 {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 394 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 395 // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 396 // {name: "user/password - user with encoded :", 397 // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 398 // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 399 {name: "user/password - user with encoded @", 400 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 401 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 402 {name: "user/password - password with encoded :", 403 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 404 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 405 {name: "user/password - password with encoded @", 406 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 407 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 408 {name: "custom tls", 409 urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped, 410 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped}, 411 } 412 for _, tc := range testcases { 413 t.Run(tc.name, func(t *testing.T) { 414 config, err := urlToMySQLConfig(tc.urlStr) 415 if err != nil { 416 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) 417 } 418 dsn := config.FormatDSN() 419 if dsn != tc.expectedDSN { 420 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) 421 } 422 }) 423 } 424 }