github.com/morigs/migrate/v4@v4.15.2-0.20221123151732-2fdcfbe124f3/internal/cli/main.go (about)

     1  package cli
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"os/signal"
     8  	"strconv"
     9  	"strings"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/morigs/migrate/v4"
    14  	"github.com/morigs/migrate/v4/database"
    15  	"github.com/morigs/migrate/v4/source"
    16  )
    17  
    18  const (
    19  	defaultTimeFormat = "20060102150405"
    20  	defaultTimezone   = "UTC"
    21  	createUsage       = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME
    22  	   Create a set of timestamped up/down migrations titled NAME, in directory D with extension E.
    23  	   Use -seq option to generate sequential up/down migrations with N digits.
    24  	   Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error.
    25             Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC).
    26  `
    27  	gotoUsage = `goto V       Migrate to version V`
    28  	upUsage   = `up [N]       Apply all or N up migrations`
    29  	downUsage = `down [N] [-all]    Apply all or N down migrations
    30  	Use -all to apply all down migrations`
    31  	dropUsage = `drop [-f]    Drop everything inside database
    32  	Use -f to bypass confirmation`
    33  	forceUsage = `force V      Set version V but don't run migration (ignores dirty state)`
    34  )
    35  
    36  func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) {
    37  	if help {
    38  		fmt.Fprintln(os.Stderr, usage)
    39  		flagSet.PrintDefaults()
    40  		os.Exit(0)
    41  	}
    42  }
    43  
    44  func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) {
    45  	flagSet := flag.NewFlagSet(name, flag.ExitOnError)
    46  	helpPtr := flagSet.Bool("help", false, "Print help information")
    47  	return flagSet, helpPtr
    48  }
    49  
    50  // set main log
    51  var log = &Log{}
    52  
    53  func printUsageAndExit() {
    54  	flag.Usage()
    55  
    56  	// If a command is not found we exit with a status 2 to match the behavior
    57  	// of flag.Parse() with flag.ExitOnError when parsing an invalid flag.
    58  	os.Exit(2)
    59  }
    60  
    61  // Main function of a cli application. It is public for backwards compatibility with `cli` package
    62  func Main(version string) {
    63  	helpPtr := flag.Bool("help", false, "")
    64  	versionPtr := flag.Bool("version", false, "")
    65  	verbosePtr := flag.Bool("verbose", false, "")
    66  	prefetchPtr := flag.Uint("prefetch", 10, "")
    67  	lockTimeoutPtr := flag.Uint("lock-timeout", 15, "")
    68  	pathPtr := flag.String("path", "", "")
    69  	databasePtr := flag.String("database", "", "")
    70  	sourcePtr := flag.String("source", "", "")
    71  
    72  	flag.Usage = func() {
    73  		fmt.Fprintf(os.Stderr,
    74  			`Usage: migrate OPTIONS COMMAND [arg...]
    75         migrate [ -version | -help ]
    76  
    77  Options:
    78    -source          Location of the migrations (driver://url)
    79    -path            Shorthand for -source=file://path
    80    -database        Run migrations against this database (driver://url)
    81    -prefetch N      Number of migrations to load in advance before executing (default 10)
    82    -lock-timeout N  Allow N seconds to acquire database lock (default 15)
    83    -verbose         Print verbose logging
    84    -version         Print version
    85    -help            Print usage
    86  
    87  Commands:
    88    %s
    89    %s
    90    %s
    91    %s
    92    %s
    93    %s
    94    version      Print current migration version
    95  
    96  Source drivers: `+strings.Join(source.List(), ", ")+`
    97  Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage)
    98  	}
    99  
   100  	flag.Parse()
   101  
   102  	// initialize logger
   103  	log.verbose = *verbosePtr
   104  
   105  	// show cli version
   106  	if *versionPtr {
   107  		fmt.Fprintln(os.Stderr, version)
   108  		os.Exit(0)
   109  	}
   110  
   111  	// show help
   112  	if *helpPtr {
   113  		flag.Usage()
   114  		os.Exit(0)
   115  	}
   116  
   117  	// translate -path into -source if given
   118  	if *sourcePtr == "" && *pathPtr != "" {
   119  		*sourcePtr = fmt.Sprintf("file://%v", *pathPtr)
   120  	}
   121  
   122  	// initialize migrate
   123  	// don't catch migraterErr here and let each command decide
   124  	// how it wants to handle the error
   125  	migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr)
   126  	defer func() {
   127  		if migraterErr == nil {
   128  			if _, err := migrater.Close(); err != nil {
   129  				log.Println(err)
   130  			}
   131  		}
   132  	}()
   133  	if migraterErr == nil {
   134  		migrater.Log = log
   135  		migrater.PrefetchMigrations = *prefetchPtr
   136  		migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second
   137  
   138  		// handle Ctrl+c
   139  		signals := make(chan os.Signal, 1)
   140  		signal.Notify(signals, syscall.SIGINT)
   141  		go func() {
   142  			for range signals {
   143  				log.Println("Stopping after this running migration ...")
   144  				migrater.GracefulStop <- true
   145  				return
   146  			}
   147  		}()
   148  	}
   149  
   150  	startTime := time.Now()
   151  
   152  	if len(flag.Args()) < 1 {
   153  		printUsageAndExit()
   154  	}
   155  	args := flag.Args()[1:]
   156  
   157  	switch flag.Arg(0) {
   158  	case "create":
   159  
   160  		seq := false
   161  		seqDigits := 6
   162  
   163  		createFlagSet, help := newFlagSetWithHelp("create")
   164  		extPtr := createFlagSet.String("ext", "", "File extension")
   165  		dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)")
   166  		formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`)
   167  		timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`)
   168  		createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)")
   169  		createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)")
   170  
   171  		if err := createFlagSet.Parse(args); err != nil {
   172  			log.fatalErr(err)
   173  		}
   174  
   175  		handleSubCmdHelp(*help, createUsage, createFlagSet)
   176  
   177  		if createFlagSet.NArg() == 0 {
   178  			log.fatal("error: please specify name")
   179  		}
   180  		name := createFlagSet.Arg(0)
   181  
   182  		if *extPtr == "" {
   183  			log.fatal("error: -ext flag must be specified")
   184  		}
   185  
   186  		timezone, err := time.LoadLocation(*timezoneName)
   187  		if err != nil {
   188  			log.fatal(err)
   189  		}
   190  
   191  		if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil {
   192  			log.fatalErr(err)
   193  		}
   194  
   195  	case "goto":
   196  
   197  		gotoSet, helpPtr := newFlagSetWithHelp("goto")
   198  
   199  		if err := gotoSet.Parse(args); err != nil {
   200  			log.fatalErr(err)
   201  		}
   202  
   203  		handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet)
   204  
   205  		if migraterErr != nil {
   206  			log.fatalErr(migraterErr)
   207  		}
   208  
   209  		if gotoSet.NArg() == 0 {
   210  			log.fatal("error: please specify version argument V")
   211  		}
   212  
   213  		v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64)
   214  		if err != nil {
   215  			log.fatal("error: can't read version argument V")
   216  		}
   217  
   218  		if err := gotoCmd(migrater, uint(v)); err != nil {
   219  			log.fatalErr(err)
   220  		}
   221  
   222  		if log.verbose {
   223  			log.Println("Finished after", time.Since(startTime))
   224  		}
   225  
   226  	case "up":
   227  		upSet, helpPtr := newFlagSetWithHelp("up")
   228  
   229  		if err := upSet.Parse(args); err != nil {
   230  			log.fatalErr(err)
   231  		}
   232  
   233  		handleSubCmdHelp(*helpPtr, upUsage, upSet)
   234  
   235  		if migraterErr != nil {
   236  			log.fatalErr(migraterErr)
   237  		}
   238  
   239  		limit := -1
   240  		if upSet.NArg() > 0 {
   241  			n, err := strconv.ParseUint(upSet.Arg(0), 10, 64)
   242  			if err != nil {
   243  				log.fatal("error: can't read limit argument N")
   244  			}
   245  			limit = int(n)
   246  		}
   247  
   248  		if err := upCmd(migrater, limit); err != nil {
   249  			log.fatalErr(err)
   250  		}
   251  
   252  		if log.verbose {
   253  			log.Println("Finished after", time.Since(startTime))
   254  		}
   255  
   256  	case "down":
   257  		downFlagSet, helpPtr := newFlagSetWithHelp("down")
   258  		applyAll := downFlagSet.Bool("all", false, "Apply all down migrations")
   259  
   260  		if err := downFlagSet.Parse(args); err != nil {
   261  			log.fatalErr(err)
   262  		}
   263  
   264  		handleSubCmdHelp(*helpPtr, downUsage, downFlagSet)
   265  
   266  		if migraterErr != nil {
   267  			log.fatalErr(migraterErr)
   268  		}
   269  
   270  		downArgs := downFlagSet.Args()
   271  		num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs)
   272  		if err != nil {
   273  			log.fatalErr(err)
   274  		}
   275  		if needsConfirm {
   276  			log.Println("Are you sure you want to apply all down migrations? [y/N]")
   277  			var response string
   278  			fmt.Scanln(&response)
   279  			response = strings.ToLower(strings.TrimSpace(response))
   280  
   281  			if response == "y" {
   282  				log.Println("Applying all down migrations")
   283  			} else {
   284  				log.fatal("Not applying all down migrations")
   285  			}
   286  		}
   287  
   288  		if err := downCmd(migrater, num); err != nil {
   289  			log.fatalErr(err)
   290  		}
   291  
   292  		if log.verbose {
   293  			log.Println("Finished after", time.Since(startTime))
   294  		}
   295  
   296  	case "drop":
   297  		dropFlagSet, help := newFlagSetWithHelp("drop")
   298  		forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt")
   299  
   300  		if err := dropFlagSet.Parse(args); err != nil {
   301  			log.fatalErr(err)
   302  		}
   303  
   304  		handleSubCmdHelp(*help, dropUsage, dropFlagSet)
   305  
   306  		if !*forceDrop {
   307  			log.Println("Are you sure you want to drop the entire database schema? [y/N]")
   308  			var response string
   309  			fmt.Scanln(&response)
   310  			response = strings.ToLower(strings.TrimSpace(response))
   311  
   312  			if response == "y" {
   313  				log.Println("Dropping the entire database schema")
   314  			} else {
   315  				log.fatal("Aborted dropping the entire database schema")
   316  			}
   317  		}
   318  
   319  		if migraterErr != nil {
   320  			log.fatalErr(migraterErr)
   321  		}
   322  
   323  		if err := dropCmd(migrater); err != nil {
   324  			log.fatalErr(err)
   325  		}
   326  
   327  		if log.verbose {
   328  			log.Println("Finished after", time.Since(startTime))
   329  		}
   330  
   331  	case "force":
   332  		forceSet, helpPtr := newFlagSetWithHelp("force")
   333  
   334  		if err := forceSet.Parse(args); err != nil {
   335  			log.fatalErr(err)
   336  		}
   337  
   338  		handleSubCmdHelp(*helpPtr, forceUsage, forceSet)
   339  
   340  		if migraterErr != nil {
   341  			log.fatalErr(migraterErr)
   342  		}
   343  
   344  		if forceSet.NArg() == 0 {
   345  			log.fatal("error: please specify version argument V")
   346  		}
   347  
   348  		v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64)
   349  		if err != nil {
   350  			log.fatal("error: can't read version argument V")
   351  		}
   352  
   353  		if v < -1 {
   354  			log.fatal("error: argument V must be >= -1")
   355  		}
   356  
   357  		if err := forceCmd(migrater, int(v)); err != nil {
   358  			log.fatalErr(err)
   359  		}
   360  
   361  		if log.verbose {
   362  			log.Println("Finished after", time.Since(startTime))
   363  		}
   364  
   365  	case "version":
   366  		if migraterErr != nil {
   367  			log.fatalErr(migraterErr)
   368  		}
   369  
   370  		if err := versionCmd(migrater); err != nil {
   371  			log.fatalErr(err)
   372  		}
   373  
   374  	default:
   375  		printUsageAndExit()
   376  	}
   377  }