github.com/mrqzzz/migrate@v5.1.7+incompatible/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 ) 21 22 import ( 23 "github.com/golang-migrate/migrate/v4" 24 "github.com/golang-migrate/migrate/v4/database" 25 ) 26 27 func init() { 28 database.Register("mysql", &Mysql{}) 29 } 30 31 var DefaultMigrationsTable = "schema_migrations" 32 33 var ( 34 ErrDatabaseDirty = fmt.Errorf("database is dirty") 35 ErrNilConfig = fmt.Errorf("no config") 36 ErrNoDatabaseName = fmt.Errorf("no database name") 37 ErrAppendPEM = fmt.Errorf("failed to append PEM") 38 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 39 ) 40 41 type Config struct { 42 MigrationsTable string 43 DatabaseName string 44 } 45 46 type Mysql struct { 47 // mysql RELEASE_LOCK must be called from the same conn, so 48 // just do everything over a single conn anyway. 49 conn *sql.Conn 50 db *sql.DB 51 isLocked bool 52 53 config *Config 54 } 55 56 // instance must have `multiStatements` set to true 57 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 58 if config == nil { 59 return nil, ErrNilConfig 60 } 61 62 if err := instance.Ping(); err != nil { 63 return nil, err 64 } 65 66 query := `SELECT DATABASE()` 67 var databaseName sql.NullString 68 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 69 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 70 } 71 72 if len(databaseName.String) == 0 { 73 return nil, ErrNoDatabaseName 74 } 75 76 config.DatabaseName = databaseName.String 77 78 if len(config.MigrationsTable) == 0 { 79 config.MigrationsTable = DefaultMigrationsTable 80 } 81 82 conn, err := instance.Conn(context.Background()) 83 if err != nil { 84 return nil, err 85 } 86 87 mx := &Mysql{ 88 conn: conn, 89 db: instance, 90 config: config, 91 } 92 93 if err := mx.ensureVersionTable(); err != nil { 94 return nil, err 95 } 96 97 return mx, nil 98 } 99 100 // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. 101 // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters 102 func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) { 103 origUserInfo := u.User 104 u.User = nil 105 106 c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://")) 107 if err != nil { 108 return nil, err 109 } 110 if origUserInfo != nil { 111 c.User = origUserInfo.Username() 112 if p, ok := origUserInfo.Password(); ok { 113 c.Passwd = p 114 } 115 } 116 return c, nil 117 } 118 119 func (m *Mysql) Open(url string) (database.Driver, error) { 120 purl, err := nurl.Parse(url) 121 if err != nil { 122 return nil, err 123 } 124 125 q := purl.Query() 126 q.Set("multiStatements", "true") 127 purl.RawQuery = q.Encode() 128 129 migrationsTable := purl.Query().Get("x-migrations-table") 130 if len(migrationsTable) == 0 { 131 migrationsTable = DefaultMigrationsTable 132 } 133 134 // use custom TLS? 135 ctls := purl.Query().Get("tls") 136 if len(ctls) > 0 { 137 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 138 rootCertPool := x509.NewCertPool() 139 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 140 if err != nil { 141 return nil, err 142 } 143 144 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 145 return nil, ErrAppendPEM 146 } 147 148 clientCert := make([]tls.Certificate, 0, 1) 149 if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" { 150 if ccert == "" || ckey == "" { 151 return nil, ErrTLSCertKeyConfig 152 } 153 certs, err := tls.LoadX509KeyPair(ccert, ckey) 154 if err != nil { 155 return nil, err 156 } 157 clientCert = append(clientCert, certs) 158 } 159 160 insecureSkipVerify := false 161 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 162 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 163 if err != nil { 164 return nil, err 165 } 166 insecureSkipVerify = x 167 } 168 169 mysql.RegisterTLSConfig(ctls, &tls.Config{ 170 RootCAs: rootCertPool, 171 Certificates: clientCert, 172 InsecureSkipVerify: insecureSkipVerify, 173 }) 174 } 175 } 176 177 c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) 178 if err != nil { 179 return nil, err 180 } 181 db, err := sql.Open("mysql", c.FormatDSN()) 182 if err != nil { 183 return nil, err 184 } 185 186 mx, err := WithInstance(db, &Config{ 187 DatabaseName: purl.Path, 188 MigrationsTable: migrationsTable, 189 }) 190 if err != nil { 191 return nil, err 192 } 193 194 return mx, nil 195 } 196 197 func (m *Mysql) Close() error { 198 connErr := m.conn.Close() 199 dbErr := m.db.Close() 200 if connErr != nil || dbErr != nil { 201 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 202 } 203 return nil 204 } 205 206 func (m *Mysql) Lock() error { 207 if m.isLocked { 208 return database.ErrLocked 209 } 210 211 aid, err := database.GenerateAdvisoryLockId( 212 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 213 if err != nil { 214 return err 215 } 216 217 query := "SELECT GET_LOCK(?, 10)" 218 var success bool 219 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 220 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 221 } 222 223 if success { 224 m.isLocked = true 225 return nil 226 } 227 228 return database.ErrLocked 229 } 230 231 func (m *Mysql) Unlock() error { 232 if !m.isLocked { 233 return nil 234 } 235 236 aid, err := database.GenerateAdvisoryLockId( 237 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 238 if err != nil { 239 return err 240 } 241 242 query := `SELECT RELEASE_LOCK(?)` 243 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 244 return &database.Error{OrigErr: err, Query: []byte(query)} 245 } 246 247 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 248 // in which case isLocked should be true until the timeout expires -- synchronizing 249 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 250 251 m.isLocked = false 252 return nil 253 } 254 255 func (m *Mysql) Run(migration io.Reader) error { 256 migr, err := ioutil.ReadAll(migration) 257 if err != nil { 258 return err 259 } 260 261 query := string(migr[:]) 262 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 263 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 264 } 265 266 return nil 267 } 268 269 func (m *Mysql) SetVersion(version int, dirty bool) error { 270 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 271 if err != nil { 272 return &database.Error{OrigErr: err, Err: "transaction start failed"} 273 } 274 275 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 276 if _, err := tx.ExecContext(context.Background(), query); err != nil { 277 tx.Rollback() 278 return &database.Error{OrigErr: err, Query: []byte(query)} 279 } 280 281 if version >= 0 { 282 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 283 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 284 tx.Rollback() 285 return &database.Error{OrigErr: err, Query: []byte(query)} 286 } 287 } 288 289 if err := tx.Commit(); err != nil { 290 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 291 } 292 293 return nil 294 } 295 296 func (m *Mysql) Version() (version int, dirty bool, err error) { 297 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 298 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 299 switch { 300 case err == sql.ErrNoRows: 301 return database.NilVersion, false, nil 302 303 case err != nil: 304 if e, ok := err.(*mysql.MySQLError); ok { 305 if e.Number == 0 { 306 return database.NilVersion, false, nil 307 } 308 } 309 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 310 311 default: 312 return version, dirty, nil 313 } 314 } 315 316 func (m *Mysql) Drop() error { 317 // select all tables 318 query := `SHOW TABLES LIKE '%'` 319 tables, err := m.conn.QueryContext(context.Background(), query) 320 if err != nil { 321 return &database.Error{OrigErr: err, Query: []byte(query)} 322 } 323 defer tables.Close() 324 325 // delete one table after another 326 tableNames := make([]string, 0) 327 for tables.Next() { 328 var tableName string 329 if err := tables.Scan(&tableName); err != nil { 330 return err 331 } 332 if len(tableName) > 0 { 333 tableNames = append(tableNames, tableName) 334 } 335 } 336 337 if len(tableNames) > 0 { 338 // delete one by one ... 339 for _, t := range tableNames { 340 query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" 341 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 342 return &database.Error{OrigErr: err, Query: []byte(query)} 343 } 344 } 345 if err := m.ensureVersionTable(); err != nil { 346 return err 347 } 348 } 349 350 return nil 351 } 352 353 func (m *Mysql) ensureVersionTable() error { 354 // check if migration table exists 355 var result string 356 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 357 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 358 if err != sql.ErrNoRows { 359 return &database.Error{OrigErr: err, Query: []byte(query)} 360 } 361 } else { 362 return nil 363 } 364 365 // if not, create the empty migration table 366 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 367 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 368 return &database.Error{OrigErr: err, Query: []byte(query)} 369 } 370 return nil 371 } 372 373 // Returns the bool value of the input. 374 // The 2nd return value indicates if the input was a valid bool value 375 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 376 func readBool(input string) (value bool, valid bool) { 377 switch input { 378 case "1", "true", "TRUE", "True": 379 return true, true 380 case "0", "false", "FALSE", "False": 381 return false, true 382 } 383 384 // Not a valid bool value 385 return 386 }