github.com/bishtawi/migrate/v4@v4.8.11/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 "github.com/hashicorp/go-multierror" 21 ) 22 23 import ( 24 "github.com/bishtawi/migrate/v4/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 } 45 46 type Mysql struct { 47 // mysql RELEASE_LOCK must be called from the same conn, so 48 // just do everything over a single conn anyway. 49 conn *sql.Conn 50 db *sql.DB 51 isLocked bool 52 53 config *Config 54 } 55 56 // instance must have `multiStatements` set to true 57 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } 61 62 if err := instance.Ping(); err != nil { 63 return nil, err 64 } 65 66 if config.DatabaseName == "" { 67 query := `SELECT DATABASE()` 68 var databaseName sql.NullString 69 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 70 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 71 } 72 73 if len(databaseName.String) == 0 { 74 return nil, ErrNoDatabaseName 75 } 76 77 config.DatabaseName = databaseName.String 78 } 79 80 if len(config.MigrationsTable) == 0 { 81 config.MigrationsTable = DefaultMigrationsTable 82 } 83 84 conn, err := instance.Conn(context.Background()) 85 if err != nil { 86 return nil, err 87 } 88 89 mx := &Mysql{ 90 conn: conn, 91 db: instance, 92 config: config, 93 } 94 95 if err := mx.ensureVersionTable(); err != nil { 96 return nil, err 97 } 98 99 return mx, nil 100 } 101 102 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 103 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 104 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 105 if c == nil { 106 return nil, ErrNilConfig 107 } 108 customQueryParams := map[string]string{} 109 110 for k, v := range c.Params { 111 if strings.HasPrefix(k, "x-") { 112 customQueryParams[k] = v 113 delete(c.Params, k) 114 } 115 } 116 return customQueryParams, nil 117 } 118 119 func urlToMySQLConfig(url string) (*mysql.Config, error) { 120 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 121 if err != nil { 122 return nil, err 123 } 124 125 config.MultiStatements = true 126 127 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 128 // net/url.Parse() would automatically unescape it for us. 129 // See: https://play.golang.org/p/q9j1io-YICQ 130 user, err := nurl.QueryUnescape(config.User) 131 if err != nil { 132 return nil, err 133 } 134 config.User = user 135 136 password, err := nurl.QueryUnescape(config.Passwd) 137 if err != nil { 138 return nil, err 139 } 140 config.Passwd = password 141 142 // use custom TLS? 143 ctls := config.TLSConfig 144 if len(ctls) > 0 { 145 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 146 rootCertPool := x509.NewCertPool() 147 pem, err := ioutil.ReadFile(config.Params["x-tls-ca"]) 148 if err != nil { 149 return nil, err 150 } 151 152 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 153 return nil, ErrAppendPEM 154 } 155 156 clientCert := make([]tls.Certificate, 0, 1) 157 if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" { 158 if ccert == "" || ckey == "" { 159 return nil, ErrTLSCertKeyConfig 160 } 161 certs, err := tls.LoadX509KeyPair(ccert, ckey) 162 if err != nil { 163 return nil, err 164 } 165 clientCert = append(clientCert, certs) 166 } 167 168 insecureSkipVerify := false 169 if len(config.Params["x-tls-insecure-skip-verify"]) > 0 { 170 x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"]) 171 if err != nil { 172 return nil, err 173 } 174 insecureSkipVerify = x 175 } 176 177 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 178 RootCAs: rootCertPool, 179 Certificates: clientCert, 180 InsecureSkipVerify: insecureSkipVerify, 181 }) 182 if err != nil { 183 return nil, err 184 } 185 } 186 } 187 188 return config, nil 189 } 190 191 func (m *Mysql) Open(url string) (database.Driver, error) { 192 config, err := urlToMySQLConfig(url) 193 if err != nil { 194 return nil, err 195 } 196 197 customParams, err := extractCustomQueryParams(config) 198 if err != nil { 199 return nil, err 200 } 201 202 db, err := sql.Open("mysql", config.FormatDSN()) 203 if err != nil { 204 return nil, err 205 } 206 207 mx, err := WithInstance(db, &Config{ 208 DatabaseName: config.DBName, 209 MigrationsTable: customParams["x-migrations-table"], 210 }) 211 if err != nil { 212 return nil, err 213 } 214 215 return mx, nil 216 } 217 218 func (m *Mysql) Close() error { 219 connErr := m.conn.Close() 220 dbErr := m.db.Close() 221 if connErr != nil || dbErr != nil { 222 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 223 } 224 return nil 225 } 226 227 func (m *Mysql) Lock() error { 228 if m.isLocked { 229 return database.ErrLocked 230 } 231 232 aid, err := database.GenerateAdvisoryLockId( 233 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 234 if err != nil { 235 return err 236 } 237 238 query := "SELECT GET_LOCK(?, 10)" 239 var success bool 240 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 241 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 242 } 243 244 if success { 245 m.isLocked = true 246 return nil 247 } 248 249 return database.ErrLocked 250 } 251 252 func (m *Mysql) Unlock() error { 253 if !m.isLocked { 254 return nil 255 } 256 257 aid, err := database.GenerateAdvisoryLockId( 258 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 259 if err != nil { 260 return err 261 } 262 263 query := `SELECT RELEASE_LOCK(?)` 264 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 265 return &database.Error{OrigErr: err, Query: []byte(query)} 266 } 267 268 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 269 // in which case isLocked should be true until the timeout expires -- synchronizing 270 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 271 272 m.isLocked = false 273 return nil 274 } 275 276 func (m *Mysql) Run(migration io.Reader) error { 277 migr, err := ioutil.ReadAll(migration) 278 if err != nil { 279 return err 280 } 281 282 query := string(migr[:]) 283 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 284 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 285 } 286 287 return nil 288 } 289 290 func (m *Mysql) SetVersion(version int, dirty bool) error { 291 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 292 if err != nil { 293 return &database.Error{OrigErr: err, Err: "transaction start failed"} 294 } 295 296 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 297 if _, err := tx.ExecContext(context.Background(), query); err != nil { 298 if errRollback := tx.Rollback(); errRollback != nil { 299 err = multierror.Append(err, errRollback) 300 } 301 return &database.Error{OrigErr: err, Query: []byte(query)} 302 } 303 304 if version >= 0 { 305 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 306 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 307 if errRollback := tx.Rollback(); errRollback != nil { 308 err = multierror.Append(err, errRollback) 309 } 310 return &database.Error{OrigErr: err, Query: []byte(query)} 311 } 312 } 313 314 if err := tx.Commit(); err != nil { 315 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 316 } 317 318 return nil 319 } 320 321 func (m *Mysql) Version() (version int, dirty bool, err error) { 322 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 323 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 324 switch { 325 case err == sql.ErrNoRows: 326 return database.NilVersion, false, nil 327 328 case err != nil: 329 if e, ok := err.(*mysql.MySQLError); ok { 330 if e.Number == 0 { 331 return database.NilVersion, false, nil 332 } 333 } 334 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 335 336 default: 337 return version, dirty, nil 338 } 339 } 340 341 func (m *Mysql) Drop() (err error) { 342 // select all tables 343 query := `SHOW TABLES LIKE '%'` 344 tables, err := m.conn.QueryContext(context.Background(), query) 345 if err != nil { 346 return &database.Error{OrigErr: err, Query: []byte(query)} 347 } 348 defer func() { 349 if errClose := tables.Close(); errClose != nil { 350 err = multierror.Append(err, errClose) 351 } 352 }() 353 354 // delete one table after another 355 tableNames := make([]string, 0) 356 for tables.Next() { 357 var tableName string 358 if err := tables.Scan(&tableName); err != nil { 359 return err 360 } 361 if len(tableName) > 0 { 362 tableNames = append(tableNames, tableName) 363 } 364 } 365 366 if len(tableNames) > 0 { 367 // disable checking foreign key constraints until finished 368 query = `SET foreign_key_checks = 0` 369 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 370 return &database.Error{OrigErr: err, Query: []byte(query)} 371 } 372 373 defer func() { 374 // enable foreign key checks 375 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 376 }() 377 378 // delete one by one ... 379 for _, t := range tableNames { 380 query = "DROP TABLE IF EXISTS `" + t + "`" 381 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 382 return &database.Error{OrigErr: err, Query: []byte(query)} 383 } 384 } 385 } 386 387 return nil 388 } 389 390 // ensureVersionTable checks if versions table exists and, if not, creates it. 391 // Note that this function locks the database, which deviates from the usual 392 // convention of "caller locks" in the Mysql type. 393 func (m *Mysql) ensureVersionTable() (err error) { 394 if err = m.Lock(); err != nil { 395 return err 396 } 397 398 defer func() { 399 if e := m.Unlock(); e != nil { 400 if err == nil { 401 err = e 402 } else { 403 err = multierror.Append(err, e) 404 } 405 } 406 }() 407 408 // check if migration table exists 409 var result string 410 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 411 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 412 if err != sql.ErrNoRows { 413 return &database.Error{OrigErr: err, Query: []byte(query)} 414 } 415 } else { 416 return nil 417 } 418 419 // if not, create the empty migration table 420 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 421 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 422 return &database.Error{OrigErr: err, Query: []byte(query)} 423 } 424 return nil 425 } 426 427 // Returns the bool value of the input. 428 // The 2nd return value indicates if the input was a valid bool value 429 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 430 func readBool(input string) (value bool, valid bool) { 431 switch input { 432 case "1", "true", "TRUE", "True": 433 return true, true 434 case "0", "false", "FALSE", "False": 435 return false, true 436 } 437 438 // Not a valid bool value 439 return 440 }