github.com/dynastymasra/migrate/v4@v4.11.0/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  import (
    19  	"github.com/go-sql-driver/mysql"
    20  	"github.com/hashicorp/go-multierror"
    21  )
    22  
    23  import (
    24  	"github.com/golang-migrate/migrate/v4/database"
    25  )
    26  
    27  func init() {
    28  	database.Register("mysql", &Mysql{})
    29  }
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    35  	ErrNilConfig        = fmt.Errorf("no config")
    36  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    37  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    38  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    39  )
    40  
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  }
    45  
    46  type Mysql struct {
    47  	// mysql RELEASE_LOCK must be called from the same conn, so
    48  	// just do everything over a single conn anyway.
    49  	conn     *sql.Conn
    50  	db       *sql.DB
    51  	isLocked bool
    52  
    53  	config *Config
    54  }
    55  
    56  // instance must have `multiStatements` set to true
    57  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    58  	if config == nil {
    59  		return nil, ErrNilConfig
    60  	}
    61  
    62  	if err := instance.Ping(); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	if config.DatabaseName == "" {
    67  		query := `SELECT DATABASE()`
    68  		var databaseName sql.NullString
    69  		if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    70  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    71  		}
    72  
    73  		if len(databaseName.String) == 0 {
    74  			return nil, ErrNoDatabaseName
    75  		}
    76  
    77  		config.DatabaseName = databaseName.String
    78  	}
    79  
    80  	if len(config.MigrationsTable) == 0 {
    81  		config.MigrationsTable = DefaultMigrationsTable
    82  	}
    83  
    84  	conn, err := instance.Conn(context.Background())
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	mx := &Mysql{
    90  		conn:   conn,
    91  		db:     instance,
    92  		config: config,
    93  	}
    94  
    95  	if err := mx.ensureVersionTable(); err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	return mx, nil
   100  }
   101  
   102  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   103  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   104  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   105  	if c == nil {
   106  		return nil, ErrNilConfig
   107  	}
   108  	customQueryParams := map[string]string{}
   109  
   110  	for k, v := range c.Params {
   111  		if strings.HasPrefix(k, "x-") {
   112  			customQueryParams[k] = v
   113  			delete(c.Params, k)
   114  		}
   115  	}
   116  	return customQueryParams, nil
   117  }
   118  
   119  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   120  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	config.MultiStatements = true
   126  
   127  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   128  	// net/url.Parse() would automatically unescape it for us.
   129  	// See: https://play.golang.org/p/q9j1io-YICQ
   130  	user, err := nurl.QueryUnescape(config.User)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	config.User = user
   135  
   136  	password, err := nurl.QueryUnescape(config.Passwd)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	config.Passwd = password
   141  
   142  	// use custom TLS?
   143  	ctls := config.TLSConfig
   144  	if len(ctls) > 0 {
   145  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   146  			rootCertPool := x509.NewCertPool()
   147  			pem, err := ioutil.ReadFile(config.Params["x-tls-ca"])
   148  			if err != nil {
   149  				return nil, err
   150  			}
   151  
   152  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   153  				return nil, ErrAppendPEM
   154  			}
   155  
   156  			clientCert := make([]tls.Certificate, 0, 1)
   157  			if ccert, ckey := config.Params["x-tls-cert"], config.Params["x-tls-key"]; ccert != "" || ckey != "" {
   158  				if ccert == "" || ckey == "" {
   159  					return nil, ErrTLSCertKeyConfig
   160  				}
   161  				certs, err := tls.LoadX509KeyPair(ccert, ckey)
   162  				if err != nil {
   163  					return nil, err
   164  				}
   165  				clientCert = append(clientCert, certs)
   166  			}
   167  
   168  			insecureSkipVerify := false
   169  			if len(config.Params["x-tls-insecure-skip-verify"]) > 0 {
   170  				x, err := strconv.ParseBool(config.Params["x-tls-insecure-skip-verify"])
   171  				if err != nil {
   172  					return nil, err
   173  				}
   174  				insecureSkipVerify = x
   175  			}
   176  
   177  			err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   178  				RootCAs:            rootCertPool,
   179  				Certificates:       clientCert,
   180  				InsecureSkipVerify: insecureSkipVerify,
   181  			})
   182  			if err != nil {
   183  				return nil, err
   184  			}
   185  		}
   186  	}
   187  
   188  	return config, nil
   189  }
   190  
   191  func (m *Mysql) Open(url string) (database.Driver, error) {
   192  	config, err := urlToMySQLConfig(url)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  
   197  	customParams, err := extractCustomQueryParams(config)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  
   202  	db, err := sql.Open("mysql", config.FormatDSN())
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	mx, err := WithInstance(db, &Config{
   208  		DatabaseName:    config.DBName,
   209  		MigrationsTable: customParams["x-migrations-table"],
   210  	})
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	return mx, nil
   216  }
   217  
   218  func (m *Mysql) Close() error {
   219  	connErr := m.conn.Close()
   220  	dbErr := m.db.Close()
   221  	if connErr != nil || dbErr != nil {
   222  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   223  	}
   224  	return nil
   225  }
   226  
   227  func (m *Mysql) Lock() error {
   228  	if m.isLocked {
   229  		return database.ErrLocked
   230  	}
   231  
   232  	aid, err := database.GenerateAdvisoryLockId(
   233  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	query := "SELECT GET_LOCK(?, 10)"
   239  	var success bool
   240  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   241  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   242  	}
   243  
   244  	if success {
   245  		m.isLocked = true
   246  		return nil
   247  	}
   248  
   249  	return database.ErrLocked
   250  }
   251  
   252  func (m *Mysql) Unlock() error {
   253  	if !m.isLocked {
   254  		return nil
   255  	}
   256  
   257  	aid, err := database.GenerateAdvisoryLockId(
   258  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	query := `SELECT RELEASE_LOCK(?)`
   264  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   265  		return &database.Error{OrigErr: err, Query: []byte(query)}
   266  	}
   267  
   268  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   269  	// in which case isLocked should be true until the timeout expires -- synchronizing
   270  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   271  
   272  	m.isLocked = false
   273  	return nil
   274  }
   275  
   276  func (m *Mysql) Run(migration io.Reader) error {
   277  	migr, err := ioutil.ReadAll(migration)
   278  	if err != nil {
   279  		return err
   280  	}
   281  
   282  	query := string(migr[:])
   283  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   284  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   285  	}
   286  
   287  	return nil
   288  }
   289  
   290  func (m *Mysql) SetVersion(version int, dirty bool) error {
   291  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   292  	if err != nil {
   293  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   294  	}
   295  
   296  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   297  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   298  		if errRollback := tx.Rollback(); errRollback != nil {
   299  			err = multierror.Append(err, errRollback)
   300  		}
   301  		return &database.Error{OrigErr: err, Query: []byte(query)}
   302  	}
   303  
   304  	// Also re-write the schema version for nil dirty versions to prevent
   305  	// empty schema version for failed down migration on the first migration
   306  	// See: https://github.com/golang-migrate/migrate/issues/330
   307  	if version >= 0 || (version == database.NilVersion && dirty) {
   308  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   309  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   310  			if errRollback := tx.Rollback(); errRollback != nil {
   311  				err = multierror.Append(err, errRollback)
   312  			}
   313  			return &database.Error{OrigErr: err, Query: []byte(query)}
   314  		}
   315  	}
   316  
   317  	if err := tx.Commit(); err != nil {
   318  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   319  	}
   320  
   321  	return nil
   322  }
   323  
   324  func (m *Mysql) Version() (version int, dirty bool, err error) {
   325  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   326  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   327  	switch {
   328  	case err == sql.ErrNoRows:
   329  		return database.NilVersion, false, nil
   330  
   331  	case err != nil:
   332  		if e, ok := err.(*mysql.MySQLError); ok {
   333  			if e.Number == 0 {
   334  				return database.NilVersion, false, nil
   335  			}
   336  		}
   337  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   338  
   339  	default:
   340  		return version, dirty, nil
   341  	}
   342  }
   343  
   344  func (m *Mysql) Drop() (err error) {
   345  	// select all tables
   346  	query := `SHOW TABLES LIKE '%'`
   347  	tables, err := m.conn.QueryContext(context.Background(), query)
   348  	if err != nil {
   349  		return &database.Error{OrigErr: err, Query: []byte(query)}
   350  	}
   351  	defer func() {
   352  		if errClose := tables.Close(); errClose != nil {
   353  			err = multierror.Append(err, errClose)
   354  		}
   355  	}()
   356  
   357  	// delete one table after another
   358  	tableNames := make([]string, 0)
   359  	for tables.Next() {
   360  		var tableName string
   361  		if err := tables.Scan(&tableName); err != nil {
   362  			return err
   363  		}
   364  		if len(tableName) > 0 {
   365  			tableNames = append(tableNames, tableName)
   366  		}
   367  	}
   368  
   369  	if len(tableNames) > 0 {
   370  		// disable checking foreign key constraints until finished
   371  		query = `SET foreign_key_checks = 0`
   372  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   373  			return &database.Error{OrigErr: err, Query: []byte(query)}
   374  		}
   375  
   376  		defer func() {
   377  			// enable foreign key checks
   378  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   379  		}()
   380  
   381  		// delete one by one ...
   382  		for _, t := range tableNames {
   383  			query = "DROP TABLE IF EXISTS `" + t + "`"
   384  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   385  				return &database.Error{OrigErr: err, Query: []byte(query)}
   386  			}
   387  		}
   388  	}
   389  
   390  	return nil
   391  }
   392  
   393  // ensureVersionTable checks if versions table exists and, if not, creates it.
   394  // Note that this function locks the database, which deviates from the usual
   395  // convention of "caller locks" in the Mysql type.
   396  func (m *Mysql) ensureVersionTable() (err error) {
   397  	if err = m.Lock(); err != nil {
   398  		return err
   399  	}
   400  
   401  	defer func() {
   402  		if e := m.Unlock(); e != nil {
   403  			if err == nil {
   404  				err = e
   405  			} else {
   406  				err = multierror.Append(err, e)
   407  			}
   408  		}
   409  	}()
   410  
   411  	// check if migration table exists
   412  	var result string
   413  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   414  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   415  		if err != sql.ErrNoRows {
   416  			return &database.Error{OrigErr: err, Query: []byte(query)}
   417  		}
   418  	} else {
   419  		return nil
   420  	}
   421  
   422  	// if not, create the empty migration table
   423  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   424  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   425  		return &database.Error{OrigErr: err, Query: []byte(query)}
   426  	}
   427  	return nil
   428  }
   429  
   430  // Returns the bool value of the input.
   431  // The 2nd return value indicates if the input was a valid bool value
   432  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   433  func readBool(input string) (value bool, valid bool) {
   434  	switch input {
   435  	case "1", "true", "TRUE", "True":
   436  		return true, true
   437  	case "0", "false", "FALSE", "False":
   438  		return false, true
   439  	}
   440  
   441  	// Not a valid bool value
   442  	return
   443  }