github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/mysql/mysql_test.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "crypto/ed25519" 6 "crypto/x509" 7 "database/sql" 8 sqldriver "database/sql/driver" 9 "encoding/pem" 10 "errors" 11 "fmt" 12 "log" 13 "math/big" 14 "math/rand" 15 "net/url" 16 "os" 17 "strconv" 18 "testing" 19 20 "github.com/dhui/dktest" 21 "github.com/go-sql-driver/mysql" 22 "github.com/golang-migrate/migrate/v4" 23 dt "github.com/golang-migrate/migrate/v4/database/testing" 24 "github.com/golang-migrate/migrate/v4/dktesting" 25 _ "github.com/golang-migrate/migrate/v4/source/file" 26 "github.com/stretchr/testify/assert" 27 ) 28 29 const defaultPort = 3306 30 31 var ( 32 opts = dktest.Options{ 33 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 34 PortRequired: true, ReadyFunc: isReady, 35 } 36 optsAnsiQuotes = dktest.Options{ 37 Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"}, 38 PortRequired: true, ReadyFunc: isReady, 39 Cmd: []string{"--sql-mode=ANSI_QUOTES"}, 40 } 41 // Supported versions: https://www.mysql.com/support/supportedplatforms/database.html 42 specs = []dktesting.ContainerSpec{ 43 {ImageName: "mysql:5.5", Options: opts}, 44 {ImageName: "mysql:5.6", Options: opts}, 45 {ImageName: "mysql:5.7", Options: opts}, 46 {ImageName: "mysql:8", Options: opts}, 47 } 48 specsAnsiQuotes = []dktesting.ContainerSpec{ 49 {ImageName: "mysql:5.5", Options: optsAnsiQuotes}, 50 {ImageName: "mysql:5.6", Options: optsAnsiQuotes}, 51 {ImageName: "mysql:5.7", Options: optsAnsiQuotes}, 52 {ImageName: "mysql:8", Options: optsAnsiQuotes}, 53 } 54 ) 55 56 func isReady(ctx context.Context, c dktest.ContainerInfo) bool { 57 ip, port, err := c.Port(defaultPort) 58 if err != nil { 59 return false 60 } 61 62 db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port)) 63 if err != nil { 64 return false 65 } 66 defer func() { 67 if err := db.Close(); err != nil { 68 log.Println("close error:", err) 69 } 70 }() 71 if err = db.PingContext(ctx); err != nil { 72 switch err { 73 case sqldriver.ErrBadConn, mysql.ErrInvalidConn: 74 return false 75 default: 76 fmt.Println(err) 77 } 78 return false 79 } 80 81 return true 82 } 83 84 func Test(t *testing.T) { 85 // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) 86 87 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 88 ip, port, err := c.Port(defaultPort) 89 if err != nil { 90 t.Fatal(err) 91 } 92 93 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 94 p := &Mysql{} 95 d, err := p.Open(addr) 96 if err != nil { 97 t.Fatal(err) 98 } 99 defer func() { 100 if err := d.Close(); err != nil { 101 t.Error(err) 102 } 103 }() 104 dt.Test(t, d, []byte("SELECT 1")) 105 106 // check ensureVersionTable 107 if err := d.(*Mysql).ensureVersionTable(); err != nil { 108 t.Fatal(err) 109 } 110 // check again 111 if err := d.(*Mysql).ensureVersionTable(); err != nil { 112 t.Fatal(err) 113 } 114 }) 115 } 116 117 func TestMigrate(t *testing.T) { 118 // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) 119 120 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 121 ip, port, err := c.Port(defaultPort) 122 if err != nil { 123 t.Fatal(err) 124 } 125 126 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 127 p := &Mysql{} 128 d, err := p.Open(addr) 129 if err != nil { 130 t.Fatal(err) 131 } 132 defer func() { 133 if err := d.Close(); err != nil { 134 t.Error(err) 135 } 136 }() 137 138 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 139 if err != nil { 140 t.Fatal(err) 141 } 142 dt.TestMigrate(t, m) 143 144 // check ensureVersionTable 145 if err := d.(*Mysql).ensureVersionTable(); err != nil { 146 t.Fatal(err) 147 } 148 // check again 149 if err := d.(*Mysql).ensureVersionTable(); err != nil { 150 t.Fatal(err) 151 } 152 }) 153 } 154 155 func TestMigrateAnsiQuotes(t *testing.T) { 156 // mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime))) 157 158 dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) { 159 ip, port, err := c.Port(defaultPort) 160 if err != nil { 161 t.Fatal(err) 162 } 163 164 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 165 p := &Mysql{} 166 d, err := p.Open(addr) 167 if err != nil { 168 t.Fatal(err) 169 } 170 defer func() { 171 if err := d.Close(); err != nil { 172 t.Error(err) 173 } 174 }() 175 176 m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) 177 if err != nil { 178 t.Fatal(err) 179 } 180 dt.TestMigrate(t, m) 181 182 // check ensureVersionTable 183 if err := d.(*Mysql).ensureVersionTable(); err != nil { 184 t.Fatal(err) 185 } 186 // check again 187 if err := d.(*Mysql).ensureVersionTable(); err != nil { 188 t.Fatal(err) 189 } 190 }) 191 } 192 193 func TestLockWorks(t *testing.T) { 194 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 195 ip, port, err := c.Port(defaultPort) 196 if err != nil { 197 t.Fatal(err) 198 } 199 200 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 201 p := &Mysql{} 202 d, err := p.Open(addr) 203 if err != nil { 204 t.Fatal(err) 205 } 206 dt.Test(t, d, []byte("SELECT 1")) 207 208 ms := d.(*Mysql) 209 210 err = ms.Lock() 211 if err != nil { 212 t.Fatal(err) 213 } 214 err = ms.Unlock() 215 if err != nil { 216 t.Fatal(err) 217 } 218 219 // make sure the 2nd lock works (RELEASE_LOCK is very finicky) 220 err = ms.Lock() 221 if err != nil { 222 t.Fatal(err) 223 } 224 err = ms.Unlock() 225 if err != nil { 226 t.Fatal(err) 227 } 228 }) 229 } 230 231 func TestNoLockParamValidation(t *testing.T) { 232 ip := "127.0.0.1" 233 port := 3306 234 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 235 p := &Mysql{} 236 _, err := p.Open(addr + "?x-no-lock=not-a-bool") 237 if !errors.Is(err, strconv.ErrSyntax) { 238 t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter") 239 } 240 } 241 242 func TestNoLockWorks(t *testing.T) { 243 dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { 244 ip, port, err := c.Port(defaultPort) 245 if err != nil { 246 t.Fatal(err) 247 } 248 249 addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) 250 p := &Mysql{} 251 d, err := p.Open(addr) 252 if err != nil { 253 t.Fatal(err) 254 } 255 256 lock := d.(*Mysql) 257 258 p = &Mysql{} 259 d, err = p.Open(addr + "?x-no-lock=true") 260 if err != nil { 261 t.Fatal(err) 262 } 263 264 noLock := d.(*Mysql) 265 266 // Should be possible to take real lock and no-lock at the same time 267 if err = lock.Lock(); err != nil { 268 t.Fatal(err) 269 } 270 if err = noLock.Lock(); err != nil { 271 t.Fatal(err) 272 } 273 if err = lock.Unlock(); err != nil { 274 t.Fatal(err) 275 } 276 if err = noLock.Unlock(); err != nil { 277 t.Fatal(err) 278 } 279 }) 280 } 281 282 func TestExtractCustomQueryParams(t *testing.T) { 283 testcases := []struct { 284 name string 285 config *mysql.Config 286 expectedParams map[string]string 287 expectedCustomParams map[string]string 288 expectedErr error 289 }{ 290 {name: "nil config", expectedErr: ErrNilConfig}, 291 { 292 name: "no params", 293 config: mysql.NewConfig(), 294 expectedCustomParams: map[string]string{}, 295 }, 296 { 297 name: "no custom params", 298 config: &mysql.Config{Params: map[string]string{"hello": "world"}}, 299 expectedParams: map[string]string{"hello": "world"}, 300 expectedCustomParams: map[string]string{}, 301 }, 302 { 303 name: "one param, one custom param", 304 config: &mysql.Config{ 305 Params: map[string]string{"hello": "world", "x-foo": "bar"}, 306 }, 307 expectedParams: map[string]string{"hello": "world"}, 308 expectedCustomParams: map[string]string{"x-foo": "bar"}, 309 }, 310 { 311 name: "multiple params, multiple custom params", 312 config: &mysql.Config{ 313 Params: map[string]string{ 314 "hello": "world", 315 "x-foo": "bar", 316 "dead": "beef", 317 "x-cat": "hat", 318 }, 319 }, 320 expectedParams: map[string]string{"hello": "world", "dead": "beef"}, 321 expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"}, 322 }, 323 } 324 for _, tc := range testcases { 325 t.Run(tc.name, func(t *testing.T) { 326 customParams, err := extractCustomQueryParams(tc.config) 327 if tc.config != nil { 328 assert.Equal(t, tc.expectedParams, tc.config.Params, 329 "Expected config params have custom params properly removed") 330 } 331 assert.Equal(t, tc.expectedErr, err, "Expected errors to match") 332 assert.Equal(t, tc.expectedCustomParams, customParams, 333 "Expected custom params to be properly extracted") 334 }) 335 } 336 } 337 338 func createTmpCert(t *testing.T) string { 339 tmpCertFile, err := os.CreateTemp("", "migrate_test_cert") 340 if err != nil { 341 t.Fatal("Failed to create temp cert file:", err) 342 } 343 t.Cleanup(func() { 344 if err := os.Remove(tmpCertFile.Name()); err != nil { 345 t.Log("Failed to cleanup temp cert file:", err) 346 } 347 }) 348 349 r := rand.New(rand.NewSource(0)) 350 pub, priv, err := ed25519.GenerateKey(r) 351 if err != nil { 352 t.Fatal("Failed to generate ed25519 key for temp cert file:", err) 353 } 354 tmpl := x509.Certificate{ 355 SerialNumber: big.NewInt(0), 356 } 357 derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv) 358 if err != nil { 359 t.Fatal("Failed to generate temp cert file:", err) 360 } 361 if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 362 t.Fatal("Failed to encode ") 363 } 364 if err := tmpCertFile.Close(); err != nil { 365 t.Fatal("Failed to close temp cert file:", err) 366 } 367 return tmpCertFile.Name() 368 } 369 370 func TestURLToMySQLConfig(t *testing.T) { 371 tmpCertFilename := createTmpCert(t) 372 tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename) 373 374 testcases := []struct { 375 name string 376 urlStr string 377 expectedDSN string // empty string signifies that an error is expected 378 }{ 379 {name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true", 380 expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 381 {name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 382 expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 383 {name: "only user - with encoded :", 384 urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 385 expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 386 {name: "only user - with encoded @", 387 urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 388 expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 389 {name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 390 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 391 // Not supported yet: https://github.com/go-sql-driver/mysql/issues/591 392 // {name: "user/password - user with encoded :", 393 // urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 394 // expectedDSN: "username::pasword@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 395 {name: "user/password - user with encoded @", 396 urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 397 expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 398 {name: "user/password - password with encoded :", 399 urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 400 expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 401 {name: "user/password - password with encoded @", 402 urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true", 403 expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"}, 404 {name: "custom tls", 405 urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped, 406 expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped}, 407 } 408 for _, tc := range testcases { 409 t.Run(tc.name, func(t *testing.T) { 410 config, err := urlToMySQLConfig(tc.urlStr) 411 if err != nil { 412 t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err) 413 } 414 dsn := config.FormatDSN() 415 if dsn != tc.expectedDSN { 416 t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN) 417 } 418 }) 419 } 420 }