github.com/dhui/migrate@v3.4.0+incompatible/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 ) 21 22 import ( 23 "github.com/golang-migrate/migrate" 24 "github.com/golang-migrate/migrate/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ) 39 40 type Config struct { 41 MigrationsTable string 42 DatabaseName string 43 } 44 45 type Mysql struct { 46 // mysql RELEASE_LOCK must be called from the same conn, so 47 // just do everything over a single conn anyway. 48 conn *sql.Conn 49 isLocked bool 50 51 config *Config 52 } 53 54 // instance must have `multiStatements` set to true 55 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 56 if config == nil { 57 return nil, ErrNilConfig 58 } 59 60 if err := instance.Ping(); err != nil { 61 return nil, err 62 } 63 64 query := `SELECT DATABASE()` 65 var databaseName sql.NullString 66 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 67 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 68 } 69 70 if len(databaseName.String) == 0 { 71 return nil, ErrNoDatabaseName 72 } 73 74 config.DatabaseName = databaseName.String 75 76 if len(config.MigrationsTable) == 0 { 77 config.MigrationsTable = DefaultMigrationsTable 78 } 79 80 conn, err := instance.Conn(context.Background()) 81 if err != nil { 82 return nil, err 83 } 84 85 mx := &Mysql{ 86 conn: conn, 87 config: config, 88 } 89 90 if err := mx.ensureVersionTable(); err != nil { 91 return nil, err 92 } 93 94 return mx, nil 95 } 96 97 // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. 98 // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters 99 func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) { 100 origUserInfo := u.User 101 u.User = nil 102 103 c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://")) 104 if err != nil { 105 return nil, err 106 } 107 if origUserInfo != nil { 108 c.User = origUserInfo.Username() 109 if p, ok := origUserInfo.Password(); ok { 110 c.Passwd = p 111 } 112 } 113 return c, nil 114 } 115 116 func (m *Mysql) Open(url string) (database.Driver, error) { 117 purl, err := nurl.Parse(url) 118 if err != nil { 119 return nil, err 120 } 121 122 q := purl.Query() 123 q.Set("multiStatements", "true") 124 purl.RawQuery = q.Encode() 125 126 c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) 127 if err != nil { 128 return nil, err 129 } 130 db, err := sql.Open("mysql", c.FormatDSN()) 131 if err != nil { 132 return nil, err 133 } 134 135 migrationsTable := purl.Query().Get("x-migrations-table") 136 if len(migrationsTable) == 0 { 137 migrationsTable = DefaultMigrationsTable 138 } 139 140 // use custom TLS? 141 ctls := purl.Query().Get("tls") 142 if len(ctls) > 0 { 143 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 144 rootCertPool := x509.NewCertPool() 145 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 146 if err != nil { 147 return nil, err 148 } 149 150 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 151 return nil, ErrAppendPEM 152 } 153 154 certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key")) 155 if err != nil { 156 return nil, err 157 } 158 159 insecureSkipVerify := false 160 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 161 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 162 if err != nil { 163 return nil, err 164 } 165 insecureSkipVerify = x 166 } 167 168 mysql.RegisterTLSConfig(ctls, &tls.Config{ 169 RootCAs: rootCertPool, 170 Certificates: []tls.Certificate{certs}, 171 InsecureSkipVerify: insecureSkipVerify, 172 }) 173 } 174 } 175 176 mx, err := WithInstance(db, &Config{ 177 DatabaseName: purl.Path, 178 MigrationsTable: migrationsTable, 179 }) 180 if err != nil { 181 return nil, err 182 } 183 184 return mx, nil 185 } 186 187 func (m *Mysql) Close() error { 188 return m.conn.Close() 189 } 190 191 func (m *Mysql) Lock() error { 192 if m.isLocked { 193 return database.ErrLocked 194 } 195 196 aid, err := database.GenerateAdvisoryLockId( 197 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 198 if err != nil { 199 return err 200 } 201 202 query := "SELECT GET_LOCK(?, 10)" 203 var success bool 204 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 205 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 206 } 207 208 if success { 209 m.isLocked = true 210 return nil 211 } 212 213 return database.ErrLocked 214 } 215 216 func (m *Mysql) Unlock() error { 217 if !m.isLocked { 218 return nil 219 } 220 221 aid, err := database.GenerateAdvisoryLockId( 222 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 223 if err != nil { 224 return err 225 } 226 227 query := `SELECT RELEASE_LOCK(?)` 228 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 229 return &database.Error{OrigErr: err, Query: []byte(query)} 230 } 231 232 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 233 // in which case isLocked should be true until the timeout expires -- synchronizing 234 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 235 236 m.isLocked = false 237 return nil 238 } 239 240 func (m *Mysql) Run(migration io.Reader) error { 241 migr, err := ioutil.ReadAll(migration) 242 if err != nil { 243 return err 244 } 245 246 query := string(migr[:]) 247 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 248 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 249 } 250 251 return nil 252 } 253 254 func (m *Mysql) SetVersion(version int, dirty bool) error { 255 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 256 if err != nil { 257 return &database.Error{OrigErr: err, Err: "transaction start failed"} 258 } 259 260 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 261 if _, err := tx.ExecContext(context.Background(), query); err != nil { 262 tx.Rollback() 263 return &database.Error{OrigErr: err, Query: []byte(query)} 264 } 265 266 if version >= 0 { 267 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 268 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 269 tx.Rollback() 270 return &database.Error{OrigErr: err, Query: []byte(query)} 271 } 272 } 273 274 if err := tx.Commit(); err != nil { 275 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 276 } 277 278 return nil 279 } 280 281 func (m *Mysql) Version() (version int, dirty bool, err error) { 282 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 283 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 284 switch { 285 case err == sql.ErrNoRows: 286 return database.NilVersion, false, nil 287 288 case err != nil: 289 if e, ok := err.(*mysql.MySQLError); ok { 290 if e.Number == 0 { 291 return database.NilVersion, false, nil 292 } 293 } 294 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 295 296 default: 297 return version, dirty, nil 298 } 299 } 300 301 func (m *Mysql) Drop() error { 302 // select all tables 303 query := `SHOW TABLES LIKE '%'` 304 tables, err := m.conn.QueryContext(context.Background(), query) 305 if err != nil { 306 return &database.Error{OrigErr: err, Query: []byte(query)} 307 } 308 defer tables.Close() 309 310 // delete one table after another 311 tableNames := make([]string, 0) 312 for tables.Next() { 313 var tableName string 314 if err := tables.Scan(&tableName); err != nil { 315 return err 316 } 317 if len(tableName) > 0 { 318 tableNames = append(tableNames, tableName) 319 } 320 } 321 322 if len(tableNames) > 0 { 323 // delete one by one ... 324 for _, t := range tableNames { 325 query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" 326 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 327 return &database.Error{OrigErr: err, Query: []byte(query)} 328 } 329 } 330 if err := m.ensureVersionTable(); err != nil { 331 return err 332 } 333 } 334 335 return nil 336 } 337 338 func (m *Mysql) ensureVersionTable() error { 339 // check if migration table exists 340 var result string 341 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 342 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 343 if err != sql.ErrNoRows { 344 return &database.Error{OrigErr: err, Query: []byte(query)} 345 } 346 } else { 347 return nil 348 } 349 350 // if not, create the empty migration table 351 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 352 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 353 return &database.Error{OrigErr: err, Query: []byte(query)} 354 } 355 return nil 356 } 357 358 // Returns the bool value of the input. 359 // The 2nd return value indicates if the input was a valid bool value 360 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 361 func readBool(input string) (value bool, valid bool) { 362 switch input { 363 case "1", "true", "TRUE", "True": 364 return true, true 365 case "0", "false", "FALSE", "False": 366 return false, true 367 } 368 369 // Not a valid bool value 370 return 371 }