github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/mysql/mysql.go (about)

     1  //go:build go1.9
     2  // +build go1.9
     3  
     4  package mysql
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"database/sql"
    11  	"fmt"
    12  	"io"
    13  	nurl "net/url"
    14  	"os"
    15  	"strconv"
    16  	"strings"
    17  	"time"
    18  
    19  	"go.uber.org/atomic"
    20  
    21  	"github.com/go-sql-driver/mysql"
    22  	"github.com/golang-migrate/migrate/v4/database"
    23  	"github.com/hashicorp/go-multierror"
    24  )
    25  
    26  var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
    27  
    28  func init() {
    29  	database.Register("mysql", &Mysql{})
    30  }
    31  
    32  var DefaultMigrationsTable = "schema_migrations"
    33  
    34  var (
    35  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    36  	ErrNilConfig        = fmt.Errorf("no config")
    37  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    38  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    39  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    40  )
    41  
    42  type Config struct {
    43  	MigrationsTable  string
    44  	DatabaseName     string
    45  	NoLock           bool
    46  	StatementTimeout time.Duration
    47  }
    48  
    49  type Mysql struct {
    50  	// mysql RELEASE_LOCK must be called from the same conn, so
    51  	// just do everything over a single conn anyway.
    52  	conn     *sql.Conn
    53  	db       *sql.DB
    54  	isLocked atomic.Bool
    55  
    56  	config *Config
    57  }
    58  
    59  // connection instance must have `multiStatements` set to true
    60  func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
    61  	if config == nil {
    62  		return nil, ErrNilConfig
    63  	}
    64  
    65  	if err := conn.PingContext(ctx); err != nil {
    66  		return nil, err
    67  	}
    68  
    69  	mx := &Mysql{
    70  		conn:   conn,
    71  		db:     nil,
    72  		config: config,
    73  	}
    74  
    75  	if config.DatabaseName == "" {
    76  		query := `SELECT DATABASE()`
    77  		var databaseName sql.NullString
    78  		if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
    79  			return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    80  		}
    81  
    82  		if len(databaseName.String) == 0 {
    83  			return nil, ErrNoDatabaseName
    84  		}
    85  
    86  		config.DatabaseName = databaseName.String
    87  	}
    88  
    89  	if len(config.MigrationsTable) == 0 {
    90  		config.MigrationsTable = DefaultMigrationsTable
    91  	}
    92  
    93  	if err := mx.ensureVersionTable(); err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return mx, nil
    98  }
    99  
   100  // instance must have `multiStatements` set to true
   101  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
   102  	ctx := context.Background()
   103  
   104  	if err := instance.Ping(); err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	conn, err := instance.Conn(ctx)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	mx, err := WithConnection(ctx, conn, config)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	mx.db = instance
   119  
   120  	return mx, nil
   121  }
   122  
   123  // extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
   124  // mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
   125  func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
   126  	if c == nil {
   127  		return nil, ErrNilConfig
   128  	}
   129  	customQueryParams := map[string]string{}
   130  
   131  	for k, v := range c.Params {
   132  		if strings.HasPrefix(k, "x-") {
   133  			customQueryParams[k] = v
   134  			delete(c.Params, k)
   135  		}
   136  	}
   137  	return customQueryParams, nil
   138  }
   139  
   140  func urlToMySQLConfig(url string) (*mysql.Config, error) {
   141  	// Need to parse out custom TLS parameters and call
   142  	// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
   143  	// which consumes the registered tls.Config
   144  	// Fixes: https://github.com/golang-migrate/migrate/issues/411
   145  	//
   146  	// Can't use url.Parse() since it fails to parse MySQL DSNs
   147  	// mysql.ParseDSN() also searches for "?" to find query parameters:
   148  	// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
   149  	if idx := strings.LastIndex(url, "?"); idx > 0 {
   150  		rawParams := url[idx+1:]
   151  		parsedParams, err := nurl.ParseQuery(rawParams)
   152  		if err != nil {
   153  			return nil, err
   154  		}
   155  
   156  		ctls := parsedParams.Get("tls")
   157  		if len(ctls) > 0 {
   158  			if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   159  				rootCertPool := x509.NewCertPool()
   160  				pem, err := os.ReadFile(parsedParams.Get("x-tls-ca"))
   161  				if err != nil {
   162  					return nil, err
   163  				}
   164  
   165  				if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   166  					return nil, ErrAppendPEM
   167  				}
   168  
   169  				clientCert := make([]tls.Certificate, 0, 1)
   170  				if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
   171  					if ccert == "" || ckey == "" {
   172  						return nil, ErrTLSCertKeyConfig
   173  					}
   174  					certs, err := tls.LoadX509KeyPair(ccert, ckey)
   175  					if err != nil {
   176  						return nil, err
   177  					}
   178  					clientCert = append(clientCert, certs)
   179  				}
   180  
   181  				insecureSkipVerify := false
   182  				insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
   183  				if len(insecureSkipVerifyStr) > 0 {
   184  					x, err := strconv.ParseBool(insecureSkipVerifyStr)
   185  					if err != nil {
   186  						return nil, err
   187  					}
   188  					insecureSkipVerify = x
   189  				}
   190  
   191  				err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   192  					RootCAs:            rootCertPool,
   193  					Certificates:       clientCert,
   194  					InsecureSkipVerify: insecureSkipVerify,
   195  				})
   196  				if err != nil {
   197  					return nil, err
   198  				}
   199  			}
   200  		}
   201  	}
   202  
   203  	config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	config.MultiStatements = true
   209  
   210  	// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
   211  	// net/url.Parse() would automatically unescape it for us.
   212  	// See: https://play.golang.org/p/q9j1io-YICQ
   213  	user, err := nurl.QueryUnescape(config.User)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	config.User = user
   218  
   219  	password, err := nurl.QueryUnescape(config.Passwd)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	config.Passwd = password
   224  
   225  	return config, nil
   226  }
   227  
   228  func (m *Mysql) Open(url string) (database.Driver, error) {
   229  	config, err := urlToMySQLConfig(url)
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	customParams, err := extractCustomQueryParams(config)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	noLockParam, noLock := customParams["x-no-lock"], false
   240  	if noLockParam != "" {
   241  		noLock, err = strconv.ParseBool(noLockParam)
   242  		if err != nil {
   243  			return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
   244  		}
   245  	}
   246  
   247  	statementTimeoutParam := customParams["x-statement-timeout"]
   248  	statementTimeout := 0
   249  	if statementTimeoutParam != "" {
   250  		statementTimeout, err = strconv.Atoi(statementTimeoutParam)
   251  		if err != nil {
   252  			return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err)
   253  		}
   254  	}
   255  
   256  	db, err := sql.Open("mysql", config.FormatDSN())
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	mx, err := WithInstance(db, &Config{
   262  		DatabaseName:     config.DBName,
   263  		MigrationsTable:  customParams["x-migrations-table"],
   264  		NoLock:           noLock,
   265  		StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
   266  	})
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  
   271  	return mx, nil
   272  }
   273  
   274  func (m *Mysql) Close() error {
   275  	connErr := m.conn.Close()
   276  	var dbErr error
   277  	if m.db != nil {
   278  		dbErr = m.db.Close()
   279  	}
   280  
   281  	if connErr != nil || dbErr != nil {
   282  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   283  	}
   284  	return nil
   285  }
   286  
   287  func (m *Mysql) Lock() error {
   288  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   289  		if m.config.NoLock {
   290  			return nil
   291  		}
   292  		aid, err := database.GenerateAdvisoryLockId(
   293  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   294  		if err != nil {
   295  			return err
   296  		}
   297  
   298  		query := "SELECT GET_LOCK(?, 10)"
   299  		var success bool
   300  		if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   301  			return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   302  		}
   303  
   304  		if !success {
   305  			return database.ErrLocked
   306  		}
   307  
   308  		return nil
   309  	})
   310  }
   311  
   312  func (m *Mysql) Unlock() error {
   313  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   314  		if m.config.NoLock {
   315  			return nil
   316  		}
   317  
   318  		aid, err := database.GenerateAdvisoryLockId(
   319  			fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   320  		if err != nil {
   321  			return err
   322  		}
   323  
   324  		query := `SELECT RELEASE_LOCK(?)`
   325  		if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   326  			return &database.Error{OrigErr: err, Query: []byte(query)}
   327  		}
   328  
   329  		// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   330  		// in which case isLocked should be true until the timeout expires -- synchronizing
   331  		// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   332  
   333  		return nil
   334  	})
   335  }
   336  
   337  func (m *Mysql) Run(migration io.Reader) error {
   338  	migr, err := io.ReadAll(migration)
   339  	if err != nil {
   340  		return err
   341  	}
   342  
   343  	ctx := context.Background()
   344  	if m.config.StatementTimeout != 0 {
   345  		var cancel context.CancelFunc
   346  		ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout)
   347  		defer cancel()
   348  	}
   349  
   350  	query := string(migr[:])
   351  	if _, err := m.conn.ExecContext(ctx, query); err != nil {
   352  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   353  	}
   354  
   355  	return nil
   356  }
   357  
   358  func (m *Mysql) SetVersion(version int, dirty bool) error {
   359  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
   360  	if err != nil {
   361  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   362  	}
   363  
   364  	query := "DELETE FROM `" + m.config.MigrationsTable + "`"
   365  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   366  		if errRollback := tx.Rollback(); errRollback != nil {
   367  			err = multierror.Append(err, errRollback)
   368  		}
   369  		return &database.Error{OrigErr: err, Query: []byte(query)}
   370  	}
   371  
   372  	// Also re-write the schema version for nil dirty versions to prevent
   373  	// empty schema version for failed down migration on the first migration
   374  	// See: https://github.com/golang-migrate/migrate/issues/330
   375  	if version >= 0 || (version == database.NilVersion && dirty) {
   376  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   377  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   378  			if errRollback := tx.Rollback(); errRollback != nil {
   379  				err = multierror.Append(err, errRollback)
   380  			}
   381  			return &database.Error{OrigErr: err, Query: []byte(query)}
   382  		}
   383  	}
   384  
   385  	if err := tx.Commit(); err != nil {
   386  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   387  	}
   388  
   389  	return nil
   390  }
   391  
   392  func (m *Mysql) Version() (version int, dirty bool, err error) {
   393  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   394  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   395  	switch {
   396  	case err == sql.ErrNoRows:
   397  		return database.NilVersion, false, nil
   398  
   399  	case err != nil:
   400  		if e, ok := err.(*mysql.MySQLError); ok {
   401  			if e.Number == 0 {
   402  				return database.NilVersion, false, nil
   403  			}
   404  		}
   405  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   406  
   407  	default:
   408  		return version, dirty, nil
   409  	}
   410  }
   411  
   412  func (m *Mysql) Drop() (err error) {
   413  	// select all tables
   414  	query := `SHOW TABLES LIKE '%'`
   415  	tables, err := m.conn.QueryContext(context.Background(), query)
   416  	if err != nil {
   417  		return &database.Error{OrigErr: err, Query: []byte(query)}
   418  	}
   419  	defer func() {
   420  		if errClose := tables.Close(); errClose != nil {
   421  			err = multierror.Append(err, errClose)
   422  		}
   423  	}()
   424  
   425  	// delete one table after another
   426  	tableNames := make([]string, 0)
   427  	for tables.Next() {
   428  		var tableName string
   429  		if err := tables.Scan(&tableName); err != nil {
   430  			return err
   431  		}
   432  		if len(tableName) > 0 {
   433  			tableNames = append(tableNames, tableName)
   434  		}
   435  	}
   436  	if err := tables.Err(); err != nil {
   437  		return &database.Error{OrigErr: err, Query: []byte(query)}
   438  	}
   439  
   440  	if len(tableNames) > 0 {
   441  		// disable checking foreign key constraints until finished
   442  		query = `SET foreign_key_checks = 0`
   443  		if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   444  			return &database.Error{OrigErr: err, Query: []byte(query)}
   445  		}
   446  
   447  		defer func() {
   448  			// enable foreign key checks
   449  			_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
   450  		}()
   451  
   452  		// delete one by one ...
   453  		for _, t := range tableNames {
   454  			query = "DROP TABLE IF EXISTS `" + t + "`"
   455  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   456  				return &database.Error{OrigErr: err, Query: []byte(query)}
   457  			}
   458  		}
   459  	}
   460  
   461  	return nil
   462  }
   463  
   464  // ensureVersionTable checks if versions table exists and, if not, creates it.
   465  // Note that this function locks the database, which deviates from the usual
   466  // convention of "caller locks" in the Mysql type.
   467  func (m *Mysql) ensureVersionTable() (err error) {
   468  	if err = m.Lock(); err != nil {
   469  		return err
   470  	}
   471  
   472  	defer func() {
   473  		if e := m.Unlock(); e != nil {
   474  			if err == nil {
   475  				err = e
   476  			} else {
   477  				err = multierror.Append(err, e)
   478  			}
   479  		}
   480  	}()
   481  
   482  	// check if migration table exists
   483  	var result string
   484  	query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
   485  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   486  		if err != sql.ErrNoRows {
   487  			return &database.Error{OrigErr: err, Query: []byte(query)}
   488  		}
   489  	} else {
   490  		return nil
   491  	}
   492  
   493  	// if not, create the empty migration table
   494  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   495  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   496  		return &database.Error{OrigErr: err, Query: []byte(query)}
   497  	}
   498  	return nil
   499  }
   500  
   501  // Returns the bool value of the input.
   502  // The 2nd return value indicates if the input was a valid bool value
   503  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   504  func readBool(input string) (value bool, valid bool) {
   505  	switch input {
   506  	case "1", "true", "TRUE", "True":
   507  		return true, true
   508  	case "0", "false", "FALSE", "False":
   509  		return false, true
   510  	}
   511  
   512  	// Not a valid bool value
   513  	return
   514  }