github.com/eatigo/migrate@v3.0.2-0.20210729130915-7610befb1b6b+incompatible/database/mysql/mysql.go (about) 1 package mysql 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "database/sql" 7 "fmt" 8 "io" 9 "io/ioutil" 10 nurl "net/url" 11 "strconv" 12 "strings" 13 14 "github.com/go-sql-driver/mysql" 15 "github.com/eatigo/migrate" 16 "github.com/eatigo/migrate/database" 17 ) 18 19 func init() { 20 database.Register("mysql", &Mysql{}) 21 } 22 23 var DefaultMigrationsTable = "schema_migrations" 24 25 var ( 26 ErrDatabaseDirty = fmt.Errorf("database is dirty") 27 ErrNilConfig = fmt.Errorf("no config") 28 ErrNoDatabaseName = fmt.Errorf("no database name") 29 ErrAppendPEM = fmt.Errorf("failed to append PEM") 30 ) 31 32 type Config struct { 33 MigrationsTable string 34 DatabaseName string 35 } 36 37 type Mysql struct { 38 db *sql.DB 39 isLocked bool 40 41 config *Config 42 } 43 44 // instance must have `multiStatements` set to true 45 func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } 49 50 if err := instance.Ping(); err != nil { 51 return nil, err 52 } 53 54 query := `SELECT DATABASE()` 55 var databaseName sql.NullString 56 if err := instance.QueryRow(query).Scan(&databaseName); err != nil { 57 return nil, &database.Error{OrigErr: err, Query: []byte(query)} 58 } 59 60 if len(databaseName.String) == 0 { 61 return nil, ErrNoDatabaseName 62 } 63 64 config.DatabaseName = databaseName.String 65 66 if len(config.MigrationsTable) == 0 { 67 config.MigrationsTable = DefaultMigrationsTable 68 } 69 70 mx := &Mysql{ 71 db: instance, 72 config: config, 73 } 74 75 if err := mx.ensureVersionTable(); err != nil { 76 return nil, err 77 } 78 79 return mx, nil 80 } 81 82 func (m *Mysql) Open(url string) (database.Driver, error) { 83 purl, err := nurl.Parse(url) 84 if err != nil { 85 return nil, err 86 } 87 88 q := purl.Query() 89 q.Set("multiStatements", "true") 90 purl.RawQuery = q.Encode() 91 92 db, err := sql.Open("mysql", strings.Replace( 93 migrate.FilterCustomQuery(purl).String(), "mysql://", "", 1)) 94 if err != nil { 95 return nil, err 96 } 97 98 migrationsTable := purl.Query().Get("x-migrations-table") 99 if len(migrationsTable) == 0 { 100 migrationsTable = DefaultMigrationsTable 101 } 102 103 // use custom TLS? 104 ctls := purl.Query().Get("tls") 105 if len(ctls) > 0 { 106 if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" { 107 rootCertPool := x509.NewCertPool() 108 pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca")) 109 if err != nil { 110 return nil, err 111 } 112 113 if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { 114 return nil, ErrAppendPEM 115 } 116 117 certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key")) 118 if err != nil { 119 return nil, err 120 } 121 122 insecureSkipVerify := false 123 if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 { 124 x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify")) 125 if err != nil { 126 return nil, err 127 } 128 insecureSkipVerify = x 129 } 130 131 mysql.RegisterTLSConfig(ctls, &tls.Config{ 132 RootCAs: rootCertPool, 133 Certificates: []tls.Certificate{certs}, 134 InsecureSkipVerify: insecureSkipVerify, 135 }) 136 } 137 } 138 139 mx, err := WithInstance(db, &Config{ 140 DatabaseName: purl.Path, 141 MigrationsTable: migrationsTable, 142 }) 143 if err != nil { 144 return nil, err 145 } 146 147 return mx, nil 148 } 149 150 func (m *Mysql) Close() error { 151 return m.db.Close() 152 } 153 154 func (m *Mysql) Lock() error { 155 if m.isLocked { 156 return database.ErrLocked 157 } 158 159 aid, err := database.GenerateAdvisoryLockId( 160 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 161 if err != nil { 162 return err 163 } 164 165 query := "SELECT GET_LOCK(?, 1)" 166 var success bool 167 if err := m.db.QueryRow(query, aid).Scan(&success); err != nil { 168 return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} 169 } 170 171 if success { 172 m.isLocked = true 173 return nil 174 } 175 176 return database.ErrLocked 177 } 178 179 func (m *Mysql) Unlock() error { 180 if !m.isLocked { 181 return nil 182 } 183 184 aid, err := database.GenerateAdvisoryLockId( 185 fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable)) 186 if err != nil { 187 return err 188 } 189 190 query := `SELECT RELEASE_LOCK(?)` 191 if _, err := m.db.Exec(query, aid); err != nil { 192 return &database.Error{OrigErr: err, Query: []byte(query)} 193 } 194 195 m.isLocked = false 196 return nil 197 } 198 199 func (m *Mysql) Run(migration io.Reader) error { 200 migr, err := ioutil.ReadAll(migration) 201 if err != nil { 202 return err 203 } 204 205 query := string(migr[:]) 206 if _, err := m.db.Exec(query); err != nil { 207 return database.Error{OrigErr: err, Err: "migration failed", Query: migr} 208 } 209 210 return nil 211 } 212 213 func (m *Mysql) SetVersion(version int, dirty bool) error { 214 tx, err := m.db.Begin() 215 if err != nil { 216 return &database.Error{OrigErr: err, Err: "transaction start failed"} 217 } 218 219 query := "TRUNCATE `" + m.config.MigrationsTable + "`" 220 if _, err := m.db.Exec(query); err != nil { 221 return &database.Error{OrigErr: err, Query: []byte(query)} 222 } 223 224 if version >= 0 { 225 query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)" 226 if _, err := m.db.Exec(query, version, dirty); err != nil { 227 tx.Rollback() 228 return &database.Error{OrigErr: err, Query: []byte(query)} 229 } 230 } 231 232 if err := tx.Commit(); err != nil { 233 return &database.Error{OrigErr: err, Err: "transaction commit failed"} 234 } 235 236 return nil 237 } 238 239 func (m *Mysql) Version() (version int, dirty bool, err error) { 240 query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1" 241 err = m.db.QueryRow(query).Scan(&version, &dirty) 242 switch { 243 case err == sql.ErrNoRows: 244 return database.NilVersion, false, nil 245 246 case err != nil: 247 if e, ok := err.(*mysql.MySQLError); ok { 248 if e.Number == 0 { 249 return database.NilVersion, false, nil 250 } 251 } 252 return 0, false, &database.Error{OrigErr: err, Query: []byte(query)} 253 254 default: 255 return version, dirty, nil 256 } 257 } 258 259 func (m *Mysql) Drop() error { 260 // select all tables 261 query := `SHOW TABLES LIKE '%'` 262 tables, err := m.db.Query(query) 263 if err != nil { 264 return &database.Error{OrigErr: err, Query: []byte(query)} 265 } 266 defer tables.Close() 267 268 // delete one table after another 269 tableNames := make([]string, 0) 270 for tables.Next() { 271 var tableName string 272 if err := tables.Scan(&tableName); err != nil { 273 return err 274 } 275 if len(tableName) > 0 { 276 tableNames = append(tableNames, tableName) 277 } 278 } 279 280 if len(tableNames) > 0 { 281 // delete one by one ... 282 for _, t := range tableNames { 283 query = "SET FOREIGN_KEY_CHECKS=0; DROP TABLE IF EXISTS `" + t + "` CASCADE; SET FOREIGN_KEY_CHECKS=1;" 284 if _, err := m.db.Exec(query); err != nil { 285 return &database.Error{OrigErr: err, Query: []byte(query)} 286 } 287 } 288 if err := m.ensureVersionTable(); err != nil { 289 return err 290 } 291 } 292 293 return nil 294 } 295 296 func (m *Mysql) ensureVersionTable() error { 297 // check if migration table exists 298 var result string 299 query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` 300 if err := m.db.QueryRow(query).Scan(&result); err != nil { 301 if err != sql.ErrNoRows { 302 return &database.Error{OrigErr: err, Query: []byte(query)} 303 } 304 } else { 305 return nil 306 } 307 308 // if not, create the empty migration table 309 query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" 310 if _, err := m.db.Exec(query); err != nil { 311 return &database.Error{OrigErr: err, Query: []byte(query)} 312 } 313 return nil 314 } 315 316 // Returns the bool value of the input. 317 // The 2nd return value indicates if the input was a valid bool value 318 // See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71 319 func readBool(input string) (value bool, valid bool) { 320 switch input { 321 case "1", "true", "TRUE", "True": 322 return true, true 323 case "0", "false", "FALSE", "False": 324 return false, true 325 } 326 327 // Not a valid bool value 328 return 329 }