github.com/seashell-org/golang-migrate/v4@v4.15.3-0.20220722221203-6ab6c6c062d1/database/mysql/mysql.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package mysql
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"database/sql"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	nurl "net/url"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"go.uber.org/atomic"
    19  
    20  	"github.com/go-sql-driver/mysql"
    21  	"github.com/hashicorp/go-multierror"
    22  	"github.com/seashell-org/golang-migrate/v4/database"
    23  )
    24  
    25  var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
    26  
    27  func init() {
    28  	database.Register("mysql", &Mysql{})
    29  }
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    35  	ErrNilConfig        = fmt.Errorf("no config")
    36  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    37  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    38  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    39  )
    40  
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  	NoLock          bool
    45  }
    46  
    47  type Mysql struct {
    48  	// mysql RELEASE_LOCK must be called from the same conn, so
    49  	// just do everything over a single conn anyway.
    50  	conn     *sql.Conn
    51  	db       *sql.DB
    52  	isLocked atomic.Bool
    53  
    54  	config *Config
    55  }
    56  
    57  // connection instance must have `multiStatements` set to true
    58  func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
    59  	if config == nil {
    60  		return nil, ErrNilConfig
    61  	}
    62  
    63  	if err := conn.PingContext(ctx); err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	mx := &Mysql{
    68  		conn:   conn,
    69  		db:     nil,
    70  		config: config,
    71  	}
    72  
    73  	if config.DatabaseName == "" {
    74  		query := `SELECT DATABASE()`
    75  		var databaseName sql.NullString
    76  		if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
    77  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    78  		}
    79  
    80  		if len(databaseName.String) == 0 {
    81  			return nil, ErrNoDatabaseName
    82  		}
    83  
    84  		config.DatabaseName = databaseName.String
    85  	}
    86  
    87  	if len(config.MigrationsTable) == 0 {
    88  		config.MigrationsTable = DefaultMigrationsTable
    89  	}
    90  
    91  	if err := mx.ensureVersionTable(); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return mx, nil
    96  }
    97  
    98  // instance must have `multiStatements` set to true
    99  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
   100  	ctx := context.Background()
   101  
   102  	if err := instance.Ping(); err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	conn, err := instance.Conn(ctx)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	mx, err := WithConnection(ctx, conn, config)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	mx.db = instance
   117  
   118  	return mx, nil
   119  }
   120  
   121  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   122  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   123  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   124  	if c == nil {
   125  		return nil, ErrNilConfig
   126  	}
   127  	customQueryParams := map[string]string{}
   128  
   129  	for k, v := range c.Params {
   130  		if strings.HasPrefix(k, "x-") {
   131  			customQueryParams[k] = v
   132  			delete(c.Params, k)
   133  		}
   134  	}
   135  	return customQueryParams, nil
   136  }
   137  
   138  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   139  	// Need to parse out custom TLS parameters and call
   140  	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
   141  	// which consumes the registered tls.Config
   142  	// Fixes: https://github.com/seashell-org/golang-migrate/issues/411
   143  	//
   144  	// Can't use url.Parse() since it fails to parse MySQL DSNs
   145  	// mysql.ParseDSN() also searches for "?" to find query parameters:
   146  	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
   147  	if idx := strings.LastIndex(url, "?"); idx > 0 {
   148  		rawParams := url[idx+1:]
   149  		parsedParams, err := nurl.ParseQuery(rawParams)
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  
   154  		ctls := parsedParams.Get("tls")
   155  		if len(ctls) > 0 {
   156  			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   157  				rootCertPool := x509.NewCertPool()
   158  				pem, err := ioutil.ReadFile(parsedParams.Get("x-tls-ca"))
   159  				if err != nil {
   160  					return nil, err
   161  				}
   162  
   163  				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   164  					return nil, ErrAppendPEM
   165  				}
   166  
   167  				clientCert := make([]tls.Certificate, 0, 1)
   168  				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
   169  					if ccert == "" || ckey == "" {
   170  						return nil, ErrTLSCertKeyConfig
   171  					}
   172  					certs, err := tls.LoadX509KeyPair(ccert, ckey)
   173  					if err != nil {
   174  						return nil, err
   175  					}
   176  					clientCert = append(clientCert, certs)
   177  				}
   178  
   179  				insecureSkipVerify := false
   180  				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
   181  				if len(insecureSkipVerifyStr) > 0 {
   182  					x, err := strconv.ParseBool(insecureSkipVerifyStr)
   183  					if err != nil {
   184  						return nil, err
   185  					}
   186  					insecureSkipVerify = x
   187  				}
   188  
   189  				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   190  					RootCAs:            rootCertPool,
   191  					Certificates:       clientCert,
   192  					InsecureSkipVerify: insecureSkipVerify,
   193  				})
   194  				if err != nil {
   195  					return nil, err
   196  				}
   197  			}
   198  		}
   199  	}
   200  
   201  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	config.MultiStatements = true
   207  
   208  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   209  	// net/url.Parse() would automatically unescape it for us.
   210  	// See: https://play.golang.org/p/q9j1io-YICQ
   211  	user, err := nurl.QueryUnescape(config.User)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	config.User = user
   216  
   217  	password, err := nurl.QueryUnescape(config.Passwd)
   218  	if err != nil {
   219  		return nil, err
   220  	}
   221  	config.Passwd = password
   222  
   223  	return config, nil
   224  }
   225  
   226  func (m *Mysql) Open(url string) (database.Driver, error) {
   227  	config, err := urlToMySQLConfig(url)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	customParams, err := extractCustomQueryParams(config)
   233  	if err != nil {
   234  		return nil, err
   235  	}
   236  
   237  	noLockParam, noLock := customParams["x-no-lock"], false
   238  	if noLockParam != "" {
   239  		noLock, err = strconv.ParseBool(noLockParam)
   240  		if err != nil {
   241  			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
   242  		}
   243  	}
   244  
   245  	db, err := sql.Open("mysql", config.FormatDSN())
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	mx, err := WithInstance(db, &Config{
   251  		DatabaseName:    config.DBName,
   252  		MigrationsTable: customParams["x-migrations-table"],
   253  		NoLock:          noLock,
   254  	})
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  
   259  	return mx, nil
   260  }
   261  
   262  func (m *Mysql) Close() error {
   263  	connErr := m.conn.Close()
   264  	var dbErr error
   265  	if m.db != nil {
   266  		dbErr = m.db.Close()
   267  	}
   268  
   269  	if connErr != nil || dbErr != nil {
   270  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   271  	}
   272  	return nil
   273  }
   274  
   275  func (m *Mysql) Lock() error {
   276  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   277  		if m.config.NoLock {
   278  			return nil
   279  		}
   280  		aid, err := database.GenerateAdvisoryLockId(
   281  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   282  		if err != nil {
   283  			return err
   284  		}
   285  
   286  		query := "SELECT GET_LOCK(?, 10)"
   287  		var success bool
   288  		if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   289  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   290  		}
   291  
   292  		if !success {
   293  			return database.ErrLocked
   294  		}
   295  
   296  		return nil
   297  	})
   298  }
   299  
   300  func (m *Mysql) Unlock() error {
   301  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   302  		if m.config.NoLock {
   303  			return nil
   304  		}
   305  
   306  		aid, err := database.GenerateAdvisoryLockId(
   307  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   308  		if err != nil {
   309  			return err
   310  		}
   311  
   312  		query := `SELECT RELEASE_LOCK(?)`
   313  		if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   314  			return &database.Error{OrigErr: err, Query: []byte(query)}
   315  		}
   316  
   317  		// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   318  		// in which case isLocked should be true until the timeout expires -- synchronizing
   319  		// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   320  
   321  		return nil
   322  	})
   323  }
   324  
   325  func (m *Mysql) Run(migration io.Reader) error {
   326  	migr, err := ioutil.ReadAll(migration)
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	query := string(migr[:])
   332  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   333  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   334  	}
   335  
   336  	return nil
   337  }
   338  
   339  func (m *Mysql) SetVersion(version int, dirty bool) error {
   340  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
   341  	if err != nil {
   342  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   343  	}
   344  
   345  	query := "DELETE FROM `" + m.config.MigrationsTable + "`"
   346  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   347  		if errRollback := tx.Rollback(); errRollback != nil {
   348  			err = multierror.Append(err, errRollback)
   349  		}
   350  		return &database.Error{OrigErr: err, Query: []byte(query)}
   351  	}
   352  
   353  	// Also re-write the schema version for nil dirty versions to prevent
   354  	// empty schema version for failed down migration on the first migration
   355  	// See: https://github.com/seashell-org/golang-migrate/issues/330
   356  	if version >= 0 || (version == database.NilVersion && dirty) {
   357  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   358  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   359  			if errRollback := tx.Rollback(); errRollback != nil {
   360  				err = multierror.Append(err, errRollback)
   361  			}
   362  			return &database.Error{OrigErr: err, Query: []byte(query)}
   363  		}
   364  	}
   365  
   366  	if err := tx.Commit(); err != nil {
   367  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   368  	}
   369  
   370  	return nil
   371  }
   372  
   373  func (m *Mysql) Version() (version int, dirty bool, err error) {
   374  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   375  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   376  	switch {
   377  	case err == sql.ErrNoRows:
   378  		return database.NilVersion, false, nil
   379  
   380  	case err != nil:
   381  		if e, ok := err.(*mysql.MySQLError); ok {
   382  			if e.Number == 0 {
   383  				return database.NilVersion, false, nil
   384  			}
   385  		}
   386  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   387  
   388  	default:
   389  		return version, dirty, nil
   390  	}
   391  }
   392  
   393  func (m *Mysql) Drop() (err error) {
   394  	// select all tables
   395  	query := `SHOW TABLES LIKE '%'`
   396  	tables, err := m.conn.QueryContext(context.Background(), query)
   397  	if err != nil {
   398  		return &database.Error{OrigErr: err, Query: []byte(query)}
   399  	}
   400  	defer func() {
   401  		if errClose := tables.Close(); errClose != nil {
   402  			err = multierror.Append(err, errClose)
   403  		}
   404  	}()
   405  
   406  	// delete one table after another
   407  	tableNames := make([]string, 0)
   408  	for tables.Next() {
   409  		var tableName string
   410  		if err := tables.Scan(&tableName); err != nil {
   411  			return err
   412  		}
   413  		if len(tableName) > 0 {
   414  			tableNames = append(tableNames, tableName)
   415  		}
   416  	}
   417  	if err := tables.Err(); err != nil {
   418  		return &database.Error{OrigErr: err, Query: []byte(query)}
   419  	}
   420  
   421  	if len(tableNames) > 0 {
   422  		// disable checking foreign key constraints until finished
   423  		query = `SET foreign_key_checks = 0`
   424  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   425  			return &database.Error{OrigErr: err, Query: []byte(query)}
   426  		}
   427  
   428  		defer func() {
   429  			// enable foreign key checks
   430  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   431  		}()
   432  
   433  		// delete one by one ...
   434  		for _, t := range tableNames {
   435  			query = "DROP TABLE IF EXISTS `" + t + "`"
   436  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   437  				return &database.Error{OrigErr: err, Query: []byte(query)}
   438  			}
   439  		}
   440  	}
   441  
   442  	return nil
   443  }
   444  
   445  // ensureVersionTable checks if versions table exists and, if not, creates it.
   446  // Note that this function locks the database, which deviates from the usual
   447  // convention of "caller locks" in the Mysql type.
   448  func (m *Mysql) ensureVersionTable() (err error) {
   449  	if err = m.Lock(); err != nil {
   450  		return err
   451  	}
   452  
   453  	defer func() {
   454  		if e := m.Unlock(); e != nil {
   455  			if err == nil {
   456  				err = e
   457  			} else {
   458  				err = multierror.Append(err, e)
   459  			}
   460  		}
   461  	}()
   462  
   463  	// check if migration table exists
   464  	var result string
   465  	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
   466  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   467  		if err != sql.ErrNoRows {
   468  			return &database.Error{OrigErr: err, Query: []byte(query)}
   469  		}
   470  	} else {
   471  		return nil
   472  	}
   473  
   474  	// if not, create the empty migration table
   475  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   476  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   477  		return &database.Error{OrigErr: err, Query: []byte(query)}
   478  	}
   479  	return nil
   480  }
   481  
   482  // Returns the bool value of the input.
   483  // The 2nd return value indicates if the input was a valid bool value
   484  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   485  func readBool(input string) (value bool, valid bool) {
   486  	switch input {
   487  	case "1", "true", "TRUE", "True":
   488  		return true, true
   489  	case "0", "false", "FALSE", "False":
   490  		return false, true
   491  	}
   492  
   493  	// Not a valid bool value
   494  	return
   495  }