github.com/amacneil/dbmate@v1.16.3-0.20230225174651-ca89b10d75d7/pkg/dbmate/db.go (about) 1 package dbmate 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "io" 8 "io/fs" 9 "net/url" 10 "os" 11 "path/filepath" 12 "regexp" 13 "sort" 14 "time" 15 16 "github.com/amacneil/dbmate/pkg/dbutil" 17 ) 18 19 // Error codes 20 var ( 21 ErrNoMigrationFiles = errors.New("no migration files found") 22 ErrInvalidURL = errors.New("invalid url, have you set your --url flag or DATABASE_URL environment variable?") 23 ErrNoRollback = errors.New("can't rollback: no migrations have been applied") 24 ErrCantConnect = errors.New("unable to connect to database") 25 ErrUnsupportedDriver = errors.New("unsupported driver") 26 ErrNoMigrationName = errors.New("please specify a name for the new migration") 27 ErrMigrationAlreadyExist = errors.New("file already exists") 28 ErrMigrationDirNotFound = errors.New("could not find migrations directory") 29 ErrMigrationNotFound = errors.New("can't find migration file") 30 ErrCreateDirectory = errors.New("unable to create directory") 31 ) 32 33 // migrationFileRegexp pattern for valid migration files 34 var migrationFileRegexp = regexp.MustCompile(`^(\d+).*\.sql$`) 35 36 // DB allows dbmate actions to be performed on a specified database 37 type DB struct { 38 // AutoDumpSchema generates schema.sql after each action 39 AutoDumpSchema bool 40 // DatabaseURL is the database connection string 41 DatabaseURL *url.URL 42 // FS allows overriding the filesystem 43 FS fs.FS 44 // Log is the interface to write stdout 45 Log io.Writer 46 // MigrationsDir specifies the directory to find migration files 47 MigrationsDir string 48 // MigrationsTableName specifies the database table to record migrations in 49 MigrationsTableName string 50 // SchemaFile specifies the location for schema.sql file 51 SchemaFile string 52 // Verbose prints the result of each statement execution 53 Verbose bool 54 // WaitBefore will wait for database to become available before running any actions 55 WaitBefore bool 56 // WaitInterval specifies length of time between connection attempts 57 WaitInterval time.Duration 58 // WaitTimeout specifies maximum time for connection attempts 59 WaitTimeout time.Duration 60 } 61 62 // StatusResult represents an available migration status 63 type StatusResult struct { 64 Filename string 65 Applied bool 66 } 67 68 // New initializes a new dbmate database 69 func New(databaseURL *url.URL) *DB { 70 return &DB{ 71 AutoDumpSchema: true, 72 DatabaseURL: databaseURL, 73 FS: os.DirFS("."), 74 Log: os.Stdout, 75 MigrationsDir: "./db/migrations", 76 MigrationsTableName: "schema_migrations", 77 SchemaFile: "./db/schema.sql", 78 Verbose: false, 79 WaitBefore: false, 80 WaitInterval: time.Second, 81 WaitTimeout: 60 * time.Second, 82 } 83 } 84 85 // Driver initializes the appropriate database driver 86 func (db *DB) Driver() (Driver, error) { 87 if db.DatabaseURL == nil || db.DatabaseURL.Scheme == "" { 88 return nil, ErrInvalidURL 89 } 90 91 driverFunc := drivers[db.DatabaseURL.Scheme] 92 if driverFunc == nil { 93 return nil, fmt.Errorf("%w: %s", ErrUnsupportedDriver, db.DatabaseURL.Scheme) 94 } 95 96 config := DriverConfig{ 97 DatabaseURL: db.DatabaseURL, 98 Log: db.Log, 99 MigrationsTableName: db.MigrationsTableName, 100 } 101 drv := driverFunc(config) 102 103 if db.WaitBefore { 104 if err := db.wait(drv); err != nil { 105 return nil, err 106 } 107 } 108 109 return drv, nil 110 } 111 112 func (db *DB) wait(drv Driver) error { 113 // attempt connection to database server 114 err := drv.Ping() 115 if err == nil { 116 // connection successful 117 return nil 118 } 119 120 fmt.Fprint(db.Log, "Waiting for database") 121 for i := 0 * time.Second; i < db.WaitTimeout; i += db.WaitInterval { 122 fmt.Fprint(db.Log, ".") 123 time.Sleep(db.WaitInterval) 124 125 // attempt connection to database server 126 err = drv.Ping() 127 if err == nil { 128 // connection successful 129 fmt.Fprint(db.Log, "\n") 130 return nil 131 } 132 } 133 134 // if we find outselves here, we could not connect within the timeout 135 fmt.Fprint(db.Log, "\n") 136 return fmt.Errorf("%w: %s", ErrCantConnect, err) 137 } 138 139 // Wait blocks until the database server is available. It does not verify that 140 // the specified database exists, only that the host is ready to accept connections. 141 func (db *DB) Wait() error { 142 drv, err := db.Driver() 143 if err != nil { 144 return err 145 } 146 147 // if db.WaitBefore is true, wait() will get called twice, no harm 148 return db.wait(drv) 149 } 150 151 // CreateAndMigrate creates the database (if necessary) and runs migrations 152 func (db *DB) CreateAndMigrate() error { 153 drv, err := db.Driver() 154 if err != nil { 155 return err 156 } 157 158 // create database if it does not already exist 159 // skip this step if we cannot determine status 160 // (e.g. user does not have list database permission) 161 exists, err := drv.DatabaseExists() 162 if err == nil && !exists { 163 if err := drv.CreateDatabase(); err != nil { 164 return err 165 } 166 } 167 168 // migrate 169 return db.Migrate() 170 } 171 172 // Create creates the current database 173 func (db *DB) Create() error { 174 drv, err := db.Driver() 175 if err != nil { 176 return err 177 } 178 179 return drv.CreateDatabase() 180 } 181 182 // Drop drops the current database (if it exists) 183 func (db *DB) Drop() error { 184 drv, err := db.Driver() 185 if err != nil { 186 return err 187 } 188 189 return drv.DropDatabase() 190 } 191 192 // DumpSchema writes the current database schema to a file 193 func (db *DB) DumpSchema() error { 194 drv, err := db.Driver() 195 if err != nil { 196 return err 197 } 198 199 sqlDB, err := db.openDatabaseForMigration(drv) 200 if err != nil { 201 return err 202 } 203 defer dbutil.MustClose(sqlDB) 204 205 schema, err := drv.DumpSchema(sqlDB) 206 if err != nil { 207 return err 208 } 209 210 fmt.Fprintf(db.Log, "Writing: %s\n", db.SchemaFile) 211 212 // ensure schema directory exists 213 if err = ensureDir(filepath.Dir(db.SchemaFile)); err != nil { 214 return err 215 } 216 217 // write schema to file 218 return os.WriteFile(db.SchemaFile, schema, 0o644) 219 } 220 221 // ensureDir creates a directory if it does not already exist 222 func ensureDir(dir string) error { 223 if err := os.MkdirAll(dir, 0o755); err != nil { 224 return fmt.Errorf("%w `%s`", ErrCreateDirectory, dir) 225 } 226 227 return nil 228 } 229 230 const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n" 231 232 // NewMigration creates a new migration file 233 func (db *DB) NewMigration(name string) error { 234 // new migration name 235 timestamp := time.Now().UTC().Format("20060102150405") 236 if name == "" { 237 return ErrNoMigrationName 238 } 239 name = fmt.Sprintf("%s_%s.sql", timestamp, name) 240 241 // create migrations dir if missing 242 if err := ensureDir(db.MigrationsDir); err != nil { 243 return err 244 } 245 246 // check file does not already exist 247 path := filepath.Join(db.MigrationsDir, name) 248 fmt.Fprintf(db.Log, "Creating migration: %s\n", path) 249 250 if _, err := os.Stat(path); !os.IsNotExist(err) { 251 return ErrMigrationAlreadyExist 252 } 253 254 // write new migration 255 file, err := os.Create(path) 256 if err != nil { 257 return err 258 } 259 260 defer dbutil.MustClose(file) 261 _, err = file.WriteString(migrationTemplate) 262 return err 263 } 264 265 func doTransaction(sqlDB *sql.DB, txFunc func(dbutil.Transaction) error) error { 266 tx, err := sqlDB.Begin() 267 if err != nil { 268 return err 269 } 270 271 if err := txFunc(tx); err != nil { 272 if err1 := tx.Rollback(); err1 != nil { 273 return err1 274 } 275 276 return err 277 } 278 279 return tx.Commit() 280 } 281 282 func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) { 283 sqlDB, err := drv.Open() 284 if err != nil { 285 return nil, err 286 } 287 288 if err := drv.CreateMigrationsTable(sqlDB); err != nil { 289 dbutil.MustClose(sqlDB) 290 return nil, err 291 } 292 293 return sqlDB, nil 294 } 295 296 // Migrate migrates database to the latest version 297 func (db *DB) Migrate() error { 298 drv, err := db.Driver() 299 if err != nil { 300 return err 301 } 302 303 migrations, err := db.FindMigrations() 304 if err != nil { 305 return err 306 } 307 308 if len(migrations) == 0 { 309 return ErrNoMigrationFiles 310 } 311 312 sqlDB, err := db.openDatabaseForMigration(drv) 313 if err != nil { 314 return err 315 } 316 defer dbutil.MustClose(sqlDB) 317 318 for _, migration := range migrations { 319 if migration.Applied { 320 continue 321 } 322 323 fmt.Fprintf(db.Log, "Applying: %s\n", migration.FileName) 324 325 parsed, err := migration.Parse() 326 if err != nil { 327 return err 328 } 329 330 execMigration := func(tx dbutil.Transaction) error { 331 // run actual migration 332 result, err := tx.Exec(parsed.Up) 333 if err != nil { 334 return err 335 } else if db.Verbose { 336 db.printVerbose(result) 337 } 338 339 // record migration 340 return drv.InsertMigration(tx, migration.Version) 341 } 342 343 if parsed.UpOptions.Transaction() { 344 // begin transaction 345 err = doTransaction(sqlDB, execMigration) 346 } else { 347 // run outside of transaction 348 err = execMigration(sqlDB) 349 } 350 351 if err != nil { 352 return err 353 } 354 } 355 356 // automatically update schema file, silence errors 357 if db.AutoDumpSchema { 358 _ = db.DumpSchema() 359 } 360 361 return nil 362 } 363 364 func (db *DB) printVerbose(result sql.Result) { 365 lastInsertID, err := result.LastInsertId() 366 if err == nil { 367 fmt.Fprintf(db.Log, "Last insert ID: %d\n", lastInsertID) 368 } 369 rowsAffected, err := result.RowsAffected() 370 if err == nil { 371 fmt.Fprintf(db.Log, "Rows affected: %d\n", rowsAffected) 372 } 373 } 374 375 // FindMigrations lists all available migrations 376 func (db *DB) FindMigrations() ([]Migration, error) { 377 drv, err := db.Driver() 378 if err != nil { 379 return nil, err 380 } 381 382 sqlDB, err := drv.Open() 383 if err != nil { 384 return nil, err 385 } 386 defer dbutil.MustClose(sqlDB) 387 388 // find applied migrations 389 appliedMigrations := map[string]bool{} 390 migrationsTableExists, err := drv.MigrationsTableExists(sqlDB) 391 if err != nil { 392 return nil, err 393 } 394 395 if migrationsTableExists { 396 appliedMigrations, err = drv.SelectMigrations(sqlDB, -1) 397 if err != nil { 398 return nil, err 399 } 400 } 401 402 // find filesystem migrations 403 files, err := fs.ReadDir(db.FS, filepath.Clean(db.MigrationsDir)) 404 if err != nil { 405 return nil, fmt.Errorf("%w `%s`", ErrMigrationDirNotFound, db.MigrationsDir) 406 } 407 408 migrations := []Migration{} 409 for _, file := range files { 410 if file.IsDir() { 411 continue 412 } 413 414 matches := migrationFileRegexp.FindStringSubmatch(file.Name()) 415 if len(matches) < 2 { 416 continue 417 } 418 419 migration := Migration{ 420 Applied: false, 421 FileName: matches[0], 422 FilePath: filepath.Join(db.MigrationsDir, matches[0]), 423 FS: db.FS, 424 Version: matches[1], 425 } 426 if ok := appliedMigrations[migration.Version]; ok { 427 migration.Applied = true 428 } 429 430 migrations = append(migrations, migration) 431 } 432 433 sort.Slice(migrations, func(i, j int) bool { 434 return migrations[i].FileName < migrations[j].FileName 435 }) 436 437 return migrations, nil 438 } 439 440 // Rollback rolls back the most recent migration 441 func (db *DB) Rollback() error { 442 drv, err := db.Driver() 443 if err != nil { 444 return err 445 } 446 447 sqlDB, err := db.openDatabaseForMigration(drv) 448 if err != nil { 449 return err 450 } 451 defer dbutil.MustClose(sqlDB) 452 453 // find last applied migration 454 var latest *Migration 455 migrations, err := db.FindMigrations() 456 if err != nil { 457 return err 458 } 459 460 for _, migration := range migrations { 461 if migration.Applied { 462 latest = &migration 463 } 464 } 465 466 if latest == nil { 467 return ErrNoRollback 468 } 469 470 fmt.Fprintf(db.Log, "Rolling back: %s\n", latest.FileName) 471 472 parsed, err := latest.Parse() 473 if err != nil { 474 return err 475 } 476 477 execMigration := func(tx dbutil.Transaction) error { 478 // rollback migration 479 result, err := tx.Exec(parsed.Down) 480 if err != nil { 481 return err 482 } else if db.Verbose { 483 db.printVerbose(result) 484 } 485 486 // remove migration record 487 return drv.DeleteMigration(tx, latest.Version) 488 } 489 490 if parsed.DownOptions.Transaction() { 491 // begin transaction 492 err = doTransaction(sqlDB, execMigration) 493 } else { 494 // run outside of transaction 495 err = execMigration(sqlDB) 496 } 497 498 if err != nil { 499 return err 500 } 501 502 // automatically update schema file, silence errors 503 if db.AutoDumpSchema { 504 _ = db.DumpSchema() 505 } 506 507 return nil 508 } 509 510 // Status shows the status of all migrations 511 func (db *DB) Status(quiet bool) (int, error) { 512 results, err := db.FindMigrations() 513 if err != nil { 514 return -1, err 515 } 516 517 var totalApplied int 518 var line string 519 520 for _, res := range results { 521 if res.Applied { 522 line = fmt.Sprintf("[X] %s", res.FileName) 523 totalApplied++ 524 } else { 525 line = fmt.Sprintf("[ ] %s", res.FileName) 526 } 527 if !quiet { 528 fmt.Fprintln(db.Log, line) 529 } 530 } 531 532 totalPending := len(results) - totalApplied 533 if !quiet { 534 fmt.Fprintln(db.Log) 535 fmt.Fprintf(db.Log, "Applied: %d\n", totalApplied) 536 fmt.Fprintf(db.Log, "Pending: %d\n", totalPending) 537 } 538 539 return totalPending, nil 540 }