gopkg.in/yuukihogo/migrate.v3@v3.0.0/database/mysql/mysql.go (about) 1 package mysql 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "database/sql" 7 "fmt" 8 "io" 9 "io/ioutil" 10 nurl "net/url" 11 "strconv" 12 "strings" 13 14 "github.com/go-sql-driver/mysql" 15 "github.com/mattes/migrate" 16 "github.com/mattes/migrate/database" 17 ) 18 19 func init() { 20 database.Register("mysql", &Mysql{}) 21 } 22 23 var DefaultMigrationsTable = "schema_migrations" 24 25 var ( 26 ErrDatabaseDirty = fmt.Errorf("database is dirty") 27 ErrNilConfig = fmt.Errorf("no config") 28 ErrNoDatabaseName = fmt.Errorf("no database name") 29 ErrAppendPEM = fmt.Errorf("failed to append PEM") 30 ) 31 32 type Config struct { 33 MigrationsTable string 34 DatabaseName string 35 } 36 37 type Mysql struct { 38 db *sql.DB 39 isLocked bool 40 41 config *Config 42 } 43 44 // instance must have `multiStatements` set to true 45 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } 49 50 if err := instance.Ping(); err != nil { 51 return nil, err 52 } 53 54 query := `SELECT DATABASE()` 55 var databaseName sql.NullString 56 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 57 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 58 } 59 60 if len(databaseName.String) == 0 { 61 return nil, ErrNoDatabaseName 62 } 63 64 config.DatabaseName = databaseName.String 65 66 if len(config.MigrationsTable) == 0 { 67 config.MigrationsTable = DefaultMigrationsTable 68 } 69 70 mx := &Mysql{ 71 db: instance, 72 config: config, 73 } 74 75 if err := mx.ensureVersionTable(); err != nil { 76 return nil, err 77 } 78 79 return mx, nil 80 } 81 82 func (m *Mysql) Open(url string) (database.Driver, error) { 83 purl, err := nurl.Parse(url) 84 if err != nil { 85 return nil, err 86 } 87 88 purl.Query().Set("multiStatements", "true") 89 90 db, err := sql.Open("mysql", strings.Replace( 91 migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1)) 92 if err != nil { 93 return nil, err 94 } 95 96 migrationsTable := purl.Query().Get("x-migrations-table") 97 if len(migrationsTable) == 0 { 98 migrationsTable = DefaultMigrationsTable 99 } 100 101 // use custom TLS? 102 ctls := purl.Query().Get("tls") 103 if len(ctls) > 0 { 104 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 105 rootCertPool := x509.NewCertPool() 106 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 107 if err != nil { 108 return nil, err 109 } 110 111 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 112 return nil, ErrAppendPEM 113 } 114 115 certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key")) 116 if err != nil { 117 return nil, err 118 } 119 120 insecureSkipVerify := false 121 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 122 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 123 if err != nil { 124 return nil, err 125 } 126 insecureSkipVerify = x 127 } 128 129 mysql.RegisterTLSConfig(ctls, &tls.Config{ 130 RootCAs: rootCertPool, 131 Certificates: []tls.Certificate{certs}, 132 InsecureSkipVerify: insecureSkipVerify, 133 }) 134 } 135 } 136 137 mx, err := WithInstance(db, &Config{ 138 DatabaseName: purl.Path, 139 MigrationsTable: migrationsTable, 140 }) 141 if err != nil { 142 return nil, err 143 } 144 145 return mx, nil 146 } 147 148 func (m *Mysql) Close() error { 149 return m.db.Close() 150 } 151 152 func (m *Mysql) Lock() error { 153 if m.isLocked { 154 return database.ErrLocked 155 } 156 157 aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName) 158 if err != nil { 159 return err 160 } 161 162 query := "SELECT GET_LOCK(?, 1)" 163 var success bool 164 if err := m.db.QueryRow(query, aid).Scan(&success); err != nil { 165 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 166 } 167 168 if success { 169 m.isLocked = true 170 return nil 171 } 172 173 return database.ErrLocked 174 } 175 176 func (m *Mysql) Unlock() error { 177 if !m.isLocked { 178 return nil 179 } 180 181 aid, err := database.GenerateAdvisoryLockId(m.config.DatabaseName) 182 if err != nil { 183 return err 184 } 185 186 query := `SELECT RELEASE_LOCK(?)` 187 if _, err := m.db.Exec(query, aid); err != nil { 188 return &database.Error{OrigErr: err, Query: []byte(query)} 189 } 190 191 m.isLocked = false 192 return nil 193 } 194 195 func (m *Mysql) Run(migration io.Reader) error { 196 migr, err := ioutil.ReadAll(migration) 197 if err != nil { 198 return err 199 } 200 201 query := string(migr[:]) 202 if _, err := m.db.Exec(query); err != nil { 203 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 204 } 205 206 return nil 207 } 208 209 func (m *Mysql) SetVersion(version int, dirty bool) error { 210 tx, err := m.db.Begin() 211 if err != nil { 212 return &database.Error{OrigErr: err, Err: "transaction start failed"} 213 } 214 215 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 216 if _, err := m.db.Exec(query); err != nil { 217 return &database.Error{OrigErr: err, Query: []byte(query)} 218 } 219 220 if version >= 0 { 221 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 222 if _, err := m.db.Exec(query, version, dirty); err != nil { 223 tx.Rollback() 224 return &database.Error{OrigErr: err, Query: []byte(query)} 225 } 226 } 227 228 if err := tx.Commit(); err != nil { 229 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 230 } 231 232 return nil 233 } 234 235 func (m *Mysql) Version() (version int, dirty bool, err error) { 236 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 237 err = m.db.QueryRow(query).Scan(&version, &dirty) 238 switch { 239 case err == sql.ErrNoRows: 240 return database.NilVersion, false, nil 241 242 case err != nil: 243 if e, ok := err.(*mysql.MySQLError); ok { 244 if e.Number == 0 { 245 return database.NilVersion, false, nil 246 } 247 } 248 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 249 250 default: 251 return version, dirty, nil 252 } 253 } 254 255 func (m *Mysql) Drop() error { 256 // select all tables 257 query := `SHOW TABLES LIKE '%'` 258 tables, err := m.db.Query(query) 259 if err != nil { 260 return &database.Error{OrigErr: err, Query: []byte(query)} 261 } 262 defer tables.Close() 263 264 // delete one table after another 265 tableNames := make([]string, 0) 266 for tables.Next() { 267 var tableName string 268 if err := tables.Scan(&tableName); err != nil { 269 return err 270 } 271 if len(tableName) > 0 { 272 tableNames = append(tableNames, tableName) 273 } 274 } 275 276 if len(tableNames) > 0 { 277 // delete one by one ... 278 for _, t := range tableNames { 279 query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" 280 if _, err := m.db.Exec(query); err != nil { 281 return &database.Error{OrigErr: err, Query: []byte(query)} 282 } 283 } 284 if err := m.ensureVersionTable(); err != nil { 285 return err 286 } 287 } 288 289 return nil 290 } 291 292 func (m *Mysql) ensureVersionTable() error { 293 // check if migration table exists 294 var result string 295 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 296 if err := m.db.QueryRow(query).Scan(&result); err != nil { 297 if err != sql.ErrNoRows { 298 return &database.Error{OrigErr: err, Query: []byte(query)} 299 } 300 } else { 301 return nil 302 } 303 304 // if not, create the empty migration table 305 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 306 if _, err := m.db.Exec(query); err != nil { 307 return &database.Error{OrigErr: err, Query: []byte(query)} 308 } 309 return nil 310 } 311 312 // Returns the bool value of the input. 313 // The 2nd return value indicates if the input was a valid bool value 314 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 315 func readBool(input string) (value bool, valid bool) { 316 switch input { 317 case "1", "true", "TRUE", "True": 318 return true, true 319 case "0", "false", "FALSE", "False": 320 return false, true 321 } 322 323 // Not a valid bool value 324 return 325 }