github.com/nokia/migrate/v4@v4.16.0/database/mysql/mysql.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package mysql 5 6 import ( 7 "context" 8 "crypto/tls" 9 "crypto/x509" 10 "database/sql" 11 "fmt" 12 "io" 13 "io/ioutil" 14 nurl "net/url" 15 "strconv" 16 "strings" 17 18 "go.uber.org/atomic" 19 20 "github.com/go-sql-driver/mysql" 21 "github.com/hashicorp/go-multierror" 22 "github.com/nokia/migrate/v4/database" 23 "github.com/nokia/migrate/v4/source" 24 ) 25 26 var _ database.Driver = (*Mysql)(nil) // explicit compile time type check 27 28 func init() { 29 database.Register("mysql", &Mysql{}) 30 } 31 32 var DefaultMigrationsTable = "schema_migrations" 33 34 var ( 35 ErrDatabaseDirty = fmt.Errorf("database is dirty") 36 ErrNilConfig = fmt.Errorf("no config") 37 ErrNoDatabaseName = fmt.Errorf("no database name") 38 ErrAppendPEM = fmt.Errorf("failed to append PEM") 39 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 40 ) 41 42 type Config struct { 43 MigrationsTable string 44 DatabaseName string 45 NoLock bool 46 } 47 48 type Mysql struct { 49 // mysql RELEASE_LOCK must be called from the same conn, so 50 // just do everything over a single conn anyway. 51 conn *sql.Conn 52 db *sql.DB 53 isLocked atomic.Bool 54 55 config *Config 56 } 57 58 // connection instance must have `multiStatements` set to true 59 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) { 60 if config == nil { 61 return nil, ErrNilConfig 62 } 63 64 if err := conn.PingContext(ctx); err != nil { 65 return nil, err 66 } 67 68 mx := &Mysql{ 69 conn: conn, 70 db: nil, 71 config: config, 72 } 73 74 if config.DatabaseName == "" { 75 query := `SELECT DATABASE()` 76 var databaseName sql.NullString 77 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { 78 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 79 } 80 81 if len(databaseName.String) == 0 { 82 return nil, ErrNoDatabaseName 83 } 84 85 config.DatabaseName = databaseName.String 86 } 87 88 if len(config.MigrationsTable) == 0 { 89 config.MigrationsTable = DefaultMigrationsTable 90 } 91 92 if err := mx.ensureVersionTable(); err != nil { 93 return nil, err 94 } 95 96 return mx, nil 97 } 98 99 // instance must have `multiStatements` set to true 100 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 101 ctx := context.Background() 102 103 if err := instance.Ping(); err != nil { 104 return nil, err 105 } 106 107 conn, err := instance.Conn(ctx) 108 if err != nil { 109 return nil, err 110 } 111 112 mx, err := WithConnection(ctx, conn, config) 113 if err != nil { 114 return nil, err 115 } 116 117 mx.db = instance 118 119 return mx, nil 120 } 121 122 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 123 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 124 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 125 if c == nil { 126 return nil, ErrNilConfig 127 } 128 customQueryParams := map[string]string{} 129 130 for k, v := range c.Params { 131 if strings.HasPrefix(k, "x-") { 132 customQueryParams[k] = v 133 delete(c.Params, k) 134 } 135 } 136 return customQueryParams, nil 137 } 138 139 func urlToMySQLConfig(url string) (*mysql.Config, error) { 140 // Need to parse out custom TLS parameters and call 141 // mysql.RegisterTLSConfig() before mysql.ParseDSN() is called 142 // which consumes the registered tls.Config 143 // Fixes: https://github.com/nokia/migrate/issues/411 144 // 145 // Can't use url.Parse() since it fails to parse MySQL DSNs 146 // mysql.ParseDSN() also searches for "?" to find query parameters: 147 // https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344 148 if idx := strings.LastIndex(url, "?"); idx > 0 { 149 rawParams := url[idx+1:] 150 parsedParams, err := nurl.ParseQuery(rawParams) 151 if err != nil { 152 return nil, err 153 } 154 155 ctls := parsedParams.Get("tls") 156 if len(ctls) > 0 { 157 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 158 rootCertPool := x509.NewCertPool() 159 pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca")) 160 if err != nil { 161 return nil, err 162 } 163 164 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 165 return nil, ErrAppendPEM 166 } 167 168 clientCert := make([]tls.Certificate, 0, 1) 169 if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" { 170 if ccert == "" || ckey == "" { 171 return nil, ErrTLSCertKeyConfig 172 } 173 certs, err := tls.LoadX509KeyPair(ccert, ckey) 174 if err != nil { 175 return nil, err 176 } 177 clientCert = append(clientCert, certs) 178 } 179 180 insecureSkipVerify := false 181 insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify") 182 if len(insecureSkipVerifyStr) > 0 { 183 x, err := strconv.ParseBool(insecureSkipVerifyStr) 184 if err != nil { 185 return nil, err 186 } 187 insecureSkipVerify = x 188 } 189 190 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 191 RootCAs: rootCertPool, 192 Certificates: clientCert, 193 InsecureSkipVerify: insecureSkipVerify, 194 }) 195 if err != nil { 196 return nil, err 197 } 198 } 199 } 200 } 201 202 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 203 if err != nil { 204 return nil, err 205 } 206 207 config.MultiStatements = true 208 209 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 210 // net/url.Parse() would automatically unescape it for us. 211 // See: https://play.golang.org/p/q9j1io-YICQ 212 user, err := nurl.QueryUnescape(config.User) 213 if err != nil { 214 return nil, err 215 } 216 config.User = user 217 218 password, err := nurl.QueryUnescape(config.Passwd) 219 if err != nil { 220 return nil, err 221 } 222 config.Passwd = password 223 224 return config, nil 225 } 226 227 func (m *Mysql) Open(url string) (database.Driver, error) { 228 config, err := urlToMySQLConfig(url) 229 if err != nil { 230 return nil, err 231 } 232 233 customParams, err := extractCustomQueryParams(config) 234 if err != nil { 235 return nil, err 236 } 237 238 noLockParam, noLock := customParams["x-no-lock"], false 239 if noLockParam != "" { 240 noLock, err = strconv.ParseBool(noLockParam) 241 if err != nil { 242 return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err) 243 } 244 } 245 246 db, err := sql.Open("mysql", config.FormatDSN()) 247 if err != nil { 248 return nil, err 249 } 250 251 mx, err := WithInstance(db, &Config{ 252 DatabaseName: config.DBName, 253 MigrationsTable: customParams["x-migrations-table"], 254 NoLock: noLock, 255 }) 256 if err != nil { 257 return nil, err 258 } 259 260 return mx, nil 261 } 262 263 func (m *Mysql) Close() error { 264 connErr := m.conn.Close() 265 var dbErr error 266 if m.db != nil { 267 dbErr = m.db.Close() 268 } 269 270 if connErr != nil || dbErr != nil { 271 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 272 } 273 return nil 274 } 275 276 func (m *Mysql) Lock() error { 277 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 278 if m.config.NoLock { 279 return nil 280 } 281 aid, err := database.GenerateAdvisoryLockId( 282 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 283 if err != nil { 284 return err 285 } 286 287 query := "SELECT GET_LOCK(?, 10)" 288 var success bool 289 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 290 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 291 } 292 293 if !success { 294 return database.ErrLocked 295 } 296 297 return nil 298 }) 299 } 300 301 func (m *Mysql) Unlock() error { 302 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 303 if m.config.NoLock { 304 return nil 305 } 306 307 aid, err := database.GenerateAdvisoryLockId( 308 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 309 if err != nil { 310 return err 311 } 312 313 query := `SELECT RELEASE_LOCK(?)` 314 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 315 return &database.Error{OrigErr: err, Query: []byte(query)} 316 } 317 318 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 319 // in which case isLocked should be true until the timeout expires -- synchronizing 320 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 321 322 return nil 323 }) 324 } 325 326 func (m *Mysql) Run(migration io.Reader) error { 327 migr, err := ioutil.ReadAll(migration) 328 if err != nil { 329 return err 330 } 331 332 query := string(migr[:]) 333 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 334 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 335 } 336 337 return nil 338 } 339 340 func (m *Mysql) RunFunctionMigration(fn source.MigrationFunc) error { 341 return database.ErrNotImpl 342 } 343 344 func (m *Mysql) SetVersion(version int, dirty bool) error { 345 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) 346 if err != nil { 347 return &database.Error{OrigErr: err, Err: "transaction start failed"} 348 } 349 350 query := "DELETE FROM `" + m.config.MigrationsTable + "`" 351 if _, err := tx.ExecContext(context.Background(), query); err != nil { 352 if errRollback := tx.Rollback(); errRollback != nil { 353 err = multierror.Append(err, errRollback) 354 } 355 return &database.Error{OrigErr: err, Query: []byte(query)} 356 } 357 358 // Also re-write the schema version for nil dirty versions to prevent 359 // empty schema version for failed down migration on the first migration 360 // See: https://github.com/nokia/migrate/issues/330 361 if version >= 0 || (version == database.NilVersion && dirty) { 362 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 363 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 364 if errRollback := tx.Rollback(); errRollback != nil { 365 err = multierror.Append(err, errRollback) 366 } 367 return &database.Error{OrigErr: err, Query: []byte(query)} 368 } 369 } 370 371 if err := tx.Commit(); err != nil { 372 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 373 } 374 375 return nil 376 } 377 378 func (m *Mysql) Version() (version int, dirty bool, err error) { 379 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 380 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 381 switch { 382 case err == sql.ErrNoRows: 383 return database.NilVersion, false, nil 384 385 case err != nil: 386 if e, ok := err.(*mysql.MySQLError); ok { 387 if e.Number == 0 { 388 return database.NilVersion, false, nil 389 } 390 } 391 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 392 393 default: 394 return version, dirty, nil 395 } 396 } 397 398 func (m *Mysql) Drop() (err error) { 399 // select all tables 400 query := `SHOW TABLES LIKE '%'` 401 tables, err := m.conn.QueryContext(context.Background(), query) 402 if err != nil { 403 return &database.Error{OrigErr: err, Query: []byte(query)} 404 } 405 defer func() { 406 if errClose := tables.Close(); errClose != nil { 407 err = multierror.Append(err, errClose) 408 } 409 }() 410 411 // delete one table after another 412 tableNames := make([]string, 0) 413 for tables.Next() { 414 var tableName string 415 if err := tables.Scan(&tableName); err != nil { 416 return err 417 } 418 if len(tableName) > 0 { 419 tableNames = append(tableNames, tableName) 420 } 421 } 422 if err := tables.Err(); err != nil { 423 return &database.Error{OrigErr: err, Query: []byte(query)} 424 } 425 426 if len(tableNames) > 0 { 427 // disable checking foreign key constraints until finished 428 query = `SET foreign_key_checks = 0` 429 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 430 return &database.Error{OrigErr: err, Query: []byte(query)} 431 } 432 433 defer func() { 434 // enable foreign key checks 435 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 436 }() 437 438 // delete one by one ... 439 for _, t := range tableNames { 440 query = "DROP TABLE IF EXISTS `" + t + "`" 441 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 442 return &database.Error{OrigErr: err, Query: []byte(query)} 443 } 444 } 445 } 446 447 return nil 448 } 449 450 // ensureVersionTable checks if versions table exists and, if not, creates it. 451 // Note that this function locks the database, which deviates from the usual 452 // convention of "caller locks" in the Mysql type. 453 func (m *Mysql) ensureVersionTable() (err error) { 454 if err = m.Lock(); err != nil { 455 return err 456 } 457 458 defer func() { 459 if e := m.Unlock(); e != nil { 460 if err == nil { 461 err = e 462 } else { 463 err = multierror.Append(err, e) 464 } 465 } 466 }() 467 468 // check if migration table exists 469 var result string 470 query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'` 471 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 472 if err != sql.ErrNoRows { 473 return &database.Error{OrigErr: err, Query: []byte(query)} 474 } 475 } else { 476 return nil 477 } 478 479 // if not, create the empty migration table 480 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 481 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 482 return &database.Error{OrigErr: err, Query: []byte(query)} 483 } 484 return nil 485 } 486 487 // Returns the bool value of the input. 488 // The 2nd return value indicates if the input was a valid bool value 489 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 490 func readBool(input string) (value bool, valid bool) { 491 switch input { 492 case "1", "true", "TRUE", "True": 493 return true, true 494 case "0", "false", "FALSE", "False": 495 return false, true 496 } 497 498 // Not a valid bool value 499 return 500 }