github.com/dhui/migrate@v3.4.0+incompatible/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  import (
    19  	"github.com/go-sql-driver/mysql"
    20  )
    21  
    22  import (
    23  	"github.com/golang-migrate/migrate"
    24  	"github.com/golang-migrate/migrate/database"
    25  )
    26  
    27  func init() {
    28  	database.Register("mysql", &Mysql{})
    29  }
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    35  	ErrNilConfig      = fmt.Errorf("no config")
    36  	ErrNoDatabaseName = fmt.Errorf("no database name")
    37  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    38  )
    39  
    40  type Config struct {
    41  	MigrationsTable string
    42  	DatabaseName    string
    43  }
    44  
    45  type Mysql struct {
    46  	// mysql RELEASE_LOCK must be called from the same conn, so
    47  	// just do everything over a single conn anyway.
    48  	conn     *sql.Conn
    49  	isLocked bool
    50  
    51  	config *Config
    52  }
    53  
    54  // instance must have `multiStatements` set to true
    55  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    56  	if config == nil {
    57  		return nil, ErrNilConfig
    58  	}
    59  
    60  	if err := instance.Ping(); err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	query := `SELECT DATABASE()`
    65  	var databaseName sql.NullString
    66  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    67  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    68  	}
    69  
    70  	if len(databaseName.String) == 0 {
    71  		return nil, ErrNoDatabaseName
    72  	}
    73  
    74  	config.DatabaseName = databaseName.String
    75  
    76  	if len(config.MigrationsTable) == 0 {
    77  		config.MigrationsTable = DefaultMigrationsTable
    78  	}
    79  
    80  	conn, err := instance.Conn(context.Background())
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	mx := &Mysql{
    86  		conn:   conn,
    87  		config: config,
    88  	}
    89  
    90  	if err := mx.ensureVersionTable(); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return mx, nil
    95  }
    96  
    97  // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
    98  // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
    99  func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
   100  	origUserInfo := u.User
   101  	u.User = nil
   102  
   103  	c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	if origUserInfo != nil {
   108  		c.User = origUserInfo.Username()
   109  		if p, ok := origUserInfo.Password(); ok {
   110  			c.Passwd = p
   111  		}
   112  	}
   113  	return c, nil
   114  }
   115  
   116  func (m *Mysql) Open(url string) (database.Driver, error) {
   117  	purl, err := nurl.Parse(url)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	q := purl.Query()
   123  	q.Set("multiStatements", "true")
   124  	purl.RawQuery = q.Encode()
   125  
   126  	c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	db, err := sql.Open("mysql", c.FormatDSN())
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	migrationsTable := purl.Query().Get("x-migrations-table")
   136  	if len(migrationsTable) == 0 {
   137  		migrationsTable = DefaultMigrationsTable
   138  	}
   139  
   140  	// use custom TLS?
   141  	ctls := purl.Query().Get("tls")
   142  	if len(ctls) > 0 {
   143  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   144  			rootCertPool := x509.NewCertPool()
   145  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   146  			if err != nil {
   147  				return nil, err
   148  			}
   149  
   150  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   151  				return nil, ErrAppendPEM
   152  			}
   153  
   154  			certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"))
   155  			if err != nil {
   156  				return nil, err
   157  			}
   158  
   159  			insecureSkipVerify := false
   160  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   161  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   162  				if err != nil {
   163  					return nil, err
   164  				}
   165  				insecureSkipVerify = x
   166  			}
   167  
   168  			mysql.RegisterTLSConfig(ctls, &tls.Config{
   169  				RootCAs:            rootCertPool,
   170  				Certificates:       []tls.Certificate{certs},
   171  				InsecureSkipVerify: insecureSkipVerify,
   172  			})
   173  		}
   174  	}
   175  
   176  	mx, err := WithInstance(db, &Config{
   177  		DatabaseName:    purl.Path,
   178  		MigrationsTable: migrationsTable,
   179  	})
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  
   184  	return mx, nil
   185  }
   186  
   187  func (m *Mysql) Close() error {
   188  	return m.conn.Close()
   189  }
   190  
   191  func (m *Mysql) Lock() error {
   192  	if m.isLocked {
   193  		return database.ErrLocked
   194  	}
   195  
   196  	aid, err := database.GenerateAdvisoryLockId(
   197  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	query := "SELECT GET_LOCK(?, 10)"
   203  	var success bool
   204  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   205  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   206  	}
   207  
   208  	if success {
   209  		m.isLocked = true
   210  		return nil
   211  	}
   212  
   213  	return database.ErrLocked
   214  }
   215  
   216  func (m *Mysql) Unlock() error {
   217  	if !m.isLocked {
   218  		return nil
   219  	}
   220  
   221  	aid, err := database.GenerateAdvisoryLockId(
   222  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   223  	if err != nil {
   224  		return err
   225  	}
   226  
   227  	query := `SELECT RELEASE_LOCK(?)`
   228  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   229  		return &database.Error{OrigErr: err, Query: []byte(query)}
   230  	}
   231  
   232  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   233  	// in which case isLocked should be true until the timeout expires -- synchronizing
   234  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   235  
   236  	m.isLocked = false
   237  	return nil
   238  }
   239  
   240  func (m *Mysql) Run(migration io.Reader) error {
   241  	migr, err := ioutil.ReadAll(migration)
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	query := string(migr[:])
   247  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   248  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func (m *Mysql) SetVersion(version int, dirty bool) error {
   255  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   256  	if err != nil {
   257  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   258  	}
   259  
   260  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   261  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   262  		tx.Rollback()
   263  		return &database.Error{OrigErr: err, Query: []byte(query)}
   264  	}
   265  
   266  	if version >= 0 {
   267  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   268  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   269  			tx.Rollback()
   270  			return &database.Error{OrigErr: err, Query: []byte(query)}
   271  		}
   272  	}
   273  
   274  	if err := tx.Commit(); err != nil {
   275  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   276  	}
   277  
   278  	return nil
   279  }
   280  
   281  func (m *Mysql) Version() (version int, dirty bool, err error) {
   282  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   283  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   284  	switch {
   285  	case err == sql.ErrNoRows:
   286  		return database.NilVersion, false, nil
   287  
   288  	case err != nil:
   289  		if e, ok := err.(*mysql.MySQLError); ok {
   290  			if e.Number == 0 {
   291  				return database.NilVersion, false, nil
   292  			}
   293  		}
   294  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   295  
   296  	default:
   297  		return version, dirty, nil
   298  	}
   299  }
   300  
   301  func (m *Mysql) Drop() error {
   302  	// select all tables
   303  	query := `SHOW TABLES LIKE '%'`
   304  	tables, err := m.conn.QueryContext(context.Background(), query)
   305  	if err != nil {
   306  		return &database.Error{OrigErr: err, Query: []byte(query)}
   307  	}
   308  	defer tables.Close()
   309  
   310  	// delete one table after another
   311  	tableNames := make([]string, 0)
   312  	for tables.Next() {
   313  		var tableName string
   314  		if err := tables.Scan(&tableName); err != nil {
   315  			return err
   316  		}
   317  		if len(tableName) > 0 {
   318  			tableNames = append(tableNames, tableName)
   319  		}
   320  	}
   321  
   322  	if len(tableNames) > 0 {
   323  		// delete one by one ...
   324  		for _, t := range tableNames {
   325  			query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
   326  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   327  				return &database.Error{OrigErr: err, Query: []byte(query)}
   328  			}
   329  		}
   330  		if err := m.ensureVersionTable(); err != nil {
   331  			return err
   332  		}
   333  	}
   334  
   335  	return nil
   336  }
   337  
   338  func (m *Mysql) ensureVersionTable() error {
   339  	// check if migration table exists
   340  	var result string
   341  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   342  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   343  		if err != sql.ErrNoRows {
   344  			return &database.Error{OrigErr: err, Query: []byte(query)}
   345  		}
   346  	} else {
   347  		return nil
   348  	}
   349  
   350  	// if not, create the empty migration table
   351  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   352  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   353  		return &database.Error{OrigErr: err, Query: []byte(query)}
   354  	}
   355  	return nil
   356  }
   357  
   358  // Returns the bool value of the input.
   359  // The 2nd return value indicates if the input was a valid bool value
   360  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   361  func readBool(input string) (value bool, valid bool) {
   362  	switch input {
   363  	case "1", "true", "TRUE", "True":
   364  		return true, true
   365  	case "0", "false", "FALSE", "False":
   366  		return false, true
   367  	}
   368  
   369  	// Not a valid bool value
   370  	return
   371  }