github.com/mrqzzz/migrate@v5.1.7+incompatible/database/mysql/mysql.go (about)

     1  // +build go1.9
     2  
     3  package mysql
     4  
     5  import (
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"database/sql"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	nurl "net/url"
    14  	"strconv"
    15  	"strings"
    16  )
    17  
    18  import (
    19  	"github.com/go-sql-driver/mysql"
    20  )
    21  
    22  import (
    23  	"github.com/golang-migrate/migrate/v4"
    24  	"github.com/golang-migrate/migrate/v4/database"
    25  )
    26  
    27  func init() {
    28  	database.Register("mysql", &Mysql{})
    29  }
    30  
    31  var DefaultMigrationsTable = "schema_migrations"
    32  
    33  var (
    34  	ErrDatabaseDirty    = fmt.Errorf("database is dirty")
    35  	ErrNilConfig        = fmt.Errorf("no config")
    36  	ErrNoDatabaseName   = fmt.Errorf("no database name")
    37  	ErrAppendPEM        = fmt.Errorf("failed to append PEM")
    38  	ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
    39  )
    40  
    41  type Config struct {
    42  	MigrationsTable string
    43  	DatabaseName    string
    44  }
    45  
    46  type Mysql struct {
    47  	// mysql RELEASE_LOCK must be called from the same conn, so
    48  	// just do everything over a single conn anyway.
    49  	conn     *sql.Conn
    50  	db       *sql.DB
    51  	isLocked bool
    52  
    53  	config *Config
    54  }
    55  
    56  // instance must have `multiStatements` set to true
    57  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    58  	if config == nil {
    59  		return nil, ErrNilConfig
    60  	}
    61  
    62  	if err := instance.Ping(); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	query := `SELECT DATABASE()`
    67  	var databaseName sql.NullString
    68  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    69  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    70  	}
    71  
    72  	if len(databaseName.String) == 0 {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	config.DatabaseName = databaseName.String
    77  
    78  	if len(config.MigrationsTable) == 0 {
    79  		config.MigrationsTable = DefaultMigrationsTable
    80  	}
    81  
    82  	conn, err := instance.Conn(context.Background())
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	mx := &Mysql{
    88  		conn:   conn,
    89  		db:     instance,
    90  		config: config,
    91  	}
    92  
    93  	if err := mx.ensureVersionTable(); err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return mx, nil
    98  }
    99  
   100  // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config.
   101  // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters
   102  func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) {
   103  	origUserInfo := u.User
   104  	u.User = nil
   105  
   106  	c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://"))
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	if origUserInfo != nil {
   111  		c.User = origUserInfo.Username()
   112  		if p, ok := origUserInfo.Password(); ok {
   113  			c.Passwd = p
   114  		}
   115  	}
   116  	return c, nil
   117  }
   118  
   119  func (m *Mysql) Open(url string) (database.Driver, error) {
   120  	purl, err := nurl.Parse(url)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	q := purl.Query()
   126  	q.Set("multiStatements", "true")
   127  	purl.RawQuery = q.Encode()
   128  
   129  	migrationsTable := purl.Query().Get("x-migrations-table")
   130  	if len(migrationsTable) == 0 {
   131  		migrationsTable = DefaultMigrationsTable
   132  	}
   133  
   134  	// use custom TLS?
   135  	ctls := purl.Query().Get("tls")
   136  	if len(ctls) > 0 {
   137  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   138  			rootCertPool := x509.NewCertPool()
   139  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   140  			if err != nil {
   141  				return nil, err
   142  			}
   143  
   144  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   145  				return nil, ErrAppendPEM
   146  			}
   147  
   148  			clientCert := make([]tls.Certificate, 0, 1)
   149  			if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" {
   150  				if ccert == "" || ckey == "" {
   151  					return nil, ErrTLSCertKeyConfig
   152  				}
   153  				certs, err := tls.LoadX509KeyPair(ccert, ckey)
   154  				if err != nil {
   155  					return nil, err
   156  				}
   157  				clientCert = append(clientCert, certs)
   158  			}
   159  
   160  			insecureSkipVerify := false
   161  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   162  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   163  				if err != nil {
   164  					return nil, err
   165  				}
   166  				insecureSkipVerify = x
   167  			}
   168  
   169  			mysql.RegisterTLSConfig(ctls, &tls.Config{
   170  				RootCAs:            rootCertPool,
   171  				Certificates:       clientCert,
   172  				InsecureSkipVerify: insecureSkipVerify,
   173  			})
   174  		}
   175  	}
   176  
   177  	c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl))
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	db, err := sql.Open("mysql", c.FormatDSN())
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	mx, err := WithInstance(db, &Config{
   187  		DatabaseName:    purl.Path,
   188  		MigrationsTable: migrationsTable,
   189  	})
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	return mx, nil
   195  }
   196  
   197  func (m *Mysql) Close() error {
   198  	connErr := m.conn.Close()
   199  	dbErr := m.db.Close()
   200  	if connErr != nil || dbErr != nil {
   201  		return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
   202  	}
   203  	return nil
   204  }
   205  
   206  func (m *Mysql) Lock() error {
   207  	if m.isLocked {
   208  		return database.ErrLocked
   209  	}
   210  
   211  	aid, err := database.GenerateAdvisoryLockId(
   212  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	query := "SELECT GET_LOCK(?, 10)"
   218  	var success bool
   219  	if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
   220  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   221  	}
   222  
   223  	if success {
   224  		m.isLocked = true
   225  		return nil
   226  	}
   227  
   228  	return database.ErrLocked
   229  }
   230  
   231  func (m *Mysql) Unlock() error {
   232  	if !m.isLocked {
   233  		return nil
   234  	}
   235  
   236  	aid, err := database.GenerateAdvisoryLockId(
   237  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   238  	if err != nil {
   239  		return err
   240  	}
   241  
   242  	query := `SELECT RELEASE_LOCK(?)`
   243  	if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
   244  		return &database.Error{OrigErr: err, Query: []byte(query)}
   245  	}
   246  
   247  	// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
   248  	// in which case isLocked should be true until the timeout expires -- synchronizing
   249  	// these states is likely not worth trying to do; reconsider the necessity of isLocked.
   250  
   251  	m.isLocked = false
   252  	return nil
   253  }
   254  
   255  func (m *Mysql) Run(migration io.Reader) error {
   256  	migr, err := ioutil.ReadAll(migration)
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	query := string(migr[:])
   262  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   263  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   264  	}
   265  
   266  	return nil
   267  }
   268  
   269  func (m *Mysql) SetVersion(version int, dirty bool) error {
   270  	tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{})
   271  	if err != nil {
   272  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   273  	}
   274  
   275  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   276  	if _, err := tx.ExecContext(context.Background(), query); err != nil {
   277  		tx.Rollback()
   278  		return &database.Error{OrigErr: err, Query: []byte(query)}
   279  	}
   280  
   281  	if version >= 0 {
   282  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   283  		if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
   284  			tx.Rollback()
   285  			return &database.Error{OrigErr: err, Query: []byte(query)}
   286  		}
   287  	}
   288  
   289  	if err := tx.Commit(); err != nil {
   290  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func (m *Mysql) Version() (version int, dirty bool, err error) {
   297  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   298  	err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
   299  	switch {
   300  	case err == sql.ErrNoRows:
   301  		return database.NilVersion, false, nil
   302  
   303  	case err != nil:
   304  		if e, ok := err.(*mysql.MySQLError); ok {
   305  			if e.Number == 0 {
   306  				return database.NilVersion, false, nil
   307  			}
   308  		}
   309  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   310  
   311  	default:
   312  		return version, dirty, nil
   313  	}
   314  }
   315  
   316  func (m *Mysql) Drop() error {
   317  	// select all tables
   318  	query := `SHOW TABLES LIKE '%'`
   319  	tables, err := m.conn.QueryContext(context.Background(), query)
   320  	if err != nil {
   321  		return &database.Error{OrigErr: err, Query: []byte(query)}
   322  	}
   323  	defer tables.Close()
   324  
   325  	// delete one table after another
   326  	tableNames := make([]string, 0)
   327  	for tables.Next() {
   328  		var tableName string
   329  		if err := tables.Scan(&tableName); err != nil {
   330  			return err
   331  		}
   332  		if len(tableName) > 0 {
   333  			tableNames = append(tableNames, tableName)
   334  		}
   335  	}
   336  
   337  	if len(tableNames) > 0 {
   338  		// delete one by one ...
   339  		for _, t := range tableNames {
   340  			query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
   341  			if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   342  				return &database.Error{OrigErr: err, Query: []byte(query)}
   343  			}
   344  		}
   345  		if err := m.ensureVersionTable(); err != nil {
   346  			return err
   347  		}
   348  	}
   349  
   350  	return nil
   351  }
   352  
   353  func (m *Mysql) ensureVersionTable() error {
   354  	// check if migration table exists
   355  	var result string
   356  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   357  	if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
   358  		if err != sql.ErrNoRows {
   359  			return &database.Error{OrigErr: err, Query: []byte(query)}
   360  		}
   361  	} else {
   362  		return nil
   363  	}
   364  
   365  	// if not, create the empty migration table
   366  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   367  	if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
   368  		return &database.Error{OrigErr: err, Query: []byte(query)}
   369  	}
   370  	return nil
   371  }
   372  
   373  // Returns the bool value of the input.
   374  // The 2nd return value indicates if the input was a valid bool value
   375  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   376  func readBool(input string) (value bool, valid bool) {
   377  	switch input {
   378  	case "1", "true", "TRUE", "True":
   379  		return true, true
   380  	case "0", "false", "FALSE", "False":
   381  		return false, true
   382  	}
   383  
   384  	// Not a valid bool value
   385  	return
   386  }