github.hscsec.cn/amacneil/dbmate@v1.4.1/pkg/dbmate/mysql.go (about) 1 package dbmate 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "net/url" 8 "strings" 9 10 _ "github.com/go-sql-driver/mysql" // mysql driver for database/sql 11 ) 12 13 func init() { 14 RegisterDriver(MySQLDriver{}, "mysql") 15 } 16 17 // MySQLDriver provides top level database functions 18 type MySQLDriver struct { 19 } 20 21 func normalizeMySQLURL(u *url.URL) string { 22 normalizedURL := *u 23 normalizedURL.Scheme = "" 24 25 // set default port 26 if normalizedURL.Port() == "" { 27 normalizedURL.Host = fmt.Sprintf("%s:3306", normalizedURL.Host) 28 } 29 30 // host format required by go-sql-driver/mysql 31 normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host) 32 33 query := normalizedURL.Query() 34 query.Set("multiStatements", "true") 35 normalizedURL.RawQuery = query.Encode() 36 37 str := normalizedURL.String() 38 return strings.TrimLeft(str, "/") 39 } 40 41 // Open creates a new database connection 42 func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) { 43 return sql.Open("mysql", normalizeMySQLURL(u)) 44 } 45 46 func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { 47 // connect to no particular database 48 rootURL := *u 49 rootURL.Path = "/" 50 51 return drv.Open(&rootURL) 52 } 53 54 func mysqlQuoteIdentifier(str string) string { 55 str = strings.Replace(str, "`", "\\`", -1) 56 57 return fmt.Sprintf("`%s`", str) 58 } 59 60 // CreateDatabase creates the specified database 61 func (drv MySQLDriver) CreateDatabase(u *url.URL) error { 62 name := databaseName(u) 63 fmt.Printf("Creating: %s\n", name) 64 65 db, err := drv.openRootDB(u) 66 if err != nil { 67 return err 68 } 69 defer mustClose(db) 70 71 _, err = db.Exec(fmt.Sprintf("create database %s", 72 mysqlQuoteIdentifier(name))) 73 74 return err 75 } 76 77 // DropDatabase drops the specified database (if it exists) 78 func (drv MySQLDriver) DropDatabase(u *url.URL) error { 79 name := databaseName(u) 80 fmt.Printf("Dropping: %s\n", name) 81 82 db, err := drv.openRootDB(u) 83 if err != nil { 84 return err 85 } 86 defer mustClose(db) 87 88 _, err = db.Exec(fmt.Sprintf("drop database if exists %s", 89 mysqlQuoteIdentifier(name))) 90 91 return err 92 } 93 94 func mysqldumpArgs(u *url.URL) []string { 95 // generate CLI arguments 96 args := []string{"--opt", "--routines", "--no-data", 97 "--skip-dump-date", "--skip-add-drop-table"} 98 99 if hostname := u.Hostname(); hostname != "" { 100 args = append(args, "--host="+hostname) 101 } 102 if port := u.Port(); port != "" { 103 args = append(args, "--port="+port) 104 } 105 if username := u.User.Username(); username != "" { 106 args = append(args, "--user="+username) 107 } 108 // mysql recommends against using environment variables to supply password 109 // https://dev.mysql.com/doc/refman/5.7/en/password-security-user.html 110 if password, set := u.User.Password(); set { 111 args = append(args, "--password="+password) 112 } 113 114 // add database name 115 args = append(args, strings.TrimLeft(u.Path, "/")) 116 117 return args 118 } 119 120 func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) { 121 // load applied migrations 122 migrations, err := queryColumn(db, 123 "select quote(version) from schema_migrations order by version asc") 124 if err != nil { 125 return nil, err 126 } 127 128 // build schema_migrations table data 129 var buf bytes.Buffer 130 buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" + 131 "LOCK TABLES `schema_migrations` WRITE;\n") 132 133 if len(migrations) > 0 { 134 buf.WriteString("INSERT INTO `schema_migrations` (version) VALUES\n (" + 135 strings.Join(migrations, "),\n (") + 136 ");\n") 137 } 138 139 buf.WriteString("UNLOCK TABLES;\n") 140 141 return buf.Bytes(), nil 142 } 143 144 // DumpSchema returns the current database schema 145 func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { 146 schema, err := runCommand("mysqldump", mysqldumpArgs(u)...) 147 if err != nil { 148 return nil, err 149 } 150 151 migrations, err := mysqlSchemaMigrationsDump(db) 152 if err != nil { 153 return nil, err 154 } 155 156 schema = append(schema, migrations...) 157 return trimLeadingSQLComments(schema) 158 } 159 160 // DatabaseExists determines whether the database exists 161 func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { 162 name := databaseName(u) 163 164 db, err := drv.openRootDB(u) 165 if err != nil { 166 return false, err 167 } 168 defer mustClose(db) 169 170 exists := false 171 err = db.QueryRow("select true from information_schema.schemata "+ 172 "where schema_name = ?", name).Scan(&exists) 173 if err == sql.ErrNoRows { 174 return false, nil 175 } 176 177 return exists, err 178 } 179 180 // CreateMigrationsTable creates the schema_migrations table 181 func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error { 182 _, err := db.Exec("create table if not exists schema_migrations " + 183 "(version varchar(255) primary key)") 184 185 return err 186 } 187 188 // SelectMigrations returns a list of applied migrations 189 // with an optional limit (in descending order) 190 func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { 191 query := "select version from schema_migrations order by version desc" 192 if limit >= 0 { 193 query = fmt.Sprintf("%s limit %d", query, limit) 194 } 195 rows, err := db.Query(query) 196 if err != nil { 197 return nil, err 198 } 199 200 defer mustClose(rows) 201 202 migrations := map[string]bool{} 203 for rows.Next() { 204 var version string 205 if err := rows.Scan(&version); err != nil { 206 return nil, err 207 } 208 209 migrations[version] = true 210 } 211 212 return migrations, nil 213 } 214 215 // InsertMigration adds a new migration record 216 func (drv MySQLDriver) InsertMigration(db Transaction, version string) error { 217 _, err := db.Exec("insert into schema_migrations (version) values (?)", version) 218 219 return err 220 } 221 222 // DeleteMigration removes a migration record 223 func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error { 224 _, err := db.Exec("delete from schema_migrations where version = ?", version) 225 226 return err 227 } 228 229 // Ping verifies a connection to the database server. It does not verify whether the 230 // specified database exists. 231 func (drv MySQLDriver) Ping(u *url.URL) error { 232 db, err := drv.openRootDB(u) 233 if err != nil { 234 return err 235 } 236 defer mustClose(db) 237 238 return db.Ping() 239 }