github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/driver/mysql/mysql.go (about) 1 package mysql 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "io" 8 "net/url" 9 "regexp" 10 "strings" 11 12 "github.com/amacneil/dbmate/pkg/dbmate" 13 "github.com/amacneil/dbmate/pkg/dbutil" 14 15 _ "github.com/go-sql-driver/mysql" // database/sql driver 16 ) 17 18 func init() { 19 dbmate.RegisterDriver(NewDriver, "mysql") 20 } 21 22 // Driver provides top level database functions 23 type Driver struct { 24 migrationsTableName string 25 databaseURL *url.URL 26 log io.Writer 27 } 28 29 // NewDriver initializes the driver 30 func NewDriver(config dbmate.DriverConfig) dbmate.Driver { 31 return &Driver{ 32 migrationsTableName: config.MigrationsTableName, 33 databaseURL: config.DatabaseURL, 34 log: config.Log, 35 } 36 } 37 38 func connectionString(u *url.URL) string { 39 query := u.Query() 40 query.Set("multiStatements", "true") 41 42 host := u.Host 43 protocol := "tcp" 44 45 if query.Get("socket") != "" { 46 protocol = "unix" 47 host = query.Get("socket") 48 query.Del("socket") 49 } else if u.Port() == "" { 50 // set default port 51 host = fmt.Sprintf("%s:3306", host) 52 } 53 54 // Get decoded user:pass 55 userPassEncoded := u.User.String() 56 userPass, _ := url.PathUnescape(userPassEncoded) 57 58 // Build DSN w/ user:pass percent-decoded 59 normalizedString := "" 60 61 if userPass != "" { // user:pass can be empty 62 normalizedString = userPass + "@" 63 } 64 65 // connection string format required by go-sql-driver/mysql 66 normalizedString = fmt.Sprintf("%s%s(%s)%s?%s", normalizedString, 67 protocol, host, u.Path, query.Encode()) 68 69 return normalizedString 70 } 71 72 // Open creates a new database connection 73 func (drv *Driver) Open() (*sql.DB, error) { 74 return sql.Open("mysql", connectionString(drv.databaseURL)) 75 } 76 77 func (drv *Driver) openRootDB() (*sql.DB, error) { 78 // clone databaseURL 79 rootURL, err := url.Parse(drv.databaseURL.String()) 80 if err != nil { 81 return nil, err 82 } 83 84 // connect to no particular database 85 rootURL.Path = "/" 86 87 return sql.Open("mysql", connectionString(rootURL)) 88 } 89 90 func (drv *Driver) quoteIdentifier(str string) string { 91 str = strings.Replace(str, "`", "\\`", -1) 92 93 return fmt.Sprintf("`%s`", str) 94 } 95 96 // CreateDatabase creates the specified database 97 func (drv *Driver) CreateDatabase() error { 98 name := dbutil.DatabaseName(drv.databaseURL) 99 fmt.Fprintf(drv.log, "Creating: %s\n", name) 100 101 db, err := drv.openRootDB() 102 if err != nil { 103 return err 104 } 105 defer dbutil.MustClose(db) 106 107 _, err = db.Exec(fmt.Sprintf("create database %s", 108 drv.quoteIdentifier(name))) 109 110 return err 111 } 112 113 // DropDatabase drops the specified database (if it exists) 114 func (drv *Driver) DropDatabase() error { 115 name := dbutil.DatabaseName(drv.databaseURL) 116 fmt.Fprintf(drv.log, "Dropping: %s\n", name) 117 118 db, err := drv.openRootDB() 119 if err != nil { 120 return err 121 } 122 defer dbutil.MustClose(db) 123 124 _, err = db.Exec(fmt.Sprintf("drop database if exists %s", 125 drv.quoteIdentifier(name))) 126 127 return err 128 } 129 130 func (drv *Driver) mysqldumpArgs() []string { 131 // generate CLI arguments 132 args := []string{"--opt", "--routines", "--no-data", 133 "--skip-dump-date", "--skip-add-drop-table"} 134 135 socket := drv.databaseURL.Query().Get("socket") 136 if socket != "" { 137 args = append(args, "--socket="+socket) 138 } else { 139 if hostname := drv.databaseURL.Hostname(); hostname != "" { 140 args = append(args, "--host="+hostname) 141 } 142 if port := drv.databaseURL.Port(); port != "" { 143 args = append(args, "--port="+port) 144 } 145 } 146 147 if username := drv.databaseURL.User.Username(); username != "" { 148 args = append(args, "--user="+username) 149 } 150 if password, set := drv.databaseURL.User.Password(); set { 151 args = append(args, "--password="+password) 152 } 153 154 // add database name 155 args = append(args, dbutil.DatabaseName(drv.databaseURL)) 156 157 return args 158 } 159 160 func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) { 161 migrationsTable := drv.quotedMigrationsTableName() 162 163 // load applied migrations 164 migrations, err := dbutil.QueryColumn(db, 165 fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable)) 166 if err != nil { 167 return nil, err 168 } 169 170 // build schema_migrations table data 171 var buf bytes.Buffer 172 buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" + 173 fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable)) 174 175 if len(migrations) > 0 { 176 buf.WriteString( 177 fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) + 178 strings.Join(migrations, "),\n (") + 179 ");\n") 180 } 181 182 buf.WriteString("UNLOCK TABLES;\n") 183 184 return buf.Bytes(), nil 185 } 186 187 // DumpSchema returns the current database schema 188 func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) { 189 schema, err := dbutil.RunCommand("mysqldump", drv.mysqldumpArgs()...) 190 if err != nil { 191 return nil, err 192 } 193 194 migrations, err := drv.schemaMigrationsDump(db) 195 if err != nil { 196 return nil, err 197 } 198 199 schema = append(schema, migrations...) 200 schema, err = dbutil.TrimLeadingSQLComments(schema) 201 if err != nil { 202 return nil, err 203 } 204 return trimAutoincrementValues(schema), nil 205 } 206 207 // trimAutoincrementValues removes AUTO_INCREMENT values from MySQL schema dumps 208 func trimAutoincrementValues(data []byte) []byte { 209 aiPattern := regexp.MustCompile(" AUTO_INCREMENT=[0-9]*") 210 return aiPattern.ReplaceAll(data, []byte("")) 211 } 212 213 // DatabaseExists determines whether the database exists 214 func (drv *Driver) DatabaseExists() (bool, error) { 215 name := dbutil.DatabaseName(drv.databaseURL) 216 217 db, err := drv.openRootDB() 218 if err != nil { 219 return false, err 220 } 221 defer dbutil.MustClose(db) 222 223 exists := false 224 err = db.QueryRow("select true from information_schema.schemata "+ 225 "where schema_name = ?", name).Scan(&exists) 226 if err == sql.ErrNoRows { 227 return false, nil 228 } 229 230 return exists, err 231 } 232 233 // MigrationsTableExists checks if the schema_migrations table exists 234 func (drv *Driver) MigrationsTableExists(db *sql.DB) (bool, error) { 235 match := "" 236 err := db.QueryRow(fmt.Sprintf("SHOW TABLES LIKE \"%s\"", 237 drv.migrationsTableName)). 238 Scan(&match) 239 if err == sql.ErrNoRows { 240 return false, nil 241 } 242 243 return match != "", err 244 } 245 246 // CreateMigrationsTable creates the schema_migrations table 247 func (drv *Driver) CreateMigrationsTable(db *sql.DB) error { 248 _, err := db.Exec(fmt.Sprintf( 249 "create table if not exists %s (version varchar(128) primary key)", 250 drv.quotedMigrationsTableName())) 251 252 return err 253 } 254 255 // SelectMigrations returns a list of applied migrations 256 // with an optional limit (in descending order) 257 func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { 258 query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName()) 259 if limit >= 0 { 260 query = fmt.Sprintf("%s limit %d", query, limit) 261 } 262 rows, err := db.Query(query) 263 if err != nil { 264 return nil, err 265 } 266 267 defer dbutil.MustClose(rows) 268 269 migrations := map[string]bool{} 270 for rows.Next() { 271 var version string 272 if err := rows.Scan(&version); err != nil { 273 return nil, err 274 } 275 276 migrations[version] = true 277 } 278 279 if err = rows.Err(); err != nil { 280 return nil, err 281 } 282 283 return migrations, nil 284 } 285 286 // InsertMigration adds a new migration record 287 func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error { 288 _, err := db.Exec( 289 fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()), 290 version) 291 292 return err 293 } 294 295 // DeleteMigration removes a migration record 296 func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error { 297 _, err := db.Exec( 298 fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()), 299 version) 300 301 return err 302 } 303 304 // Ping verifies a connection to the database server. It does not verify whether the 305 // specified database exists. 306 func (drv *Driver) Ping() error { 307 db, err := drv.openRootDB() 308 if err != nil { 309 return err 310 } 311 defer dbutil.MustClose(db) 312 313 return db.Ping() 314 } 315 316 func (drv *Driver) quotedMigrationsTableName() string { 317 return drv.quoteIdentifier(drv.migrationsTableName) 318 }