github.com/eatigo/migrate@v3.0.2-0.20210729130915-7610befb1b6b+incompatible/database/mysql/mysql.go (about)

     1  package mysql
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"database/sql"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	nurl "net/url"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/go-sql-driver/mysql"
    15  	"github.com/eatigo/migrate"
    16  	"github.com/eatigo/migrate/database"
    17  )
    18  
    19  func init() {
    20  	database.Register("mysql", &Mysql{})
    21  }
    22  
    23  var DefaultMigrationsTable = "schema_migrations"
    24  
    25  var (
    26  	ErrDatabaseDirty  = fmt.Errorf("database is dirty")
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  	ErrNoDatabaseName = fmt.Errorf("no database name")
    29  	ErrAppendPEM      = fmt.Errorf("failed to append PEM")
    30  )
    31  
    32  type Config struct {
    33  	MigrationsTable string
    34  	DatabaseName    string
    35  }
    36  
    37  type Mysql struct {
    38  	db       *sql.DB
    39  	isLocked bool
    40  
    41  	config *Config
    42  }
    43  
    44  // instance must have `multiStatements` set to true
    45  func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	}
    49  
    50  	if err := instance.Ping(); err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	query := `SELECT DATABASE()`
    55  	var databaseName sql.NullString
    56  	if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
    57  		return nil, &database.Error{OrigErr: err, Query: []byte(query)}
    58  	}
    59  
    60  	if len(databaseName.String) == 0 {
    61  		return nil, ErrNoDatabaseName
    62  	}
    63  
    64  	config.DatabaseName = databaseName.String
    65  
    66  	if len(config.MigrationsTable) == 0 {
    67  		config.MigrationsTable = DefaultMigrationsTable
    68  	}
    69  
    70  	mx := &Mysql{
    71  		db:     instance,
    72  		config: config,
    73  	}
    74  
    75  	if err := mx.ensureVersionTable(); err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	return mx, nil
    80  }
    81  
    82  func (m *Mysql) Open(url string) (database.Driver, error) {
    83  	purl, err := nurl.Parse(url)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	q := purl.Query()
    89  	q.Set("multiStatements", "true")
    90  	purl.RawQuery = q.Encode()
    91  
    92  	db, err := sql.Open("mysql", strings.Replace(
    93  		migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1))
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	migrationsTable := purl.Query().Get("x-migrations-table")
    99  	if len(migrationsTable) == 0 {
   100  		migrationsTable = DefaultMigrationsTable
   101  	}
   102  
   103  	// use custom TLS?
   104  	ctls := purl.Query().Get("tls")
   105  	if len(ctls) > 0 {
   106  		if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
   107  			rootCertPool := x509.NewCertPool()
   108  			pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
   109  			if err != nil {
   110  				return nil, err
   111  			}
   112  
   113  			if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
   114  				return nil, ErrAppendPEM
   115  			}
   116  
   117  			certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"))
   118  			if err != nil {
   119  				return nil, err
   120  			}
   121  
   122  			insecureSkipVerify := false
   123  			if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
   124  				x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
   125  				if err != nil {
   126  					return nil, err
   127  				}
   128  				insecureSkipVerify = x
   129  			}
   130  
   131  			mysql.RegisterTLSConfig(ctls, &tls.Config{
   132  				RootCAs:            rootCertPool,
   133  				Certificates:       []tls.Certificate{certs},
   134  				InsecureSkipVerify: insecureSkipVerify,
   135  			})
   136  		}
   137  	}
   138  
   139  	mx, err := WithInstance(db, &Config{
   140  		DatabaseName:    purl.Path,
   141  		MigrationsTable: migrationsTable,
   142  	})
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	return mx, nil
   148  }
   149  
   150  func (m *Mysql) Close() error {
   151  	return m.db.Close()
   152  }
   153  
   154  func (m *Mysql) Lock() error {
   155  	if m.isLocked {
   156  		return database.ErrLocked
   157  	}
   158  
   159  	aid, err := database.GenerateAdvisoryLockId(
   160  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	query := "SELECT GET_LOCK(?, 1)"
   166  	var success bool
   167  	if err := m.db.QueryRow(query, aid).Scan(&success); err != nil {
   168  		return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
   169  	}
   170  
   171  	if success {
   172  		m.isLocked = true
   173  		return nil
   174  	}
   175  
   176  	return database.ErrLocked
   177  }
   178  
   179  func (m *Mysql) Unlock() error {
   180  	if !m.isLocked {
   181  		return nil
   182  	}
   183  
   184  	aid, err := database.GenerateAdvisoryLockId(
   185  		fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	query := `SELECT RELEASE_LOCK(?)`
   191  	if _, err := m.db.Exec(query, aid); err != nil {
   192  		return &database.Error{OrigErr: err, Query: []byte(query)}
   193  	}
   194  
   195  	m.isLocked = false
   196  	return nil
   197  }
   198  
   199  func (m *Mysql) Run(migration io.Reader) error {
   200  	migr, err := ioutil.ReadAll(migration)
   201  	if err != nil {
   202  		return err
   203  	}
   204  
   205  	query := string(migr[:])
   206  	if _, err := m.db.Exec(query); err != nil {
   207  		return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  func (m *Mysql) SetVersion(version int, dirty bool) error {
   214  	tx, err := m.db.Begin()
   215  	if err != nil {
   216  		return &database.Error{OrigErr: err, Err: "transaction start failed"}
   217  	}
   218  
   219  	query := "TRUNCATE `" + m.config.MigrationsTable + "`"
   220  	if _, err := m.db.Exec(query); err != nil {
   221  		return &database.Error{OrigErr: err, Query: []byte(query)}
   222  	}
   223  
   224  	if version >= 0 {
   225  		query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
   226  		if _, err := m.db.Exec(query, version, dirty); err != nil {
   227  			tx.Rollback()
   228  			return &database.Error{OrigErr: err, Query: []byte(query)}
   229  		}
   230  	}
   231  
   232  	if err := tx.Commit(); err != nil {
   233  		return &database.Error{OrigErr: err, Err: "transaction commit failed"}
   234  	}
   235  
   236  	return nil
   237  }
   238  
   239  func (m *Mysql) Version() (version int, dirty bool, err error) {
   240  	query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
   241  	err = m.db.QueryRow(query).Scan(&version, &dirty)
   242  	switch {
   243  	case err == sql.ErrNoRows:
   244  		return database.NilVersion, false, nil
   245  
   246  	case err != nil:
   247  		if e, ok := err.(*mysql.MySQLError); ok {
   248  			if e.Number == 0 {
   249  				return database.NilVersion, false, nil
   250  			}
   251  		}
   252  		return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
   253  
   254  	default:
   255  		return version, dirty, nil
   256  	}
   257  }
   258  
   259  func (m *Mysql) Drop() error {
   260  	// select all tables
   261  	query := `SHOW TABLES LIKE '%'`
   262  	tables, err := m.db.Query(query)
   263  	if err != nil {
   264  		return &database.Error{OrigErr: err, Query: []byte(query)}
   265  	}
   266  	defer tables.Close()
   267  
   268  	// delete one table after another
   269  	tableNames := make([]string, 0)
   270  	for tables.Next() {
   271  		var tableName string
   272  		if err := tables.Scan(&tableName); err != nil {
   273  			return err
   274  		}
   275  		if len(tableName) > 0 {
   276  			tableNames = append(tableNames, tableName)
   277  		}
   278  	}
   279  
   280  	if len(tableNames) > 0 {
   281  		// delete one by one ...
   282  		for _, t := range tableNames {
   283  			query = "SET FOREIGN_KEY_CHECKS=0; DROP TABLE IF EXISTS `" + t + "` CASCADE; SET FOREIGN_KEY_CHECKS=1;"
   284  			if _, err := m.db.Exec(query); err != nil {
   285  				return &database.Error{OrigErr: err, Query: []byte(query)}
   286  			}
   287  		}
   288  		if err := m.ensureVersionTable(); err != nil {
   289  			return err
   290  		}
   291  	}
   292  
   293  	return nil
   294  }
   295  
   296  func (m *Mysql) ensureVersionTable() error {
   297  	// check if migration table exists
   298  	var result string
   299  	query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"`
   300  	if err := m.db.QueryRow(query).Scan(&result); err != nil {
   301  		if err != sql.ErrNoRows {
   302  			return &database.Error{OrigErr: err, Query: []byte(query)}
   303  		}
   304  	} else {
   305  		return nil
   306  	}
   307  
   308  	// if not, create the empty migration table
   309  	query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
   310  	if _, err := m.db.Exec(query); err != nil {
   311  		return &database.Error{OrigErr: err, Query: []byte(query)}
   312  	}
   313  	return nil
   314  }
   315  
   316  // Returns the bool value of the input.
   317  // The 2nd return value indicates if the input was a valid bool value
   318  // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
   319  func readBool(input string) (value bool, valid bool) {
   320  	switch input {
   321  	case "1", "true", "TRUE", "True":
   322  		return true, true
   323  	case "0", "false", "FALSE", "False":
   324  		return false, true
   325  	}
   326  
   327  	// Not a valid bool value
   328  	return
   329  }