github.com/dynastymasra/migrate/v4@v4.11.0/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 "github.com/hashicorp/go-multierror" 21 ) 22 23 import ( 24 "github.com/golang-migrate/migrate/v4/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 } 45 46 type Mysql struct { 47 // mysql RELEASE_LOCK must be called from the same conn, so 48 // just do everything over a single conn anyway. 49 conn *sql.Conn 50 db *sql.DB 51 isLocked bool 52 53 config *Config 54 } 55 56 // instance must have `multiStatements` set to true 57 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } 61 62 if err := instance.Ping(); err != nil { 63 return nil, err 64 } 65 66 if config.DatabaseName == "" { 67 query := `SELECT DATABASE()` 68 var databaseName sql.NullString 69 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 70 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 71 } 72 73 if len(databaseName.String) == 0 { 74 return nil, ErrNoDatabaseName 75 } 76 77 config.DatabaseName = databaseName.String 78 } 79 80 if len(config.MigrationsTable) == 0 { 81 config.MigrationsTable = DefaultMigrationsTable 82 } 83 84 conn, err := instance.Conn(context.Background()) 85 if err != nil { 86 return nil, err 87 } 88 89 mx := &Mysql{ 90 conn: conn, 91 db: instance, 92 config: config, 93 } 94 95 if err := mx.ensureVersionTable(); err != nil { 96 return nil, err 97 } 98 99 return mx, nil 100 } 101 102 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 103 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 104 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 105 if c == nil { 106 return nil, ErrNilConfig 107 } 108 customQueryParams := map[string]string{} 109 110 for k, v := range c.Params { 111 if strings.HasPrefix(k, "x-") { 112 customQueryParams[k] = v 113 delete(c.Params, k) 114 } 115 } 116 return customQueryParams, nil 117 } 118 119 func urlToMySQLConfig(url string) (*mysql.Config, error) { 120 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 121 if err != nil { 122 return nil, err 123 } 124 125 config.MultiStatements = true 126 127 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 128 // net/url.Parse() would automatically unescape it for us. 129 // See: https://play.golang.org/p/q9j1io-YICQ 130 user, err := nurl.QueryUnescape(config.User) 131 if err != nil { 132 return nil, err 133 } 134 config.User = user 135 136 password, err := nurl.QueryUnescape(config.Passwd) 137 if err != nil { 138 return nil, err 139 } 140 config.Passwd = password 141 142 // use custom TLS? 143 ctls := config.TLSConfig 144 if len(ctls) > 0 { 145 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 146 rootCertPool := x509.NewCertPool() 147 pem, err := ioutil.ReadFile(config.Params["x-tls-ca"]) 148 if err != nil { 149 return nil, err 150 } 151 152 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 153 return nil, ErrAppendPEM 154 } 155 156 clientCert := make([]tls.Certificate, 0, 1) 157 if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" { 158 if ccert == "" || ckey == "" { 159 return nil, ErrTLSCertKeyConfig 160 } 161 certs, err := tls.LoadX509KeyPair(ccert, ckey) 162 if err != nil { 163 return nil, err 164 } 165 clientCert = append(clientCert, certs) 166 } 167 168 insecureSkipVerify := false 169 if len(config.Params["x-tls-insecure-skip-verify"]) > 0 { 170 x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"]) 171 if err != nil { 172 return nil, err 173 } 174 insecureSkipVerify = x 175 } 176 177 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 178 RootCAs: rootCertPool, 179 Certificates: clientCert, 180 InsecureSkipVerify: insecureSkipVerify, 181 }) 182 if err != nil { 183 return nil, err 184 } 185 } 186 } 187 188 return config, nil 189 } 190 191 func (m *Mysql) Open(url string) (database.Driver, error) { 192 config, err := urlToMySQLConfig(url) 193 if err != nil { 194 return nil, err 195 } 196 197 customParams, err := extractCustomQueryParams(config) 198 if err != nil { 199 return nil, err 200 } 201 202 db, err := sql.Open("mysql", config.FormatDSN()) 203 if err != nil { 204 return nil, err 205 } 206 207 mx, err := WithInstance(db, &Config{ 208 DatabaseName: config.DBName, 209 MigrationsTable: customParams["x-migrations-table"], 210 }) 211 if err != nil { 212 return nil, err 213 } 214 215 return mx, nil 216 } 217 218 func (m *Mysql) Close() error { 219 connErr := m.conn.Close() 220 dbErr := m.db.Close() 221 if connErr != nil || dbErr != nil { 222 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 223 } 224 return nil 225 } 226 227 func (m *Mysql) Lock() error { 228 if m.isLocked { 229 return database.ErrLocked 230 } 231 232 aid, err := database.GenerateAdvisoryLockId( 233 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 234 if err != nil { 235 return err 236 } 237 238 query := "SELECT GET_LOCK(?, 10)" 239 var success bool 240 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 241 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 242 } 243 244 if success { 245 m.isLocked = true 246 return nil 247 } 248 249 return database.ErrLocked 250 } 251 252 func (m *Mysql) Unlock() error { 253 if !m.isLocked { 254 return nil 255 } 256 257 aid, err := database.GenerateAdvisoryLockId( 258 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 259 if err != nil { 260 return err 261 } 262 263 query := `SELECT RELEASE_LOCK(?)` 264 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 265 return &database.Error{OrigErr: err, Query: []byte(query)} 266 } 267 268 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 269 // in which case isLocked should be true until the timeout expires -- synchronizing 270 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 271 272 m.isLocked = false 273 return nil 274 } 275 276 func (m *Mysql) Run(migration io.Reader) error { 277 migr, err := ioutil.ReadAll(migration) 278 if err != nil { 279 return err 280 } 281 282 query := string(migr[:]) 283 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 284 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 285 } 286 287 return nil 288 } 289 290 func (m *Mysql) SetVersion(version int, dirty bool) error { 291 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 292 if err != nil { 293 return &database.Error{OrigErr: err, Err: "transaction start failed"} 294 } 295 296 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 297 if _, err := tx.ExecContext(context.Background(), query); err != nil { 298 if errRollback := tx.Rollback(); errRollback != nil { 299 err = multierror.Append(err, errRollback) 300 } 301 return &database.Error{OrigErr: err, Query: []byte(query)} 302 } 303 304 // Also re-write the schema version for nil dirty versions to prevent 305 // empty schema version for failed down migration on the first migration 306 // See: https://github.com/golang-migrate/migrate/issues/330 307 if version >= 0 || (version == database.NilVersion && dirty) { 308 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 309 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 310 if errRollback := tx.Rollback(); errRollback != nil { 311 err = multierror.Append(err, errRollback) 312 } 313 return &database.Error{OrigErr: err, Query: []byte(query)} 314 } 315 } 316 317 if err := tx.Commit(); err != nil { 318 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 319 } 320 321 return nil 322 } 323 324 func (m *Mysql) Version() (version int, dirty bool, err error) { 325 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 326 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 327 switch { 328 case err == sql.ErrNoRows: 329 return database.NilVersion, false, nil 330 331 case err != nil: 332 if e, ok := err.(*mysql.MySQLError); ok { 333 if e.Number == 0 { 334 return database.NilVersion, false, nil 335 } 336 } 337 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 338 339 default: 340 return version, dirty, nil 341 } 342 } 343 344 func (m *Mysql) Drop() (err error) { 345 // select all tables 346 query := `SHOW TABLES LIKE '%'` 347 tables, err := m.conn.QueryContext(context.Background(), query) 348 if err != nil { 349 return &database.Error{OrigErr: err, Query: []byte(query)} 350 } 351 defer func() { 352 if errClose := tables.Close(); errClose != nil { 353 err = multierror.Append(err, errClose) 354 } 355 }() 356 357 // delete one table after another 358 tableNames := make([]string, 0) 359 for tables.Next() { 360 var tableName string 361 if err := tables.Scan(&tableName); err != nil { 362 return err 363 } 364 if len(tableName) > 0 { 365 tableNames = append(tableNames, tableName) 366 } 367 } 368 369 if len(tableNames) > 0 { 370 // disable checking foreign key constraints until finished 371 query = `SET foreign_key_checks = 0` 372 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 373 return &database.Error{OrigErr: err, Query: []byte(query)} 374 } 375 376 defer func() { 377 // enable foreign key checks 378 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 379 }() 380 381 // delete one by one ... 382 for _, t := range tableNames { 383 query = "DROP TABLE IF EXISTS `" + t + "`" 384 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 385 return &database.Error{OrigErr: err, Query: []byte(query)} 386 } 387 } 388 } 389 390 return nil 391 } 392 393 // ensureVersionTable checks if versions table exists and, if not, creates it. 394 // Note that this function locks the database, which deviates from the usual 395 // convention of "caller locks" in the Mysql type. 396 func (m *Mysql) ensureVersionTable() (err error) { 397 if err = m.Lock(); err != nil { 398 return err 399 } 400 401 defer func() { 402 if e := m.Unlock(); e != nil { 403 if err == nil { 404 err = e 405 } else { 406 err = multierror.Append(err, e) 407 } 408 } 409 }() 410 411 // check if migration table exists 412 var result string 413 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 414 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 415 if err != sql.ErrNoRows { 416 return &database.Error{OrigErr: err, Query: []byte(query)} 417 } 418 } else { 419 return nil 420 } 421 422 // if not, create the empty migration table 423 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 424 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 425 return &database.Error{OrigErr: err, Query: []byte(query)} 426 } 427 return nil 428 } 429 430 // Returns the bool value of the input. 431 // The 2nd return value indicates if the input was a valid bool value 432 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 433 func readBool(input string) (value bool, valid bool) { 434 switch input { 435 case "1", "true", "TRUE", "True": 436 return true, true 437 case "0", "false", "FALSE", "False": 438 return false, true 439 } 440 441 // Not a valid bool value 442 return 443 }