github.com/bishtawi/migrate/v4@v4.8.11/internal/cli/main.go (about)

     1  package cli
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"os/signal"
     8  	"strconv"
     9  	"strings"
    10  	"syscall"
    11  	"time"
    12  
    13  	"github.com/bishtawi/migrate/v4"
    14  	"github.com/bishtawi/migrate/v4/database"
    15  	"github.com/bishtawi/migrate/v4/source"
    16  )
    17  
    18  const defaultTimeFormat = "20060102150405"
    19  
    20  // set main log
    21  var log = &Log{}
    22  
    23  func Main(version string) {
    24  	helpPtr := flag.Bool("help", false, "")
    25  	versionPtr := flag.Bool("version", false, "")
    26  	verbosePtr := flag.Bool("verbose", false, "")
    27  	prefetchPtr := flag.Uint("prefetch", 10, "")
    28  	lockTimeoutPtr := flag.Uint("lock-timeout", 15, "")
    29  	pathPtr := flag.String("path", "", "")
    30  	databasePtr := flag.String("database", "", "")
    31  	sourcePtr := flag.String("source", "", "")
    32  
    33  	flag.Usage = func() {
    34  		fmt.Fprint(os.Stderr,
    35  			`Usage: migrate OPTIONS COMMAND [arg...]
    36         migrate [ -version | -help ]
    37  
    38  Options:
    39    -source          Location of the migrations (driver://url)
    40    -path            Shorthand for -source=file://path
    41    -database        Run migrations against this database (driver://url)
    42    -prefetch N      Number of migrations to load in advance before executing (default 10)
    43    -lock-timeout N  Allow N seconds to acquire database lock (default 15)
    44    -verbose         Print verbose logging
    45    -version         Print version
    46    -help            Print usage
    47  
    48  Commands:
    49    create [-ext E] [-dir D] [-seq] [-digits N] [-format] NAME
    50  			   Create a set of timestamped up/down migrations titled NAME, in directory D with extension E.
    51  			   Use -seq option to generate sequential up/down migrations with N digits.
    52  			   Use -format option to specify a Go time format string.
    53    goto V       Migrate to version V
    54    up [N]       Apply all or N up migrations
    55    down [N]     Apply all or N down migrations
    56    drop         Drop everything inside database
    57    force V      Set version V but don't run migration (ignores dirty state)
    58    version      Print current migration version
    59  
    60  Source drivers: `+strings.Join(source.List(), ", ")+`
    61  Database drivers: `+strings.Join(database.List(), ", ")+"\n")
    62  	}
    63  
    64  	flag.Parse()
    65  
    66  	// initialize logger
    67  	log.verbose = *verbosePtr
    68  
    69  	// show cli version
    70  	if *versionPtr {
    71  		fmt.Fprintln(os.Stderr, version)
    72  		os.Exit(0)
    73  	}
    74  
    75  	// show help
    76  	if *helpPtr {
    77  		flag.Usage()
    78  		os.Exit(0)
    79  	}
    80  
    81  	// translate -path into -source if given
    82  	if *sourcePtr == "" && *pathPtr != "" {
    83  		*sourcePtr = fmt.Sprintf("file://%v", *pathPtr)
    84  	}
    85  
    86  	// initialize migrate
    87  	// don't catch migraterErr here and let each command decide
    88  	// how it wants to handle the error
    89  	migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr)
    90  	defer func() {
    91  		if migraterErr == nil {
    92  			if _, err := migrater.Close(); err != nil {
    93  				log.Println(err)
    94  			}
    95  		}
    96  	}()
    97  	if migraterErr == nil {
    98  		migrater.Log = log
    99  		migrater.PrefetchMigrations = *prefetchPtr
   100  		migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second
   101  
   102  		// handle Ctrl+c
   103  		signals := make(chan os.Signal, 1)
   104  		signal.Notify(signals, syscall.SIGINT)
   105  		go func() {
   106  			for range signals {
   107  				log.Println("Stopping after this running migration ...")
   108  				migrater.GracefulStop <- true
   109  				return
   110  			}
   111  		}()
   112  	}
   113  
   114  	startTime := time.Now()
   115  
   116  	switch flag.Arg(0) {
   117  	case "create":
   118  		args := flag.Args()[1:]
   119  		seq := false
   120  		seqDigits := 6
   121  
   122  		createFlagSet := flag.NewFlagSet("create", flag.ExitOnError)
   123  		extPtr := createFlagSet.String("ext", "", "File extension")
   124  		dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)")
   125  		formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`)
   126  		createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)")
   127  		createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)")
   128  		if err := createFlagSet.Parse(args); err != nil {
   129  			log.Println(err)
   130  		}
   131  
   132  		if createFlagSet.NArg() == 0 {
   133  			log.fatal("error: please specify name")
   134  		}
   135  		name := createFlagSet.Arg(0)
   136  
   137  		if *extPtr == "" {
   138  			log.fatal("error: -ext flag must be specified")
   139  		}
   140  		*extPtr = "." + strings.TrimPrefix(*extPtr, ".")
   141  
   142  		createCmd(*dirPtr, startTime, *formatPtr, name, *extPtr, seq, seqDigits)
   143  
   144  	case "goto":
   145  		if migraterErr != nil {
   146  			log.fatalErr(migraterErr)
   147  		}
   148  
   149  		if flag.Arg(1) == "" {
   150  			log.fatal("error: please specify version argument V")
   151  		}
   152  
   153  		v, err := strconv.ParseUint(flag.Arg(1), 10, 64)
   154  		if err != nil {
   155  			log.fatal("error: can't read version argument V")
   156  		}
   157  
   158  		gotoCmd(migrater, uint(v))
   159  
   160  		if log.verbose {
   161  			log.Println("Finished after", time.Since(startTime))
   162  		}
   163  
   164  	case "up":
   165  		if migraterErr != nil {
   166  			log.fatalErr(migraterErr)
   167  		}
   168  
   169  		limit := -1
   170  		if flag.Arg(1) != "" {
   171  			n, err := strconv.ParseUint(flag.Arg(1), 10, 64)
   172  			if err != nil {
   173  				log.fatal("error: can't read limit argument N")
   174  			}
   175  			limit = int(n)
   176  		}
   177  
   178  		upCmd(migrater, limit)
   179  
   180  		if log.verbose {
   181  			log.Println("Finished after", time.Since(startTime))
   182  		}
   183  
   184  	case "down":
   185  		if migraterErr != nil {
   186  			log.fatalErr(migraterErr)
   187  		}
   188  
   189  		downFlagSet := flag.NewFlagSet("down", flag.ExitOnError)
   190  		applyAll := downFlagSet.Bool("all", false, "Apply all down migrations")
   191  
   192  		args := flag.Args()[1:]
   193  		if err := downFlagSet.Parse(args); err != nil {
   194  			log.fatalErr(err)
   195  		}
   196  
   197  		downArgs := downFlagSet.Args()
   198  		num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs)
   199  		if err != nil {
   200  			log.fatalErr(err)
   201  		}
   202  		if needsConfirm {
   203  			log.Println("Are you sure you want to apply all down migrations? [y/N]")
   204  			var response string
   205  			fmt.Scanln(&response)
   206  			response = strings.ToLower(strings.TrimSpace(response))
   207  
   208  			if response == "y" {
   209  				log.Println("Applying all down migrations")
   210  			} else {
   211  				log.fatal("Not applying all down migrations")
   212  			}
   213  		}
   214  
   215  		downCmd(migrater, num)
   216  
   217  		if log.verbose {
   218  			log.Println("Finished after", time.Since(startTime))
   219  		}
   220  
   221  	case "drop":
   222  		if migraterErr != nil {
   223  			log.fatalErr(migraterErr)
   224  		}
   225  
   226  		dropCmd(migrater)
   227  
   228  		if log.verbose {
   229  			log.Println("Finished after", time.Since(startTime))
   230  		}
   231  
   232  	case "force":
   233  		if migraterErr != nil {
   234  			log.fatalErr(migraterErr)
   235  		}
   236  
   237  		if flag.Arg(1) == "" {
   238  			log.fatal("error: please specify version argument V")
   239  		}
   240  
   241  		v, err := strconv.ParseInt(flag.Arg(1), 10, 64)
   242  		if err != nil {
   243  			log.fatal("error: can't read version argument V")
   244  		}
   245  
   246  		if v < -1 {
   247  			log.fatal("error: argument V must be >= -1")
   248  		}
   249  
   250  		forceCmd(migrater, int(v))
   251  
   252  		if log.verbose {
   253  			log.Println("Finished after", time.Since(startTime))
   254  		}
   255  
   256  	case "version":
   257  		if migraterErr != nil {
   258  			log.fatalErr(migraterErr)
   259  		}
   260  
   261  		versionCmd(migrater)
   262  
   263  	default:
   264  		flag.Usage()
   265  
   266  		// If a command is not found we exit with a status 2 to match the behavior
   267  		// of flag.Parse() with flag.ExitOnError when parsing an invalid flag.
   268  		os.Exit(2)
   269  	}
   270  }