github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/mysql/mysql.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package mysql 5 6 import ( 7 "context" 8 "crypto/tls" 9 "crypto/x509" 10 "database/sql" 11 "fmt" 12 "io" 13 "io/ioutil" 14 nurl "net/url" 15 "strconv" 16 "strings" 17 18 "go.uber.org/atomic" 19 20 "github.com/go-sql-driver/mysql" 21 "github.com/hashicorp/go-multierror" 22 "github.com/seashell-org/golang-migrate/v4/database" 23 ) 24 25 var _ database.Driver = (*Mysql)(nil) // explicit compile time type check 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 NoLock bool 45 } 46 47 type Mysql struct { 48 // mysql RELEASE_LOCK must be called from the same conn, so 49 // just do everything over a single conn anyway. 50 conn *sql.Conn 51 db *sql.DB 52 isLocked atomic.Bool 53 54 config *Config 55 } 56 57 // connection instance must have `multiStatements` set to true 58 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) { 59 if config == nil { 60 return nil, ErrNilConfig 61 } 62 63 if err := conn.PingContext(ctx); err != nil { 64 return nil, err 65 } 66 67 mx := &Mysql{ 68 conn: conn, 69 db: nil, 70 config: config, 71 } 72 73 if config.DatabaseName == "" { 74 query := `SELECT DATABASE()` 75 var databaseName sql.NullString 76 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { 77 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 78 } 79 80 if len(databaseName.String) == 0 { 81 return nil, ErrNoDatabaseName 82 } 83 84 config.DatabaseName = databaseName.String 85 } 86 87 if len(config.MigrationsTable) == 0 { 88 config.MigrationsTable = DefaultMigrationsTable 89 } 90 91 if err := mx.ensureVersionTable(); err != nil { 92 return nil, err 93 } 94 95 return mx, nil 96 } 97 98 // instance must have `multiStatements` set to true 99 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 100 ctx := context.Background() 101 102 if err := instance.Ping(); err != nil { 103 return nil, err 104 } 105 106 conn, err := instance.Conn(ctx) 107 if err != nil { 108 return nil, err 109 } 110 111 mx, err := WithConnection(ctx, conn, config) 112 if err != nil { 113 return nil, err 114 } 115 116 mx.db = instance 117 118 return mx, nil 119 } 120 121 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 122 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 123 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 124 if c == nil { 125 return nil, ErrNilConfig 126 } 127 customQueryParams := map[string]string{} 128 129 for k, v := range c.Params { 130 if strings.HasPrefix(k, "x-") { 131 customQueryParams[k] = v 132 delete(c.Params, k) 133 } 134 } 135 return customQueryParams, nil 136 } 137 138 func urlToMySQLConfig(url string) (*mysql.Config, error) { 139 // Need to parse out custom TLS parameters and call 140 // mysql.RegisterTLSConfig() before mysql.ParseDSN() is called 141 // which consumes the registered tls.Config 142 // Fixes: https://github.com/seashell-org/golang-migrate/issues/411 143 // 144 // Can't use url.Parse() since it fails to parse MySQL DSNs 145 // mysql.ParseDSN() also searches for "?" to find query parameters: 146 // https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344 147 if idx := strings.LastIndex(url, "?"); idx > 0 { 148 rawParams := url[idx+1:] 149 parsedParams, err := nurl.ParseQuery(rawParams) 150 if err != nil { 151 return nil, err 152 } 153 154 ctls := parsedParams.Get("tls") 155 if len(ctls) > 0 { 156 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 157 rootCertPool := x509.NewCertPool() 158 pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca")) 159 if err != nil { 160 return nil, err 161 } 162 163 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 164 return nil, ErrAppendPEM 165 } 166 167 clientCert := make([]tls.Certificate, 0, 1) 168 if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" { 169 if ccert == "" || ckey == "" { 170 return nil, ErrTLSCertKeyConfig 171 } 172 certs, err := tls.LoadX509KeyPair(ccert, ckey) 173 if err != nil { 174 return nil, err 175 } 176 clientCert = append(clientCert, certs) 177 } 178 179 insecureSkipVerify := false 180 insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify") 181 if len(insecureSkipVerifyStr) > 0 { 182 x, err := strconv.ParseBool(insecureSkipVerifyStr) 183 if err != nil { 184 return nil, err 185 } 186 insecureSkipVerify = x 187 } 188 189 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 190 RootCAs: rootCertPool, 191 Certificates: clientCert, 192 InsecureSkipVerify: insecureSkipVerify, 193 }) 194 if err != nil { 195 return nil, err 196 } 197 } 198 } 199 } 200 201 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 202 if err != nil { 203 return nil, err 204 } 205 206 config.MultiStatements = true 207 208 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 209 // net/url.Parse() would automatically unescape it for us. 210 // See: https://play.golang.org/p/q9j1io-YICQ 211 user, err := nurl.QueryUnescape(config.User) 212 if err != nil { 213 return nil, err 214 } 215 config.User = user 216 217 password, err := nurl.QueryUnescape(config.Passwd) 218 if err != nil { 219 return nil, err 220 } 221 config.Passwd = password 222 223 return config, nil 224 } 225 226 func (m *Mysql) Open(url string) (database.Driver, error) { 227 config, err := urlToMySQLConfig(url) 228 if err != nil { 229 return nil, err 230 } 231 232 customParams, err := extractCustomQueryParams(config) 233 if err != nil { 234 return nil, err 235 } 236 237 noLockParam, noLock := customParams["x-no-lock"], false 238 if noLockParam != "" { 239 noLock, err = strconv.ParseBool(noLockParam) 240 if err != nil { 241 return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err) 242 } 243 } 244 245 db, err := sql.Open("mysql", config.FormatDSN()) 246 if err != nil { 247 return nil, err 248 } 249 250 mx, err := WithInstance(db, &Config{ 251 DatabaseName: config.DBName, 252 MigrationsTable: customParams["x-migrations-table"], 253 NoLock: noLock, 254 }) 255 if err != nil { 256 return nil, err 257 } 258 259 return mx, nil 260 } 261 262 func (m *Mysql) Close() error { 263 connErr := m.conn.Close() 264 var dbErr error 265 if m.db != nil { 266 dbErr = m.db.Close() 267 } 268 269 if connErr != nil || dbErr != nil { 270 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 271 } 272 return nil 273 } 274 275 func (m *Mysql) Lock() error { 276 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 277 if m.config.NoLock { 278 return nil 279 } 280 aid, err := database.GenerateAdvisoryLockId( 281 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 282 if err != nil { 283 return err 284 } 285 286 query := "SELECT GET_LOCK(?, 10)" 287 var success bool 288 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 289 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 290 } 291 292 if !success { 293 return database.ErrLocked 294 } 295 296 return nil 297 }) 298 } 299 300 func (m *Mysql) Unlock() error { 301 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 302 if m.config.NoLock { 303 return nil 304 } 305 306 aid, err := database.GenerateAdvisoryLockId( 307 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 308 if err != nil { 309 return err 310 } 311 312 query := `SELECT RELEASE_LOCK(?)` 313 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 314 return &database.Error{OrigErr: err, Query: []byte(query)} 315 } 316 317 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 318 // in which case isLocked should be true until the timeout expires -- synchronizing 319 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 320 321 return nil 322 }) 323 } 324 325 func (m *Mysql) Run(migration io.Reader) error { 326 migr, err := ioutil.ReadAll(migration) 327 if err != nil { 328 return err 329 } 330 331 query := string(migr[:]) 332 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 333 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 334 } 335 336 return nil 337 } 338 339 func (m *Mysql) SetVersion(version int, dirty bool) error { 340 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) 341 if err != nil { 342 return &database.Error{OrigErr: err, Err: "transaction start failed"} 343 } 344 345 query := "DELETE FROM `" + m.config.MigrationsTable + "`" 346 if _, err := tx.ExecContext(context.Background(), query); err != nil { 347 if errRollback := tx.Rollback(); errRollback != nil { 348 err = multierror.Append(err, errRollback) 349 } 350 return &database.Error{OrigErr: err, Query: []byte(query)} 351 } 352 353 // Also re-write the schema version for nil dirty versions to prevent 354 // empty schema version for failed down migration on the first migration 355 // See: https://github.com/seashell-org/golang-migrate/issues/330 356 if version >= 0 || (version == database.NilVersion && dirty) { 357 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 358 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 359 if errRollback := tx.Rollback(); errRollback != nil { 360 err = multierror.Append(err, errRollback) 361 } 362 return &database.Error{OrigErr: err, Query: []byte(query)} 363 } 364 } 365 366 if err := tx.Commit(); err != nil { 367 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 368 } 369 370 return nil 371 } 372 373 func (m *Mysql) Version() (version int, dirty bool, err error) { 374 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 375 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 376 switch { 377 case err == sql.ErrNoRows: 378 return database.NilVersion, false, nil 379 380 case err != nil: 381 if e, ok := err.(*mysql.MySQLError); ok { 382 if e.Number == 0 { 383 return database.NilVersion, false, nil 384 } 385 } 386 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 387 388 default: 389 return version, dirty, nil 390 } 391 } 392 393 func (m *Mysql) Drop() (err error) { 394 // select all tables 395 query := `SHOW TABLES LIKE '%'` 396 tables, err := m.conn.QueryContext(context.Background(), query) 397 if err != nil { 398 return &database.Error{OrigErr: err, Query: []byte(query)} 399 } 400 defer func() { 401 if errClose := tables.Close(); errClose != nil { 402 err = multierror.Append(err, errClose) 403 } 404 }() 405 406 // delete one table after another 407 tableNames := make([]string, 0) 408 for tables.Next() { 409 var tableName string 410 if err := tables.Scan(&tableName); err != nil { 411 return err 412 } 413 if len(tableName) > 0 { 414 tableNames = append(tableNames, tableName) 415 } 416 } 417 if err := tables.Err(); err != nil { 418 return &database.Error{OrigErr: err, Query: []byte(query)} 419 } 420 421 if len(tableNames) > 0 { 422 // disable checking foreign key constraints until finished 423 query = `SET foreign_key_checks = 0` 424 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 425 return &database.Error{OrigErr: err, Query: []byte(query)} 426 } 427 428 defer func() { 429 // enable foreign key checks 430 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 431 }() 432 433 // delete one by one ... 434 for _, t := range tableNames { 435 query = "DROP TABLE IF EXISTS `" + t + "`" 436 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 437 return &database.Error{OrigErr: err, Query: []byte(query)} 438 } 439 } 440 } 441 442 return nil 443 } 444 445 // ensureVersionTable checks if versions table exists and, if not, creates it. 446 // Note that this function locks the database, which deviates from the usual 447 // convention of "caller locks" in the Mysql type. 448 func (m *Mysql) ensureVersionTable() (err error) { 449 if err = m.Lock(); err != nil { 450 return err 451 } 452 453 defer func() { 454 if e := m.Unlock(); e != nil { 455 if err == nil { 456 err = e 457 } else { 458 err = multierror.Append(err, e) 459 } 460 } 461 }() 462 463 // check if migration table exists 464 var result string 465 query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'` 466 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 467 if err != sql.ErrNoRows { 468 return &database.Error{OrigErr: err, Query: []byte(query)} 469 } 470 } else { 471 return nil 472 } 473 474 // if not, create the empty migration table 475 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 476 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 477 return &database.Error{OrigErr: err, Query: []byte(query)} 478 } 479 return nil 480 } 481 482 // Returns the bool value of the input. 483 // The 2nd return value indicates if the input was a valid bool value 484 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 485 func readBool(input string) (value bool, valid bool) { 486 switch input { 487 case "1", "true", "TRUE", "True": 488 return true, true 489 case "0", "false", "FALSE", "False": 490 return false, true 491 } 492 493 // Not a valid bool value 494 return 495 }