github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/mysql/mysql.go (about) 1 //go:build go1.9 2 // +build go1.9 3 4 package mysql 5 6 import ( 7 "context" 8 "crypto/tls" 9 "crypto/x509" 10 "database/sql" 11 "fmt" 12 "io" 13 nurl "net/url" 14 "os" 15 "strconv" 16 "strings" 17 "time" 18 19 "go.uber.org/atomic" 20 21 "github.com/go-sql-driver/mysql" 22 "github.com/golang-migrate/migrate/v4/database" 23 "github.com/hashicorp/go-multierror" 24 ) 25 26 var _ database.Driver = (*Mysql)(nil) // explicit compile time type check 27 28 func init() { 29 database.Register("mysql", &Mysql{}) 30 } 31 32 var DefaultMigrationsTable = "schema_migrations" 33 34 var ( 35 ErrDatabaseDirty = fmt.Errorf("database is dirty") 36 ErrNilConfig = fmt.Errorf("no config") 37 ErrNoDatabaseName = fmt.Errorf("no database name") 38 ErrAppendPEM = fmt.Errorf("failed to append PEM") 39 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 40 ) 41 42 type Config struct { 43 MigrationsTable string 44 DatabaseName string 45 NoLock bool 46 StatementTimeout time.Duration 47 } 48 49 type Mysql struct { 50 // mysql RELEASE_LOCK must be called from the same conn, so 51 // just do everything over a single conn anyway. 52 conn *sql.Conn 53 db *sql.DB 54 isLocked atomic.Bool 55 56 config *Config 57 } 58 59 // connection instance must have `multiStatements` set to true 60 func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) { 61 if config == nil { 62 return nil, ErrNilConfig 63 } 64 65 if err := conn.PingContext(ctx); err != nil { 66 return nil, err 67 } 68 69 mx := &Mysql{ 70 conn: conn, 71 db: nil, 72 config: config, 73 } 74 75 if config.DatabaseName == "" { 76 query := `SELECT DATABASE()` 77 var databaseName sql.NullString 78 if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil { 79 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 80 } 81 82 if len(databaseName.String) == 0 { 83 return nil, ErrNoDatabaseName 84 } 85 86 config.DatabaseName = databaseName.String 87 } 88 89 if len(config.MigrationsTable) == 0 { 90 config.MigrationsTable = DefaultMigrationsTable 91 } 92 93 if err := mx.ensureVersionTable(); err != nil { 94 return nil, err 95 } 96 97 return mx, nil 98 } 99 100 // instance must have `multiStatements` set to true 101 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 102 ctx := context.Background() 103 104 if err := instance.Ping(); err != nil { 105 return nil, err 106 } 107 108 conn, err := instance.Conn(ctx) 109 if err != nil { 110 return nil, err 111 } 112 113 mx, err := WithConnection(ctx, conn, config) 114 if err != nil { 115 return nil, err 116 } 117 118 mx.db = instance 119 120 return mx, nil 121 } 122 123 // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from 124 // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL 125 func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) { 126 if c == nil { 127 return nil, ErrNilConfig 128 } 129 customQueryParams := map[string]string{} 130 131 for k, v := range c.Params { 132 if strings.HasPrefix(k, "x-") { 133 customQueryParams[k] = v 134 delete(c.Params, k) 135 } 136 } 137 return customQueryParams, nil 138 } 139 140 func urlToMySQLConfig(url string) (*mysql.Config, error) { 141 // Need to parse out custom TLS parameters and call 142 // mysql.RegisterTLSConfig() before mysql.ParseDSN() is called 143 // which consumes the registered tls.Config 144 // Fixes: https://github.com/golang-migrate/migrate/issues/411 145 // 146 // Can't use url.Parse() since it fails to parse MySQL DSNs 147 // mysql.ParseDSN() also searches for "?" to find query parameters: 148 // https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344 149 if idx := strings.LastIndex(url, "?"); idx > 0 { 150 rawParams := url[idx+1:] 151 parsedParams, err := nurl.ParseQuery(rawParams) 152 if err != nil { 153 return nil, err 154 } 155 156 ctls := parsedParams.Get("tls") 157 if len(ctls) > 0 { 158 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 159 rootCertPool := x509.NewCertPool() 160 pem, err := os.ReadFile(parsedParams.Get("x-tls-ca")) 161 if err != nil { 162 return nil, err 163 } 164 165 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 166 return nil, ErrAppendPEM 167 } 168 169 clientCert := make([]tls.Certificate, 0, 1) 170 if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" { 171 if ccert == "" || ckey == "" { 172 return nil, ErrTLSCertKeyConfig 173 } 174 certs, err := tls.LoadX509KeyPair(ccert, ckey) 175 if err != nil { 176 return nil, err 177 } 178 clientCert = append(clientCert, certs) 179 } 180 181 insecureSkipVerify := false 182 insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify") 183 if len(insecureSkipVerifyStr) > 0 { 184 x, err := strconv.ParseBool(insecureSkipVerifyStr) 185 if err != nil { 186 return nil, err 187 } 188 insecureSkipVerify = x 189 } 190 191 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 192 RootCAs: rootCertPool, 193 Certificates: clientCert, 194 InsecureSkipVerify: insecureSkipVerify, 195 }) 196 if err != nil { 197 return nil, err 198 } 199 } 200 } 201 } 202 203 config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://")) 204 if err != nil { 205 return nil, err 206 } 207 208 config.MultiStatements = true 209 210 // Keep backwards compatibility from when we used net/url.Parse() to parse the DSN. 211 // net/url.Parse() would automatically unescape it for us. 212 // See: https://play.golang.org/p/q9j1io-YICQ 213 user, err := nurl.QueryUnescape(config.User) 214 if err != nil { 215 return nil, err 216 } 217 config.User = user 218 219 password, err := nurl.QueryUnescape(config.Passwd) 220 if err != nil { 221 return nil, err 222 } 223 config.Passwd = password 224 225 return config, nil 226 } 227 228 func (m *Mysql) Open(url string) (database.Driver, error) { 229 config, err := urlToMySQLConfig(url) 230 if err != nil { 231 return nil, err 232 } 233 234 customParams, err := extractCustomQueryParams(config) 235 if err != nil { 236 return nil, err 237 } 238 239 noLockParam, noLock := customParams["x-no-lock"], false 240 if noLockParam != "" { 241 noLock, err = strconv.ParseBool(noLockParam) 242 if err != nil { 243 return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err) 244 } 245 } 246 247 statementTimeoutParam := customParams["x-statement-timeout"] 248 statementTimeout := 0 249 if statementTimeoutParam != "" { 250 statementTimeout, err = strconv.Atoi(statementTimeoutParam) 251 if err != nil { 252 return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err) 253 } 254 } 255 256 db, err := sql.Open("mysql", config.FormatDSN()) 257 if err != nil { 258 return nil, err 259 } 260 261 mx, err := WithInstance(db, &Config{ 262 DatabaseName: config.DBName, 263 MigrationsTable: customParams["x-migrations-table"], 264 NoLock: noLock, 265 StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, 266 }) 267 if err != nil { 268 return nil, err 269 } 270 271 return mx, nil 272 } 273 274 func (m *Mysql) Close() error { 275 connErr := m.conn.Close() 276 var dbErr error 277 if m.db != nil { 278 dbErr = m.db.Close() 279 } 280 281 if connErr != nil || dbErr != nil { 282 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 283 } 284 return nil 285 } 286 287 func (m *Mysql) Lock() error { 288 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 289 if m.config.NoLock { 290 return nil 291 } 292 aid, err := database.GenerateAdvisoryLockId( 293 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 294 if err != nil { 295 return err 296 } 297 298 query := "SELECT GET_LOCK(?, 10)" 299 var success bool 300 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 301 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 302 } 303 304 if !success { 305 return database.ErrLocked 306 } 307 308 return nil 309 }) 310 } 311 312 func (m *Mysql) Unlock() error { 313 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 314 if m.config.NoLock { 315 return nil 316 } 317 318 aid, err := database.GenerateAdvisoryLockId( 319 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 320 if err != nil { 321 return err 322 } 323 324 query := `SELECT RELEASE_LOCK(?)` 325 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 326 return &database.Error{OrigErr: err, Query: []byte(query)} 327 } 328 329 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 330 // in which case isLocked should be true until the timeout expires -- synchronizing 331 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 332 333 return nil 334 }) 335 } 336 337 func (m *Mysql) Run(migration io.Reader) error { 338 migr, err := io.ReadAll(migration) 339 if err != nil { 340 return err 341 } 342 343 ctx := context.Background() 344 if m.config.StatementTimeout != 0 { 345 var cancel context.CancelFunc 346 ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout) 347 defer cancel() 348 } 349 350 query := string(migr[:]) 351 if _, err := m.conn.ExecContext(ctx, query); err != nil { 352 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 353 } 354 355 return nil 356 } 357 358 func (m *Mysql) SetVersion(version int, dirty bool) error { 359 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) 360 if err != nil { 361 return &database.Error{OrigErr: err, Err: "transaction start failed"} 362 } 363 364 query := "DELETE FROM `" + m.config.MigrationsTable + "`" 365 if _, err := tx.ExecContext(context.Background(), query); err != nil { 366 if errRollback := tx.Rollback(); errRollback != nil { 367 err = multierror.Append(err, errRollback) 368 } 369 return &database.Error{OrigErr: err, Query: []byte(query)} 370 } 371 372 // Also re-write the schema version for nil dirty versions to prevent 373 // empty schema version for failed down migration on the first migration 374 // See: https://github.com/golang-migrate/migrate/issues/330 375 if version >= 0 || (version == database.NilVersion && dirty) { 376 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 377 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 378 if errRollback := tx.Rollback(); errRollback != nil { 379 err = multierror.Append(err, errRollback) 380 } 381 return &database.Error{OrigErr: err, Query: []byte(query)} 382 } 383 } 384 385 if err := tx.Commit(); err != nil { 386 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 387 } 388 389 return nil 390 } 391 392 func (m *Mysql) Version() (version int, dirty bool, err error) { 393 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 394 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 395 switch { 396 case err == sql.ErrNoRows: 397 return database.NilVersion, false, nil 398 399 case err != nil: 400 if e, ok := err.(*mysql.MySQLError); ok { 401 if e.Number == 0 { 402 return database.NilVersion, false, nil 403 } 404 } 405 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 406 407 default: 408 return version, dirty, nil 409 } 410 } 411 412 func (m *Mysql) Drop() (err error) { 413 // select all tables 414 query := `SHOW TABLES LIKE '%'` 415 tables, err := m.conn.QueryContext(context.Background(), query) 416 if err != nil { 417 return &database.Error{OrigErr: err, Query: []byte(query)} 418 } 419 defer func() { 420 if errClose := tables.Close(); errClose != nil { 421 err = multierror.Append(err, errClose) 422 } 423 }() 424 425 // delete one table after another 426 tableNames := make([]string, 0) 427 for tables.Next() { 428 var tableName string 429 if err := tables.Scan(&tableName); err != nil { 430 return err 431 } 432 if len(tableName) > 0 { 433 tableNames = append(tableNames, tableName) 434 } 435 } 436 if err := tables.Err(); err != nil { 437 return &database.Error{OrigErr: err, Query: []byte(query)} 438 } 439 440 if len(tableNames) > 0 { 441 // disable checking foreign key constraints until finished 442 query = `SET foreign_key_checks = 0` 443 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 444 return &database.Error{OrigErr: err, Query: []byte(query)} 445 } 446 447 defer func() { 448 // enable foreign key checks 449 _, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`) 450 }() 451 452 // delete one by one ... 453 for _, t := range tableNames { 454 query = "DROP TABLE IF EXISTS `" + t + "`" 455 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 456 return &database.Error{OrigErr: err, Query: []byte(query)} 457 } 458 } 459 } 460 461 return nil 462 } 463 464 // ensureVersionTable checks if versions table exists and, if not, creates it. 465 // Note that this function locks the database, which deviates from the usual 466 // convention of "caller locks" in the Mysql type. 467 func (m *Mysql) ensureVersionTable() (err error) { 468 if err = m.Lock(); err != nil { 469 return err 470 } 471 472 defer func() { 473 if e := m.Unlock(); e != nil { 474 if err == nil { 475 err = e 476 } else { 477 err = multierror.Append(err, e) 478 } 479 } 480 }() 481 482 // check if migration table exists 483 var result string 484 query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'` 485 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 486 if err != sql.ErrNoRows { 487 return &database.Error{OrigErr: err, Query: []byte(query)} 488 } 489 } else { 490 return nil 491 } 492 493 // if not, create the empty migration table 494 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 495 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 496 return &database.Error{OrigErr: err, Query: []byte(query)} 497 } 498 return nil 499 } 500 501 // Returns the bool value of the input. 502 // The 2nd return value indicates if the input was a valid bool value 503 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 504 func readBool(input string) (value bool, valid bool) { 505 switch input { 506 case "1", "true", "TRUE", "True": 507 return true, true 508 case "0", "false", "FALSE", "False": 509 return false, true 510 } 511 512 // Not a valid bool value 513 return 514 }