github.com/fr-nvriep/migrate/v4@v4.3.2/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  import (
    19  	"github.com/go-sql-driver/mysql"
    20  	"github.com/hashicorp/go-multierror"
    21  )
    22  
    23  import (
    24  	"github.com/fr-nvriep/migrate/v4"
    25  	"github.com/fr-nvriep/migrate/v4/database"
    26  )
    27  
    28  func init() {
    29  	database.Register("mysql", &Mysql{})
    30  }
    31  
    32  var DefaultMigrationsTable = "schema_migrations"
    33  
    34  var (
    35  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    36  	ErrNilConfig        = fmt.Errorf("no config")
    37  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    38  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    39  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    40  )
    41  
    42  type Config struct {
    43  	MigrationsTable string
    44  	DatabaseName    string
    45  }
    46  
    47  type Mysql struct {
    48  	// mysql RELEASE_LOCK must be called from the same conn, so
    49  	// just do everything over a single conn anyway.
    50  	conn     *sql.Conn
    51  	db       *sql.DB
    52  	isLocked bool
    53  
    54  	config *Config
    55  }
    56  
    57  // instance must have `multiStatements` set to true
    58  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    59  	if config == nil {
    60  		return nil, ErrNilConfig
    61  	}
    62  
    63  	if err := instance.Ping(); err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	query := `SELECT DATABASE()`
    68  	var databaseName sql.NullString
    69  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    70  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    71  	}
    72  
    73  	if len(databaseName.String) == 0 {
    74  		return nil, ErrNoDatabaseName
    75  	}
    76  
    77  	config.DatabaseName = databaseName.String
    78  
    79  	if len(config.MigrationsTable) == 0 {
    80  		config.MigrationsTable = DefaultMigrationsTable
    81  	}
    82  
    83  	conn, err := instance.Conn(context.Background())
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	mx := &Mysql{
    89  		conn:   conn,
    90  		db:     instance,
    91  		config: config,
    92  	}
    93  
    94  	if err := mx.ensureVersionTable(); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	return mx, nil
    99  }
   100  
   101  // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
   102  // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
   103  func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
   104  	origUserInfo := u.User
   105  	u.User = nil
   106  
   107  	c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	if origUserInfo != nil {
   112  		c.User = origUserInfo.Username()
   113  		if p, ok := origUserInfo.Password(); ok {
   114  			c.Passwd = p
   115  		}
   116  	}
   117  	return c, nil
   118  }
   119  
   120  func (m *Mysql) Open(url string) (database.Driver, error) {
   121  	purl, err := nurl.Parse(url)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	q := purl.Query()
   127  	q.Set("multiStatements", "true")
   128  	purl.RawQuery = q.Encode()
   129  
   130  	migrationsTable := purl.Query().Get("x-migrations-table")
   131  
   132  	// use custom TLS?
   133  	ctls := purl.Query().Get("tls")
   134  	if len(ctls) > 0 {
   135  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   136  			rootCertPool := x509.NewCertPool()
   137  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   138  			if err != nil {
   139  				return nil, err
   140  			}
   141  
   142  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   143  				return nil, ErrAppendPEM
   144  			}
   145  
   146  			clientCert := make([]tls.Certificate, 0, 1)
   147  			if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
   148  				if ccert == "" || ckey == "" {
   149  					return nil, ErrTLSCertKeyConfig
   150  				}
   151  				certs, err := tls.LoadX509KeyPair(ccert, ckey)
   152  				if err != nil {
   153  					return nil, err
   154  				}
   155  				clientCert = append(clientCert, certs)
   156  			}
   157  
   158  			insecureSkipVerify := false
   159  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   160  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   161  				if err != nil {
   162  					return nil, err
   163  				}
   164  				insecureSkipVerify = x
   165  			}
   166  
   167  			err = mysql.RegisterTLSConfig(ctls, &tls.Config{
   168  				RootCAs:            rootCertPool,
   169  				Certificates:       clientCert,
   170  				InsecureSkipVerify: insecureSkipVerify,
   171  			})
   172  			if err != nil {
   173  				return nil, err
   174  			}
   175  		}
   176  	}
   177  
   178  	c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	db, err := sql.Open("mysql", c.FormatDSN())
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	mx, err := WithInstance(db, &Config{
   188  		DatabaseName:    purl.Path,
   189  		MigrationsTable: migrationsTable,
   190  	})
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	return mx, nil
   196  }
   197  
   198  func (m *Mysql) Close() error {
   199  	connErr := m.conn.Close()
   200  	dbErr := m.db.Close()
   201  	if connErr != nil || dbErr != nil {
   202  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   203  	}
   204  	return nil
   205  }
   206  
   207  func (m *Mysql) Lock() error {
   208  	if m.isLocked {
   209  		return database.ErrLocked
   210  	}
   211  
   212  	aid, err := database.GenerateAdvisoryLockId(
   213  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   214  	if err != nil {
   215  		return err
   216  	}
   217  
   218  	query := "SELECT GET_LOCK(?, 10)"
   219  	var success bool
   220  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   221  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   222  	}
   223  
   224  	if success {
   225  		m.isLocked = true
   226  		return nil
   227  	}
   228  
   229  	return database.ErrLocked
   230  }
   231  
   232  func (m *Mysql) Unlock() error {
   233  	if !m.isLocked {
   234  		return nil
   235  	}
   236  
   237  	aid, err := database.GenerateAdvisoryLockId(
   238  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	query := `SELECT RELEASE_LOCK(?)`
   244  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   245  		return &database.Error{OrigErr: err, Query: []byte(query)}
   246  	}
   247  
   248  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   249  	// in which case isLocked should be true until the timeout expires -- synchronizing
   250  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   251  
   252  	m.isLocked = false
   253  	return nil
   254  }
   255  
   256  func (m *Mysql) Run(migration io.Reader) error {
   257  	migr, err := ioutil.ReadAll(migration)
   258  	if err != nil {
   259  		return err
   260  	}
   261  
   262  	query := string(migr[:])
   263  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   264  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   265  	}
   266  
   267  	return nil
   268  }
   269  
   270  func (m *Mysql) SetVersion(version int, dirty bool) error {
   271  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   272  	if err != nil {
   273  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   274  	}
   275  
   276  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   277  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   278  		if errRollback := tx.Rollback(); errRollback != nil {
   279  			err = multierror.Append(err, errRollback)
   280  		}
   281  		return &database.Error{OrigErr: err, Query: []byte(query)}
   282  	}
   283  
   284  	if version >= 0 {
   285  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   286  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   287  			if errRollback := tx.Rollback(); errRollback != nil {
   288  				err = multierror.Append(err, errRollback)
   289  			}
   290  			return &database.Error{OrigErr: err, Query: []byte(query)}
   291  		}
   292  	}
   293  
   294  	if err := tx.Commit(); err != nil {
   295  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  func (m *Mysql) Version() (version int, dirty bool, err error) {
   302  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   303  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   304  	switch {
   305  	case err == sql.ErrNoRows:
   306  		return database.NilVersion, false, nil
   307  
   308  	case err != nil:
   309  		if e, ok := err.(*mysql.MySQLError); ok {
   310  			if e.Number == 0 {
   311  				return database.NilVersion, false, nil
   312  			}
   313  		}
   314  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   315  
   316  	default:
   317  		return version, dirty, nil
   318  	}
   319  }
   320  
   321  func (m *Mysql) Drop() (err error) {
   322  	// select all tables
   323  	query := `SHOW TABLES LIKE '%'`
   324  	tables, err := m.conn.QueryContext(context.Background(), query)
   325  	if err != nil {
   326  		return &database.Error{OrigErr: err, Query: []byte(query)}
   327  	}
   328  	defer func() {
   329  		if errClose := tables.Close(); errClose != nil {
   330  			err = multierror.Append(err, errClose)
   331  		}
   332  	}()
   333  
   334  	// delete one table after another
   335  	tableNames := make([]string, 0)
   336  	for tables.Next() {
   337  		var tableName string
   338  		if err := tables.Scan(&tableName); err != nil {
   339  			return err
   340  		}
   341  		if len(tableName) > 0 {
   342  			tableNames = append(tableNames, tableName)
   343  		}
   344  	}
   345  
   346  	if len(tableNames) > 0 {
   347  		// delete one by one ...
   348  		for _, t := range tableNames {
   349  			query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
   350  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   351  				return &database.Error{OrigErr: err, Query: []byte(query)}
   352  			}
   353  		}
   354  	}
   355  
   356  	return nil
   357  }
   358  
   359  // ensureVersionTable checks if versions table exists and, if not, creates it.
   360  // Note that this function locks the database, which deviates from the usual
   361  // convention of "caller locks" in the Mysql type.
   362  func (m *Mysql) ensureVersionTable() (err error) {
   363  	if err = m.Lock(); err != nil {
   364  		return err
   365  	}
   366  
   367  	defer func() {
   368  		if e := m.Unlock(); e != nil {
   369  			if err == nil {
   370  				err = e
   371  			} else {
   372  				err = multierror.Append(err, e)
   373  			}
   374  		}
   375  	}()
   376  
   377  	// check if migration table exists
   378  	var result string
   379  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   380  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   381  		if err != sql.ErrNoRows {
   382  			return &database.Error{OrigErr: err, Query: []byte(query)}
   383  		}
   384  	} else {
   385  		return nil
   386  	}
   387  
   388  	// if not, create the empty migration table
   389  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   390  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   391  		return &database.Error{OrigErr: err, Query: []byte(query)}
   392  	}
   393  	return nil
   394  }
   395  
   396  // Returns the bool value of the input.
   397  // The 2nd return value indicates if the input was a valid bool value
   398  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   399  func readBool(input string) (value bool, valid bool) {
   400  	switch input {
   401  	case "1", "true", "TRUE", "True":
   402  		return true, true
   403  	case "0", "false", "FALSE", "False":
   404  		return false, true
   405  	}
   406  
   407  	// Not a valid bool value
   408  	return
   409  }