github.com/bingtel/dbmate@v1.4.1/pkg/dbmate/db.go (about) 1 package dbmate 2 3 import ( 4 "database/sql" 5 "fmt" 6 "io/ioutil" 7 "net/url" 8 "os" 9 "path/filepath" 10 "regexp" 11 "sort" 12 "time" 13 ) 14 15 // DefaultMigrationsDir specifies default directory to find migration files 16 const DefaultMigrationsDir = "./db/migrations" 17 18 // DefaultSchemaFile specifies default location for schema.sql 19 const DefaultSchemaFile = "./db/schema.sql" 20 21 // DefaultWaitInterval specifies length of time between connection attempts 22 const DefaultWaitInterval = time.Second 23 24 // DefaultWaitTimeout specifies maximum time for connection attempts 25 const DefaultWaitTimeout = 60 * time.Second 26 27 // DB allows dbmate actions to be performed on a specified database 28 type DB struct { 29 AutoDumpSchema bool 30 DatabaseURL *url.URL 31 MigrationsDir string 32 SchemaFile string 33 WaitInterval time.Duration 34 WaitTimeout time.Duration 35 } 36 37 // New initializes a new dbmate database 38 func New(databaseURL *url.URL) *DB { 39 return &DB{ 40 AutoDumpSchema: true, 41 DatabaseURL: databaseURL, 42 MigrationsDir: DefaultMigrationsDir, 43 SchemaFile: DefaultSchemaFile, 44 WaitInterval: DefaultWaitInterval, 45 WaitTimeout: DefaultWaitTimeout, 46 } 47 } 48 49 // GetDriver loads the required database driver 50 func (db *DB) GetDriver() (Driver, error) { 51 return GetDriver(db.DatabaseURL.Scheme) 52 } 53 54 // Wait blocks until the database server is available. It does not verify that 55 // the specified database exists, only that the host is ready to accept connections. 56 func (db *DB) Wait() error { 57 drv, err := db.GetDriver() 58 if err != nil { 59 return err 60 } 61 62 // attempt connection to database server 63 err = drv.Ping(db.DatabaseURL) 64 if err == nil { 65 // connection successful 66 return nil 67 } 68 69 fmt.Print("Waiting for database") 70 for i := 0 * time.Second; i < db.WaitTimeout; i += db.WaitInterval { 71 fmt.Print(".") 72 time.Sleep(db.WaitInterval) 73 74 // attempt connection to database server 75 err = drv.Ping(db.DatabaseURL) 76 if err == nil { 77 // connection successful 78 fmt.Print("\n") 79 return nil 80 } 81 } 82 83 // if we find outselves here, we could not connect within the timeout 84 fmt.Print("\n") 85 return fmt.Errorf("unable to connect to database: %s", err) 86 } 87 88 // CreateAndMigrate creates the database (if necessary) and runs migrations 89 func (db *DB) CreateAndMigrate() error { 90 drv, err := db.GetDriver() 91 if err != nil { 92 return err 93 } 94 95 // create database if it does not already exist 96 // skip this step if we cannot determine status 97 // (e.g. user does not have list database permission) 98 exists, err := drv.DatabaseExists(db.DatabaseURL) 99 if err == nil && !exists { 100 if err := drv.CreateDatabase(db.DatabaseURL); err != nil { 101 return err 102 } 103 } 104 105 // migrate 106 return db.Migrate() 107 } 108 109 // Create creates the current database 110 func (db *DB) Create() error { 111 drv, err := db.GetDriver() 112 if err != nil { 113 return err 114 } 115 116 return drv.CreateDatabase(db.DatabaseURL) 117 } 118 119 // Drop drops the current database (if it exists) 120 func (db *DB) Drop() error { 121 drv, err := db.GetDriver() 122 if err != nil { 123 return err 124 } 125 126 return drv.DropDatabase(db.DatabaseURL) 127 } 128 129 // DumpSchema writes the current database schema to a file 130 func (db *DB) DumpSchema() error { 131 drv, sqlDB, err := db.openDatabaseForMigration() 132 if err != nil { 133 return err 134 } 135 defer mustClose(sqlDB) 136 137 schema, err := drv.DumpSchema(db.DatabaseURL, sqlDB) 138 if err != nil { 139 return err 140 } 141 142 fmt.Printf("Writing: %s\n", db.SchemaFile) 143 144 // ensure schema directory exists 145 if err = ensureDir(filepath.Dir(db.SchemaFile)); err != nil { 146 return err 147 } 148 149 // write schema to file 150 return ioutil.WriteFile(db.SchemaFile, schema, 0644) 151 } 152 153 const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n" 154 155 // NewMigration creates a new migration file 156 func (db *DB) NewMigration(name string) error { 157 // new migration name 158 timestamp := time.Now().UTC().Format("20060102150405") 159 if name == "" { 160 return fmt.Errorf("please specify a name for the new migration") 161 } 162 name = fmt.Sprintf("%s_%s.sql", timestamp, name) 163 164 // create migrations dir if missing 165 if err := ensureDir(db.MigrationsDir); err != nil { 166 return err 167 } 168 169 // check file does not already exist 170 path := filepath.Join(db.MigrationsDir, name) 171 fmt.Printf("Creating migration: %s\n", path) 172 173 if _, err := os.Stat(path); !os.IsNotExist(err) { 174 return fmt.Errorf("file already exists") 175 } 176 177 // write new migration 178 file, err := os.Create(path) 179 if err != nil { 180 return err 181 } 182 183 defer mustClose(file) 184 _, err = file.WriteString(migrationTemplate) 185 return err 186 } 187 188 func doTransaction(db *sql.DB, txFunc func(Transaction) error) error { 189 tx, err := db.Begin() 190 if err != nil { 191 return err 192 } 193 194 if err := txFunc(tx); err != nil { 195 if err1 := tx.Rollback(); err1 != nil { 196 return err1 197 } 198 199 return err 200 } 201 202 return tx.Commit() 203 } 204 205 func (db *DB) openDatabaseForMigration() (Driver, *sql.DB, error) { 206 drv, err := db.GetDriver() 207 if err != nil { 208 return nil, nil, err 209 } 210 211 sqlDB, err := drv.Open(db.DatabaseURL) 212 if err != nil { 213 return nil, nil, err 214 } 215 216 if err := drv.CreateMigrationsTable(sqlDB); err != nil { 217 mustClose(sqlDB) 218 return nil, nil, err 219 } 220 221 return drv, sqlDB, nil 222 } 223 224 // Migrate migrates database to the latest version 225 func (db *DB) Migrate() error { 226 re := regexp.MustCompile(`^\d.*\.sql$`) 227 files, err := findMigrationFiles(db.MigrationsDir, re) 228 if err != nil { 229 return err 230 } 231 232 if len(files) == 0 { 233 return fmt.Errorf("no migration files found") 234 } 235 236 drv, sqlDB, err := db.openDatabaseForMigration() 237 if err != nil { 238 return err 239 } 240 defer mustClose(sqlDB) 241 242 applied, err := drv.SelectMigrations(sqlDB, -1) 243 if err != nil { 244 return err 245 } 246 247 for _, filename := range files { 248 ver := migrationVersion(filename) 249 if ok := applied[ver]; ok { 250 // migration already applied 251 continue 252 } 253 254 fmt.Printf("Applying: %s\n", filename) 255 256 migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename)) 257 if err != nil { 258 return err 259 } 260 261 // begin transaction 262 err = doTransaction(sqlDB, func(tx Transaction) error { 263 // run actual migration 264 if _, err := tx.Exec(migration["up"]); err != nil { 265 return err 266 } 267 268 // record migration 269 return drv.InsertMigration(tx, ver) 270 }) 271 if err != nil { 272 return err 273 } 274 275 } 276 277 // automatically update schema file, silence errors 278 if db.AutoDumpSchema { 279 _ = db.DumpSchema() 280 } 281 282 return nil 283 } 284 285 func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) { 286 files, err := ioutil.ReadDir(dir) 287 if err != nil { 288 return nil, fmt.Errorf("could not find migrations directory `%s`", dir) 289 } 290 291 matches := []string{} 292 for _, file := range files { 293 if file.IsDir() { 294 continue 295 } 296 297 name := file.Name() 298 if !re.MatchString(name) { 299 continue 300 } 301 302 matches = append(matches, name) 303 } 304 305 sort.Strings(matches) 306 307 return matches, nil 308 } 309 310 func findMigrationFile(dir string, ver string) (string, error) { 311 if ver == "" { 312 panic("migration version is required") 313 } 314 315 ver = regexp.QuoteMeta(ver) 316 re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver)) 317 318 files, err := findMigrationFiles(dir, re) 319 if err != nil { 320 return "", err 321 } 322 323 if len(files) == 0 { 324 return "", fmt.Errorf("can't find migration file: %s*.sql", ver) 325 } 326 327 return files[0], nil 328 } 329 330 func migrationVersion(filename string) string { 331 return regexp.MustCompile(`^\d+`).FindString(filename) 332 } 333 334 // parseMigration reads a migration file into a map with up/down keys 335 // implementation is similar to regexp.Split() 336 func parseMigration(path string) (map[string]string, error) { 337 // read migration file into string 338 data, err := ioutil.ReadFile(path) 339 if err != nil { 340 return nil, err 341 } 342 contents := string(data) 343 344 // split string on our trigger comment 345 separatorRegexp := regexp.MustCompile(`(?m)^-- migrate:(.*)$`) 346 matches := separatorRegexp.FindAllStringSubmatchIndex(contents, -1) 347 348 migrations := map[string]string{} 349 direction := "" 350 beg := 0 351 end := 0 352 353 for _, match := range matches { 354 end = match[0] 355 if direction != "" { 356 // write previous direction to output map 357 migrations[direction] = contents[beg:end] 358 } 359 360 // each match records the start of a new direction 361 direction = contents[match[2]:match[3]] 362 beg = match[1] 363 } 364 365 // write final direction to output map 366 migrations[direction] = contents[beg:] 367 368 return migrations, nil 369 } 370 371 // Rollback rolls back the most recent migration 372 func (db *DB) Rollback() error { 373 drv, sqlDB, err := db.openDatabaseForMigration() 374 if err != nil { 375 return err 376 } 377 defer mustClose(sqlDB) 378 379 applied, err := drv.SelectMigrations(sqlDB, 1) 380 if err != nil { 381 return err 382 } 383 384 // grab most recent applied migration (applied has len=1) 385 latest := "" 386 for ver := range applied { 387 latest = ver 388 } 389 if latest == "" { 390 return fmt.Errorf("can't rollback: no migrations have been applied") 391 } 392 393 filename, err := findMigrationFile(db.MigrationsDir, latest) 394 if err != nil { 395 return err 396 } 397 398 fmt.Printf("Rolling back: %s\n", filename) 399 400 migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename)) 401 if err != nil { 402 return err 403 } 404 405 // begin transaction 406 err = doTransaction(sqlDB, func(tx Transaction) error { 407 // rollback migration 408 if _, err := tx.Exec(migration["down"]); err != nil { 409 return err 410 } 411 412 // remove migration record 413 return drv.DeleteMigration(tx, latest) 414 }) 415 if err != nil { 416 return err 417 } 418 419 // automatically update schema file, silence errors 420 if db.AutoDumpSchema { 421 _ = db.DumpSchema() 422 } 423 424 return nil 425 }