github.com/tooolbox/migrate/v4@v4.6.2-0.20200325001913-461b03b92064/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 "github.com/hashicorp/go-multierror" 21 ) 22 23 import ( 24 "github.com/tooolbox/migrate/v4/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 } 45 46 type Mysql struct { 47 // mysql RELEASE_LOCK must be called from the same conn, so 48 // just do everything over a single conn anyway. 49 conn *sql.Conn 50 db *sql.DB 51 isLocked bool 52 53 config *Config 54 } 55 56 // instance must have `multiStatements` set to true 57 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } 61 62 if err := instance.Ping(); err != nil { 63 return nil, err 64 } 65 66 if config.DatabaseName == "" { 67 query := `SELECT DATABASE()` 68 var databaseName sql.NullString 69 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 70 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 71 } 72 73 if len(databaseName.String) == 0 { 74 return nil, ErrNoDatabaseName 75 } 76 77 config.DatabaseName = databaseName.String 78 } 79 80 if len(config.MigrationsTable) == 0 { 81 config.MigrationsTable = DefaultMigrationsTable 82 } 83 84 conn, err := instance.Conn(context.Background()) 85 if err != nil { 86 return nil, err 87 } 88 89 mx := &Mysql{ 90 conn: conn, 91 db: instance, 92 config: config, 93 } 94 95 if err := mx.ensureVersionTable(); err != nil { 96 return nil, err 97 } 98 99 return mx, nil 100 } 101 102 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 103 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 104 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 105 if c == nil { 106 return nil, ErrNilConfig 107 } 108 customQueryParams := map[string]string{} 109 110 for k, v := range c.Params { 111 if strings.HasPrefix(k, "x-") { 112 customQueryParams[k] = v 113 delete(c.Params, k) 114 } 115 } 116 return customQueryParams, nil 117 } 118 119 func urlToMySQLConfig(url string) (*mysql.Config, error) { 120 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 121 if err != nil { 122 return nil, err 123 } 124 125 config.MultiStatements = true 126 127 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 128 // net/url.Parse() would automatically unescape it for us. 129 // See: https://play.golang.org/p/q9j1io-YICQ 130 user, err := nurl.QueryUnescape(config.User) 131 if err != nil { 132 return nil, err 133 } 134 config.User = user 135 136 password, err := nurl.QueryUnescape(config.Passwd) 137 if err != nil { 138 return nil, err 139 } 140 config.Passwd = password 141 142 // use custom TLS? 143 ctls := config.TLSConfig 144 if len(ctls) > 0 { 145 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 146 rootCertPool := x509.NewCertPool() 147 pem, err := ioutil.ReadFile(config.Params["x-tls-ca"]) 148 if err != nil { 149 return nil, err 150 } 151 152 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 153 return nil, ErrAppendPEM 154 } 155 156 clientCert := make([]tls.Certificate, 0, 1) 157 if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" { 158 if ccert == "" || ckey == "" { 159 return nil, ErrTLSCertKeyConfig 160 } 161 certs, err := tls.LoadX509KeyPair(ccert, ckey) 162 if err != nil { 163 return nil, err 164 } 165 clientCert = append(clientCert, certs) 166 } 167 168 insecureSkipVerify := false 169 if len(config.Params["x-tls-insecure-skip-verify"]) > 0 { 170 x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"]) 171 if err != nil { 172 return nil, err 173 } 174 insecureSkipVerify = x 175 } 176 177 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 178 RootCAs: rootCertPool, 179 Certificates: clientCert, 180 InsecureSkipVerify: insecureSkipVerify, 181 }) 182 if err != nil { 183 return nil, err 184 } 185 } 186 } 187 188 return config, nil 189 } 190 191 func (m *Mysql) Open(url string) (database.Driver, error) { 192 config, err := urlToMySQLConfig(url) 193 if err != nil { 194 return nil, err 195 } 196 197 customParams, err := extractCustomQueryParams(config) 198 if err != nil { 199 return nil, err 200 } 201 202 db, err := sql.Open("mysql", config.FormatDSN()) 203 if err != nil { 204 return nil, err 205 } 206 207 mx, err := WithInstance(db, &Config{ 208 DatabaseName: config.DBName, 209 MigrationsTable: customParams["x-migrations-table"], 210 }) 211 if err != nil { 212 return nil, err 213 } 214 215 return mx, nil 216 } 217 218 func (m *Mysql) Close() error { 219 connErr := m.conn.Close() 220 dbErr := m.db.Close() 221 if connErr != nil || dbErr != nil { 222 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 223 } 224 return nil 225 } 226 227 func (m *Mysql) Lock() error { 228 if m.isLocked { 229 return database.ErrLocked 230 } 231 232 aid, err := database.GenerateAdvisoryLockId( 233 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 234 if err != nil { 235 return err 236 } 237 238 query := "SELECT GET_LOCK(?, 10)" 239 var success bool 240 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 241 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 242 } 243 244 if success { 245 m.isLocked = true 246 return nil 247 } 248 249 return database.ErrLocked 250 } 251 252 func (m *Mysql) Unlock() error { 253 if !m.isLocked { 254 return nil 255 } 256 257 aid, err := database.GenerateAdvisoryLockId( 258 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 259 if err != nil { 260 return err 261 } 262 263 query := `SELECT RELEASE_LOCK(?)` 264 var success *bool 265 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 266 return &database.Error{OrigErr: err, Query: []byte(query)} 267 } 268 269 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 270 // in which case isLocked should be true until the timeout expires -- synchronizing 271 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 272 273 if success == nil { 274 return database.Error{OrigErr: database.ErrNotLocked, Err: "can't unlock, named lock did not exist"} 275 } else if !(*success) { 276 return database.Error{OrigErr: database.ErrNotLocked, Err: "can't unlock, named lock established in different thread"} 277 } 278 279 m.isLocked = false 280 return nil 281 } 282 283 func (m *Mysql) Run(migration io.Reader) error { 284 migr, err := ioutil.ReadAll(migration) 285 if err != nil { 286 return err 287 } 288 289 query := string(migr[:]) 290 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 291 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 292 } 293 294 return nil 295 } 296 297 func (m *Mysql) SetVersion(version int, dirty bool) error { 298 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 299 if err != nil { 300 return &database.Error{OrigErr: err, Err: "transaction start failed"} 301 } 302 303 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 304 if _, err := tx.ExecContext(context.Background(), query); err != nil { 305 if errRollback := tx.Rollback(); errRollback != nil { 306 err = multierror.Append(err, errRollback) 307 } 308 return &database.Error{OrigErr: err, Query: []byte(query)} 309 } 310 311 if version >= 0 { 312 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 313 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 314 if errRollback := tx.Rollback(); errRollback != nil { 315 err = multierror.Append(err, errRollback) 316 } 317 return &database.Error{OrigErr: err, Query: []byte(query)} 318 } 319 } 320 321 if err := tx.Commit(); err != nil { 322 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 323 } 324 325 return nil 326 } 327 328 func (m *Mysql) Version() (version int, dirty bool, err error) { 329 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 330 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 331 switch { 332 case err == sql.ErrNoRows: 333 return database.NilVersion, false, nil 334 335 case err != nil: 336 if e, ok := err.(*mysql.MySQLError); ok { 337 if e.Number == 0 { 338 return database.NilVersion, false, nil 339 } 340 } 341 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 342 343 default: 344 return version, dirty, nil 345 } 346 } 347 348 func (m *Mysql) Drop() (err error) { 349 // select all tables 350 query := `SHOW TABLES LIKE '%'` 351 tables, err := m.conn.QueryContext(context.Background(), query) 352 if err != nil { 353 return &database.Error{OrigErr: err, Query: []byte(query)} 354 } 355 defer func() { 356 if errClose := tables.Close(); errClose != nil { 357 err = multierror.Append(err, errClose) 358 } 359 }() 360 361 // delete one table after another 362 tableNames := make([]string, 0) 363 for tables.Next() { 364 var tableName string 365 if err := tables.Scan(&tableName); err != nil { 366 return err 367 } 368 if len(tableName) > 0 { 369 tableNames = append(tableNames, tableName) 370 } 371 } 372 373 if len(tableNames) > 0 { 374 // disable checking foreign key constraints until finished 375 query = `SET foreign_key_checks = 0` 376 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 377 return &database.Error{OrigErr: err, Query: []byte(query)} 378 } 379 380 defer func() { 381 // enable foreign key checks 382 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 383 }() 384 385 // delete one by one ... 386 for _, t := range tableNames { 387 query = "DROP TABLE IF EXISTS `" + t + "`" 388 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 389 return &database.Error{OrigErr: err, Query: []byte(query)} 390 } 391 } 392 } 393 394 return nil 395 } 396 397 // ensureVersionTable checks if versions table exists and, if not, creates it. 398 // Note that this function locks the database, which deviates from the usual 399 // convention of "caller locks" in the Mysql type. 400 func (m *Mysql) ensureVersionTable() (err error) { 401 if err = m.Lock(); err != nil { 402 return err 403 } 404 405 defer func() { 406 if e := m.Unlock(); e != nil { 407 if err == nil { 408 err = e 409 } else { 410 err = multierror.Append(err, e) 411 } 412 } 413 }() 414 415 // check if migration table exists 416 var result string 417 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 418 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 419 if err != sql.ErrNoRows { 420 return &database.Error{OrigErr: err, Query: []byte(query)} 421 } 422 } else { 423 return nil 424 } 425 426 // if not, create the empty migration table 427 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 428 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 429 return &database.Error{OrigErr: err, Query: []byte(query)} 430 } 431 return nil 432 } 433 434 // Returns the bool value of the input. 435 // The 2nd return value indicates if the input was a valid bool value 436 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 437 func readBool(input string) (value bool, valid bool) { 438 switch input { 439 case "1", "true", "TRUE", "True": 440 return true, true 441 case "0", "false", "FALSE", "False": 442 return false, true 443 } 444 445 // Not a valid bool value 446 return 447 }