gopkg.in/yuukihogo/migrate.v3@v3.0.0/database/mysql/mysql.go (about)

     1  package mysql
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"database/sql"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	nurl "net/url"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/go-sql-driver/mysql"
    15  	"github.com/mattes/migrate"
    16  	"github.com/mattes/migrate/database"
    17  )
    18  
    19  func init() {
    20  	database.Register("mysql", &Mysql{})
    21  }
    22  
    23  var DefaultMigrationsTable = "schema_migrations"
    24  
    25  var (
    26  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  	ErrNoDatabaseName = fmt.Errorf("no database name")
    29  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    30  )
    31  
    32  type Config struct {
    33  	MigrationsTable string
    34  	DatabaseName    string
    35  }
    36  
    37  type Mysql struct {
    38  	db       *sql.DB
    39  	isLocked bool
    40  
    41  	config *Config
    42  }
    43  
    44  // instance must have `multiStatements` set to true
    45  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	}
    49  
    50  	if err := instance.Ping(); err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	query := `SELECT DATABASE()`
    55  	var databaseName sql.NullString
    56  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    57  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    58  	}
    59  
    60  	if len(databaseName.String) == 0 {
    61  		return nil, ErrNoDatabaseName
    62  	}
    63  
    64  	config.DatabaseName = databaseName.String
    65  
    66  	if len(config.MigrationsTable) == 0 {
    67  		config.MigrationsTable = DefaultMigrationsTable
    68  	}
    69  
    70  	mx := &Mysql{
    71  		db:     instance,
    72  		config: config,
    73  	}
    74  
    75  	if err := mx.ensureVersionTable(); err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	return mx, nil
    80  }
    81  
    82  func (m *Mysql) Open(url string) (database.Driver, error) {
    83  	purl, err := nurl.Parse(url)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	purl.Query().Set("multiStatements", "true")
    89  
    90  	db, err := sql.Open("mysql", strings.Replace(
    91  		migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1))
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	migrationsTable := purl.Query().Get("x-migrations-table")
    97  	if len(migrationsTable) == 0 {
    98  		migrationsTable = DefaultMigrationsTable
    99  	}
   100  
   101  	// use custom TLS?
   102  	ctls := purl.Query().Get("tls")
   103  	if len(ctls) > 0 {
   104  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   105  			rootCertPool := x509.NewCertPool()
   106  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   107  			if err != nil {
   108  				return nil, err
   109  			}
   110  
   111  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   112  				return nil, ErrAppendPEM
   113  			}
   114  
   115  			certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"))
   116  			if err != nil {
   117  				return nil, err
   118  			}
   119  
   120  			insecureSkipVerify := false
   121  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   122  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   123  				if err != nil {
   124  					return nil, err
   125  				}
   126  				insecureSkipVerify = x
   127  			}
   128  
   129  			mysql.RegisterTLSConfig(ctls, &tls.Config{
   130  				RootCAs:            rootCertPool,
   131  				Certificates:       []tls.Certificate{certs},
   132  				InsecureSkipVerify: insecureSkipVerify,
   133  			})
   134  		}
   135  	}
   136  
   137  	mx, err := WithInstance(db, &Config{
   138  		DatabaseName:    purl.Path,
   139  		MigrationsTable: migrationsTable,
   140  	})
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	return mx, nil
   146  }
   147  
   148  func (m *Mysql) Close() error {
   149  	return m.db.Close()
   150  }
   151  
   152  func (m *Mysql) Lock() error {
   153  	if m.isLocked {
   154  		return database.ErrLocked
   155  	}
   156  
   157  	aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName)
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	query := "SELECT GET_LOCK(?, 1)"
   163  	var success bool
   164  	if err := m.db.QueryRow(query, aid).Scan(&success); err != nil {
   165  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   166  	}
   167  
   168  	if success {
   169  		m.isLocked = true
   170  		return nil
   171  	}
   172  
   173  	return database.ErrLocked
   174  }
   175  
   176  func (m *Mysql) Unlock() error {
   177  	if !m.isLocked {
   178  		return nil
   179  	}
   180  
   181  	aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName)
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	query := `SELECT RELEASE_LOCK(?)`
   187  	if _, err := m.db.Exec(query, aid); err != nil {
   188  		return &database.Error{OrigErr: err, Query: []byte(query)}
   189  	}
   190  
   191  	m.isLocked = false
   192  	return nil
   193  }
   194  
   195  func (m *Mysql) Run(migration io.Reader) error {
   196  	migr, err := ioutil.ReadAll(migration)
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	query := string(migr[:])
   202  	if _, err := m.db.Exec(query); err != nil {
   203  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  func (m *Mysql) SetVersion(version int, dirty bool) error {
   210  	tx, err := m.db.Begin()
   211  	if err != nil {
   212  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   213  	}
   214  
   215  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   216  	if _, err := m.db.Exec(query); err != nil {
   217  		return &database.Error{OrigErr: err, Query: []byte(query)}
   218  	}
   219  
   220  	if version >= 0 {
   221  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   222  		if _, err := m.db.Exec(query, version, dirty); err != nil {
   223  			tx.Rollback()
   224  			return &database.Error{OrigErr: err, Query: []byte(query)}
   225  		}
   226  	}
   227  
   228  	if err := tx.Commit(); err != nil {
   229  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  func (m *Mysql) Version() (version int, dirty bool, err error) {
   236  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   237  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   238  	switch {
   239  	case err == sql.ErrNoRows:
   240  		return database.NilVersion, false, nil
   241  
   242  	case err != nil:
   243  		if e, ok := err.(*mysql.MySQLError); ok {
   244  			if e.Number == 0 {
   245  				return database.NilVersion, false, nil
   246  			}
   247  		}
   248  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   249  
   250  	default:
   251  		return version, dirty, nil
   252  	}
   253  }
   254  
   255  func (m *Mysql) Drop() error {
   256  	// select all tables
   257  	query := `SHOW TABLES LIKE '%'`
   258  	tables, err := m.db.Query(query)
   259  	if err != nil {
   260  		return &database.Error{OrigErr: err, Query: []byte(query)}
   261  	}
   262  	defer tables.Close()
   263  
   264  	// delete one table after another
   265  	tableNames := make([]string, 0)
   266  	for tables.Next() {
   267  		var tableName string
   268  		if err := tables.Scan(&tableName); err != nil {
   269  			return err
   270  		}
   271  		if len(tableName) > 0 {
   272  			tableNames = append(tableNames, tableName)
   273  		}
   274  	}
   275  
   276  	if len(tableNames) > 0 {
   277  		// delete one by one ...
   278  		for _, t := range tableNames {
   279  			query = "DROP TABLE IF EXISTS `" + t + "` CASCADE"
   280  			if _, err := m.db.Exec(query); err != nil {
   281  				return &database.Error{OrigErr: err, Query: []byte(query)}
   282  			}
   283  		}
   284  		if err := m.ensureVersionTable(); err != nil {
   285  			return err
   286  		}
   287  	}
   288  
   289  	return nil
   290  }
   291  
   292  func (m *Mysql) ensureVersionTable() error {
   293  	// check if migration table exists
   294  	var result string
   295  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   296  	if err := m.db.QueryRow(query).Scan(&result); err != nil {
   297  		if err != sql.ErrNoRows {
   298  			return &database.Error{OrigErr: err, Query: []byte(query)}
   299  		}
   300  	} else {
   301  		return nil
   302  	}
   303  
   304  	// if not, create the empty migration table
   305  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   306  	if _, err := m.db.Exec(query); err != nil {
   307  		return &database.Error{OrigErr: err, Query: []byte(query)}
   308  	}
   309  	return nil
   310  }
   311  
   312  // Returns the bool value of the input.
   313  // The 2nd return value indicates if the input was a valid bool value
   314  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   315  func readBool(input string) (value bool, valid bool) {
   316  	switch input {
   317  	case "1", "true", "TRUE", "True":
   318  		return true, true
   319  	case "0", "false", "FALSE", "False":
   320  		return false, true
   321  	}
   322  
   323  	// Not a valid bool value
   324  	return
   325  }