github.com/nagyist/migrate/v4@v4.14.6/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  import (
    19  	"github.com/go-sql-driver/mysql"
    20  	"github.com/hashicorp/go-multierror"
    21  )
    22  
    23  import (
    24  	"github.com/golang-migrate/migrate/v4/database"
    25  )
    26  
    27  func init() {
    28  	database.Register("mysql", &Mysql{})
    29  }
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    35  	ErrNilConfig        = fmt.Errorf("no config")
    36  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    37  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    38  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    39  )
    40  
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  	NoLock          bool
    45  }
    46  
    47  type Mysql struct {
    48  	// mysql RELEASE_LOCK must be called from the same conn, so
    49  	// just do everything over a single conn anyway.
    50  	conn     *sql.Conn
    51  	db       *sql.DB
    52  	isLocked bool
    53  
    54  	config *Config
    55  }
    56  
    57  // instance must have `multiStatements` set to true
    58  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    59  	if config == nil {
    60  		return nil, ErrNilConfig
    61  	}
    62  
    63  	if err := instance.Ping(); err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	if config.DatabaseName == "" {
    68  		query := `SELECT DATABASE()`
    69  		var databaseName sql.NullString
    70  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    71  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    72  		}
    73  
    74  		if len(databaseName.String) == 0 {
    75  			return nil, ErrNoDatabaseName
    76  		}
    77  
    78  		config.DatabaseName = databaseName.String
    79  	}
    80  
    81  	if len(config.MigrationsTable) == 0 {
    82  		config.MigrationsTable = DefaultMigrationsTable
    83  	}
    84  
    85  	conn, err := instance.Conn(context.Background())
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	mx := &Mysql{
    91  		conn:   conn,
    92  		db:     instance,
    93  		config: config,
    94  	}
    95  
    96  	if err := mx.ensureVersionTable(); err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	return mx, nil
   101  }
   102  
   103  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   104  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   105  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   106  	if c == nil {
   107  		return nil, ErrNilConfig
   108  	}
   109  	customQueryParams := map[string]string{}
   110  
   111  	for k, v := range c.Params {
   112  		if strings.HasPrefix(k, "x-") {
   113  			customQueryParams[k] = v
   114  			delete(c.Params, k)
   115  		}
   116  	}
   117  	return customQueryParams, nil
   118  }
   119  
   120  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   121  	// Need to parse out custom TLS parameters and call
   122  	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
   123  	// which consumes the registered tls.Config
   124  	// Fixes: https://github.com/golang-migrate/migrate/issues/411
   125  	//
   126  	// Can't use url.Parse() since it fails to parse MySQL DSNs
   127  	// mysql.ParseDSN() also searches for "?" to find query parameters:
   128  	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
   129  	if idx := strings.LastIndex(url, "?"); idx > 0 {
   130  		rawParams := url[idx+1:]
   131  		parsedParams, err := nurl.ParseQuery(rawParams)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		ctls := parsedParams.Get("tls")
   137  		if len(ctls) > 0 {
   138  			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   139  				rootCertPool := x509.NewCertPool()
   140  				pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
   141  				if err != nil {
   142  					return nil, err
   143  				}
   144  
   145  				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   146  					return nil, ErrAppendPEM
   147  				}
   148  
   149  				clientCert := make([]tls.Certificate, 0, 1)
   150  				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
   151  					if ccert == "" || ckey == "" {
   152  						return nil, ErrTLSCertKeyConfig
   153  					}
   154  					certs, err := tls.LoadX509KeyPair(ccert, ckey)
   155  					if err != nil {
   156  						return nil, err
   157  					}
   158  					clientCert = append(clientCert, certs)
   159  				}
   160  
   161  				insecureSkipVerify := false
   162  				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
   163  				if len(insecureSkipVerifyStr) > 0 {
   164  					x, err := strconv.ParseBool(insecureSkipVerifyStr)
   165  					if err != nil {
   166  						return nil, err
   167  					}
   168  					insecureSkipVerify = x
   169  				}
   170  
   171  				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   172  					RootCAs:            rootCertPool,
   173  					Certificates:       clientCert,
   174  					InsecureSkipVerify: insecureSkipVerify,
   175  				})
   176  				if err != nil {
   177  					return nil, err
   178  				}
   179  			}
   180  		}
   181  	}
   182  
   183  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	config.MultiStatements = true
   189  
   190  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   191  	// net/url.Parse() would automatically unescape it for us.
   192  	// See: https://play.golang.org/p/q9j1io-YICQ
   193  	user, err := nurl.QueryUnescape(config.User)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	config.User = user
   198  
   199  	password, err := nurl.QueryUnescape(config.Passwd)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	config.Passwd = password
   204  
   205  	return config, nil
   206  }
   207  
   208  func (m *Mysql) Open(url string) (database.Driver, error) {
   209  	config, err := urlToMySQLConfig(url)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  
   214  	customParams, err := extractCustomQueryParams(config)
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	noLockParam, noLock := customParams["x-no-lock"], false
   220  	if noLockParam != "" {
   221  		noLock, err = strconv.ParseBool(noLockParam)
   222  		if err != nil {
   223  			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
   224  		}
   225  	}
   226  
   227  	db, err := sql.Open("mysql", config.FormatDSN())
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	mx, err := WithInstance(db, &Config{
   233  		DatabaseName:    config.DBName,
   234  		MigrationsTable: customParams["x-migrations-table"],
   235  		NoLock:          noLock,
   236  	})
   237  	if err != nil {
   238  		return nil, err
   239  	}
   240  
   241  	return mx, nil
   242  }
   243  
   244  func (m *Mysql) Close() error {
   245  	connErr := m.conn.Close()
   246  	dbErr := m.db.Close()
   247  	if connErr != nil || dbErr != nil {
   248  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   249  	}
   250  	return nil
   251  }
   252  
   253  func (m *Mysql) Lock() error {
   254  	if m.isLocked {
   255  		return database.ErrLocked
   256  	}
   257  
   258  	if m.config.NoLock {
   259  		m.isLocked = true
   260  		return nil
   261  	}
   262  
   263  	aid, err := database.GenerateAdvisoryLockId(
   264  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   265  	if err != nil {
   266  		return err
   267  	}
   268  
   269  	query := "SELECT GET_LOCK(?, 10)"
   270  	var success bool
   271  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   272  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   273  	}
   274  
   275  	if success {
   276  		m.isLocked = true
   277  		return nil
   278  	}
   279  
   280  	return database.ErrLocked
   281  }
   282  
   283  func (m *Mysql) Unlock() error {
   284  	if !m.isLocked {
   285  		return nil
   286  	}
   287  
   288  	if m.config.NoLock {
   289  		m.isLocked = false
   290  		return nil
   291  	}
   292  
   293  	aid, err := database.GenerateAdvisoryLockId(
   294  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   295  	if err != nil {
   296  		return err
   297  	}
   298  
   299  	query := `SELECT RELEASE_LOCK(?)`
   300  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   301  		return &database.Error{OrigErr: err, Query: []byte(query)}
   302  	}
   303  
   304  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   305  	// in which case isLocked should be true until the timeout expires -- synchronizing
   306  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   307  
   308  	m.isLocked = false
   309  	return nil
   310  }
   311  
   312  func (m *Mysql) Run(migration io.Reader) error {
   313  	migr, err := ioutil.ReadAll(migration)
   314  	if err != nil {
   315  		return err
   316  	}
   317  
   318  	query := string(migr[:])
   319  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   320  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   321  	}
   322  
   323  	return nil
   324  }
   325  
   326  func (m *Mysql) SetVersion(version int, dirty bool) error {
   327  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   328  	if err != nil {
   329  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   330  	}
   331  
   332  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   333  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   334  		if errRollback := tx.Rollback(); errRollback != nil {
   335  			err = multierror.Append(err, errRollback)
   336  		}
   337  		return &database.Error{OrigErr: err, Query: []byte(query)}
   338  	}
   339  
   340  	// Also re-write the schema version for nil dirty versions to prevent
   341  	// empty schema version for failed down migration on the first migration
   342  	// See: https://github.com/golang-migrate/migrate/issues/330
   343  	if version >= 0 || (version == database.NilVersion && dirty) {
   344  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   345  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   346  			if errRollback := tx.Rollback(); errRollback != nil {
   347  				err = multierror.Append(err, errRollback)
   348  			}
   349  			return &database.Error{OrigErr: err, Query: []byte(query)}
   350  		}
   351  	}
   352  
   353  	if err := tx.Commit(); err != nil {
   354  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   355  	}
   356  
   357  	return nil
   358  }
   359  
   360  func (m *Mysql) Version() (version int, dirty bool, err error) {
   361  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   362  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   363  	switch {
   364  	case err == sql.ErrNoRows:
   365  		return database.NilVersion, false, nil
   366  
   367  	case err != nil:
   368  		if e, ok := err.(*mysql.MySQLError); ok {
   369  			if e.Number == 0 {
   370  				return database.NilVersion, false, nil
   371  			}
   372  		}
   373  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   374  
   375  	default:
   376  		return version, dirty, nil
   377  	}
   378  }
   379  
   380  func (m *Mysql) Drop() (err error) {
   381  	// select all tables
   382  	query := `SHOW TABLES LIKE '%'`
   383  	tables, err := m.conn.QueryContext(context.Background(), query)
   384  	if err != nil {
   385  		return &database.Error{OrigErr: err, Query: []byte(query)}
   386  	}
   387  	defer func() {
   388  		if errClose := tables.Close(); errClose != nil {
   389  			err = multierror.Append(err, errClose)
   390  		}
   391  	}()
   392  
   393  	// delete one table after another
   394  	tableNames := make([]string, 0)
   395  	for tables.Next() {
   396  		var tableName string
   397  		if err := tables.Scan(&tableName); err != nil {
   398  			return err
   399  		}
   400  		if len(tableName) > 0 {
   401  			tableNames = append(tableNames, tableName)
   402  		}
   403  	}
   404  	if err := tables.Err(); err != nil {
   405  		return &database.Error{OrigErr: err, Query: []byte(query)}
   406  	}
   407  
   408  	if len(tableNames) > 0 {
   409  		// disable checking foreign key constraints until finished
   410  		query = `SET foreign_key_checks = 0`
   411  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   412  			return &database.Error{OrigErr: err, Query: []byte(query)}
   413  		}
   414  
   415  		defer func() {
   416  			// enable foreign key checks
   417  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   418  		}()
   419  
   420  		// delete one by one ...
   421  		for _, t := range tableNames {
   422  			query = "DROP TABLE IF EXISTS `" + t + "`"
   423  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   424  				return &database.Error{OrigErr: err, Query: []byte(query)}
   425  			}
   426  		}
   427  	}
   428  
   429  	return nil
   430  }
   431  
   432  // ensureVersionTable checks if versions table exists and, if not, creates it.
   433  // Note that this function locks the database, which deviates from the usual
   434  // convention of "caller locks" in the Mysql type.
   435  func (m *Mysql) ensureVersionTable() (err error) {
   436  	if err = m.Lock(); err != nil {
   437  		return err
   438  	}
   439  
   440  	defer func() {
   441  		if e := m.Unlock(); e != nil {
   442  			if err == nil {
   443  				err = e
   444  			} else {
   445  				err = multierror.Append(err, e)
   446  			}
   447  		}
   448  	}()
   449  
   450  	// check if migration table exists
   451  	var result string
   452  	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
   453  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   454  		if err != sql.ErrNoRows {
   455  			return &database.Error{OrigErr: err, Query: []byte(query)}
   456  		}
   457  	} else {
   458  		return nil
   459  	}
   460  
   461  	// if not, create the empty migration table
   462  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   463  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   464  		return &database.Error{OrigErr: err, Query: []byte(query)}
   465  	}
   466  	return nil
   467  }
   468  
   469  // Returns the bool value of the input.
   470  // The 2nd return value indicates if the input was a valid bool value
   471  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   472  func readBool(input string) (value bool, valid bool) {
   473  	switch input {
   474  	case "1", "true", "TRUE", "True":
   475  		return true, true
   476  	case "0", "false", "FALSE", "False":
   477  		return false, true
   478  	}
   479  
   480  	// Not a valid bool value
   481  	return
   482  }