github.com/nagyist/migrate/v4@v4.14.6/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 "github.com/hashicorp/go-multierror" 21 ) 22 23 import ( 24 "github.com/golang-migrate/migrate/v4/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 NoLock bool 45 } 46 47 type Mysql struct { 48 // mysql RELEASE_LOCK must be called from the same conn, so 49 // just do everything over a single conn anyway. 50 conn *sql.Conn 51 db *sql.DB 52 isLocked bool 53 54 config *Config 55 } 56 57 // instance must have `multiStatements` set to true 58 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 59 if config == nil { 60 return nil, ErrNilConfig 61 } 62 63 if err := instance.Ping(); err != nil { 64 return nil, err 65 } 66 67 if config.DatabaseName == "" { 68 query := `SELECT DATABASE()` 69 var databaseName sql.NullString 70 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 71 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 72 } 73 74 if len(databaseName.String) == 0 { 75 return nil, ErrNoDatabaseName 76 } 77 78 config.DatabaseName = databaseName.String 79 } 80 81 if len(config.MigrationsTable) == 0 { 82 config.MigrationsTable = DefaultMigrationsTable 83 } 84 85 conn, err := instance.Conn(context.Background()) 86 if err != nil { 87 return nil, err 88 } 89 90 mx := &Mysql{ 91 conn: conn, 92 db: instance, 93 config: config, 94 } 95 96 if err := mx.ensureVersionTable(); err != nil { 97 return nil, err 98 } 99 100 return mx, nil 101 } 102 103 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 104 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 105 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 106 if c == nil { 107 return nil, ErrNilConfig 108 } 109 customQueryParams := map[string]string{} 110 111 for k, v := range c.Params { 112 if strings.HasPrefix(k, "x-") { 113 customQueryParams[k] = v 114 delete(c.Params, k) 115 } 116 } 117 return customQueryParams, nil 118 } 119 120 func urlToMySQLConfig(url string) (*mysql.Config, error) { 121 // Need to parse out custom TLS parameters and call 122 // mysql.RegisterTLSConfig() before mysql.ParseDSN() is called 123 // which consumes the registered tls.Config 124 // Fixes: https://github.com/golang-migrate/migrate/issues/411 125 // 126 // Can't use url.Parse() since it fails to parse MySQL DSNs 127 // mysql.ParseDSN() also searches for "?" to find query parameters: 128 // https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344 129 if idx := strings.LastIndex(url, "?"); idx > 0 { 130 rawParams := url[idx+1:] 131 parsedParams, err := nurl.ParseQuery(rawParams) 132 if err != nil { 133 return nil, err 134 } 135 136 ctls := parsedParams.Get("tls") 137 if len(ctls) > 0 { 138 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 139 rootCertPool := x509.NewCertPool() 140 pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca")) 141 if err != nil { 142 return nil, err 143 } 144 145 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 146 return nil, ErrAppendPEM 147 } 148 149 clientCert := make([]tls.Certificate, 0, 1) 150 if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" { 151 if ccert == "" || ckey == "" { 152 return nil, ErrTLSCertKeyConfig 153 } 154 certs, err := tls.LoadX509KeyPair(ccert, ckey) 155 if err != nil { 156 return nil, err 157 } 158 clientCert = append(clientCert, certs) 159 } 160 161 insecureSkipVerify := false 162 insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify") 163 if len(insecureSkipVerifyStr) > 0 { 164 x, err := strconv.ParseBool(insecureSkipVerifyStr) 165 if err != nil { 166 return nil, err 167 } 168 insecureSkipVerify = x 169 } 170 171 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 172 RootCAs: rootCertPool, 173 Certificates: clientCert, 174 InsecureSkipVerify: insecureSkipVerify, 175 }) 176 if err != nil { 177 return nil, err 178 } 179 } 180 } 181 } 182 183 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 184 if err != nil { 185 return nil, err 186 } 187 188 config.MultiStatements = true 189 190 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 191 // net/url.Parse() would automatically unescape it for us. 192 // See: https://play.golang.org/p/q9j1io-YICQ 193 user, err := nurl.QueryUnescape(config.User) 194 if err != nil { 195 return nil, err 196 } 197 config.User = user 198 199 password, err := nurl.QueryUnescape(config.Passwd) 200 if err != nil { 201 return nil, err 202 } 203 config.Passwd = password 204 205 return config, nil 206 } 207 208 func (m *Mysql) Open(url string) (database.Driver, error) { 209 config, err := urlToMySQLConfig(url) 210 if err != nil { 211 return nil, err 212 } 213 214 customParams, err := extractCustomQueryParams(config) 215 if err != nil { 216 return nil, err 217 } 218 219 noLockParam, noLock := customParams["x-no-lock"], false 220 if noLockParam != "" { 221 noLock, err = strconv.ParseBool(noLockParam) 222 if err != nil { 223 return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err) 224 } 225 } 226 227 db, err := sql.Open("mysql", config.FormatDSN()) 228 if err != nil { 229 return nil, err 230 } 231 232 mx, err := WithInstance(db, &Config{ 233 DatabaseName: config.DBName, 234 MigrationsTable: customParams["x-migrations-table"], 235 NoLock: noLock, 236 }) 237 if err != nil { 238 return nil, err 239 } 240 241 return mx, nil 242 } 243 244 func (m *Mysql) Close() error { 245 connErr := m.conn.Close() 246 dbErr := m.db.Close() 247 if connErr != nil || dbErr != nil { 248 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 249 } 250 return nil 251 } 252 253 func (m *Mysql) Lock() error { 254 if m.isLocked { 255 return database.ErrLocked 256 } 257 258 if m.config.NoLock { 259 m.isLocked = true 260 return nil 261 } 262 263 aid, err := database.GenerateAdvisoryLockId( 264 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 265 if err != nil { 266 return err 267 } 268 269 query := "SELECT GET_LOCK(?, 10)" 270 var success bool 271 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 272 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 273 } 274 275 if success { 276 m.isLocked = true 277 return nil 278 } 279 280 return database.ErrLocked 281 } 282 283 func (m *Mysql) Unlock() error { 284 if !m.isLocked { 285 return nil 286 } 287 288 if m.config.NoLock { 289 m.isLocked = false 290 return nil 291 } 292 293 aid, err := database.GenerateAdvisoryLockId( 294 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 295 if err != nil { 296 return err 297 } 298 299 query := `SELECT RELEASE_LOCK(?)` 300 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 301 return &database.Error{OrigErr: err, Query: []byte(query)} 302 } 303 304 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 305 // in which case isLocked should be true until the timeout expires -- synchronizing 306 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 307 308 m.isLocked = false 309 return nil 310 } 311 312 func (m *Mysql) Run(migration io.Reader) error { 313 migr, err := ioutil.ReadAll(migration) 314 if err != nil { 315 return err 316 } 317 318 query := string(migr[:]) 319 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 320 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 321 } 322 323 return nil 324 } 325 326 func (m *Mysql) SetVersion(version int, dirty bool) error { 327 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 328 if err != nil { 329 return &database.Error{OrigErr: err, Err: "transaction start failed"} 330 } 331 332 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 333 if _, err := tx.ExecContext(context.Background(), query); err != nil { 334 if errRollback := tx.Rollback(); errRollback != nil { 335 err = multierror.Append(err, errRollback) 336 } 337 return &database.Error{OrigErr: err, Query: []byte(query)} 338 } 339 340 // Also re-write the schema version for nil dirty versions to prevent 341 // empty schema version for failed down migration on the first migration 342 // See: https://github.com/golang-migrate/migrate/issues/330 343 if version >= 0 || (version == database.NilVersion && dirty) { 344 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 345 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 346 if errRollback := tx.Rollback(); errRollback != nil { 347 err = multierror.Append(err, errRollback) 348 } 349 return &database.Error{OrigErr: err, Query: []byte(query)} 350 } 351 } 352 353 if err := tx.Commit(); err != nil { 354 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 355 } 356 357 return nil 358 } 359 360 func (m *Mysql) Version() (version int, dirty bool, err error) { 361 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 362 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 363 switch { 364 case err == sql.ErrNoRows: 365 return database.NilVersion, false, nil 366 367 case err != nil: 368 if e, ok := err.(*mysql.MySQLError); ok { 369 if e.Number == 0 { 370 return database.NilVersion, false, nil 371 } 372 } 373 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 374 375 default: 376 return version, dirty, nil 377 } 378 } 379 380 func (m *Mysql) Drop() (err error) { 381 // select all tables 382 query := `SHOW TABLES LIKE '%'` 383 tables, err := m.conn.QueryContext(context.Background(), query) 384 if err != nil { 385 return &database.Error{OrigErr: err, Query: []byte(query)} 386 } 387 defer func() { 388 if errClose := tables.Close(); errClose != nil { 389 err = multierror.Append(err, errClose) 390 } 391 }() 392 393 // delete one table after another 394 tableNames := make([]string, 0) 395 for tables.Next() { 396 var tableName string 397 if err := tables.Scan(&tableName); err != nil { 398 return err 399 } 400 if len(tableName) > 0 { 401 tableNames = append(tableNames, tableName) 402 } 403 } 404 if err := tables.Err(); err != nil { 405 return &database.Error{OrigErr: err, Query: []byte(query)} 406 } 407 408 if len(tableNames) > 0 { 409 // disable checking foreign key constraints until finished 410 query = `SET foreign_key_checks = 0` 411 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 412 return &database.Error{OrigErr: err, Query: []byte(query)} 413 } 414 415 defer func() { 416 // enable foreign key checks 417 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 418 }() 419 420 // delete one by one ... 421 for _, t := range tableNames { 422 query = "DROP TABLE IF EXISTS `" + t + "`" 423 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 424 return &database.Error{OrigErr: err, Query: []byte(query)} 425 } 426 } 427 } 428 429 return nil 430 } 431 432 // ensureVersionTable checks if versions table exists and, if not, creates it. 433 // Note that this function locks the database, which deviates from the usual 434 // convention of "caller locks" in the Mysql type. 435 func (m *Mysql) ensureVersionTable() (err error) { 436 if err = m.Lock(); err != nil { 437 return err 438 } 439 440 defer func() { 441 if e := m.Unlock(); e != nil { 442 if err == nil { 443 err = e 444 } else { 445 err = multierror.Append(err, e) 446 } 447 } 448 }() 449 450 // check if migration table exists 451 var result string 452 query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'` 453 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 454 if err != sql.ErrNoRows { 455 return &database.Error{OrigErr: err, Query: []byte(query)} 456 } 457 } else { 458 return nil 459 } 460 461 // if not, create the empty migration table 462 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 463 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 464 return &database.Error{OrigErr: err, Query: []byte(query)} 465 } 466 return nil 467 } 468 469 // Returns the bool value of the input. 470 // The 2nd return value indicates if the input was a valid bool value 471 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 472 func readBool(input string) (value bool, valid bool) { 473 switch input { 474 case "1", "true", "TRUE", "True": 475 return true, true 476 case "0", "false", "FALSE", "False": 477 return false, true 478 } 479 480 // Not a valid bool value 481 return 482 }