github.com/fr-nvriep/migrate/v4@v4.3.2/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 ) 17 18 import ( 19 "github.com/go-sql-driver/mysql" 20 "github.com/hashicorp/go-multierror" 21 ) 22 23 import ( 24 "github.com/fr-nvriep/migrate/v4" 25 "github.com/fr-nvriep/migrate/v4/database" 26 ) 27 28 func init() { 29 database.Register("mysql", &Mysql{}) 30 } 31 32 var DefaultMigrationsTable = "schema_migrations" 33 34 var ( 35 ErrDatabaseDirty = fmt.Errorf("database is dirty") 36 ErrNilConfig = fmt.Errorf("no config") 37 ErrNoDatabaseName = fmt.Errorf("no database name") 38 ErrAppendPEM = fmt.Errorf("failed to append PEM") 39 ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty") 40 ) 41 42 type Config struct { 43 MigrationsTable string 44 DatabaseName string 45 } 46 47 type Mysql struct { 48 // mysql RELEASE_LOCK must be called from the same conn, so 49 // just do everything over a single conn anyway. 50 conn *sql.Conn 51 db *sql.DB 52 isLocked bool 53 54 config *Config 55 } 56 57 // instance must have `multiStatements` set to true 58 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 59 if config == nil { 60 return nil, ErrNilConfig 61 } 62 63 if err := instance.Ping(); err != nil { 64 return nil, err 65 } 66 67 query := `SELECT DATABASE()` 68 var databaseName sql.NullString 69 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 70 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 71 } 72 73 if len(databaseName.String) == 0 { 74 return nil, ErrNoDatabaseName 75 } 76 77 config.DatabaseName = databaseName.String 78 79 if len(config.MigrationsTable) == 0 { 80 config.MigrationsTable = DefaultMigrationsTable 81 } 82 83 conn, err := instance.Conn(context.Background()) 84 if err != nil { 85 return nil, err 86 } 87 88 mx := &Mysql{ 89 conn: conn, 90 db: instance, 91 config: config, 92 } 93 94 if err := mx.ensureVersionTable(); err != nil { 95 return nil, err 96 } 97 98 return mx, nil 99 } 100 101 // urlToMySQLConfig takes a net/url URL and returns a go-sql-driver/mysql Config. 102 // Manually sets username and password to avoid net/url from url-encoding the reserved URL characters 103 func urlToMySQLConfig(u nurl.URL) (*mysql.Config, error) { 104 origUserInfo := u.User 105 u.User = nil 106 107 c, err := mysql.ParseDSN(strings.TrimPrefix(u.String(), "mysql://")) 108 if err != nil { 109 return nil, err 110 } 111 if origUserInfo != nil { 112 c.User = origUserInfo.Username() 113 if p, ok := origUserInfo.Password(); ok { 114 c.Passwd = p 115 } 116 } 117 return c, nil 118 } 119 120 func (m *Mysql) Open(url string) (database.Driver, error) { 121 purl, err := nurl.Parse(url) 122 if err != nil { 123 return nil, err 124 } 125 126 q := purl.Query() 127 q.Set("multiStatements", "true") 128 purl.RawQuery = q.Encode() 129 130 migrationsTable := purl.Query().Get("x-migrations-table") 131 132 // use custom TLS? 133 ctls := purl.Query().Get("tls") 134 if len(ctls) > 0 { 135 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 136 rootCertPool := x509.NewCertPool() 137 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 138 if err != nil { 139 return nil, err 140 } 141 142 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 143 return nil, ErrAppendPEM 144 } 145 146 clientCert := make([]tls.Certificate, 0, 1) 147 if ccert, ckey := purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"); ccert != "" || ckey != "" { 148 if ccert == "" || ckey == "" { 149 return nil, ErrTLSCertKeyConfig 150 } 151 certs, err := tls.LoadX509KeyPair(ccert, ckey) 152 if err != nil { 153 return nil, err 154 } 155 clientCert = append(clientCert, certs) 156 } 157 158 insecureSkipVerify := false 159 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 160 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 161 if err != nil { 162 return nil, err 163 } 164 insecureSkipVerify = x 165 } 166 167 err = mysql.RegisterTLSConfig(ctls, &tls.Config{ 168 RootCAs: rootCertPool, 169 Certificates: clientCert, 170 InsecureSkipVerify: insecureSkipVerify, 171 }) 172 if err != nil { 173 return nil, err 174 } 175 } 176 } 177 178 c, err := urlToMySQLConfig(*migrate.FilterCustomQuery(purl)) 179 if err != nil { 180 return nil, err 181 } 182 db, err := sql.Open("mysql", c.FormatDSN()) 183 if err != nil { 184 return nil, err 185 } 186 187 mx, err := WithInstance(db, &Config{ 188 DatabaseName: purl.Path, 189 MigrationsTable: migrationsTable, 190 }) 191 if err != nil { 192 return nil, err 193 } 194 195 return mx, nil 196 } 197 198 func (m *Mysql) Close() error { 199 connErr := m.conn.Close() 200 dbErr := m.db.Close() 201 if connErr != nil || dbErr != nil { 202 return fmt.Errorf("conn: %v, db: %v", connErr, dbErr) 203 } 204 return nil 205 } 206 207 func (m *Mysql) Lock() error { 208 if m.isLocked { 209 return database.ErrLocked 210 } 211 212 aid, err := database.GenerateAdvisoryLockId( 213 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 214 if err != nil { 215 return err 216 } 217 218 query := "SELECT GET_LOCK(?, 10)" 219 var success bool 220 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 221 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 222 } 223 224 if success { 225 m.isLocked = true 226 return nil 227 } 228 229 return database.ErrLocked 230 } 231 232 func (m *Mysql) Unlock() error { 233 if !m.isLocked { 234 return nil 235 } 236 237 aid, err := database.GenerateAdvisoryLockId( 238 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 239 if err != nil { 240 return err 241 } 242 243 query := `SELECT RELEASE_LOCK(?)` 244 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 245 return &database.Error{OrigErr: err, Query: []byte(query)} 246 } 247 248 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 249 // in which case isLocked should be true until the timeout expires -- synchronizing 250 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 251 252 m.isLocked = false 253 return nil 254 } 255 256 func (m *Mysql) Run(migration io.Reader) error { 257 migr, err := ioutil.ReadAll(migration) 258 if err != nil { 259 return err 260 } 261 262 query := string(migr[:]) 263 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 264 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 265 } 266 267 return nil 268 } 269 270 func (m *Mysql) SetVersion(version int, dirty bool) error { 271 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 272 if err != nil { 273 return &database.Error{OrigErr: err, Err: "transaction start failed"} 274 } 275 276 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 277 if _, err := tx.ExecContext(context.Background(), query); err != nil { 278 if errRollback := tx.Rollback(); errRollback != nil { 279 err = multierror.Append(err, errRollback) 280 } 281 return &database.Error{OrigErr: err, Query: []byte(query)} 282 } 283 284 if version >= 0 { 285 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 286 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 287 if errRollback := tx.Rollback(); errRollback != nil { 288 err = multierror.Append(err, errRollback) 289 } 290 return &database.Error{OrigErr: err, Query: []byte(query)} 291 } 292 } 293 294 if err := tx.Commit(); err != nil { 295 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 296 } 297 298 return nil 299 } 300 301 func (m *Mysql) Version() (version int, dirty bool, err error) { 302 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 303 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 304 switch { 305 case err == sql.ErrNoRows: 306 return database.NilVersion, false, nil 307 308 case err != nil: 309 if e, ok := err.(*mysql.MySQLError); ok { 310 if e.Number == 0 { 311 return database.NilVersion, false, nil 312 } 313 } 314 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 315 316 default: 317 return version, dirty, nil 318 } 319 } 320 321 func (m *Mysql) Drop() (err error) { 322 // select all tables 323 query := `SHOW TABLES LIKE '%'` 324 tables, err := m.conn.QueryContext(context.Background(), query) 325 if err != nil { 326 return &database.Error{OrigErr: err, Query: []byte(query)} 327 } 328 defer func() { 329 if errClose := tables.Close(); errClose != nil { 330 err = multierror.Append(err, errClose) 331 } 332 }() 333 334 // delete one table after another 335 tableNames := make([]string, 0) 336 for tables.Next() { 337 var tableName string 338 if err := tables.Scan(&tableName); err != nil { 339 return err 340 } 341 if len(tableName) > 0 { 342 tableNames = append(tableNames, tableName) 343 } 344 } 345 346 if len(tableNames) > 0 { 347 // delete one by one ... 348 for _, t := range tableNames { 349 query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" 350 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 351 return &database.Error{OrigErr: err, Query: []byte(query)} 352 } 353 } 354 } 355 356 return nil 357 } 358 359 // ensureVersionTable checks if versions table exists and, if not, creates it. 360 // Note that this function locks the database, which deviates from the usual 361 // convention of "caller locks" in the Mysql type. 362 func (m *Mysql) ensureVersionTable() (err error) { 363 if err = m.Lock(); err != nil { 364 return err 365 } 366 367 defer func() { 368 if e := m.Unlock(); e != nil { 369 if err == nil { 370 err = e 371 } else { 372 err = multierror.Append(err, e) 373 } 374 } 375 }() 376 377 // check if migration table exists 378 var result string 379 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 380 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 381 if err != sql.ErrNoRows { 382 return &database.Error{OrigErr: err, Query: []byte(query)} 383 } 384 } else { 385 return nil 386 } 387 388 // if not, create the empty migration table 389 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 390 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 391 return &database.Error{OrigErr: err, Query: []byte(query)} 392 } 393 return nil 394 } 395 396 // Returns the bool value of the input. 397 // The 2nd return value indicates if the input was a valid bool value 398 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 399 func readBool(input string) (value bool, valid bool) { 400 switch input { 401 case "1", "true", "TRUE", "True": 402 return true, true 403 case "0", "false", "FALSE", "False": 404 return false, true 405 } 406 407 // Not a valid bool value 408 return 409 }