github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/mysql/mysql.go (about)

     1  package mysql
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"io"
     8  	"net/url"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"github.com/amacneil/dbmate/pkg/dbmate"
    13  	"github.com/amacneil/dbmate/pkg/dbutil"
    14  
    15  	_ "github.com/go-sql-driver/mysql" // database/sql driver
    16  )
    17  
    18  func init() {
    19  	dbmate.RegisterDriver(NewDriver, "mysql")
    20  }
    21  
    22  // Driver provides top level database functions
    23  type Driver struct {
    24  	migrationsTableName string
    25  	databaseURL         *url.URL
    26  	log                 io.Writer
    27  }
    28  
    29  // NewDriver initializes the driver
    30  func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
    31  	return &Driver{
    32  		migrationsTableName: config.MigrationsTableName,
    33  		databaseURL:         config.DatabaseURL,
    34  		log:                 config.Log,
    35  	}
    36  }
    37  
    38  func connectionString(u *url.URL) string {
    39  	query := u.Query()
    40  	query.Set("multiStatements", "true")
    41  
    42  	host := u.Host
    43  	protocol := "tcp"
    44  
    45  	if query.Get("socket") != "" {
    46  		protocol = "unix"
    47  		host = query.Get("socket")
    48  		query.Del("socket")
    49  	} else if u.Port() == "" {
    50  		// set default port
    51  		host = fmt.Sprintf("%s:3306", host)
    52  	}
    53  
    54  	// Get decoded user:pass
    55  	userPassEncoded := u.User.String()
    56  	userPass, _ := url.PathUnescape(userPassEncoded)
    57  
    58  	// Build DSN w/ user:pass percent-decoded
    59  	normalizedString := ""
    60  
    61  	if userPass != "" { // user:pass can be empty
    62  		normalizedString = userPass + "@"
    63  	}
    64  
    65  	// connection string format required by go-sql-driver/mysql
    66  	normalizedString = fmt.Sprintf("%s%s(%s)%s?%s", normalizedString,
    67  		protocol, host, u.Path, query.Encode())
    68  
    69  	return normalizedString
    70  }
    71  
    72  // Open creates a new database connection
    73  func (drv *Driver) Open() (*sql.DB, error) {
    74  	return sql.Open("mysql", connectionString(drv.databaseURL))
    75  }
    76  
    77  func (drv *Driver) openRootDB() (*sql.DB, error) {
    78  	// clone databaseURL
    79  	rootURL, err := url.Parse(drv.databaseURL.String())
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	// connect to no particular database
    85  	rootURL.Path = "/"
    86  
    87  	return sql.Open("mysql", connectionString(rootURL))
    88  }
    89  
    90  func (drv *Driver) quoteIdentifier(str string) string {
    91  	str = strings.Replace(str, "`", "\\`", -1)
    92  
    93  	return fmt.Sprintf("`%s`", str)
    94  }
    95  
    96  // CreateDatabase creates the specified database
    97  func (drv *Driver) CreateDatabase() error {
    98  	name := dbutil.DatabaseName(drv.databaseURL)
    99  	fmt.Fprintf(drv.log, "Creating: %s\n", name)
   100  
   101  	db, err := drv.openRootDB()
   102  	if err != nil {
   103  		return err
   104  	}
   105  	defer dbutil.MustClose(db)
   106  
   107  	_, err = db.Exec(fmt.Sprintf("create database %s",
   108  		drv.quoteIdentifier(name)))
   109  
   110  	return err
   111  }
   112  
   113  // DropDatabase drops the specified database (if it exists)
   114  func (drv *Driver) DropDatabase() error {
   115  	name := dbutil.DatabaseName(drv.databaseURL)
   116  	fmt.Fprintf(drv.log, "Dropping: %s\n", name)
   117  
   118  	db, err := drv.openRootDB()
   119  	if err != nil {
   120  		return err
   121  	}
   122  	defer dbutil.MustClose(db)
   123  
   124  	_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
   125  		drv.quoteIdentifier(name)))
   126  
   127  	return err
   128  }
   129  
   130  func (drv *Driver) mysqldumpArgs() []string {
   131  	// generate CLI arguments
   132  	args := []string{"--opt", "--routines", "--no-data",
   133  		"--skip-dump-date", "--skip-add-drop-table"}
   134  
   135  	socket := drv.databaseURL.Query().Get("socket")
   136  	if socket != "" {
   137  		args = append(args, "--socket="+socket)
   138  	} else {
   139  		if hostname := drv.databaseURL.Hostname(); hostname != "" {
   140  			args = append(args, "--host="+hostname)
   141  		}
   142  		if port := drv.databaseURL.Port(); port != "" {
   143  			args = append(args, "--port="+port)
   144  		}
   145  	}
   146  
   147  	if username := drv.databaseURL.User.Username(); username != "" {
   148  		args = append(args, "--user="+username)
   149  	}
   150  	if password, set := drv.databaseURL.User.Password(); set {
   151  		args = append(args, "--password="+password)
   152  	}
   153  
   154  	// add database name
   155  	args = append(args, dbutil.DatabaseName(drv.databaseURL))
   156  
   157  	return args
   158  }
   159  
   160  func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
   161  	migrationsTable := drv.quotedMigrationsTableName()
   162  
   163  	// load applied migrations
   164  	migrations, err := dbutil.QueryColumn(db,
   165  		fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	// build schema_migrations table data
   171  	var buf bytes.Buffer
   172  	buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" +
   173  		fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable))
   174  
   175  	if len(migrations) > 0 {
   176  		buf.WriteString(
   177  			fmt.Sprintf("INSERT INTO %s (version) VALUES\n  (", migrationsTable) +
   178  				strings.Join(migrations, "),\n  (") +
   179  				");\n")
   180  	}
   181  
   182  	buf.WriteString("UNLOCK TABLES;\n")
   183  
   184  	return buf.Bytes(), nil
   185  }
   186  
   187  // DumpSchema returns the current database schema
   188  func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
   189  	schema, err := dbutil.RunCommand("mysqldump", drv.mysqldumpArgs()...)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	migrations, err := drv.schemaMigrationsDump(db)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	schema = append(schema, migrations...)
   200  	schema, err = dbutil.TrimLeadingSQLComments(schema)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  	return trimAutoincrementValues(schema), nil
   205  }
   206  
   207  // trimAutoincrementValues removes AUTO_INCREMENT values from MySQL schema dumps
   208  func trimAutoincrementValues(data []byte) []byte {
   209  	aiPattern := regexp.MustCompile(" AUTO_INCREMENT=[0-9]*")
   210  	return aiPattern.ReplaceAll(data, []byte(""))
   211  }
   212  
   213  // DatabaseExists determines whether the database exists
   214  func (drv *Driver) DatabaseExists() (bool, error) {
   215  	name := dbutil.DatabaseName(drv.databaseURL)
   216  
   217  	db, err := drv.openRootDB()
   218  	if err != nil {
   219  		return false, err
   220  	}
   221  	defer dbutil.MustClose(db)
   222  
   223  	exists := false
   224  	err = db.QueryRow("select true from information_schema.schemata "+
   225  		"where schema_name = ?", name).Scan(&exists)
   226  	if err == sql.ErrNoRows {
   227  		return false, nil
   228  	}
   229  
   230  	return exists, err
   231  }
   232  
   233  // MigrationsTableExists checks if the schema_migrations table exists
   234  func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) {
   235  	match := ""
   236  	err := db.QueryRow(fmt.Sprintf("SHOW TABLES LIKE \"%s\"",
   237  		drv.migrationsTableName)).
   238  		Scan(&match)
   239  	if err == sql.ErrNoRows {
   240  		return false, nil
   241  	}
   242  
   243  	return match != "", err
   244  }
   245  
   246  // CreateMigrationsTable creates the schema_migrations table
   247  func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
   248  	_, err := db.Exec(fmt.Sprintf(
   249  		"create table if not exists %s (version varchar(128) primary key)",
   250  		drv.quotedMigrationsTableName()))
   251  
   252  	return err
   253  }
   254  
   255  // SelectMigrations returns a list of applied migrations
   256  // with an optional limit (in descending order)
   257  func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
   258  	query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
   259  	if limit >= 0 {
   260  		query = fmt.Sprintf("%s limit %d", query, limit)
   261  	}
   262  	rows, err := db.Query(query)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	defer dbutil.MustClose(rows)
   268  
   269  	migrations := map[string]bool{}
   270  	for rows.Next() {
   271  		var version string
   272  		if err := rows.Scan(&version); err != nil {
   273  			return nil, err
   274  		}
   275  
   276  		migrations[version] = true
   277  	}
   278  
   279  	if err = rows.Err(); err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	return migrations, nil
   284  }
   285  
   286  // InsertMigration adds a new migration record
   287  func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
   288  	_, err := db.Exec(
   289  		fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
   290  		version)
   291  
   292  	return err
   293  }
   294  
   295  // DeleteMigration removes a migration record
   296  func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
   297  	_, err := db.Exec(
   298  		fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
   299  		version)
   300  
   301  	return err
   302  }
   303  
   304  // Ping verifies a connection to the database server. It does not verify whether the
   305  // specified database exists.
   306  func (drv *Driver) Ping() error {
   307  	db, err := drv.openRootDB()
   308  	if err != nil {
   309  		return err
   310  	}
   311  	defer dbutil.MustClose(db)
   312  
   313  	return db.Ping()
   314  }
   315  
   316  func (drv *Driver) quotedMigrationsTableName() string {
   317  	return drv.quoteIdentifier(drv.migrationsTableName)
   318  }