github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/go-sql-driver/mysql"
    18  	"github.com/golang-migrate/migrate"
    19  	"github.com/golang-migrate/migrate/database"
    20  )
    21  
    22  func init() {
    23  	database.Register("mysql", &Mysql{})
    24  }
    25  
    26  var DefaultMigrationsTable = "schema_migrations"
    27  
    28  var (
    29  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    30  	ErrNilConfig      = fmt.Errorf("no config")
    31  	ErrNoDatabaseName = fmt.Errorf("no database name")
    32  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    33  )
    34  
    35  type Config struct {
    36  	MigrationsTable string
    37  	DatabaseName    string
    38  }
    39  
    40  type Mysql struct {
    41  	// mysql RELEASE_LOCK must be called from the same conn, so
    42  	// just do everything over a single conn anyway.
    43  	conn     *sql.Conn
    44  	isLocked bool
    45  
    46  	config *Config
    47  }
    48  
    49  // instance must have `multiStatements` set to true
    50  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    51  	if config == nil {
    52  		return nil, ErrNilConfig
    53  	}
    54  
    55  	if err := instance.Ping(); err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	query := `SELECT DATABASE()`
    60  	var databaseName sql.NullString
    61  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    62  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    63  	}
    64  
    65  	if len(databaseName.String) == 0 {
    66  		return nil, ErrNoDatabaseName
    67  	}
    68  
    69  	config.DatabaseName = databaseName.String
    70  
    71  	if len(config.MigrationsTable) == 0 {
    72  		config.MigrationsTable = DefaultMigrationsTable
    73  	}
    74  
    75  	conn, err := instance.Conn(context.Background())
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	mx := &Mysql{
    81  		conn:   conn,
    82  		config: config,
    83  	}
    84  
    85  	if err := mx.ensureVersionTable(); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	return mx, nil
    90  }
    91  
    92  func (m *Mysql) Open(url string) (database.Driver, error) {
    93  	url = strings.TrimPrefix(url, "mysql://")
    94  	purl, err := nurl.Parse(url)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	q := purl.Query()
   100  	q.Set("multiStatements", "true")
   101  	purl.RawQuery = q.Encode()
   102  
   103  	db, err := sql.Open("mysql", migrate.FilterCustomQuery(purl).String())
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	migrationsTable := purl.Query().Get("x-migrations-table")
   109  	if len(migrationsTable) == 0 {
   110  		migrationsTable = DefaultMigrationsTable
   111  	}
   112  
   113  	// use custom TLS?
   114  	ctls := purl.Query().Get("tls")
   115  	if len(ctls) > 0 {
   116  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   117  			rootCertPool := x509.NewCertPool()
   118  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   119  			if err != nil {
   120  				return nil, err
   121  			}
   122  
   123  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   124  				return nil, ErrAppendPEM
   125  			}
   126  
   127  			certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"))
   128  			if err != nil {
   129  				return nil, err
   130  			}
   131  
   132  			insecureSkipVerify := false
   133  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   134  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   135  				if err != nil {
   136  					return nil, err
   137  				}
   138  				insecureSkipVerify = x
   139  			}
   140  
   141  			mysql.RegisterTLSConfig(ctls, &tls.Config{
   142  				RootCAs:            rootCertPool,
   143  				Certificates:       []tls.Certificate{certs},
   144  				InsecureSkipVerify: insecureSkipVerify,
   145  			})
   146  		}
   147  	}
   148  
   149  	mx, err := WithInstance(db, &Config{
   150  		DatabaseName:    purl.Path,
   151  		MigrationsTable: migrationsTable,
   152  	})
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  
   157  	return mx, nil
   158  }
   159  
   160  func (m *Mysql) Close() error {
   161  	return m.conn.Close()
   162  }
   163  
   164  func (m *Mysql) Lock() error {
   165  	if m.isLocked {
   166  		return database.ErrLocked
   167  	}
   168  
   169  	aid, err := database.GenerateAdvisoryLockId(
   170  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	query := "SELECT GET_LOCK(?, 10)"
   176  	var success bool
   177  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   178  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   179  	}
   180  
   181  	if success {
   182  		m.isLocked = true
   183  		return nil
   184  	}
   185  
   186  	return database.ErrLocked
   187  }
   188  
   189  func (m *Mysql) Unlock() error {
   190  	if !m.isLocked {
   191  		return nil
   192  	}
   193  
   194  	aid, err := database.GenerateAdvisoryLockId(
   195  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	query := `SELECT RELEASE_LOCK(?)`
   201  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   202  		return &database.Error{OrigErr: err, Query: []byte(query)}
   203  	}
   204  
   205  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   206  	// in which case isLocked should be true until the timeout expires -- synchronizing
   207  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   208  
   209  	m.isLocked = false
   210  	return nil
   211  }
   212  
   213  func (m *Mysql) Run(migration io.Reader) error {
   214  	migr, err := ioutil.ReadAll(migration)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	query := string(migr[:])
   220  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   221  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  func (m *Mysql) SetVersion(version int, dirty bool) error {
   228  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   229  	if err != nil {
   230  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   231  	}
   232  
   233  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   234  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   235  		tx.Rollback()
   236  		return &database.Error{OrigErr: err, Query: []byte(query)}
   237  	}
   238  
   239  	if version >= 0 {
   240  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   241  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   242  			tx.Rollback()
   243  			return &database.Error{OrigErr: err, Query: []byte(query)}
   244  		}
   245  	}
   246  
   247  	if err := tx.Commit(); err != nil {
   248  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   249  	}
   250  
   251  	return nil
   252  }
   253  
   254  func (m *Mysql) Version() (version int, dirty bool, err error) {
   255  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   256  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   257  	switch {
   258  	case err == sql.ErrNoRows:
   259  		return database.NilVersion, false, nil
   260  
   261  	case err != nil:
   262  		if e, ok := err.(*mysql.MySQLError); ok {
   263  			if e.Number == 0 {
   264  				return database.NilVersion, false, nil
   265  			}
   266  		}
   267  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   268  
   269  	default:
   270  		return version, dirty, nil
   271  	}
   272  }
   273  
   274  func (m *Mysql) Drop() error {
   275  	// select all tables
   276  	query := `SHOW TABLES LIKE '%'`
   277  	tables, err := m.conn.QueryContext(context.Background(), query)
   278  	if err != nil {
   279  		return &database.Error{OrigErr: err, Query: []byte(query)}
   280  	}
   281  	defer tables.Close()
   282  
   283  	// delete one table after another
   284  	tableNames := make([]string, 0)
   285  	for tables.Next() {
   286  		var tableName string
   287  		if err := tables.Scan(&tableName); err != nil {
   288  			return err
   289  		}
   290  		if len(tableName) > 0 {
   291  			tableNames = append(tableNames, tableName)
   292  		}
   293  	}
   294  
   295  	if len(tableNames) > 0 {
   296  		// delete one by one ...
   297  		for _, t := range tableNames {
   298  			query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
   299  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   300  				return &database.Error{OrigErr: err, Query: []byte(query)}
   301  			}
   302  		}
   303  		if err := m.ensureVersionTable(); err != nil {
   304  			return err
   305  		}
   306  	}
   307  
   308  	return nil
   309  }
   310  
   311  func (m *Mysql) ensureVersionTable() error {
   312  	// check if migration table exists
   313  	var result string
   314  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   315  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   316  		if err != sql.ErrNoRows {
   317  			return &database.Error{OrigErr: err, Query: []byte(query)}
   318  		}
   319  	} else {
   320  		return nil
   321  	}
   322  
   323  	// if not, create the empty migration table
   324  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   325  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   326  		return &database.Error{OrigErr: err, Query: []byte(query)}
   327  	}
   328  	return nil
   329  }
   330  
   331  // Returns the bool value of the input.
   332  // The 2nd return value indicates if the input was a valid bool value
   333  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   334  func readBool(input string) (value bool, valid bool) {
   335  	switch input {
   336  	case "1", "true", "TRUE", "True":
   337  		return true, true
   338  	case "0", "false", "FALSE", "False":
   339  		return false, true
   340  	}
   341  
   342  	// Not a valid bool value
   343  	return
   344  }