github.com/nokia/migrate/v4@v4.16.0/database/mysql/mysql.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package mysql
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"database/sql"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	nurl "net/url"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"go.uber.org/atomic"
    19  
    20  	"github.com/go-sql-driver/mysql"
    21  	"github.com/hashicorp/go-multierror"
    22  	"github.com/nokia/migrate/v4/database"
    23  	"github.com/nokia/migrate/v4/source"
    24  )
    25  
    26  var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
    27  
    28  func init() {
    29  	database.Register("mysql", &Mysql{})
    30  }
    31  
    32  var DefaultMigrationsTable = "schema_migrations"
    33  
    34  var (
    35  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    36  	ErrNilConfig        = fmt.Errorf("no config")
    37  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    38  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    39  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    40  )
    41  
    42  type Config struct {
    43  	MigrationsTable string
    44  	DatabaseName    string
    45  	NoLock          bool
    46  }
    47  
    48  type Mysql struct {
    49  	// mysql RELEASE_LOCK must be called from the same conn, so
    50  	// just do everything over a single conn anyway.
    51  	conn     *sql.Conn
    52  	db       *sql.DB
    53  	isLocked atomic.Bool
    54  
    55  	config *Config
    56  }
    57  
    58  // connection instance must have `multiStatements` set to true
    59  func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
    60  	if config == nil {
    61  		return nil, ErrNilConfig
    62  	}
    63  
    64  	if err := conn.PingContext(ctx); err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	mx := &Mysql{
    69  		conn:   conn,
    70  		db:     nil,
    71  		config: config,
    72  	}
    73  
    74  	if config.DatabaseName == "" {
    75  		query := `SELECT DATABASE()`
    76  		var databaseName sql.NullString
    77  		if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
    78  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    79  		}
    80  
    81  		if len(databaseName.String) == 0 {
    82  			return nil, ErrNoDatabaseName
    83  		}
    84  
    85  		config.DatabaseName = databaseName.String
    86  	}
    87  
    88  	if len(config.MigrationsTable) == 0 {
    89  		config.MigrationsTable = DefaultMigrationsTable
    90  	}
    91  
    92  	if err := mx.ensureVersionTable(); err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	return mx, nil
    97  }
    98  
    99  // instance must have `multiStatements` set to true
   100  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
   101  	ctx := context.Background()
   102  
   103  	if err := instance.Ping(); err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	conn, err := instance.Conn(ctx)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	mx, err := WithConnection(ctx, conn, config)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	mx.db = instance
   118  
   119  	return mx, nil
   120  }
   121  
   122  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   123  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   124  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   125  	if c == nil {
   126  		return nil, ErrNilConfig
   127  	}
   128  	customQueryParams := map[string]string{}
   129  
   130  	for k, v := range c.Params {
   131  		if strings.HasPrefix(k, "x-") {
   132  			customQueryParams[k] = v
   133  			delete(c.Params, k)
   134  		}
   135  	}
   136  	return customQueryParams, nil
   137  }
   138  
   139  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   140  	// Need to parse out custom TLS parameters and call
   141  	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
   142  	// which consumes the registered tls.Config
   143  	// Fixes: https://github.com/nokia/migrate/issues/411
   144  	//
   145  	// Can't use url.Parse() since it fails to parse MySQL DSNs
   146  	// mysql.ParseDSN() also searches for "?" to find query parameters:
   147  	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
   148  	if idx := strings.LastIndex(url, "?"); idx > 0 {
   149  		rawParams := url[idx+1:]
   150  		parsedParams, err := nurl.ParseQuery(rawParams)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  
   155  		ctls := parsedParams.Get("tls")
   156  		if len(ctls) > 0 {
   157  			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   158  				rootCertPool := x509.NewCertPool()
   159  				pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
   160  				if err != nil {
   161  					return nil, err
   162  				}
   163  
   164  				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   165  					return nil, ErrAppendPEM
   166  				}
   167  
   168  				clientCert := make([]tls.Certificate, 0, 1)
   169  				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
   170  					if ccert == "" || ckey == "" {
   171  						return nil, ErrTLSCertKeyConfig
   172  					}
   173  					certs, err := tls.LoadX509KeyPair(ccert, ckey)
   174  					if err != nil {
   175  						return nil, err
   176  					}
   177  					clientCert = append(clientCert, certs)
   178  				}
   179  
   180  				insecureSkipVerify := false
   181  				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
   182  				if len(insecureSkipVerifyStr) > 0 {
   183  					x, err := strconv.ParseBool(insecureSkipVerifyStr)
   184  					if err != nil {
   185  						return nil, err
   186  					}
   187  					insecureSkipVerify = x
   188  				}
   189  
   190  				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   191  					RootCAs:            rootCertPool,
   192  					Certificates:       clientCert,
   193  					InsecureSkipVerify: insecureSkipVerify,
   194  				})
   195  				if err != nil {
   196  					return nil, err
   197  				}
   198  			}
   199  		}
   200  	}
   201  
   202  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	config.MultiStatements = true
   208  
   209  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   210  	// net/url.Parse() would automatically unescape it for us.
   211  	// See: https://play.golang.org/p/q9j1io-YICQ
   212  	user, err := nurl.QueryUnescape(config.User)
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	config.User = user
   217  
   218  	password, err := nurl.QueryUnescape(config.Passwd)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	config.Passwd = password
   223  
   224  	return config, nil
   225  }
   226  
   227  func (m *Mysql) Open(url string) (database.Driver, error) {
   228  	config, err := urlToMySQLConfig(url)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	customParams, err := extractCustomQueryParams(config)
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	noLockParam, noLock := customParams["x-no-lock"], false
   239  	if noLockParam != "" {
   240  		noLock, err = strconv.ParseBool(noLockParam)
   241  		if err != nil {
   242  			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
   243  		}
   244  	}
   245  
   246  	db, err := sql.Open("mysql", config.FormatDSN())
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	mx, err := WithInstance(db, &Config{
   252  		DatabaseName:    config.DBName,
   253  		MigrationsTable: customParams["x-migrations-table"],
   254  		NoLock:          noLock,
   255  	})
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	return mx, nil
   261  }
   262  
   263  func (m *Mysql) Close() error {
   264  	connErr := m.conn.Close()
   265  	var dbErr error
   266  	if m.db != nil {
   267  		dbErr = m.db.Close()
   268  	}
   269  
   270  	if connErr != nil || dbErr != nil {
   271  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   272  	}
   273  	return nil
   274  }
   275  
   276  func (m *Mysql) Lock() error {
   277  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   278  		if m.config.NoLock {
   279  			return nil
   280  		}
   281  		aid, err := database.GenerateAdvisoryLockId(
   282  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   283  		if err != nil {
   284  			return err
   285  		}
   286  
   287  		query := "SELECT GET_LOCK(?, 10)"
   288  		var success bool
   289  		if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   290  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   291  		}
   292  
   293  		if !success {
   294  			return database.ErrLocked
   295  		}
   296  
   297  		return nil
   298  	})
   299  }
   300  
   301  func (m *Mysql) Unlock() error {
   302  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   303  		if m.config.NoLock {
   304  			return nil
   305  		}
   306  
   307  		aid, err := database.GenerateAdvisoryLockId(
   308  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   309  		if err != nil {
   310  			return err
   311  		}
   312  
   313  		query := `SELECT RELEASE_LOCK(?)`
   314  		if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   315  			return &database.Error{OrigErr: err, Query: []byte(query)}
   316  		}
   317  
   318  		// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   319  		// in which case isLocked should be true until the timeout expires -- synchronizing
   320  		// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   321  
   322  		return nil
   323  	})
   324  }
   325  
   326  func (m *Mysql) Run(migration io.Reader) error {
   327  	migr, err := ioutil.ReadAll(migration)
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	query := string(migr[:])
   333  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   334  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   335  	}
   336  
   337  	return nil
   338  }
   339  
   340  func (m *Mysql) RunFunctionMigration(fn source.MigrationFunc) error {
   341  	return database.ErrNotImpl
   342  }
   343  
   344  func (m *Mysql) SetVersion(version int, dirty bool) error {
   345  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
   346  	if err != nil {
   347  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   348  	}
   349  
   350  	query := "DELETE FROM `" + m.config.MigrationsTable + "`"
   351  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   352  		if errRollback := tx.Rollback(); errRollback != nil {
   353  			err = multierror.Append(err, errRollback)
   354  		}
   355  		return &database.Error{OrigErr: err, Query: []byte(query)}
   356  	}
   357  
   358  	// Also re-write the schema version for nil dirty versions to prevent
   359  	// empty schema version for failed down migration on the first migration
   360  	// See: https://github.com/nokia/migrate/issues/330
   361  	if version >= 0 || (version == database.NilVersion && dirty) {
   362  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   363  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   364  			if errRollback := tx.Rollback(); errRollback != nil {
   365  				err = multierror.Append(err, errRollback)
   366  			}
   367  			return &database.Error{OrigErr: err, Query: []byte(query)}
   368  		}
   369  	}
   370  
   371  	if err := tx.Commit(); err != nil {
   372  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   373  	}
   374  
   375  	return nil
   376  }
   377  
   378  func (m *Mysql) Version() (version int, dirty bool, err error) {
   379  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   380  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   381  	switch {
   382  	case err == sql.ErrNoRows:
   383  		return database.NilVersion, false, nil
   384  
   385  	case err != nil:
   386  		if e, ok := err.(*mysql.MySQLError); ok {
   387  			if e.Number == 0 {
   388  				return database.NilVersion, false, nil
   389  			}
   390  		}
   391  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   392  
   393  	default:
   394  		return version, dirty, nil
   395  	}
   396  }
   397  
   398  func (m *Mysql) Drop() (err error) {
   399  	// select all tables
   400  	query := `SHOW TABLES LIKE '%'`
   401  	tables, err := m.conn.QueryContext(context.Background(), query)
   402  	if err != nil {
   403  		return &database.Error{OrigErr: err, Query: []byte(query)}
   404  	}
   405  	defer func() {
   406  		if errClose := tables.Close(); errClose != nil {
   407  			err = multierror.Append(err, errClose)
   408  		}
   409  	}()
   410  
   411  	// delete one table after another
   412  	tableNames := make([]string, 0)
   413  	for tables.Next() {
   414  		var tableName string
   415  		if err := tables.Scan(&tableName); err != nil {
   416  			return err
   417  		}
   418  		if len(tableName) > 0 {
   419  			tableNames = append(tableNames, tableName)
   420  		}
   421  	}
   422  	if err := tables.Err(); err != nil {
   423  		return &database.Error{OrigErr: err, Query: []byte(query)}
   424  	}
   425  
   426  	if len(tableNames) > 0 {
   427  		// disable checking foreign key constraints until finished
   428  		query = `SET foreign_key_checks = 0`
   429  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   430  			return &database.Error{OrigErr: err, Query: []byte(query)}
   431  		}
   432  
   433  		defer func() {
   434  			// enable foreign key checks
   435  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   436  		}()
   437  
   438  		// delete one by one ...
   439  		for _, t := range tableNames {
   440  			query = "DROP TABLE IF EXISTS `" + t + "`"
   441  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   442  				return &database.Error{OrigErr: err, Query: []byte(query)}
   443  			}
   444  		}
   445  	}
   446  
   447  	return nil
   448  }
   449  
   450  // ensureVersionTable checks if versions table exists and, if not, creates it.
   451  // Note that this function locks the database, which deviates from the usual
   452  // convention of "caller locks" in the Mysql type.
   453  func (m *Mysql) ensureVersionTable() (err error) {
   454  	if err = m.Lock(); err != nil {
   455  		return err
   456  	}
   457  
   458  	defer func() {
   459  		if e := m.Unlock(); e != nil {
   460  			if err == nil {
   461  				err = e
   462  			} else {
   463  				err = multierror.Append(err, e)
   464  			}
   465  		}
   466  	}()
   467  
   468  	// check if migration table exists
   469  	var result string
   470  	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
   471  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   472  		if err != sql.ErrNoRows {
   473  			return &database.Error{OrigErr: err, Query: []byte(query)}
   474  		}
   475  	} else {
   476  		return nil
   477  	}
   478  
   479  	// if not, create the empty migration table
   480  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   481  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   482  		return &database.Error{OrigErr: err, Query: []byte(query)}
   483  	}
   484  	return nil
   485  }
   486  
   487  // Returns the bool value of the input.
   488  // The 2nd return value indicates if the input was a valid bool value
   489  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   490  func readBool(input string) (value bool, valid bool) {
   491  	switch input {
   492  	case "1", "true", "TRUE", "True":
   493  		return true, true
   494  	case "0", "false", "FALSE", "False":
   495  		return false, true
   496  	}
   497  
   498  	// Not a valid bool value
   499  	return
   500  }