github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/internal/cli/main.go (about) 1 package cli 2 3 import ( 4 "flag" 5 "fmt" 6 "os" 7 "os/signal" 8 "strconv" 9 "strings" 10 "syscall" 11 "time" 12 13 "github.com/golang-migrate/migrate/v4" 14 "github.com/golang-migrate/migrate/v4/database" 15 "github.com/golang-migrate/migrate/v4/source" 16 ) 17 18 const ( 19 defaultTimeFormat = "20060102150405" 20 defaultTimezone = "UTC" 21 createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME 22 Create a set of timestamped up/down migrations titled NAME, in directory D with extension E. 23 Use -seq option to generate sequential up/down migrations with N digits. 24 Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error. 25 Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC). 26 ` 27 gotoUsage = `goto V Migrate to version V` 28 upUsage = `up [N] Apply all or N up migrations` 29 downUsage = `down [N] [-all] Apply all or N down migrations 30 Use -all to apply all down migrations` 31 dropUsage = `drop [-f] Drop everything inside database 32 Use -f to bypass confirmation` 33 forceUsage = `force V Set version V but don't run migration (ignores dirty state)` 34 ) 35 36 func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) { 37 if help { 38 fmt.Fprintln(os.Stderr, usage) 39 flagSet.PrintDefaults() 40 os.Exit(0) 41 } 42 } 43 44 func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) { 45 flagSet := flag.NewFlagSet(name, flag.ExitOnError) 46 helpPtr := flagSet.Bool("help", false, "Print help information") 47 return flagSet, helpPtr 48 } 49 50 // set main log 51 var log = &Log{} 52 53 func printUsageAndExit() { 54 flag.Usage() 55 56 // If a command is not found we exit with a status 2 to match the behavior 57 // of flag.Parse() with flag.ExitOnError when parsing an invalid flag. 58 os.Exit(2) 59 } 60 61 // Main function of a cli application. It is public for backwards compatibility with `cli` package 62 func Main(version string) { 63 helpPtr := flag.Bool("help", false, "") 64 versionPtr := flag.Bool("version", false, "") 65 verbosePtr := flag.Bool("verbose", false, "") 66 prefetchPtr := flag.Uint("prefetch", 10, "") 67 lockTimeoutPtr := flag.Uint("lock-timeout", 15, "") 68 pathPtr := flag.String("path", "", "") 69 databasePtr := flag.String("database", "", "") 70 sourcePtr := flag.String("source", "", "") 71 72 flag.Usage = func() { 73 fmt.Fprintf(os.Stderr, 74 `Usage: migrate OPTIONS COMMAND [arg...] 75 migrate [ -version | -help ] 76 77 Options: 78 -source Location of the migrations (driver://url) 79 -path Shorthand for -source=file://path 80 -database Run migrations against this database (driver://url) 81 -prefetch N Number of migrations to load in advance before executing (default 10) 82 -lock-timeout N Allow N seconds to acquire database lock (default 15) 83 -verbose Print verbose logging 84 -version Print version 85 -help Print usage 86 87 Commands: 88 %s 89 %s 90 %s 91 %s 92 %s 93 %s 94 version Print current migration version 95 96 Source drivers: `+strings.Join(source.List(), ", ")+` 97 Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage) 98 } 99 100 flag.Parse() 101 102 // initialize logger 103 log.verbose = *verbosePtr 104 105 // show cli version 106 if *versionPtr { 107 fmt.Fprintln(os.Stderr, version) 108 os.Exit(0) 109 } 110 111 // show help 112 if *helpPtr { 113 flag.Usage() 114 os.Exit(0) 115 } 116 117 // translate -path into -source if given 118 if *sourcePtr == "" && *pathPtr != "" { 119 *sourcePtr = fmt.Sprintf("file://%v", *pathPtr) 120 } 121 122 // initialize migrate 123 // don't catch migraterErr here and let each command decide 124 // how it wants to handle the error 125 migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr) 126 defer func() { 127 if migraterErr == nil { 128 if _, err := migrater.Close(); err != nil { 129 log.Println(err) 130 } 131 } 132 }() 133 if migraterErr == nil { 134 migrater.Log = log 135 migrater.PrefetchMigrations = *prefetchPtr 136 migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second 137 138 // handle Ctrl+c 139 signals := make(chan os.Signal, 1) 140 signal.Notify(signals, syscall.SIGINT) 141 go func() { 142 for range signals { 143 log.Println("Stopping after this running migration ...") 144 migrater.GracefulStop <- true 145 return 146 } 147 }() 148 } 149 150 startTime := time.Now() 151 152 if len(flag.Args()) < 1 { 153 printUsageAndExit() 154 } 155 args := flag.Args()[1:] 156 157 switch flag.Arg(0) { 158 case "create": 159 160 seq := false 161 seqDigits := 6 162 163 createFlagSet, help := newFlagSetWithHelp("create") 164 extPtr := createFlagSet.String("ext", "", "File extension") 165 dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)") 166 formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`) 167 timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`) 168 createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)") 169 createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)") 170 171 if err := createFlagSet.Parse(args); err != nil { 172 log.fatalErr(err) 173 } 174 175 handleSubCmdHelp(*help, createUsage, createFlagSet) 176 177 if createFlagSet.NArg() == 0 { 178 log.fatal("error: please specify name") 179 } 180 name := createFlagSet.Arg(0) 181 182 if *extPtr == "" { 183 log.fatal("error: -ext flag must be specified") 184 } 185 186 timezone, err := time.LoadLocation(*timezoneName) 187 if err != nil { 188 log.fatal(err) 189 } 190 191 if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil { 192 log.fatalErr(err) 193 } 194 195 case "goto": 196 197 gotoSet, helpPtr := newFlagSetWithHelp("goto") 198 199 if err := gotoSet.Parse(args); err != nil { 200 log.fatalErr(err) 201 } 202 203 handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet) 204 205 if migraterErr != nil { 206 log.fatalErr(migraterErr) 207 } 208 209 if gotoSet.NArg() == 0 { 210 log.fatal("error: please specify version argument V") 211 } 212 213 v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64) 214 if err != nil { 215 log.fatal("error: can't read version argument V") 216 } 217 218 if err := gotoCmd(migrater, uint(v)); err != nil { 219 log.fatalErr(err) 220 } 221 222 if log.verbose { 223 log.Println("Finished after", time.Since(startTime)) 224 } 225 226 case "up": 227 upSet, helpPtr := newFlagSetWithHelp("up") 228 229 if err := upSet.Parse(args); err != nil { 230 log.fatalErr(err) 231 } 232 233 handleSubCmdHelp(*helpPtr, upUsage, upSet) 234 235 if migraterErr != nil { 236 log.fatalErr(migraterErr) 237 } 238 239 limit := -1 240 if upSet.NArg() > 0 { 241 n, err := strconv.ParseUint(upSet.Arg(0), 10, 64) 242 if err != nil { 243 log.fatal("error: can't read limit argument N") 244 } 245 limit = int(n) 246 } 247 248 if err := upCmd(migrater, limit); err != nil { 249 log.fatalErr(err) 250 } 251 252 if log.verbose { 253 log.Println("Finished after", time.Since(startTime)) 254 } 255 256 case "down": 257 downFlagSet, helpPtr := newFlagSetWithHelp("down") 258 applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") 259 260 if err := downFlagSet.Parse(args); err != nil { 261 log.fatalErr(err) 262 } 263 264 handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) 265 266 if migraterErr != nil { 267 log.fatalErr(migraterErr) 268 } 269 270 downArgs := downFlagSet.Args() 271 num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) 272 if err != nil { 273 log.fatalErr(err) 274 } 275 if needsConfirm { 276 log.Println("Are you sure you want to apply all down migrations? [y/N]") 277 var response string 278 fmt.Scanln(&response) 279 response = strings.ToLower(strings.TrimSpace(response)) 280 281 if response == "y" { 282 log.Println("Applying all down migrations") 283 } else { 284 log.fatal("Not applying all down migrations") 285 } 286 } 287 288 if err := downCmd(migrater, num); err != nil { 289 log.fatalErr(err) 290 } 291 292 if log.verbose { 293 log.Println("Finished after", time.Since(startTime)) 294 } 295 296 case "drop": 297 dropFlagSet, help := newFlagSetWithHelp("drop") 298 forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") 299 300 if err := dropFlagSet.Parse(args); err != nil { 301 log.fatalErr(err) 302 } 303 304 handleSubCmdHelp(*help, dropUsage, dropFlagSet) 305 306 if !*forceDrop { 307 log.Println("Are you sure you want to drop the entire database schema? [y/N]") 308 var response string 309 fmt.Scanln(&response) 310 response = strings.ToLower(strings.TrimSpace(response)) 311 312 if response == "y" { 313 log.Println("Dropping the entire database schema") 314 } else { 315 log.fatal("Aborted dropping the entire database schema") 316 } 317 } 318 319 if migraterErr != nil { 320 log.fatalErr(migraterErr) 321 } 322 323 if err := dropCmd(migrater); err != nil { 324 log.fatalErr(err) 325 } 326 327 if log.verbose { 328 log.Println("Finished after", time.Since(startTime)) 329 } 330 331 case "force": 332 forceSet, helpPtr := newFlagSetWithHelp("force") 333 334 if err := forceSet.Parse(args); err != nil { 335 log.fatalErr(err) 336 } 337 338 handleSubCmdHelp(*helpPtr, forceUsage, forceSet) 339 340 if migraterErr != nil { 341 log.fatalErr(migraterErr) 342 } 343 344 if forceSet.NArg() == 0 { 345 log.fatal("error: please specify version argument V") 346 } 347 348 v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64) 349 if err != nil { 350 log.fatal("error: can't read version argument V") 351 } 352 353 if v < -1 { 354 log.fatal("error: argument V must be >= -1") 355 } 356 357 if err := forceCmd(migrater, int(v)); err != nil { 358 log.fatalErr(err) 359 } 360 361 if log.verbose { 362 log.Println("Finished after", time.Since(startTime)) 363 } 364 365 case "version": 366 if migraterErr != nil { 367 log.fatalErr(migraterErr) 368 } 369 370 if err := versionCmd(migrater); err != nil { 371 log.fatalErr(err) 372 } 373 374 default: 375 printUsageAndExit() 376 } 377 }