github.com/matcornic/migrate@v3.3.2-0.20180717234201-feea45c20506+incompatible/database/mysql/mysql.go (about) 1 // +build go1.9 2 3 package mysql 4 5 import ( 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "database/sql" 10 "fmt" 11 "io" 12 "io/ioutil" 13 nurl "net/url" 14 "strconv" 15 "strings" 16 17 "github.com/go-sql-driver/mysql" 18 "github.com/golang-migrate/migrate" 19 "github.com/golang-migrate/migrate/database" 20 ) 21 22 func init() { 23 database.Register("mysql", &Mysql{}) 24 } 25 26 var DefaultMigrationsTable = "schema_migrations" 27 28 var ( 29 ErrDatabaseDirty = fmt.Errorf("database is dirty") 30 ErrNilConfig = fmt.Errorf("no config") 31 ErrNoDatabaseName = fmt.Errorf("no database name") 32 ErrAppendPEM = fmt.Errorf("failed to append PEM") 33 ) 34 35 type Config struct { 36 MigrationsTable string 37 DatabaseName string 38 } 39 40 type Mysql struct { 41 // mysql RELEASE_LOCK must be called from the same conn, so 42 // just do everything over a single conn anyway. 43 conn *sql.Conn 44 isLocked bool 45 46 config *Config 47 } 48 49 // instance must have `multiStatements` set to true 50 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 51 if config == nil { 52 return nil, ErrNilConfig 53 } 54 55 if err := instance.Ping(); err != nil { 56 return nil, err 57 } 58 59 query := `SELECT DATABASE()` 60 var databaseName sql.NullString 61 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 62 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 63 } 64 65 if len(databaseName.String) == 0 { 66 return nil, ErrNoDatabaseName 67 } 68 69 config.DatabaseName = databaseName.String 70 71 if len(config.MigrationsTable) == 0 { 72 config.MigrationsTable = DefaultMigrationsTable 73 } 74 75 conn, err := instance.Conn(context.Background()) 76 if err != nil { 77 return nil, err 78 } 79 80 mx := &Mysql{ 81 conn: conn, 82 config: config, 83 } 84 85 if err := mx.ensureVersionTable(); err != nil { 86 return nil, err 87 } 88 89 return mx, nil 90 } 91 92 func (m *Mysql) Open(url string) (database.Driver, error) { 93 url = strings.TrimPrefix(url, "mysql://") 94 purl, err := nurl.Parse(url) 95 if err != nil { 96 return nil, err 97 } 98 99 q := purl.Query() 100 q.Set("multiStatements", "true") 101 purl.RawQuery = q.Encode() 102 103 db, err := sql.Open("mysql", migrate.FilterCustomQuery(purl).String()) 104 if err != nil { 105 return nil, err 106 } 107 108 migrationsTable := purl.Query().Get("x-migrations-table") 109 if len(migrationsTable) == 0 { 110 migrationsTable = DefaultMigrationsTable 111 } 112 113 // use custom TLS? 114 ctls := purl.Query().Get("tls") 115 if len(ctls) > 0 { 116 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 117 rootCertPool := x509.NewCertPool() 118 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 119 if err != nil { 120 return nil, err 121 } 122 123 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 124 return nil, ErrAppendPEM 125 } 126 127 certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key")) 128 if err != nil { 129 return nil, err 130 } 131 132 insecureSkipVerify := false 133 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 134 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 135 if err != nil { 136 return nil, err 137 } 138 insecureSkipVerify = x 139 } 140 141 mysql.RegisterTLSConfig(ctls, &tls.Config{ 142 RootCAs: rootCertPool, 143 Certificates: []tls.Certificate{certs}, 144 InsecureSkipVerify: insecureSkipVerify, 145 }) 146 } 147 } 148 149 mx, err := WithInstance(db, &Config{ 150 DatabaseName: purl.Path, 151 MigrationsTable: migrationsTable, 152 }) 153 if err != nil { 154 return nil, err 155 } 156 157 return mx, nil 158 } 159 160 func (m *Mysql) Close() error { 161 return m.conn.Close() 162 } 163 164 func (m *Mysql) Lock() error { 165 if m.isLocked { 166 return database.ErrLocked 167 } 168 169 aid, err := database.GenerateAdvisoryLockId( 170 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 171 if err != nil { 172 return err 173 } 174 175 query := "SELECT GET_LOCK(?, 10)" 176 var success bool 177 if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil { 178 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 179 } 180 181 if success { 182 m.isLocked = true 183 return nil 184 } 185 186 return database.ErrLocked 187 } 188 189 func (m *Mysql) Unlock() error { 190 if !m.isLocked { 191 return nil 192 } 193 194 aid, err := database.GenerateAdvisoryLockId( 195 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 196 if err != nil { 197 return err 198 } 199 200 query := `SELECT RELEASE_LOCK(?)` 201 if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil { 202 return &database.Error{OrigErr: err, Query: []byte(query)} 203 } 204 205 // NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed), 206 // in which case isLocked should be true until the timeout expires -- synchronizing 207 // these states is likely not worth trying to do; reconsider the necessity of isLocked. 208 209 m.isLocked = false 210 return nil 211 } 212 213 func (m *Mysql) Run(migration io.Reader) error { 214 migr, err := ioutil.ReadAll(migration) 215 if err != nil { 216 return err 217 } 218 219 query := string(migr[:]) 220 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 221 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 222 } 223 224 return nil 225 } 226 227 func (m *Mysql) SetVersion(version int, dirty bool) error { 228 tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{}) 229 if err != nil { 230 return &database.Error{OrigErr: err, Err: "transaction start failed"} 231 } 232 233 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 234 if _, err := tx.ExecContext(context.Background(), query); err != nil { 235 tx.Rollback() 236 return &database.Error{OrigErr: err, Query: []byte(query)} 237 } 238 239 if version >= 0 { 240 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 241 if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil { 242 tx.Rollback() 243 return &database.Error{OrigErr: err, Query: []byte(query)} 244 } 245 } 246 247 if err := tx.Commit(); err != nil { 248 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 249 } 250 251 return nil 252 } 253 254 func (m *Mysql) Version() (version int, dirty bool, err error) { 255 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 256 err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) 257 switch { 258 case err == sql.ErrNoRows: 259 return database.NilVersion, false, nil 260 261 case err != nil: 262 if e, ok := err.(*mysql.MySQLError); ok { 263 if e.Number == 0 { 264 return database.NilVersion, false, nil 265 } 266 } 267 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 268 269 default: 270 return version, dirty, nil 271 } 272 } 273 274 func (m *Mysql) Drop() error { 275 // select all tables 276 query := `SHOW TABLES LIKE '%'` 277 tables, err := m.conn.QueryContext(context.Background(), query) 278 if err != nil { 279 return &database.Error{OrigErr: err, Query: []byte(query)} 280 } 281 defer tables.Close() 282 283 // delete one table after another 284 tableNames := make([]string, 0) 285 for tables.Next() { 286 var tableName string 287 if err := tables.Scan(&tableName); err != nil { 288 return err 289 } 290 if len(tableName) > 0 { 291 tableNames = append(tableNames, tableName) 292 } 293 } 294 295 if len(tableNames) > 0 { 296 // delete one by one ... 297 for _, t := range tableNames { 298 query = "DROP TABLE IF EXISTS `" + t + "` CASCADE" 299 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 300 return &database.Error{OrigErr: err, Query: []byte(query)} 301 } 302 } 303 if err := m.ensureVersionTable(); err != nil { 304 return err 305 } 306 } 307 308 return nil 309 } 310 311 func (m *Mysql) ensureVersionTable() error { 312 // check if migration table exists 313 var result string 314 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 315 if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil { 316 if err != sql.ErrNoRows { 317 return &database.Error{OrigErr: err, Query: []byte(query)} 318 } 319 } else { 320 return nil 321 } 322 323 // if not, create the empty migration table 324 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 325 if _, err := m.conn.ExecContext(context.Background(), query); err != nil { 326 return &database.Error{OrigErr: err, Query: []byte(query)} 327 } 328 return nil 329 } 330 331 // Returns the bool value of the input. 332 // The 2nd return value indicates if the input was a valid bool value 333 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 334 func readBool(input string) (value bool, valid bool) { 335 switch input { 336 case "1", "true", "TRUE", "True": 337 return true, true 338 case "0", "false", "FALSE", "False": 339 return false, true 340 } 341 342 // Not a valid bool value 343 return 344 }